""" OKX Data Collector implementation. This module provides the main OKX data collector class that extends BaseDataCollector, handling real-time market data collection for a single trading pair with robust error handling, health monitoring, and database integration. """ import asyncio from datetime import datetime, timezone from decimal import Decimal from typing import Dict, List, Optional, Any, Set from dataclasses import dataclass from ...base_collector import ( BaseDataCollector, DataType, CollectorStatus, MarketDataPoint, OHLCVData, DataValidationError, ConnectionError ) from .websocket import ( OKXWebSocketClient, OKXSubscription, OKXChannelType, ConnectionState, OKXWebSocketError ) from database.connection import get_db_manager, get_raw_data_manager from database.models import MarketData, RawTrade from utils.logger import get_logger @dataclass class OKXMarketData: """OKX-specific market data structure.""" symbol: str timestamp: datetime data_type: str channel: str raw_data: Dict[str, Any] class OKXCollector(BaseDataCollector): """ OKX data collector for real-time market data. This collector handles a single trading pair and collects real-time data including trades, orderbook, and ticker information from OKX exchange. """ def __init__(self, symbol: str, data_types: Optional[List[DataType]] = None, component_name: Optional[str] = None, auto_restart: bool = True, health_check_interval: float = 30.0, store_raw_data: bool = True): """ Initialize OKX collector for a single trading pair. Args: symbol: Trading symbol (e.g., 'BTC-USDT') data_types: Types of data to collect (default: [DataType.TRADE, DataType.ORDERBOOK]) component_name: Name for logging (default: f'okx_collector_{symbol}') auto_restart: Enable automatic restart on failures health_check_interval: Seconds between health checks store_raw_data: Whether to store raw data for debugging """ # Default data types if not specified if data_types is None: data_types = [DataType.TRADE, DataType.ORDERBOOK] # Component name for logging if component_name is None: component_name = f"okx_collector_{symbol.replace('-', '_').lower()}" # Initialize base collector super().__init__( exchange_name="okx", symbols=[symbol], data_types=data_types, component_name=component_name, auto_restart=auto_restart, health_check_interval=health_check_interval ) # OKX-specific settings self.symbol = symbol self.store_raw_data = store_raw_data # WebSocket client self._ws_client: Optional[OKXWebSocketClient] = None # Database managers self._db_manager = None self._raw_data_manager = None # Data processing self._message_buffer: List[Dict[str, Any]] = [] self._last_trade_id: Optional[str] = None self._last_orderbook_ts: Optional[int] = None # OKX channel mapping self._channel_mapping = { DataType.TRADE: OKXChannelType.TRADES.value, DataType.ORDERBOOK: OKXChannelType.BOOKS5.value, DataType.TICKER: OKXChannelType.TICKERS.value } self.logger.info(f"Initialized OKX collector for {symbol} with data types: {[dt.value for dt in data_types]}") async def connect(self) -> bool: """ Establish connection to OKX WebSocket API. Returns: True if connection successful, False otherwise """ try: self.logger.info(f"Connecting OKX collector for {self.symbol}") # Initialize database managers self._db_manager = get_db_manager() if self.store_raw_data: self._raw_data_manager = get_raw_data_manager() # Create WebSocket client ws_component_name = f"okx_ws_{self.symbol.replace('-', '_').lower()}" self._ws_client = OKXWebSocketClient( component_name=ws_component_name, ping_interval=25.0, pong_timeout=10.0, max_reconnect_attempts=5, reconnect_delay=5.0 ) # Add message callback self._ws_client.add_message_callback(self._on_message) # Connect to WebSocket if not await self._ws_client.connect(use_public=True): self.logger.error("Failed to connect to OKX WebSocket") return False self.logger.info(f"Successfully connected OKX collector for {self.symbol}") return True except Exception as e: self.logger.error(f"Error connecting OKX collector for {self.symbol}: {e}") return False async def disconnect(self) -> None: """Disconnect from OKX WebSocket API.""" try: self.logger.info(f"Disconnecting OKX collector for {self.symbol}") if self._ws_client: await self._ws_client.disconnect() self._ws_client = None self.logger.info(f"Disconnected OKX collector for {self.symbol}") except Exception as e: self.logger.error(f"Error disconnecting OKX collector for {self.symbol}: {e}") async def subscribe_to_data(self, symbols: List[str], data_types: List[DataType]) -> bool: """ Subscribe to data streams for specified symbols and data types. Args: symbols: Trading symbols to subscribe to (should contain self.symbol) data_types: Types of data to subscribe to Returns: True if subscription successful, False otherwise """ if not self._ws_client or not self._ws_client.is_connected: self.logger.error("WebSocket client not connected") return False # Validate symbol if self.symbol not in symbols: self.logger.warning(f"Symbol {self.symbol} not in subscription list: {symbols}") return False try: # Build subscriptions subscriptions = [] for data_type in data_types: if data_type in self._channel_mapping: channel = self._channel_mapping[data_type] subscription = OKXSubscription( channel=channel, inst_id=self.symbol, enabled=True ) subscriptions.append(subscription) self.logger.debug(f"Added subscription: {channel} for {self.symbol}") else: self.logger.warning(f"Unsupported data type: {data_type}") if not subscriptions: self.logger.warning("No valid subscriptions to create") return False # Subscribe to channels success = await self._ws_client.subscribe(subscriptions) if success: self.logger.info(f"Successfully subscribed to {len(subscriptions)} channels for {self.symbol}") else: self.logger.error(f"Failed to subscribe to channels for {self.symbol}") return success except Exception as e: self.logger.error(f"Error subscribing to data for {self.symbol}: {e}") return False async def unsubscribe_from_data(self, symbols: List[str], data_types: List[DataType]) -> bool: """ Unsubscribe from data streams for specified symbols and data types. Args: symbols: Trading symbols to unsubscribe from data_types: Types of data to unsubscribe from Returns: True if unsubscription successful, False otherwise """ if not self._ws_client or not self._ws_client.is_connected: self.logger.warning("WebSocket client not connected for unsubscription") return True # Consider it successful if already disconnected try: # Build unsubscriptions subscriptions = [] for data_type in data_types: if data_type in self._channel_mapping: channel = self._channel_mapping[data_type] subscription = OKXSubscription( channel=channel, inst_id=self.symbol, enabled=False ) subscriptions.append(subscription) if not subscriptions: return True # Unsubscribe from channels success = await self._ws_client.unsubscribe(subscriptions) if success: self.logger.info(f"Successfully unsubscribed from {len(subscriptions)} channels for {self.symbol}") else: self.logger.warning(f"Failed to unsubscribe from channels for {self.symbol}") return success except Exception as e: self.logger.error(f"Error unsubscribing from data for {self.symbol}: {e}") return False async def _process_message(self, message: Any) -> Optional[MarketDataPoint]: """ Process incoming message from OKX WebSocket. Args: message: Raw message from WebSocket Returns: Processed MarketDataPoint or None if processing failed """ try: if not isinstance(message, dict): self.logger.warning(f"Unexpected message type: {type(message)}") return None # Extract channel and data arg = message.get('arg', {}) channel = arg.get('channel') inst_id = arg.get('instId') data_list = message.get('data', []) # Validate message structure if not channel or not inst_id or not data_list: self.logger.debug(f"Incomplete message structure: {message}") return None # Check if this message is for our symbol if inst_id != self.symbol: self.logger.debug(f"Message for different symbol: {inst_id} (expected: {self.symbol})") return None # Process each data item market_data_points = [] for data_item in data_list: data_point = await self._process_data_item(channel, data_item) if data_point: market_data_points.append(data_point) # Store raw data if enabled if self.store_raw_data and self._raw_data_manager: await self._store_raw_data(channel, message) # Return the first processed data point (for the base class interface) return market_data_points[0] if market_data_points else None except Exception as e: self.logger.error(f"Error processing message for {self.symbol}: {e}") return None async def _handle_messages(self) -> None: """ Handle incoming messages from WebSocket. This is called by the base class message loop. """ # The actual message handling is done through the WebSocket client callback # This method satisfies the abstract method requirement if self._ws_client and self._ws_client.is_connected: # Just sleep briefly to yield control await asyncio.sleep(0.1) else: # If not connected, sleep longer to avoid busy loop await asyncio.sleep(1.0) async def _process_data_item(self, channel: str, data_item: Dict[str, Any]) -> Optional[MarketDataPoint]: """ Process individual data item from OKX message. Args: channel: OKX channel name data_item: Individual data item Returns: Processed MarketDataPoint or None """ try: # Determine data type from channel data_type = None for dt, ch in self._channel_mapping.items(): if ch == channel: data_type = dt break if not data_type: self.logger.warning(f"Unknown channel: {channel}") return None # Extract timestamp timestamp_ms = data_item.get('ts') if timestamp_ms: timestamp = datetime.fromtimestamp(int(timestamp_ms) / 1000, tz=timezone.utc) else: timestamp = datetime.now(timezone.utc) # Create MarketDataPoint market_data_point = MarketDataPoint( exchange="okx", symbol=self.symbol, timestamp=timestamp, data_type=data_type, data=data_item ) # Store processed data to database await self._store_processed_data(market_data_point) # Update statistics self._stats['messages_processed'] += 1 self._stats['last_message_time'] = timestamp return market_data_point except Exception as e: self.logger.error(f"Error processing data item for {self.symbol}: {e}") self._stats['errors'] += 1 return None async def _store_processed_data(self, data_point: MarketDataPoint) -> None: """ Store processed data to MarketData table. Args: data_point: Processed market data point """ try: # For now, we'll focus on trade data storage # Orderbook and ticker storage can be added later if data_point.data_type == DataType.TRADE: await self._store_trade_data(data_point) except Exception as e: self.logger.error(f"Error storing processed data for {self.symbol}: {e}") async def _store_trade_data(self, data_point: MarketDataPoint) -> None: """ Store trade data to database. Args: data_point: Trade data point """ try: if not self._db_manager: return trade_data = data_point.data # Extract trade information trade_id = trade_data.get('tradeId') price = Decimal(str(trade_data.get('px', '0'))) size = Decimal(str(trade_data.get('sz', '0'))) side = trade_data.get('side', 'unknown') # Skip duplicate trades if trade_id == self._last_trade_id: return self._last_trade_id = trade_id # For now, we'll log the trade data # Actual database storage will be implemented in the next phase self.logger.debug(f"Trade: {self.symbol} - {side} {size} @ {price} (ID: {trade_id})") except Exception as e: self.logger.error(f"Error storing trade data for {self.symbol}: {e}") async def _store_raw_data(self, channel: str, raw_message: Dict[str, Any]) -> None: """ Store raw data for debugging and compliance. Args: channel: OKX channel name raw_message: Complete raw message """ try: if not self._raw_data_manager: return # Store raw data using the raw data manager self._raw_data_manager.store_raw_data( exchange="okx", symbol=self.symbol, data_type=channel, raw_data=raw_message, timestamp=datetime.now(timezone.utc) ) except Exception as e: self.logger.error(f"Error storing raw data for {self.symbol}: {e}") def _on_message(self, message: Dict[str, Any]) -> None: """ Callback function for WebSocket messages. Args: message: Message received from WebSocket """ try: # Add message to buffer for processing self._message_buffer.append(message) # Process message asynchronously asyncio.create_task(self._process_message(message)) except Exception as e: self.logger.error(f"Error in message callback for {self.symbol}: {e}") def get_status(self) -> Dict[str, Any]: """Get collector status including WebSocket client status.""" base_status = super().get_status() # Add OKX-specific status okx_status = { 'symbol': self.symbol, 'websocket_connected': self._ws_client.is_connected if self._ws_client else False, 'websocket_state': self._ws_client.connection_state.value if self._ws_client else 'disconnected', 'last_trade_id': self._last_trade_id, 'message_buffer_size': len(self._message_buffer), 'store_raw_data': self.store_raw_data } # Add WebSocket stats if available if self._ws_client: okx_status['websocket_stats'] = self._ws_client.get_stats() return {**base_status, **okx_status} def __repr__(self) -> str: return f""