""" OKX WebSocket Client for low-level WebSocket management. This module provides a robust WebSocket client specifically designed for OKX API, handling connection management, authentication, keepalive, and message parsing. """ import asyncio import json import time import ssl from datetime import datetime, timezone from typing import Dict, List, Optional, Any, Callable, Union from enum import Enum from dataclasses import dataclass import websockets from websockets.exceptions import ConnectionClosed, InvalidHandshake, InvalidURI class OKXChannelType(Enum): """OKX WebSocket channel types.""" TRADES = "trades" BOOKS5 = "books5" BOOKS50 = "books50" BOOKS_TBT = "books-l2-tbt" TICKERS = "tickers" CANDLE1M = "candle1m" CANDLE5M = "candle5m" CANDLE15M = "candle15m" CANDLE1H = "candle1H" CANDLE4H = "candle4H" CANDLE1D = "candle1D" class ConnectionState(Enum): """WebSocket connection states.""" DISCONNECTED = "disconnected" CONNECTING = "connecting" CONNECTED = "connected" AUTHENTICATED = "authenticated" RECONNECTING = "reconnecting" ERROR = "error" @dataclass class OKXSubscription: """OKX subscription configuration.""" channel: str inst_id: str enabled: bool = True def to_dict(self) -> Dict[str, str]: """Convert to OKX subscription format.""" return { "channel": self.channel, "instId": self.inst_id } class OKXWebSocketError(Exception): """Base exception for OKX WebSocket errors.""" pass class OKXAuthenticationError(OKXWebSocketError): """Exception raised when authentication fails.""" pass class OKXConnectionError(OKXWebSocketError): """Exception raised when connection fails.""" pass class OKXWebSocketClient: """ OKX WebSocket client for handling real-time market data. This client manages WebSocket connections to OKX, handles authentication, subscription management, and provides robust error handling with reconnection logic. """ PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public" PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private" def __init__(self, component_name: str = "okx_websocket", ping_interval: float = 25.0, pong_timeout: float = 10.0, max_reconnect_attempts: int = 5, reconnect_delay: float = 5.0, logger = None): """ Initialize OKX WebSocket client. Args: component_name: Name for logging ping_interval: Seconds between ping messages (must be < 30 for OKX) pong_timeout: Seconds to wait for pong response max_reconnect_attempts: Maximum reconnection attempts reconnect_delay: Initial delay between reconnection attempts """ self.component_name = component_name self.ping_interval = ping_interval self.pong_timeout = pong_timeout self.max_reconnect_attempts = max_reconnect_attempts self.reconnect_delay = reconnect_delay # Initialize logger self.logger = logger # Connection management self._websocket: Optional[Any] = None # Changed to Any to handle different websocket types self._connection_state = ConnectionState.DISCONNECTED self._is_authenticated = False self._reconnect_attempts = 0 self._last_ping_time = 0.0 self._last_pong_time = 0.0 # Message handling self._message_callbacks: List[Callable[[Dict[str, Any]], None]] = [] self._subscriptions: Dict[str, OKXSubscription] = {} # Enhanced task management self._ping_task: Optional[asyncio.Task] = None self._message_handler_task: Optional[asyncio.Task] = None self._reconnection_lock = asyncio.Lock() # Prevent concurrent reconnections self._tasks_stopping = False # Flag to prevent task overlap # Statistics self._stats = { 'messages_received': 0, 'messages_sent': 0, 'pings_sent': 0, 'pongs_received': 0, 'reconnections': 0, 'connection_time': None, 'last_message_time': None } if self.logger: self.logger.info(f"{self.component_name}: Initialized OKX WebSocket client") @property def is_connected(self) -> bool: """Check if WebSocket is connected.""" return (self._websocket is not None and self._connection_state == ConnectionState.CONNECTED and self._websocket_is_open()) def _websocket_is_open(self) -> bool: """Check if the WebSocket connection is open.""" if not self._websocket: return False try: # For websockets 11.0+, check the state if hasattr(self._websocket, 'state'): from websockets.protocol import State return self._websocket.state == State.OPEN # Fallback for older versions elif hasattr(self._websocket, 'closed'): return not self._websocket.closed elif hasattr(self._websocket, 'open'): return self._websocket.open else: # If we can't determine the state, assume it's closed return False except Exception: return False @property def connection_state(self) -> ConnectionState: """Get current connection state.""" return self._connection_state async def connect(self, use_public: bool = True) -> bool: """ Connect to OKX WebSocket API. Args: use_public: Use public endpoint (True) or private endpoint (False) Returns: True if connection successful, False otherwise """ if self.is_connected: if self.logger: self.logger.warning("Already connected to OKX WebSocket") return True url = self.PUBLIC_WS_URL if use_public else self.PRIVATE_WS_URL # Try connection with retry logic for attempt in range(self.max_reconnect_attempts): self._connection_state = ConnectionState.CONNECTING try: if self.logger: self.logger.info(f"{self.component_name}: Connecting to OKX WebSocket (attempt {attempt + 1}/{self.max_reconnect_attempts}): {url}") # Create SSL context for secure connection ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE # Connect to WebSocket self._websocket = await websockets.connect( url, ssl=ssl_context, ping_interval=None, # We'll handle ping manually ping_timeout=None, close_timeout=10, max_size=2**20, # 1MB max message size compression=None # Disable compression for better performance ) self._connection_state = ConnectionState.CONNECTED self._stats['connection_time'] = datetime.now(timezone.utc) self._reconnect_attempts = 0 # Start background tasks await self._start_background_tasks() if self.logger: self.logger.info(f"{self.component_name}: Successfully connected to OKX WebSocket") return True except (InvalidURI, InvalidHandshake) as e: if self.logger: self.logger.error(f"{self.component_name}: Invalid WebSocket configuration: {e}") self._connection_state = ConnectionState.ERROR return False except Exception as e: attempt_num = attempt + 1 if self.logger: self.logger.error(f"{self.component_name}: Connection attempt {attempt_num} failed: {e}") if attempt_num < self.max_reconnect_attempts: # Exponential backoff with jitter delay = self.reconnect_delay * (2 ** attempt) + (0.1 * attempt) if self.logger: self.logger.info(f"{self.component_name}: Retrying connection in {delay:.1f} seconds...") await asyncio.sleep(delay) else: if self.logger: self.logger.error(f"{self.component_name}: All {self.max_reconnect_attempts} connection attempts failed") self._connection_state = ConnectionState.ERROR return False return False async def disconnect(self) -> None: """Disconnect from WebSocket.""" if not self._websocket: return if self.logger: self.logger.info(f"{self.component_name}: Disconnecting from OKX WebSocket") self._connection_state = ConnectionState.DISCONNECTED # Cancel background tasks await self._stop_background_tasks() # Close WebSocket connection try: await self._websocket.close() except Exception as e: if self.logger: self.logger.warning(f"{self.component_name}: Error closing WebSocket: {e}") self._websocket = None self._is_authenticated = False if self.logger: self.logger.info(f"{self.component_name}: Disconnected from OKX WebSocket") async def subscribe(self, subscriptions: List[OKXSubscription]) -> bool: """ Subscribe to channels. Args: subscriptions: List of subscription configurations Returns: True if subscription successful, False otherwise """ if not self.is_connected: if self.logger: self.logger.error("Cannot subscribe: WebSocket not connected") return False try: # Build subscription message args = [sub.to_dict() for sub in subscriptions] message = { "op": "subscribe", "args": args } # Send subscription await self._send_message(message) # Store subscriptions for sub in subscriptions: key = f"{sub.channel}:{sub.inst_id}" self._subscriptions[key] = sub if self.logger: self.logger.info(f"{self.component_name}: Subscribed to {len(subscriptions)} channels") return True except Exception as e: if self.logger: self.logger.error(f"{self.component_name}: Failed to subscribe to channels: {e}") return False async def unsubscribe(self, subscriptions: List[OKXSubscription]) -> bool: """ Unsubscribe from channels. Args: subscriptions: List of subscription configurations Returns: True if unsubscription successful, False otherwise """ if not self.is_connected: if self.logger: self.logger.error("Cannot unsubscribe: WebSocket not connected") return False try: # Build unsubscription message args = [sub.to_dict() for sub in subscriptions] message = { "op": "unsubscribe", "args": args } # Send unsubscription await self._send_message(message) # Remove subscriptions for sub in subscriptions: key = f"{sub.channel}:{sub.inst_id}" self._subscriptions.pop(key, None) if self.logger: self.logger.info(f"{self.component_name}: Unsubscribed from {len(subscriptions)} channels") return True except Exception as e: if self.logger: self.logger.error(f"{self.component_name}: Failed to unsubscribe from channels: {e}") return False def add_message_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None: """ Add callback function for processing messages. Args: callback: Function to call when message received """ self._message_callbacks.append(callback) if self.logger: self.logger.debug(f"{self.component_name}: Added message callback: {callback.__name__}") def remove_message_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None: """ Remove message callback. Args: callback: Function to remove """ if callback in self._message_callbacks: self._message_callbacks.remove(callback) if self.logger: self.logger.debug(f"{self.component_name}: Removed message callback: {callback.__name__}") async def _start_background_tasks(self) -> None: """Start background tasks for ping and message handling.""" # Ensure no tasks are currently stopping if self._tasks_stopping: if self.logger: self.logger.warning(f"{self.component_name}: Cannot start tasks while stopping is in progress") return # Cancel any existing tasks first await self._stop_background_tasks() # Start ping task self._ping_task = asyncio.create_task(self._ping_loop()) # Start message handler task self._message_handler_task = asyncio.create_task(self._message_handler()) if self.logger: self.logger.debug(f"{self.component_name}: Started background tasks") async def _stop_background_tasks(self) -> None: """Stop background tasks with proper synchronization.""" self._tasks_stopping = True try: tasks = [] # Collect tasks to cancel if self._ping_task and not self._ping_task.done(): tasks.append(self._ping_task) if self._message_handler_task and not self._message_handler_task.done(): tasks.append(self._message_handler_task) if not tasks: if self.logger: self.logger.debug(f"{self.component_name}: No background tasks to stop") return if self.logger: self.logger.debug(f"{self.component_name}: Stopping {len(tasks)} background tasks") # Cancel all tasks for task in tasks: task.cancel() # Wait for all tasks to complete with timeout if tasks: try: await asyncio.wait_for( asyncio.gather(*tasks, return_exceptions=True), timeout=5.0 ) except asyncio.TimeoutError: if self.logger: self.logger.warning(f"{self.component_name}: Task shutdown timeout - some tasks may still be running") except Exception as e: if self.logger: self.logger.debug(f"{self.component_name}: Expected exception during task shutdown: {e}") # Clear task references self._ping_task = None self._message_handler_task = None if self.logger: self.logger.debug(f"{self.component_name}: Background tasks stopped successfully") finally: self._tasks_stopping = False async def _ping_loop(self) -> None: """Background task for sending ping messages.""" while self.is_connected: try: current_time = time.time() # Send ping if interval elapsed if current_time - self._last_ping_time >= self.ping_interval: await self._send_ping() self._last_ping_time = current_time # Check for pong timeout if (self._last_ping_time > self._last_pong_time and current_time - self._last_ping_time > self.pong_timeout): if self.logger: self.logger.warning(f"{self.component_name}: Pong timeout - connection may be stale") # Don't immediately disconnect, let connection error handling deal with it await asyncio.sleep(1) # Check every second except asyncio.CancelledError: break except Exception as e: if self.logger: self.logger.error(f"{self.component_name}: Error in ping loop: {e}") await asyncio.sleep(5) async def _message_handler(self) -> None: """Background task for handling incoming messages with enhanced error handling.""" if self.logger: self.logger.debug(f"{self.component_name}: Message handler started") try: while self.is_connected and not self._tasks_stopping: try: if not self._websocket or self._tasks_stopping: break # Receive message with timeout try: message = await asyncio.wait_for( self._websocket.recv(), timeout=1.0 ) except asyncio.TimeoutError: continue # No message received, continue loop # Check if we're still supposed to be running if self._tasks_stopping: break # Process message await self._process_message(message) except ConnectionClosed as e: if self._tasks_stopping: break # Expected during shutdown if self.logger: self.logger.warning(f"{self.component_name}: WebSocket connection closed: {e}") self._connection_state = ConnectionState.DISCONNECTED # Use lock to prevent concurrent reconnection attempts async with self._reconnection_lock: # Double-check we still need to reconnect if (self._connection_state == ConnectionState.DISCONNECTED and self._reconnect_attempts < self.max_reconnect_attempts and not self._tasks_stopping): self._reconnect_attempts += 1 if self.logger: self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})") # Stop current tasks properly await self._stop_background_tasks() # Attempt reconnection with stored subscriptions stored_subscriptions = list(self._subscriptions.values()) if await self.reconnect(): if self.logger: self.logger.info(f"{self.component_name}: Automatic reconnection successful") # The reconnect method will restart tasks, so we exit this handler break else: if self.logger: self.logger.error(f"{self.component_name}: Automatic reconnection failed") break else: if self.logger: self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress") break except asyncio.CancelledError: if self.logger: self.logger.debug(f"{self.component_name}: Message handler cancelled") break except Exception as e: if self._tasks_stopping: break if self.logger: self.logger.error(f"{self.component_name}: Error in message handler: {e}") await asyncio.sleep(1) except asyncio.CancelledError: if self.logger: self.logger.debug(f"{self.component_name}: Message handler task cancelled") except Exception as e: if self.logger: self.logger.error(f"{self.component_name}: Fatal error in message handler: {e}") finally: if self.logger: self.logger.debug(f"{self.component_name}: Message handler exiting") async def _send_message(self, message: Dict[str, Any]) -> None: """ Send message to WebSocket. Args: message: Message to send """ if not self.is_connected or not self._websocket: raise OKXConnectionError("WebSocket not connected") try: message_str = json.dumps(message) await self._websocket.send(message_str) self._stats['messages_sent'] += 1 if self.logger: self.logger.debug(f"{self.component_name}: Sent message: {message}") except ConnectionClosed as e: if self.logger: self.logger.error(f"{self.component_name}: Connection closed while sending message: {e}") self._connection_state = ConnectionState.DISCONNECTED raise OKXConnectionError(f"Connection closed: {e}") except Exception as e: if self.logger: self.logger.error(f"{self.component_name}: Failed to send message: {e}") raise OKXConnectionError(f"Failed to send message: {e}") async def _send_ping(self) -> None: """Send ping message to OKX.""" if not self.is_connected or not self._websocket: raise OKXConnectionError("WebSocket not connected") try: # OKX expects a simple "ping" string, not JSON await self._websocket.send("ping") self._stats['pings_sent'] += 1 if self.logger: self.logger.debug(f"{self.component_name}: Sent ping to OKX") except ConnectionClosed as e: if self.logger: self.logger.error(f"{self.component_name}: Connection closed while sending ping: {e}") self._connection_state = ConnectionState.DISCONNECTED raise OKXConnectionError(f"Connection closed: {e}") except Exception as e: if self.logger: self.logger.error(f"{self.component_name}: Failed to send ping: {e}") raise OKXConnectionError(f"Failed to send ping: {e}") async def _process_message(self, message: str) -> None: """ Process incoming message. Args: message: Raw message string """ try: # Update statistics first self._stats['messages_received'] += 1 self._stats['last_message_time'] = datetime.now(timezone.utc) # Handle simple pong response (OKX sends "pong" as plain string) if message.strip() == "pong": self._last_pong_time = time.time() self._stats['pongs_received'] += 1 if self.logger: self.logger.debug(f"{self.component_name}: Received pong from OKX") return # Parse JSON message for all other responses data = json.loads(message) # Handle special messages if data.get('event') == 'pong': self._last_pong_time = time.time() self._stats['pongs_received'] += 1 if self.logger: self.logger.debug(f"{self.component_name}: Received pong from OKX (JSON format)") return # Handle subscription confirmations if data.get('event') == 'subscribe': if self.logger: self.logger.info(f"{self.component_name}: Subscription confirmed: {data}") return if data.get('event') == 'unsubscribe': if self.logger: self.logger.info(f"{self.component_name}: Unsubscription confirmed: {data}") return # Handle error messages if data.get('event') == 'error': if self.logger: self.logger.error(f"{self.component_name}: OKX error: {data}") return # Process data messages if 'data' in data and 'arg' in data: # Notify callbacks for callback in self._message_callbacks: try: callback(data) except Exception as e: if self.logger: self.logger.error(f"{self.component_name}: Error in message callback {callback.__name__}: {e}") except json.JSONDecodeError as e: # Check if it's a simple string response we haven't handled if message.strip() in ["ping", "pong"]: if self.logger: self.logger.debug(f"{self.component_name}: Received simple message: {message.strip()}") if message.strip() == "pong": self._last_pong_time = time.time() self._stats['pongs_received'] += 1 else: if self.logger: self.logger.error(f"{self.component_name}: Failed to parse JSON message: {e}, message: {message}") except Exception as e: if self.logger: self.logger.error(f"{self.component_name}: Error processing message: {e}") def get_stats(self) -> Dict[str, Any]: """Get connection statistics.""" return { **self._stats, 'connection_state': self._connection_state.value, 'is_connected': self.is_connected, 'subscriptions_count': len(self._subscriptions), 'reconnect_attempts': self._reconnect_attempts } def get_subscriptions(self) -> List[Dict[str, str]]: """Get current subscriptions.""" return [sub.to_dict() for sub in self._subscriptions.values()] async def reconnect(self) -> bool: """ Reconnect to WebSocket with enhanced synchronization. Returns: True if reconnection successful, False otherwise """ async with self._reconnection_lock: if self.logger: self.logger.info(f"{self.component_name}: Attempting to reconnect to OKX WebSocket") self._connection_state = ConnectionState.RECONNECTING self._stats['reconnections'] += 1 # Store current subscriptions before disconnect stored_subscriptions = list(self._subscriptions.values()) # Disconnect first with proper cleanup await self.disconnect() # Wait a moment before reconnecting await asyncio.sleep(1) # Attempt to reconnect success = await self.connect() if success: # Re-subscribe to previous subscriptions if stored_subscriptions: if self.logger: self.logger.info(f"{self.component_name}: Re-subscribing to {len(stored_subscriptions)} channels") await self.subscribe(stored_subscriptions) # Reset reconnect attempts on successful reconnection self._reconnect_attempts = 0 return success def __repr__(self) -> str: return f""