- Introduced the `OKXCollector` and `OKXWebSocketClient` classes for real-time market data collection from the OKX exchange. - Implemented a factory pattern for creating exchange-specific collectors, enhancing modularity and scalability. - Added configuration support for the OKX collector in `config/okx_config.json`. - Updated documentation to reflect the new modular architecture and provide guidance on using the OKX collector. - Created unit tests for the OKX collector and exchange factory to ensure functionality and reliability. - Enhanced logging and error handling throughout the new implementation for improved monitoring and debugging.
485 lines
18 KiB
Python
485 lines
18 KiB
Python
"""
|
|
OKX Data Collector implementation.
|
|
|
|
This module provides the main OKX data collector class that extends BaseDataCollector,
|
|
handling real-time market data collection for a single trading pair with robust
|
|
error handling, health monitoring, and database integration.
|
|
"""
|
|
|
|
import asyncio
|
|
from datetime import datetime, timezone
|
|
from decimal import Decimal
|
|
from typing import Dict, List, Optional, Any, Set
|
|
from dataclasses import dataclass
|
|
|
|
from ...base_collector import (
|
|
BaseDataCollector, DataType, CollectorStatus, MarketDataPoint,
|
|
OHLCVData, DataValidationError, ConnectionError
|
|
)
|
|
from .websocket import (
|
|
OKXWebSocketClient, OKXSubscription, OKXChannelType,
|
|
ConnectionState, OKXWebSocketError
|
|
)
|
|
from database.connection import get_db_manager, get_raw_data_manager
|
|
from database.models import MarketData, RawTrade
|
|
from utils.logger import get_logger
|
|
|
|
|
|
@dataclass
|
|
class OKXMarketData:
|
|
"""OKX-specific market data structure."""
|
|
symbol: str
|
|
timestamp: datetime
|
|
data_type: str
|
|
channel: str
|
|
raw_data: Dict[str, Any]
|
|
|
|
|
|
class OKXCollector(BaseDataCollector):
|
|
"""
|
|
OKX data collector for real-time market data.
|
|
|
|
This collector handles a single trading pair and collects real-time data
|
|
including trades, orderbook, and ticker information from OKX exchange.
|
|
"""
|
|
|
|
def __init__(self,
|
|
symbol: str,
|
|
data_types: Optional[List[DataType]] = None,
|
|
component_name: Optional[str] = None,
|
|
auto_restart: bool = True,
|
|
health_check_interval: float = 30.0,
|
|
store_raw_data: bool = True):
|
|
"""
|
|
Initialize OKX collector for a single trading pair.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'BTC-USDT')
|
|
data_types: Types of data to collect (default: [DataType.TRADE, DataType.ORDERBOOK])
|
|
component_name: Name for logging (default: f'okx_collector_{symbol}')
|
|
auto_restart: Enable automatic restart on failures
|
|
health_check_interval: Seconds between health checks
|
|
store_raw_data: Whether to store raw data for debugging
|
|
"""
|
|
# Default data types if not specified
|
|
if data_types is None:
|
|
data_types = [DataType.TRADE, DataType.ORDERBOOK]
|
|
|
|
# Component name for logging
|
|
if component_name is None:
|
|
component_name = f"okx_collector_{symbol.replace('-', '_').lower()}"
|
|
|
|
# Initialize base collector
|
|
super().__init__(
|
|
exchange_name="okx",
|
|
symbols=[symbol],
|
|
data_types=data_types,
|
|
component_name=component_name,
|
|
auto_restart=auto_restart,
|
|
health_check_interval=health_check_interval
|
|
)
|
|
|
|
# OKX-specific settings
|
|
self.symbol = symbol
|
|
self.store_raw_data = store_raw_data
|
|
|
|
# WebSocket client
|
|
self._ws_client: Optional[OKXWebSocketClient] = None
|
|
|
|
# Database managers
|
|
self._db_manager = None
|
|
self._raw_data_manager = None
|
|
|
|
# Data processing
|
|
self._message_buffer: List[Dict[str, Any]] = []
|
|
self._last_trade_id: Optional[str] = None
|
|
self._last_orderbook_ts: Optional[int] = None
|
|
|
|
# OKX channel mapping
|
|
self._channel_mapping = {
|
|
DataType.TRADE: OKXChannelType.TRADES.value,
|
|
DataType.ORDERBOOK: OKXChannelType.BOOKS5.value,
|
|
DataType.TICKER: OKXChannelType.TICKERS.value
|
|
}
|
|
|
|
self.logger.info(f"Initialized OKX collector for {symbol} with data types: {[dt.value for dt in data_types]}")
|
|
|
|
async def connect(self) -> bool:
|
|
"""
|
|
Establish connection to OKX WebSocket API.
|
|
|
|
Returns:
|
|
True if connection successful, False otherwise
|
|
"""
|
|
try:
|
|
self.logger.info(f"Connecting OKX collector for {self.symbol}")
|
|
|
|
# Initialize database managers
|
|
self._db_manager = get_db_manager()
|
|
if self.store_raw_data:
|
|
self._raw_data_manager = get_raw_data_manager()
|
|
|
|
# Create WebSocket client
|
|
ws_component_name = f"okx_ws_{self.symbol.replace('-', '_').lower()}"
|
|
self._ws_client = OKXWebSocketClient(
|
|
component_name=ws_component_name,
|
|
ping_interval=25.0,
|
|
pong_timeout=10.0,
|
|
max_reconnect_attempts=5,
|
|
reconnect_delay=5.0
|
|
)
|
|
|
|
# Add message callback
|
|
self._ws_client.add_message_callback(self._on_message)
|
|
|
|
# Connect to WebSocket
|
|
if not await self._ws_client.connect(use_public=True):
|
|
self.logger.error("Failed to connect to OKX WebSocket")
|
|
return False
|
|
|
|
self.logger.info(f"Successfully connected OKX collector for {self.symbol}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error connecting OKX collector for {self.symbol}: {e}")
|
|
return False
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from OKX WebSocket API."""
|
|
try:
|
|
self.logger.info(f"Disconnecting OKX collector for {self.symbol}")
|
|
|
|
if self._ws_client:
|
|
await self._ws_client.disconnect()
|
|
self._ws_client = None
|
|
|
|
self.logger.info(f"Disconnected OKX collector for {self.symbol}")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error disconnecting OKX collector for {self.symbol}: {e}")
|
|
|
|
async def subscribe_to_data(self, symbols: List[str], data_types: List[DataType]) -> bool:
|
|
"""
|
|
Subscribe to data streams for specified symbols and data types.
|
|
|
|
Args:
|
|
symbols: Trading symbols to subscribe to (should contain self.symbol)
|
|
data_types: Types of data to subscribe to
|
|
|
|
Returns:
|
|
True if subscription successful, False otherwise
|
|
"""
|
|
if not self._ws_client or not self._ws_client.is_connected:
|
|
self.logger.error("WebSocket client not connected")
|
|
return False
|
|
|
|
# Validate symbol
|
|
if self.symbol not in symbols:
|
|
self.logger.warning(f"Symbol {self.symbol} not in subscription list: {symbols}")
|
|
return False
|
|
|
|
try:
|
|
# Build subscriptions
|
|
subscriptions = []
|
|
for data_type in data_types:
|
|
if data_type in self._channel_mapping:
|
|
channel = self._channel_mapping[data_type]
|
|
subscription = OKXSubscription(
|
|
channel=channel,
|
|
inst_id=self.symbol,
|
|
enabled=True
|
|
)
|
|
subscriptions.append(subscription)
|
|
self.logger.debug(f"Added subscription: {channel} for {self.symbol}")
|
|
else:
|
|
self.logger.warning(f"Unsupported data type: {data_type}")
|
|
|
|
if not subscriptions:
|
|
self.logger.warning("No valid subscriptions to create")
|
|
return False
|
|
|
|
# Subscribe to channels
|
|
success = await self._ws_client.subscribe(subscriptions)
|
|
|
|
if success:
|
|
self.logger.info(f"Successfully subscribed to {len(subscriptions)} channels for {self.symbol}")
|
|
else:
|
|
self.logger.error(f"Failed to subscribe to channels for {self.symbol}")
|
|
|
|
return success
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error subscribing to data for {self.symbol}: {e}")
|
|
return False
|
|
|
|
async def unsubscribe_from_data(self, symbols: List[str], data_types: List[DataType]) -> bool:
|
|
"""
|
|
Unsubscribe from data streams for specified symbols and data types.
|
|
|
|
Args:
|
|
symbols: Trading symbols to unsubscribe from
|
|
data_types: Types of data to unsubscribe from
|
|
|
|
Returns:
|
|
True if unsubscription successful, False otherwise
|
|
"""
|
|
if not self._ws_client or not self._ws_client.is_connected:
|
|
self.logger.warning("WebSocket client not connected for unsubscription")
|
|
return True # Consider it successful if already disconnected
|
|
|
|
try:
|
|
# Build unsubscriptions
|
|
subscriptions = []
|
|
for data_type in data_types:
|
|
if data_type in self._channel_mapping:
|
|
channel = self._channel_mapping[data_type]
|
|
subscription = OKXSubscription(
|
|
channel=channel,
|
|
inst_id=self.symbol,
|
|
enabled=False
|
|
)
|
|
subscriptions.append(subscription)
|
|
|
|
if not subscriptions:
|
|
return True
|
|
|
|
# Unsubscribe from channels
|
|
success = await self._ws_client.unsubscribe(subscriptions)
|
|
|
|
if success:
|
|
self.logger.info(f"Successfully unsubscribed from {len(subscriptions)} channels for {self.symbol}")
|
|
else:
|
|
self.logger.warning(f"Failed to unsubscribe from channels for {self.symbol}")
|
|
|
|
return success
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error unsubscribing from data for {self.symbol}: {e}")
|
|
return False
|
|
|
|
async def _process_message(self, message: Any) -> Optional[MarketDataPoint]:
|
|
"""
|
|
Process incoming message from OKX WebSocket.
|
|
|
|
Args:
|
|
message: Raw message from WebSocket
|
|
|
|
Returns:
|
|
Processed MarketDataPoint or None if processing failed
|
|
"""
|
|
try:
|
|
if not isinstance(message, dict):
|
|
self.logger.warning(f"Unexpected message type: {type(message)}")
|
|
return None
|
|
|
|
# Extract channel and data
|
|
arg = message.get('arg', {})
|
|
channel = arg.get('channel')
|
|
inst_id = arg.get('instId')
|
|
data_list = message.get('data', [])
|
|
|
|
# Validate message structure
|
|
if not channel or not inst_id or not data_list:
|
|
self.logger.debug(f"Incomplete message structure: {message}")
|
|
return None
|
|
|
|
# Check if this message is for our symbol
|
|
if inst_id != self.symbol:
|
|
self.logger.debug(f"Message for different symbol: {inst_id} (expected: {self.symbol})")
|
|
return None
|
|
|
|
# Process each data item
|
|
market_data_points = []
|
|
for data_item in data_list:
|
|
data_point = await self._process_data_item(channel, data_item)
|
|
if data_point:
|
|
market_data_points.append(data_point)
|
|
|
|
# Store raw data if enabled
|
|
if self.store_raw_data and self._raw_data_manager:
|
|
await self._store_raw_data(channel, message)
|
|
|
|
# Return the first processed data point (for the base class interface)
|
|
return market_data_points[0] if market_data_points else None
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error processing message for {self.symbol}: {e}")
|
|
return None
|
|
|
|
async def _handle_messages(self) -> None:
|
|
"""
|
|
Handle incoming messages from WebSocket.
|
|
This is called by the base class message loop.
|
|
"""
|
|
# The actual message handling is done through the WebSocket client callback
|
|
# This method satisfies the abstract method requirement
|
|
if self._ws_client and self._ws_client.is_connected:
|
|
# Just sleep briefly to yield control
|
|
await asyncio.sleep(0.1)
|
|
else:
|
|
# If not connected, sleep longer to avoid busy loop
|
|
await asyncio.sleep(1.0)
|
|
|
|
async def _process_data_item(self, channel: str, data_item: Dict[str, Any]) -> Optional[MarketDataPoint]:
|
|
"""
|
|
Process individual data item from OKX message.
|
|
|
|
Args:
|
|
channel: OKX channel name
|
|
data_item: Individual data item
|
|
|
|
Returns:
|
|
Processed MarketDataPoint or None
|
|
"""
|
|
try:
|
|
# Determine data type from channel
|
|
data_type = None
|
|
for dt, ch in self._channel_mapping.items():
|
|
if ch == channel:
|
|
data_type = dt
|
|
break
|
|
|
|
if not data_type:
|
|
self.logger.warning(f"Unknown channel: {channel}")
|
|
return None
|
|
|
|
# Extract timestamp
|
|
timestamp_ms = data_item.get('ts')
|
|
if timestamp_ms:
|
|
timestamp = datetime.fromtimestamp(int(timestamp_ms) / 1000, tz=timezone.utc)
|
|
else:
|
|
timestamp = datetime.now(timezone.utc)
|
|
|
|
# Create MarketDataPoint
|
|
market_data_point = MarketDataPoint(
|
|
exchange="okx",
|
|
symbol=self.symbol,
|
|
timestamp=timestamp,
|
|
data_type=data_type,
|
|
data=data_item
|
|
)
|
|
|
|
# Store processed data to database
|
|
await self._store_processed_data(market_data_point)
|
|
|
|
# Update statistics
|
|
self._stats['messages_processed'] += 1
|
|
self._stats['last_message_time'] = timestamp
|
|
|
|
return market_data_point
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error processing data item for {self.symbol}: {e}")
|
|
self._stats['errors'] += 1
|
|
return None
|
|
|
|
async def _store_processed_data(self, data_point: MarketDataPoint) -> None:
|
|
"""
|
|
Store processed data to MarketData table.
|
|
|
|
Args:
|
|
data_point: Processed market data point
|
|
"""
|
|
try:
|
|
# For now, we'll focus on trade data storage
|
|
# Orderbook and ticker storage can be added later
|
|
if data_point.data_type == DataType.TRADE:
|
|
await self._store_trade_data(data_point)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error storing processed data for {self.symbol}: {e}")
|
|
|
|
async def _store_trade_data(self, data_point: MarketDataPoint) -> None:
|
|
"""
|
|
Store trade data to database.
|
|
|
|
Args:
|
|
data_point: Trade data point
|
|
"""
|
|
try:
|
|
if not self._db_manager:
|
|
return
|
|
|
|
trade_data = data_point.data
|
|
|
|
# Extract trade information
|
|
trade_id = trade_data.get('tradeId')
|
|
price = Decimal(str(trade_data.get('px', '0')))
|
|
size = Decimal(str(trade_data.get('sz', '0')))
|
|
side = trade_data.get('side', 'unknown')
|
|
|
|
# Skip duplicate trades
|
|
if trade_id == self._last_trade_id:
|
|
return
|
|
self._last_trade_id = trade_id
|
|
|
|
# For now, we'll log the trade data
|
|
# Actual database storage will be implemented in the next phase
|
|
self.logger.debug(f"Trade: {self.symbol} - {side} {size} @ {price} (ID: {trade_id})")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error storing trade data for {self.symbol}: {e}")
|
|
|
|
async def _store_raw_data(self, channel: str, raw_message: Dict[str, Any]) -> None:
|
|
"""
|
|
Store raw data for debugging and compliance.
|
|
|
|
Args:
|
|
channel: OKX channel name
|
|
raw_message: Complete raw message
|
|
"""
|
|
try:
|
|
if not self._raw_data_manager:
|
|
return
|
|
|
|
# Store raw data using the raw data manager
|
|
self._raw_data_manager.store_raw_data(
|
|
exchange="okx",
|
|
symbol=self.symbol,
|
|
data_type=channel,
|
|
raw_data=raw_message,
|
|
timestamp=datetime.now(timezone.utc)
|
|
)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error storing raw data for {self.symbol}: {e}")
|
|
|
|
def _on_message(self, message: Dict[str, Any]) -> None:
|
|
"""
|
|
Callback function for WebSocket messages.
|
|
|
|
Args:
|
|
message: Message received from WebSocket
|
|
"""
|
|
try:
|
|
# Add message to buffer for processing
|
|
self._message_buffer.append(message)
|
|
|
|
# Process message asynchronously
|
|
asyncio.create_task(self._process_message(message))
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error in message callback for {self.symbol}: {e}")
|
|
|
|
def get_status(self) -> Dict[str, Any]:
|
|
"""Get collector status including WebSocket client status."""
|
|
base_status = super().get_status()
|
|
|
|
# Add OKX-specific status
|
|
okx_status = {
|
|
'symbol': self.symbol,
|
|
'websocket_connected': self._ws_client.is_connected if self._ws_client else False,
|
|
'websocket_state': self._ws_client.connection_state.value if self._ws_client else 'disconnected',
|
|
'last_trade_id': self._last_trade_id,
|
|
'message_buffer_size': len(self._message_buffer),
|
|
'store_raw_data': self.store_raw_data
|
|
}
|
|
|
|
# Add WebSocket stats if available
|
|
if self._ws_client:
|
|
okx_status['websocket_stats'] = self._ws_client.get_stats()
|
|
|
|
return {**base_status, **okx_status}
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<OKXCollector(symbol={self.symbol}, status={self.status.value}, data_types={[dt.value for dt in self.data_types]})>" |