Vasily.onl 371c0a4591 Eliminates the "coroutine was never awaited" warnings
 Properly handles lock acquisition with timeout
 Maintains the same functionality (timeout protection for lock acquisition)
 Ensures proper lock cleanup in the finally block
2025-06-03 13:11:51 +08:00

780 lines
31 KiB
Python

"""
OKX WebSocket Client for low-level WebSocket management.
This module provides a robust WebSocket client specifically designed for OKX API,
handling connection management, authentication, keepalive, and message parsing.
"""
import asyncio
import json
import time
import ssl
from datetime import datetime, timezone
from typing import Dict, List, Optional, Any, Callable, Union
from enum import Enum
from dataclasses import dataclass
import websockets
from websockets.exceptions import ConnectionClosed, InvalidHandshake, InvalidURI
class OKXChannelType(Enum):
"""OKX WebSocket channel types."""
TRADES = "trades"
BOOKS5 = "books5"
BOOKS50 = "books50"
BOOKS_TBT = "books-l2-tbt"
TICKERS = "tickers"
CANDLE1M = "candle1m"
CANDLE5M = "candle5m"
CANDLE15M = "candle15m"
CANDLE1H = "candle1H"
CANDLE4H = "candle4H"
CANDLE1D = "candle1D"
class ConnectionState(Enum):
"""WebSocket connection states."""
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
AUTHENTICATED = "authenticated"
RECONNECTING = "reconnecting"
ERROR = "error"
@dataclass
class OKXSubscription:
"""OKX subscription configuration."""
channel: str
inst_id: str
enabled: bool = True
def to_dict(self) -> Dict[str, str]:
"""Convert to OKX subscription format."""
return {
"channel": self.channel,
"instId": self.inst_id
}
class OKXWebSocketError(Exception):
"""Base exception for OKX WebSocket errors."""
pass
class OKXAuthenticationError(OKXWebSocketError):
"""Exception raised when authentication fails."""
pass
class OKXConnectionError(OKXWebSocketError):
"""Exception raised when connection fails."""
pass
class OKXWebSocketClient:
"""
OKX WebSocket client for handling real-time market data.
This client manages WebSocket connections to OKX, handles authentication,
subscription management, and provides robust error handling with reconnection logic.
"""
PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public"
PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private"
def __init__(self,
component_name: str = "okx_websocket",
ping_interval: float = 25.0,
pong_timeout: float = 10.0,
max_reconnect_attempts: int = 5,
reconnect_delay: float = 5.0,
logger = None):
"""
Initialize OKX WebSocket client.
Args:
component_name: Name for logging
ping_interval: Seconds between ping messages (must be < 30 for OKX)
pong_timeout: Seconds to wait for pong response
max_reconnect_attempts: Maximum reconnection attempts
reconnect_delay: Initial delay between reconnection attempts
"""
self.component_name = component_name
self.ping_interval = ping_interval
self.pong_timeout = pong_timeout
self.max_reconnect_attempts = max_reconnect_attempts
self.reconnect_delay = reconnect_delay
# Initialize logger
self.logger = logger
# Connection management
self._websocket: Optional[Any] = None # Changed to Any to handle different websocket types
self._connection_state = ConnectionState.DISCONNECTED
self._is_authenticated = False
self._reconnect_attempts = 0
self._last_ping_time = 0.0
self._last_pong_time = 0.0
# Message handling
self._message_callbacks: List[Callable[[Dict[str, Any]], None]] = []
self._subscriptions: Dict[str, OKXSubscription] = {}
# Enhanced task management
self._ping_task: Optional[asyncio.Task] = None
self._message_handler_task: Optional[asyncio.Task] = None
self._reconnection_lock = asyncio.Lock() # Prevent concurrent reconnections
self._tasks_stopping = False # Flag to prevent task overlap
# Statistics
self._stats = {
'messages_received': 0,
'messages_sent': 0,
'pings_sent': 0,
'pongs_received': 0,
'reconnections': 0,
'connection_time': None,
'last_message_time': None
}
if self.logger:
self.logger.info(f"{self.component_name}: Initialized OKX WebSocket client")
@property
def is_connected(self) -> bool:
"""Check if WebSocket is connected."""
return (self._websocket is not None and
self._connection_state == ConnectionState.CONNECTED and
self._websocket_is_open())
def _websocket_is_open(self) -> bool:
"""Check if the WebSocket connection is open."""
if not self._websocket:
return False
try:
# For websockets 11.0+, check the state
if hasattr(self._websocket, 'state'):
from websockets.protocol import State
return self._websocket.state == State.OPEN
# Fallback for older versions
elif hasattr(self._websocket, 'closed'):
return not self._websocket.closed
elif hasattr(self._websocket, 'open'):
return self._websocket.open
else:
# If we can't determine the state, assume it's closed
return False
except Exception:
return False
@property
def connection_state(self) -> ConnectionState:
"""Get current connection state."""
return self._connection_state
async def connect(self, use_public: bool = True) -> bool:
"""
Connect to OKX WebSocket API.
Args:
use_public: Use public endpoint (True) or private endpoint (False)
Returns:
True if connection successful, False otherwise
"""
if self.is_connected:
if self.logger:
self.logger.warning("Already connected to OKX WebSocket")
return True
url = self.PUBLIC_WS_URL if use_public else self.PRIVATE_WS_URL
# Try connection with retry logic
for attempt in range(self.max_reconnect_attempts):
self._connection_state = ConnectionState.CONNECTING
try:
if self.logger:
self.logger.info(f"{self.component_name}: Connecting to OKX WebSocket (attempt {attempt + 1}/{self.max_reconnect_attempts}): {url}")
# Create SSL context for secure connection
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
# Connect to WebSocket
self._websocket = await websockets.connect(
url,
ssl=ssl_context,
ping_interval=None, # We'll handle ping manually
ping_timeout=None,
close_timeout=10,
max_size=2**20, # 1MB max message size
compression=None # Disable compression for better performance
)
self._connection_state = ConnectionState.CONNECTED
self._stats['connection_time'] = datetime.now(timezone.utc)
self._reconnect_attempts = 0
# Start background tasks
await self._start_background_tasks()
if self.logger:
self.logger.info(f"{self.component_name}: Successfully connected to OKX WebSocket")
return True
except (InvalidURI, InvalidHandshake) as e:
if self.logger:
self.logger.error(f"{self.component_name}: Invalid WebSocket configuration: {e}")
self._connection_state = ConnectionState.ERROR
return False
except Exception as e:
attempt_num = attempt + 1
if self.logger:
self.logger.error(f"{self.component_name}: Connection attempt {attempt_num} failed: {e}")
if attempt_num < self.max_reconnect_attempts:
# Exponential backoff with jitter
delay = self.reconnect_delay * (2 ** attempt) + (0.1 * attempt)
if self.logger:
self.logger.info(f"{self.component_name}: Retrying connection in {delay:.1f} seconds...")
await asyncio.sleep(delay)
else:
if self.logger:
self.logger.error(f"{self.component_name}: All {self.max_reconnect_attempts} connection attempts failed")
self._connection_state = ConnectionState.ERROR
return False
return False
async def disconnect(self) -> None:
"""Disconnect from WebSocket."""
if not self._websocket:
return
if self.logger:
self.logger.info(f"{self.component_name}: Disconnecting from OKX WebSocket")
self._connection_state = ConnectionState.DISCONNECTED
# Cancel background tasks
await self._stop_background_tasks()
# Close WebSocket connection
try:
await self._websocket.close()
except Exception as e:
if self.logger:
self.logger.warning(f"{self.component_name}: Error closing WebSocket: {e}")
self._websocket = None
self._is_authenticated = False
if self.logger:
self.logger.info(f"{self.component_name}: Disconnected from OKX WebSocket")
async def subscribe(self, subscriptions: List[OKXSubscription]) -> bool:
"""
Subscribe to channels.
Args:
subscriptions: List of subscription configurations
Returns:
True if subscription successful, False otherwise
"""
if not self.is_connected:
if self.logger:
self.logger.error("Cannot subscribe: WebSocket not connected")
return False
try:
# Build subscription message
args = [sub.to_dict() for sub in subscriptions]
message = {
"op": "subscribe",
"args": args
}
# Send subscription
await self._send_message(message)
# Store subscriptions
for sub in subscriptions:
key = f"{sub.channel}:{sub.inst_id}"
self._subscriptions[key] = sub
if self.logger:
self.logger.info(f"{self.component_name}: Subscribed to {len(subscriptions)} channels")
return True
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Failed to subscribe to channels: {e}")
return False
async def unsubscribe(self, subscriptions: List[OKXSubscription]) -> bool:
"""
Unsubscribe from channels.
Args:
subscriptions: List of subscription configurations
Returns:
True if unsubscription successful, False otherwise
"""
if not self.is_connected:
if self.logger:
self.logger.error("Cannot unsubscribe: WebSocket not connected")
return False
try:
# Build unsubscription message
args = [sub.to_dict() for sub in subscriptions]
message = {
"op": "unsubscribe",
"args": args
}
# Send unsubscription
await self._send_message(message)
# Remove subscriptions
for sub in subscriptions:
key = f"{sub.channel}:{sub.inst_id}"
self._subscriptions.pop(key, None)
if self.logger:
self.logger.info(f"{self.component_name}: Unsubscribed from {len(subscriptions)} channels")
return True
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Failed to unsubscribe from channels: {e}")
return False
def add_message_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None:
"""
Add callback function for processing messages.
Args:
callback: Function to call when message received
"""
self._message_callbacks.append(callback)
if self.logger:
self.logger.debug(f"{self.component_name}: Added message callback: {callback.__name__}")
def remove_message_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None:
"""
Remove message callback.
Args:
callback: Function to remove
"""
if callback in self._message_callbacks:
self._message_callbacks.remove(callback)
if self.logger:
self.logger.debug(f"{self.component_name}: Removed message callback: {callback.__name__}")
async def _start_background_tasks(self) -> None:
"""Start background tasks for ping and message handling."""
# Ensure no tasks are currently stopping
if self._tasks_stopping:
if self.logger:
self.logger.warning(f"{self.component_name}: Cannot start tasks while stopping is in progress")
return
# Check if tasks are already running
if (self._ping_task and not self._ping_task.done() and
self._message_handler_task and not self._message_handler_task.done()):
if self.logger:
self.logger.debug(f"{self.component_name}: Background tasks already running")
return
# Cancel any existing tasks first (safety measure)
await self._stop_background_tasks()
# Ensure we're still supposed to start tasks after stopping
if self._tasks_stopping or not self.is_connected:
if self.logger:
self.logger.debug(f"{self.component_name}: Aborting task start - stopping or disconnected")
return
try:
# Start ping task
self._ping_task = asyncio.create_task(self._ping_loop())
# Start message handler task
self._message_handler_task = asyncio.create_task(self._message_handler())
if self.logger:
self.logger.debug(f"{self.component_name}: Started background tasks")
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Error starting background tasks: {e}")
# Clean up on failure
await self._stop_background_tasks()
async def _stop_background_tasks(self) -> None:
"""Stop background tasks with proper synchronization - simplified approach."""
self._tasks_stopping = True
try:
# Collect tasks to cancel
tasks_to_cancel = []
if self._ping_task and not self._ping_task.done():
tasks_to_cancel.append(('ping_task', self._ping_task))
if self._message_handler_task and not self._message_handler_task.done():
tasks_to_cancel.append(('message_handler_task', self._message_handler_task))
if not tasks_to_cancel:
if self.logger:
self.logger.debug(f"{self.component_name}: No background tasks to stop")
return
if self.logger:
self.logger.debug(f"{self.component_name}: Stopping {len(tasks_to_cancel)} background tasks")
# Cancel tasks individually to avoid recursion
for task_name, task in tasks_to_cancel:
try:
if not task.done():
task.cancel()
if self.logger:
self.logger.debug(f"{self.component_name}: Cancelled {task_name}")
except Exception as e:
if self.logger:
self.logger.debug(f"{self.component_name}: Error cancelling {task_name}: {e}")
# Wait for tasks to complete individually with shorter timeouts
for task_name, task in tasks_to_cancel:
try:
await asyncio.wait_for(task, timeout=2.0)
except asyncio.TimeoutError:
if self.logger:
self.logger.warning(f"{self.component_name}: {task_name} shutdown timeout")
except asyncio.CancelledError:
# Expected when task is cancelled
pass
except Exception as e:
if self.logger:
self.logger.debug(f"{self.component_name}: {task_name} shutdown exception: {e}")
# Clear task references
self._ping_task = None
self._message_handler_task = None
if self.logger:
self.logger.debug(f"{self.component_name}: Background tasks stopped successfully")
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Error in _stop_background_tasks: {e}")
finally:
self._tasks_stopping = False
async def _ping_loop(self) -> None:
"""Background task for sending ping messages."""
while self.is_connected:
try:
current_time = time.time()
# Send ping if interval elapsed
if current_time - self._last_ping_time >= self.ping_interval:
await self._send_ping()
self._last_ping_time = current_time
# Check for pong timeout
if (self._last_ping_time > self._last_pong_time and
current_time - self._last_ping_time > self.pong_timeout):
if self.logger:
self.logger.warning(f"{self.component_name}: Pong timeout - connection may be stale")
# Don't immediately disconnect, let connection error handling deal with it
await asyncio.sleep(1) # Check every second
except asyncio.CancelledError:
break
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Error in ping loop: {e}")
await asyncio.sleep(5)
async def _message_handler(self) -> None:
"""Background task for handling incoming messages with enhanced error handling."""
if self.logger:
self.logger.debug(f"{self.component_name}: Message handler started")
try:
while self.is_connected and not self._tasks_stopping:
try:
if not self._websocket or self._tasks_stopping:
break
# Receive message with timeout
try:
message = await asyncio.wait_for(
self._websocket.recv(),
timeout=1.0
)
except asyncio.TimeoutError:
continue # No message received, continue loop
except asyncio.CancelledError:
# Exit immediately on cancellation
break
# Check if we're still supposed to be running
if self._tasks_stopping:
break
# Process message
await self._process_message(message)
except ConnectionClosed as e:
if self._tasks_stopping:
break # Expected during shutdown
if self.logger:
self.logger.warning(f"{self.component_name}: WebSocket connection closed: {e}")
self._connection_state = ConnectionState.DISCONNECTED
# Use lock to prevent concurrent reconnection attempts
try:
# Properly acquire lock with timeout
await asyncio.wait_for(self._reconnection_lock.acquire(), timeout=5.0)
try:
# Double-check we still need to reconnect
if (self._connection_state == ConnectionState.DISCONNECTED and
self._reconnect_attempts < self.max_reconnect_attempts and
not self._tasks_stopping):
self._reconnect_attempts += 1
if self.logger:
self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
# Attempt reconnection (this will handle task cleanup)
if await self.reconnect():
if self.logger:
self.logger.info(f"{self.component_name}: Automatic reconnection successful")
# Exit this handler as reconnect will start new tasks
break
else:
if self.logger:
self.logger.error(f"{self.component_name}: Automatic reconnection failed")
break
else:
if self.logger:
self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress")
break
finally:
self._reconnection_lock.release()
except asyncio.TimeoutError:
if self.logger:
self.logger.warning(f"{self.component_name}: Timeout acquiring reconnection lock")
break
except asyncio.CancelledError:
# Exit immediately on cancellation
break
except asyncio.CancelledError:
if self.logger:
self.logger.debug(f"{self.component_name}: Message handler cancelled")
break
except Exception as e:
if self._tasks_stopping:
break
if self.logger:
self.logger.error(f"{self.component_name}: Error in message handler: {e}")
await asyncio.sleep(1)
except asyncio.CancelledError:
if self.logger:
self.logger.debug(f"{self.component_name}: Message handler task cancelled")
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Fatal error in message handler: {e}")
finally:
if self.logger:
self.logger.debug(f"{self.component_name}: Message handler exiting")
async def _send_message(self, message: Dict[str, Any]) -> None:
"""
Send message to WebSocket.
Args:
message: Message to send
"""
if not self.is_connected or not self._websocket:
raise OKXConnectionError("WebSocket not connected")
try:
message_str = json.dumps(message)
await self._websocket.send(message_str)
self._stats['messages_sent'] += 1
if self.logger:
self.logger.debug(f"{self.component_name}: Sent message: {message}")
except ConnectionClosed as e:
if self.logger:
self.logger.error(f"{self.component_name}: Connection closed while sending message: {e}")
self._connection_state = ConnectionState.DISCONNECTED
raise OKXConnectionError(f"Connection closed: {e}")
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Failed to send message: {e}")
raise OKXConnectionError(f"Failed to send message: {e}")
async def _send_ping(self) -> None:
"""Send ping message to OKX."""
if not self.is_connected or not self._websocket:
raise OKXConnectionError("WebSocket not connected")
try:
# OKX expects a simple "ping" string, not JSON
await self._websocket.send("ping")
self._stats['pings_sent'] += 1
if self.logger:
self.logger.debug(f"{self.component_name}: Sent ping to OKX")
except ConnectionClosed as e:
if self.logger:
self.logger.error(f"{self.component_name}: Connection closed while sending ping: {e}")
self._connection_state = ConnectionState.DISCONNECTED
raise OKXConnectionError(f"Connection closed: {e}")
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Failed to send ping: {e}")
raise OKXConnectionError(f"Failed to send ping: {e}")
async def _process_message(self, message: str) -> None:
"""
Process incoming message.
Args:
message: Raw message string
"""
try:
# Update statistics first
self._stats['messages_received'] += 1
self._stats['last_message_time'] = datetime.now(timezone.utc)
# Handle simple pong response (OKX sends "pong" as plain string)
if message.strip() == "pong":
self._last_pong_time = time.time()
self._stats['pongs_received'] += 1
if self.logger:
self.logger.debug(f"{self.component_name}: Received pong from OKX")
return
# Parse JSON message for all other responses
data = json.loads(message)
# Handle special messages
if data.get('event') == 'pong':
self._last_pong_time = time.time()
self._stats['pongs_received'] += 1
if self.logger:
self.logger.debug(f"{self.component_name}: Received pong from OKX (JSON format)")
return
# Handle subscription confirmations
if data.get('event') == 'subscribe':
if self.logger:
self.logger.info(f"{self.component_name}: Subscription confirmed: {data}")
return
if data.get('event') == 'unsubscribe':
if self.logger:
self.logger.info(f"{self.component_name}: Unsubscription confirmed: {data}")
return
# Handle error messages
if data.get('event') == 'error':
if self.logger:
self.logger.error(f"{self.component_name}: OKX error: {data}")
return
# Process data messages
if 'data' in data and 'arg' in data:
# Notify callbacks
for callback in self._message_callbacks:
try:
callback(data)
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Error in message callback {callback.__name__}: {e}")
except json.JSONDecodeError as e:
# Check if it's a simple string response we haven't handled
if message.strip() in ["ping", "pong"]:
if self.logger:
self.logger.debug(f"{self.component_name}: Received simple message: {message.strip()}")
if message.strip() == "pong":
self._last_pong_time = time.time()
self._stats['pongs_received'] += 1
else:
if self.logger:
self.logger.error(f"{self.component_name}: Failed to parse JSON message: {e}, message: {message}")
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Error processing message: {e}")
def get_stats(self) -> Dict[str, Any]:
"""Get connection statistics."""
return {
**self._stats,
'connection_state': self._connection_state.value,
'is_connected': self.is_connected,
'subscriptions_count': len(self._subscriptions),
'reconnect_attempts': self._reconnect_attempts
}
def get_subscriptions(self) -> List[Dict[str, str]]:
"""Get current subscriptions."""
return [sub.to_dict() for sub in self._subscriptions.values()]
async def reconnect(self) -> bool:
"""
Reconnect to WebSocket with enhanced synchronization.
Returns:
True if reconnection successful, False otherwise
"""
async with self._reconnection_lock:
if self.logger:
self.logger.info(f"{self.component_name}: Attempting to reconnect to OKX WebSocket")
self._connection_state = ConnectionState.RECONNECTING
self._stats['reconnections'] += 1
# Store current subscriptions before disconnect
stored_subscriptions = list(self._subscriptions.values())
# Disconnect first with proper cleanup
await self.disconnect()
# Wait a moment before reconnecting
await asyncio.sleep(1)
# Attempt to reconnect
success = await self.connect()
if success:
# Re-subscribe to previous subscriptions
if stored_subscriptions:
if self.logger:
self.logger.info(f"{self.component_name}: Re-subscribing to {len(stored_subscriptions)} channels")
await self.subscribe(stored_subscriptions)
# Reset reconnect attempts on successful reconnection
self._reconnect_attempts = 0
return success
def __repr__(self) -> str:
return f"<OKXWebSocketClient(state={self._connection_state.value}, subscriptions={len(self._subscriptions)})>"