"""Real-time data streaming utilities.
Provides a WebSocket client for consuming streaming market data and a
tick buffer for aggregating raw ticks into OHLCV bars.
"""
from __future__ import annotations
import asyncio
from datetime import datetime
from typing import Any, Callable
import pandas as pd
__all__ = [
"WebSocketClient",
"TickBuffer",
]
[docs]
class WebSocketClient:
"""Async WebSocket client for streaming market data.
Wraps the ``websockets`` library to provide a simple interface for
subscribing to real-time data feeds.
Parameters:
url: WebSocket server URL (e.g., ``"wss://stream.example.com"``).
on_message: Optional callback invoked with each received message.
on_error: Optional callback invoked when an error occurs.
Example:
>>> client = WebSocketClient("wss://stream.example.com/v1/ws")
>>> client.on_message = lambda msg: print(msg)
>>> client.run() # blocks until disconnected
"""
[docs]
def __init__(
self,
url: str,
on_message: Callable[[str], None] | None = None,
on_error: Callable[[Exception], None] | None = None,
) -> None:
self.url = url
self.on_message = on_message
self.on_error = on_error
self._ws: Any = None
self._running: bool = False
self._subscriptions: set[str] = set()
[docs]
async def connect(self) -> None:
"""Open the WebSocket connection.
Requires the ``websockets`` package (part of the ``ingestion``
extra).
"""
import websockets
self._ws = await websockets.connect(self.url)
self._running = True
[docs]
async def disconnect(self) -> None:
"""Close the WebSocket connection gracefully."""
self._running = False
if self._ws is not None:
await self._ws.close()
self._ws = None
[docs]
async def subscribe(self, channels: list[str]) -> None:
"""Subscribe to one or more data channels.
Parameters:
channels: List of channel identifiers to subscribe to.
"""
import json
self._subscriptions.update(channels)
if self._ws is not None:
message = json.dumps({"action": "subscribe", "channels": channels})
await self._ws.send(message)
[docs]
async def unsubscribe(self, channels: list[str]) -> None:
"""Unsubscribe from one or more data channels.
Parameters:
channels: List of channel identifiers to unsubscribe from.
"""
import json
self._subscriptions.difference_update(channels)
if self._ws is not None:
message = json.dumps({"action": "unsubscribe", "channels": channels})
await self._ws.send(message)
async def _listen(self) -> None:
"""Internal listener that dispatches incoming messages."""
try:
async for message in self._ws:
if not self._running:
break
if self.on_message is not None:
self.on_message(message)
except Exception as exc:
if self.on_error is not None:
self.on_error(exc)
else:
raise
[docs]
def run(self) -> None:
"""Start the WebSocket event loop (blocking).
Connects to the server and listens for messages until the
connection is closed or :meth:`disconnect` is called.
"""
asyncio.get_event_loop().run_until_complete(self._run_async())
async def _run_async(self) -> None:
"""Internal coroutine that manages the connection lifecycle."""
await self.connect()
try:
await self._listen()
finally:
await self.disconnect()
[docs]
class TickBuffer:
"""Buffer incoming ticks and aggregate them into OHLCV bars.
Stores raw tick data and groups it into time-based bars at the
requested interval.
Parameters:
bar_interval: Pandas-compatible frequency string for bar
aggregation (e.g., ``'1min'``, ``'5min'``, ``'1h'``).
Example:
>>> buf = TickBuffer(bar_interval="1min")
>>> buf.add_tick(pd.Timestamp("2024-01-02 09:30:00.100"), 150.25, 100)
>>> buf.add_tick(pd.Timestamp("2024-01-02 09:30:00.500"), 150.50, 200)
>>> bars = buf.get_bars()
"""
[docs]
def __init__(self, bar_interval: str = "1min") -> None:
self.bar_interval = bar_interval
self._timestamps: list[datetime | pd.Timestamp] = []
self._prices: list[float] = []
self._volumes: list[float] = []
[docs]
def add_tick(
self,
timestamp: datetime | pd.Timestamp,
price: float,
volume: float = 0,
) -> None:
"""Add a single tick to the buffer.
Parameters:
timestamp: Tick timestamp.
price: Tick price.
volume: Tick volume. Defaults to 0.
"""
self._timestamps.append(timestamp)
self._prices.append(price)
self._volumes.append(volume)
[docs]
def get_bars(self) -> pd.DataFrame:
"""Aggregate buffered ticks into OHLCV bars.
Returns:
DataFrame with columns ``open``, ``high``, ``low``,
``close``, ``volume`` indexed by the bar period start time.
Returns an empty DataFrame if no ticks have been added.
"""
if not self._timestamps:
return pd.DataFrame(columns=["open", "high", "low", "close", "volume"])
ticks = pd.DataFrame(
{
"price": self._prices,
"volume": self._volumes,
},
index=pd.DatetimeIndex(self._timestamps, name="timestamp"),
)
grouper = pd.Grouper(freq=self.bar_interval)
bars = ticks.groupby(grouper).agg(
open=("price", "first"),
high=("price", "max"),
low=("price", "min"),
close=("price", "last"),
volume=("volume", "sum"),
)
# Drop rows with no ticks (all NaN from empty groups)
bars = bars.dropna(subset=["open"])
return bars
[docs]
def flush(self) -> pd.DataFrame:
"""Return completed bars and clear the internal buffer.
Returns:
DataFrame with OHLCV bars from all buffered ticks.
"""
bars = self.get_bars()
self.clear()
return bars
[docs]
def clear(self) -> None:
"""Clear all buffered ticks without returning bars."""
self._timestamps.clear()
self._prices.clear()
self._volumes.clear()
[docs]
def __len__(self) -> int:
"""Return the number of buffered ticks."""
return len(self._timestamps)