Add OKX data collector implementation and modular exchange architecture

- Introduced the `OKXCollector` and `OKXWebSocketClient` classes for real-time market data collection from the OKX exchange.
- Implemented a factory pattern for creating exchange-specific collectors, enhancing modularity and scalability.
- Added configuration support for the OKX collector in `config/okx_config.json`.
- Updated documentation to reflect the new modular architecture and provide guidance on using the OKX collector.
- Created unit tests for the OKX collector and exchange factory to ensure functionality and reliability.
- Enhanced logging and error handling throughout the new implementation for improved monitoring and debugging.
This commit is contained in:
Vasily.onl
2025-05-31 20:49:31 +08:00
parent 4936e5cd73
commit 4510181b39
16 changed files with 3221 additions and 109 deletions

View File

@@ -0,0 +1,39 @@
"""
Exchange-specific data collectors.
This package contains implementations for different cryptocurrency exchanges,
each organized in its own subfolder with standardized interfaces.
"""
from .okx import OKXCollector, OKXWebSocketClient
from .factory import ExchangeFactory, ExchangeCollectorConfig, create_okx_collector
from .registry import get_supported_exchanges, get_exchange_info
__all__ = [
'OKXCollector',
'OKXWebSocketClient',
'ExchangeFactory',
'ExchangeCollectorConfig',
'create_okx_collector',
'get_supported_exchanges',
'get_exchange_info',
]
# Exchange registry for factory pattern
EXCHANGE_REGISTRY = {
'okx': {
'collector': 'data.exchanges.okx.collector.OKXCollector',
'websocket': 'data.exchanges.okx.websocket.OKXWebSocketClient',
'name': 'OKX',
'supported_pairs': ['BTC-USDT', 'ETH-USDT', 'SOL-USDT', 'DOGE-USDT', 'TON-USDT'],
'supported_data_types': ['trade', 'orderbook', 'ticker', 'candles']
}
}
def get_supported_exchanges():
"""Get list of supported exchange names."""
return list(EXCHANGE_REGISTRY.keys())
def get_exchange_info(exchange_name: str):
"""Get information about a specific exchange."""
return EXCHANGE_REGISTRY.get(exchange_name.lower())

196
data/exchanges/factory.py Normal file
View File

@@ -0,0 +1,196 @@
"""
Exchange Factory for creating data collectors.
This module provides a factory pattern for creating data collectors
from different exchanges based on configuration.
"""
import importlib
from typing import Dict, List, Optional, Any, Type
from dataclasses import dataclass
from ..base_collector import BaseDataCollector, DataType
from .registry import EXCHANGE_REGISTRY, get_supported_exchanges, get_exchange_info
@dataclass
class ExchangeCollectorConfig:
"""Configuration for creating an exchange collector."""
exchange: str
symbol: str
data_types: List[DataType]
auto_restart: bool = True
health_check_interval: float = 30.0
store_raw_data: bool = True
custom_params: Optional[Dict[str, Any]] = None
class ExchangeFactory:
"""Factory for creating exchange-specific data collectors."""
@staticmethod
def create_collector(config: ExchangeCollectorConfig) -> BaseDataCollector:
"""
Create a data collector for the specified exchange.
Args:
config: Configuration for the collector
Returns:
Instance of the appropriate collector class
Raises:
ValueError: If exchange is not supported
ImportError: If collector class cannot be imported
"""
exchange_name = config.exchange.lower()
if exchange_name not in EXCHANGE_REGISTRY:
supported = get_supported_exchanges()
raise ValueError(f"Exchange '{config.exchange}' not supported. "
f"Supported exchanges: {supported}")
exchange_info = get_exchange_info(exchange_name)
collector_class_path = exchange_info['collector']
# Parse module and class name
module_path, class_name = collector_class_path.rsplit('.', 1)
try:
# Import the module
module = importlib.import_module(module_path)
# Get the collector class
collector_class = getattr(module, class_name)
# Prepare collector arguments
collector_args = {
'symbol': config.symbol,
'data_types': config.data_types,
'auto_restart': config.auto_restart,
'health_check_interval': config.health_check_interval,
'store_raw_data': config.store_raw_data
}
# Add any custom parameters
if config.custom_params:
collector_args.update(config.custom_params)
# Create and return the collector instance
return collector_class(**collector_args)
except ImportError as e:
raise ImportError(f"Failed to import collector class '{collector_class_path}': {e}")
except Exception as e:
raise RuntimeError(f"Failed to create collector for '{config.exchange}': {e}")
@staticmethod
def create_multiple_collectors(configs: List[ExchangeCollectorConfig]) -> List[BaseDataCollector]:
"""
Create multiple collectors from a list of configurations.
Args:
configs: List of collector configurations
Returns:
List of collector instances
"""
collectors = []
for config in configs:
try:
collector = ExchangeFactory.create_collector(config)
collectors.append(collector)
except Exception as e:
# Log error but continue with other collectors
print(f"Failed to create collector for {config.exchange} {config.symbol}: {e}")
return collectors
@staticmethod
def get_supported_pairs(exchange: str) -> List[str]:
"""
Get supported trading pairs for an exchange.
Args:
exchange: Exchange name
Returns:
List of supported trading pairs
"""
exchange_info = get_exchange_info(exchange)
if exchange_info:
return exchange_info.get('supported_pairs', [])
return []
@staticmethod
def get_supported_data_types(exchange: str) -> List[str]:
"""
Get supported data types for an exchange.
Args:
exchange: Exchange name
Returns:
List of supported data types
"""
exchange_info = get_exchange_info(exchange)
if exchange_info:
return exchange_info.get('supported_data_types', [])
return []
@staticmethod
def validate_config(config: ExchangeCollectorConfig) -> bool:
"""
Validate collector configuration.
Args:
config: Configuration to validate
Returns:
True if valid, False otherwise
"""
# Check if exchange is supported
if config.exchange.lower() not in EXCHANGE_REGISTRY:
return False
# Check if symbol is supported
supported_pairs = ExchangeFactory.get_supported_pairs(config.exchange)
if supported_pairs and config.symbol not in supported_pairs:
return False
# Check if data types are supported
supported_data_types = ExchangeFactory.get_supported_data_types(config.exchange)
if supported_data_types:
for data_type in config.data_types:
if data_type.value not in supported_data_types:
return False
return True
def create_okx_collector(symbol: str,
data_types: Optional[List[DataType]] = None,
**kwargs) -> BaseDataCollector:
"""
Convenience function to create an OKX collector.
Args:
symbol: Trading pair symbol (e.g., 'BTC-USDT')
data_types: List of data types to collect
**kwargs: Additional collector parameters
Returns:
OKX collector instance
"""
if data_types is None:
data_types = [DataType.TRADE, DataType.ORDERBOOK]
config = ExchangeCollectorConfig(
exchange='okx',
symbol=symbol,
data_types=data_types,
**kwargs
)
return ExchangeFactory.create_collector(config)

View File

@@ -0,0 +1,14 @@
"""
OKX Exchange integration.
This module provides OKX-specific implementations for data collection,
including WebSocket client and data collector classes.
"""
from .collector import OKXCollector
from .websocket import OKXWebSocketClient
__all__ = [
'OKXCollector',
'OKXWebSocketClient',
]

View File

@@ -0,0 +1,485 @@
"""
OKX Data Collector implementation.
This module provides the main OKX data collector class that extends BaseDataCollector,
handling real-time market data collection for a single trading pair with robust
error handling, health monitoring, and database integration.
"""
import asyncio
from datetime import datetime, timezone
from decimal import Decimal
from typing import Dict, List, Optional, Any, Set
from dataclasses import dataclass
from ...base_collector import (
BaseDataCollector, DataType, CollectorStatus, MarketDataPoint,
OHLCVData, DataValidationError, ConnectionError
)
from .websocket import (
OKXWebSocketClient, OKXSubscription, OKXChannelType,
ConnectionState, OKXWebSocketError
)
from database.connection import get_db_manager, get_raw_data_manager
from database.models import MarketData, RawTrade
from utils.logger import get_logger
@dataclass
class OKXMarketData:
"""OKX-specific market data structure."""
symbol: str
timestamp: datetime
data_type: str
channel: str
raw_data: Dict[str, Any]
class OKXCollector(BaseDataCollector):
"""
OKX data collector for real-time market data.
This collector handles a single trading pair and collects real-time data
including trades, orderbook, and ticker information from OKX exchange.
"""
def __init__(self,
symbol: str,
data_types: Optional[List[DataType]] = None,
component_name: Optional[str] = None,
auto_restart: bool = True,
health_check_interval: float = 30.0,
store_raw_data: bool = True):
"""
Initialize OKX collector for a single trading pair.
Args:
symbol: Trading symbol (e.g., 'BTC-USDT')
data_types: Types of data to collect (default: [DataType.TRADE, DataType.ORDERBOOK])
component_name: Name for logging (default: f'okx_collector_{symbol}')
auto_restart: Enable automatic restart on failures
health_check_interval: Seconds between health checks
store_raw_data: Whether to store raw data for debugging
"""
# Default data types if not specified
if data_types is None:
data_types = [DataType.TRADE, DataType.ORDERBOOK]
# Component name for logging
if component_name is None:
component_name = f"okx_collector_{symbol.replace('-', '_').lower()}"
# Initialize base collector
super().__init__(
exchange_name="okx",
symbols=[symbol],
data_types=data_types,
component_name=component_name,
auto_restart=auto_restart,
health_check_interval=health_check_interval
)
# OKX-specific settings
self.symbol = symbol
self.store_raw_data = store_raw_data
# WebSocket client
self._ws_client: Optional[OKXWebSocketClient] = None
# Database managers
self._db_manager = None
self._raw_data_manager = None
# Data processing
self._message_buffer: List[Dict[str, Any]] = []
self._last_trade_id: Optional[str] = None
self._last_orderbook_ts: Optional[int] = None
# OKX channel mapping
self._channel_mapping = {
DataType.TRADE: OKXChannelType.TRADES.value,
DataType.ORDERBOOK: OKXChannelType.BOOKS5.value,
DataType.TICKER: OKXChannelType.TICKERS.value
}
self.logger.info(f"Initialized OKX collector for {symbol} with data types: {[dt.value for dt in data_types]}")
async def connect(self) -> bool:
"""
Establish connection to OKX WebSocket API.
Returns:
True if connection successful, False otherwise
"""
try:
self.logger.info(f"Connecting OKX collector for {self.symbol}")
# Initialize database managers
self._db_manager = get_db_manager()
if self.store_raw_data:
self._raw_data_manager = get_raw_data_manager()
# Create WebSocket client
ws_component_name = f"okx_ws_{self.symbol.replace('-', '_').lower()}"
self._ws_client = OKXWebSocketClient(
component_name=ws_component_name,
ping_interval=25.0,
pong_timeout=10.0,
max_reconnect_attempts=5,
reconnect_delay=5.0
)
# Add message callback
self._ws_client.add_message_callback(self._on_message)
# Connect to WebSocket
if not await self._ws_client.connect(use_public=True):
self.logger.error("Failed to connect to OKX WebSocket")
return False
self.logger.info(f"Successfully connected OKX collector for {self.symbol}")
return True
except Exception as e:
self.logger.error(f"Error connecting OKX collector for {self.symbol}: {e}")
return False
async def disconnect(self) -> None:
"""Disconnect from OKX WebSocket API."""
try:
self.logger.info(f"Disconnecting OKX collector for {self.symbol}")
if self._ws_client:
await self._ws_client.disconnect()
self._ws_client = None
self.logger.info(f"Disconnected OKX collector for {self.symbol}")
except Exception as e:
self.logger.error(f"Error disconnecting OKX collector for {self.symbol}: {e}")
async def subscribe_to_data(self, symbols: List[str], data_types: List[DataType]) -> bool:
"""
Subscribe to data streams for specified symbols and data types.
Args:
symbols: Trading symbols to subscribe to (should contain self.symbol)
data_types: Types of data to subscribe to
Returns:
True if subscription successful, False otherwise
"""
if not self._ws_client or not self._ws_client.is_connected:
self.logger.error("WebSocket client not connected")
return False
# Validate symbol
if self.symbol not in symbols:
self.logger.warning(f"Symbol {self.symbol} not in subscription list: {symbols}")
return False
try:
# Build subscriptions
subscriptions = []
for data_type in data_types:
if data_type in self._channel_mapping:
channel = self._channel_mapping[data_type]
subscription = OKXSubscription(
channel=channel,
inst_id=self.symbol,
enabled=True
)
subscriptions.append(subscription)
self.logger.debug(f"Added subscription: {channel} for {self.symbol}")
else:
self.logger.warning(f"Unsupported data type: {data_type}")
if not subscriptions:
self.logger.warning("No valid subscriptions to create")
return False
# Subscribe to channels
success = await self._ws_client.subscribe(subscriptions)
if success:
self.logger.info(f"Successfully subscribed to {len(subscriptions)} channels for {self.symbol}")
else:
self.logger.error(f"Failed to subscribe to channels for {self.symbol}")
return success
except Exception as e:
self.logger.error(f"Error subscribing to data for {self.symbol}: {e}")
return False
async def unsubscribe_from_data(self, symbols: List[str], data_types: List[DataType]) -> bool:
"""
Unsubscribe from data streams for specified symbols and data types.
Args:
symbols: Trading symbols to unsubscribe from
data_types: Types of data to unsubscribe from
Returns:
True if unsubscription successful, False otherwise
"""
if not self._ws_client or not self._ws_client.is_connected:
self.logger.warning("WebSocket client not connected for unsubscription")
return True # Consider it successful if already disconnected
try:
# Build unsubscriptions
subscriptions = []
for data_type in data_types:
if data_type in self._channel_mapping:
channel = self._channel_mapping[data_type]
subscription = OKXSubscription(
channel=channel,
inst_id=self.symbol,
enabled=False
)
subscriptions.append(subscription)
if not subscriptions:
return True
# Unsubscribe from channels
success = await self._ws_client.unsubscribe(subscriptions)
if success:
self.logger.info(f"Successfully unsubscribed from {len(subscriptions)} channels for {self.symbol}")
else:
self.logger.warning(f"Failed to unsubscribe from channels for {self.symbol}")
return success
except Exception as e:
self.logger.error(f"Error unsubscribing from data for {self.symbol}: {e}")
return False
async def _process_message(self, message: Any) -> Optional[MarketDataPoint]:
"""
Process incoming message from OKX WebSocket.
Args:
message: Raw message from WebSocket
Returns:
Processed MarketDataPoint or None if processing failed
"""
try:
if not isinstance(message, dict):
self.logger.warning(f"Unexpected message type: {type(message)}")
return None
# Extract channel and data
arg = message.get('arg', {})
channel = arg.get('channel')
inst_id = arg.get('instId')
data_list = message.get('data', [])
# Validate message structure
if not channel or not inst_id or not data_list:
self.logger.debug(f"Incomplete message structure: {message}")
return None
# Check if this message is for our symbol
if inst_id != self.symbol:
self.logger.debug(f"Message for different symbol: {inst_id} (expected: {self.symbol})")
return None
# Process each data item
market_data_points = []
for data_item in data_list:
data_point = await self._process_data_item(channel, data_item)
if data_point:
market_data_points.append(data_point)
# Store raw data if enabled
if self.store_raw_data and self._raw_data_manager:
await self._store_raw_data(channel, message)
# Return the first processed data point (for the base class interface)
return market_data_points[0] if market_data_points else None
except Exception as e:
self.logger.error(f"Error processing message for {self.symbol}: {e}")
return None
async def _handle_messages(self) -> None:
"""
Handle incoming messages from WebSocket.
This is called by the base class message loop.
"""
# The actual message handling is done through the WebSocket client callback
# This method satisfies the abstract method requirement
if self._ws_client and self._ws_client.is_connected:
# Just sleep briefly to yield control
await asyncio.sleep(0.1)
else:
# If not connected, sleep longer to avoid busy loop
await asyncio.sleep(1.0)
async def _process_data_item(self, channel: str, data_item: Dict[str, Any]) -> Optional[MarketDataPoint]:
"""
Process individual data item from OKX message.
Args:
channel: OKX channel name
data_item: Individual data item
Returns:
Processed MarketDataPoint or None
"""
try:
# Determine data type from channel
data_type = None
for dt, ch in self._channel_mapping.items():
if ch == channel:
data_type = dt
break
if not data_type:
self.logger.warning(f"Unknown channel: {channel}")
return None
# Extract timestamp
timestamp_ms = data_item.get('ts')
if timestamp_ms:
timestamp = datetime.fromtimestamp(int(timestamp_ms) / 1000, tz=timezone.utc)
else:
timestamp = datetime.now(timezone.utc)
# Create MarketDataPoint
market_data_point = MarketDataPoint(
exchange="okx",
symbol=self.symbol,
timestamp=timestamp,
data_type=data_type,
data=data_item
)
# Store processed data to database
await self._store_processed_data(market_data_point)
# Update statistics
self._stats['messages_processed'] += 1
self._stats['last_message_time'] = timestamp
return market_data_point
except Exception as e:
self.logger.error(f"Error processing data item for {self.symbol}: {e}")
self._stats['errors'] += 1
return None
async def _store_processed_data(self, data_point: MarketDataPoint) -> None:
"""
Store processed data to MarketData table.
Args:
data_point: Processed market data point
"""
try:
# For now, we'll focus on trade data storage
# Orderbook and ticker storage can be added later
if data_point.data_type == DataType.TRADE:
await self._store_trade_data(data_point)
except Exception as e:
self.logger.error(f"Error storing processed data for {self.symbol}: {e}")
async def _store_trade_data(self, data_point: MarketDataPoint) -> None:
"""
Store trade data to database.
Args:
data_point: Trade data point
"""
try:
if not self._db_manager:
return
trade_data = data_point.data
# Extract trade information
trade_id = trade_data.get('tradeId')
price = Decimal(str(trade_data.get('px', '0')))
size = Decimal(str(trade_data.get('sz', '0')))
side = trade_data.get('side', 'unknown')
# Skip duplicate trades
if trade_id == self._last_trade_id:
return
self._last_trade_id = trade_id
# For now, we'll log the trade data
# Actual database storage will be implemented in the next phase
self.logger.debug(f"Trade: {self.symbol} - {side} {size} @ {price} (ID: {trade_id})")
except Exception as e:
self.logger.error(f"Error storing trade data for {self.symbol}: {e}")
async def _store_raw_data(self, channel: str, raw_message: Dict[str, Any]) -> None:
"""
Store raw data for debugging and compliance.
Args:
channel: OKX channel name
raw_message: Complete raw message
"""
try:
if not self._raw_data_manager:
return
# Store raw data using the raw data manager
self._raw_data_manager.store_raw_data(
exchange="okx",
symbol=self.symbol,
data_type=channel,
raw_data=raw_message,
timestamp=datetime.now(timezone.utc)
)
except Exception as e:
self.logger.error(f"Error storing raw data for {self.symbol}: {e}")
def _on_message(self, message: Dict[str, Any]) -> None:
"""
Callback function for WebSocket messages.
Args:
message: Message received from WebSocket
"""
try:
# Add message to buffer for processing
self._message_buffer.append(message)
# Process message asynchronously
asyncio.create_task(self._process_message(message))
except Exception as e:
self.logger.error(f"Error in message callback for {self.symbol}: {e}")
def get_status(self) -> Dict[str, Any]:
"""Get collector status including WebSocket client status."""
base_status = super().get_status()
# Add OKX-specific status
okx_status = {
'symbol': self.symbol,
'websocket_connected': self._ws_client.is_connected if self._ws_client else False,
'websocket_state': self._ws_client.connection_state.value if self._ws_client else 'disconnected',
'last_trade_id': self._last_trade_id,
'message_buffer_size': len(self._message_buffer),
'store_raw_data': self.store_raw_data
}
# Add WebSocket stats if available
if self._ws_client:
okx_status['websocket_stats'] = self._ws_client.get_stats()
return {**base_status, **okx_status}
def __repr__(self) -> str:
return f"<OKXCollector(symbol={self.symbol}, status={self.status.value}, data_types={[dt.value for dt in self.data_types]})>"

View File

@@ -0,0 +1,614 @@
"""
OKX WebSocket Client for low-level WebSocket management.
This module provides a robust WebSocket client specifically designed for OKX API,
handling connection management, authentication, keepalive, and message parsing.
"""
import asyncio
import json
import time
import ssl
from datetime import datetime, timezone
from typing import Dict, List, Optional, Any, Callable, Union
from enum import Enum
from dataclasses import dataclass
import websockets
from websockets.exceptions import ConnectionClosed, InvalidHandshake, InvalidURI
from utils.logger import get_logger
class OKXChannelType(Enum):
"""OKX WebSocket channel types."""
TRADES = "trades"
BOOKS5 = "books5"
BOOKS50 = "books50"
BOOKS_TBT = "books-l2-tbt"
TICKERS = "tickers"
CANDLE1M = "candle1m"
CANDLE5M = "candle5m"
CANDLE15M = "candle15m"
CANDLE1H = "candle1H"
CANDLE4H = "candle4H"
CANDLE1D = "candle1D"
class ConnectionState(Enum):
"""WebSocket connection states."""
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
AUTHENTICATED = "authenticated"
RECONNECTING = "reconnecting"
ERROR = "error"
@dataclass
class OKXSubscription:
"""OKX subscription configuration."""
channel: str
inst_id: str
enabled: bool = True
def to_dict(self) -> Dict[str, str]:
"""Convert to OKX subscription format."""
return {
"channel": self.channel,
"instId": self.inst_id
}
class OKXWebSocketError(Exception):
"""Base exception for OKX WebSocket errors."""
pass
class OKXAuthenticationError(OKXWebSocketError):
"""Exception raised when authentication fails."""
pass
class OKXConnectionError(OKXWebSocketError):
"""Exception raised when connection fails."""
pass
class OKXWebSocketClient:
"""
OKX WebSocket client for handling real-time market data.
This client manages WebSocket connections to OKX, handles authentication,
subscription management, and provides robust error handling with reconnection logic.
"""
PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public"
PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private"
def __init__(self,
component_name: str = "okx_websocket",
ping_interval: float = 25.0,
pong_timeout: float = 10.0,
max_reconnect_attempts: int = 5,
reconnect_delay: float = 5.0):
"""
Initialize OKX WebSocket client.
Args:
component_name: Name for logging
ping_interval: Seconds between ping messages (must be < 30 for OKX)
pong_timeout: Seconds to wait for pong response
max_reconnect_attempts: Maximum reconnection attempts
reconnect_delay: Initial delay between reconnection attempts
"""
self.component_name = component_name
self.ping_interval = ping_interval
self.pong_timeout = pong_timeout
self.max_reconnect_attempts = max_reconnect_attempts
self.reconnect_delay = reconnect_delay
# Initialize logger
self.logger = get_logger(self.component_name, verbose=True)
# Connection management
self._websocket: Optional[Any] = None # Changed to Any to handle different websocket types
self._connection_state = ConnectionState.DISCONNECTED
self._is_authenticated = False
self._reconnect_attempts = 0
self._last_ping_time = 0.0
self._last_pong_time = 0.0
# Message handling
self._message_callbacks: List[Callable[[Dict[str, Any]], None]] = []
self._subscriptions: Dict[str, OKXSubscription] = {}
# Tasks
self._ping_task: Optional[asyncio.Task] = None
self._message_handler_task: Optional[asyncio.Task] = None
# Statistics
self._stats = {
'messages_received': 0,
'messages_sent': 0,
'pings_sent': 0,
'pongs_received': 0,
'reconnections': 0,
'connection_time': None,
'last_message_time': None
}
self.logger.info(f"Initialized OKX WebSocket client: {component_name}")
@property
def is_connected(self) -> bool:
"""Check if WebSocket is connected."""
return (self._websocket is not None and
self._connection_state == ConnectionState.CONNECTED and
self._websocket_is_open())
def _websocket_is_open(self) -> bool:
"""Check if the WebSocket connection is open."""
if not self._websocket:
return False
try:
# For websockets 11.0+, check the state
if hasattr(self._websocket, 'state'):
from websockets.protocol import State
return self._websocket.state == State.OPEN
# Fallback for older versions
elif hasattr(self._websocket, 'closed'):
return not self._websocket.closed
elif hasattr(self._websocket, 'open'):
return self._websocket.open
else:
# If we can't determine the state, assume it's closed
return False
except Exception:
return False
@property
def connection_state(self) -> ConnectionState:
"""Get current connection state."""
return self._connection_state
async def connect(self, use_public: bool = True) -> bool:
"""
Connect to OKX WebSocket API.
Args:
use_public: Use public endpoint (True) or private endpoint (False)
Returns:
True if connection successful, False otherwise
"""
if self.is_connected:
self.logger.warning("Already connected to OKX WebSocket")
return True
url = self.PUBLIC_WS_URL if use_public else self.PRIVATE_WS_URL
# Try connection with retry logic
for attempt in range(self.max_reconnect_attempts):
self._connection_state = ConnectionState.CONNECTING
try:
self.logger.info(f"Connecting to OKX WebSocket (attempt {attempt + 1}/{self.max_reconnect_attempts}): {url}")
# Create SSL context for secure connection
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
# Connect to WebSocket
self._websocket = await websockets.connect(
url,
ssl=ssl_context,
ping_interval=None, # We'll handle ping manually
ping_timeout=None,
close_timeout=10,
max_size=2**20, # 1MB max message size
compression=None # Disable compression for better performance
)
self._connection_state = ConnectionState.CONNECTED
self._stats['connection_time'] = datetime.now(timezone.utc)
self._reconnect_attempts = 0
# Start background tasks
await self._start_background_tasks()
self.logger.info("Successfully connected to OKX WebSocket")
return True
except (InvalidURI, InvalidHandshake) as e:
self.logger.error(f"Invalid WebSocket configuration: {e}")
self._connection_state = ConnectionState.ERROR
return False
except Exception as e:
attempt_num = attempt + 1
self.logger.error(f"Connection attempt {attempt_num} failed: {e}")
if attempt_num < self.max_reconnect_attempts:
# Exponential backoff with jitter
delay = self.reconnect_delay * (2 ** attempt) + (0.1 * attempt)
self.logger.info(f"Retrying connection in {delay:.1f} seconds...")
await asyncio.sleep(delay)
else:
self.logger.error(f"All {self.max_reconnect_attempts} connection attempts failed")
self._connection_state = ConnectionState.ERROR
return False
return False
async def disconnect(self) -> None:
"""Disconnect from WebSocket."""
if not self._websocket:
return
self.logger.info("Disconnecting from OKX WebSocket")
self._connection_state = ConnectionState.DISCONNECTED
# Cancel background tasks
await self._stop_background_tasks()
# Close WebSocket connection
try:
await self._websocket.close()
except Exception as e:
self.logger.warning(f"Error closing WebSocket: {e}")
self._websocket = None
self._is_authenticated = False
self.logger.info("Disconnected from OKX WebSocket")
async def subscribe(self, subscriptions: List[OKXSubscription]) -> bool:
"""
Subscribe to channels.
Args:
subscriptions: List of subscription configurations
Returns:
True if subscription successful, False otherwise
"""
if not self.is_connected:
self.logger.error("Cannot subscribe: WebSocket not connected")
return False
try:
# Build subscription message
args = [sub.to_dict() for sub in subscriptions]
message = {
"op": "subscribe",
"args": args
}
# Send subscription
await self._send_message(message)
# Store subscriptions
for sub in subscriptions:
key = f"{sub.channel}:{sub.inst_id}"
self._subscriptions[key] = sub
self.logger.info(f"Subscribed to {len(subscriptions)} channels")
return True
except Exception as e:
self.logger.error(f"Failed to subscribe to channels: {e}")
return False
async def unsubscribe(self, subscriptions: List[OKXSubscription]) -> bool:
"""
Unsubscribe from channels.
Args:
subscriptions: List of subscription configurations
Returns:
True if unsubscription successful, False otherwise
"""
if not self.is_connected:
self.logger.error("Cannot unsubscribe: WebSocket not connected")
return False
try:
# Build unsubscription message
args = [sub.to_dict() for sub in subscriptions]
message = {
"op": "unsubscribe",
"args": args
}
# Send unsubscription
await self._send_message(message)
# Remove subscriptions
for sub in subscriptions:
key = f"{sub.channel}:{sub.inst_id}"
self._subscriptions.pop(key, None)
self.logger.info(f"Unsubscribed from {len(subscriptions)} channels")
return True
except Exception as e:
self.logger.error(f"Failed to unsubscribe from channels: {e}")
return False
def add_message_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None:
"""
Add callback function for processing messages.
Args:
callback: Function to call when message received
"""
self._message_callbacks.append(callback)
self.logger.debug(f"Added message callback: {callback.__name__}")
def remove_message_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None:
"""
Remove message callback.
Args:
callback: Function to remove
"""
if callback in self._message_callbacks:
self._message_callbacks.remove(callback)
self.logger.debug(f"Removed message callback: {callback.__name__}")
async def _start_background_tasks(self) -> None:
"""Start background tasks for ping and message handling."""
# Start ping task
self._ping_task = asyncio.create_task(self._ping_loop())
# Start message handler task
self._message_handler_task = asyncio.create_task(self._message_handler())
self.logger.debug("Started background tasks")
async def _stop_background_tasks(self) -> None:
"""Stop background tasks."""
tasks = [self._ping_task, self._message_handler_task]
for task in tasks:
if task and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self._ping_task = None
self._message_handler_task = None
self.logger.debug("Stopped background tasks")
async def _ping_loop(self) -> None:
"""Background task for sending ping messages."""
while self.is_connected:
try:
current_time = time.time()
# Send ping if interval elapsed
if current_time - self._last_ping_time >= self.ping_interval:
await self._send_ping()
self._last_ping_time = current_time
# Check for pong timeout
if (self._last_ping_time > self._last_pong_time and
current_time - self._last_ping_time > self.pong_timeout):
self.logger.warning("Pong timeout - connection may be stale")
# Don't immediately disconnect, let connection error handling deal with it
await asyncio.sleep(1) # Check every second
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Error in ping loop: {e}")
await asyncio.sleep(5)
async def _message_handler(self) -> None:
"""Background task for handling incoming messages."""
while self.is_connected:
try:
if not self._websocket:
break
# Receive message with timeout
try:
message = await asyncio.wait_for(
self._websocket.recv(),
timeout=1.0
)
except asyncio.TimeoutError:
continue # No message received, continue loop
# Process message
await self._process_message(message)
except ConnectionClosed as e:
self.logger.warning(f"WebSocket connection closed: {e}")
self._connection_state = ConnectionState.DISCONNECTED
# Attempt automatic reconnection if enabled
if self._reconnect_attempts < self.max_reconnect_attempts:
self._reconnect_attempts += 1
self.logger.info(f"Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
# Stop current tasks
await self._stop_background_tasks()
# Attempt reconnection
if await self.reconnect():
self.logger.info("Automatic reconnection successful")
continue
else:
self.logger.error("Automatic reconnection failed")
break
else:
self.logger.error("Max reconnection attempts exceeded")
break
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Error in message handler: {e}")
await asyncio.sleep(1)
async def _send_message(self, message: Dict[str, Any]) -> None:
"""
Send message to WebSocket.
Args:
message: Message to send
"""
if not self.is_connected or not self._websocket:
raise OKXConnectionError("WebSocket not connected")
try:
message_str = json.dumps(message)
await self._websocket.send(message_str)
self._stats['messages_sent'] += 1
self.logger.debug(f"Sent message: {message}")
except ConnectionClosed as e:
self.logger.error(f"Connection closed while sending message: {e}")
self._connection_state = ConnectionState.DISCONNECTED
raise OKXConnectionError(f"Connection closed: {e}")
except Exception as e:
self.logger.error(f"Failed to send message: {e}")
raise OKXConnectionError(f"Failed to send message: {e}")
async def _send_ping(self) -> None:
"""Send ping message to OKX."""
if not self.is_connected or not self._websocket:
raise OKXConnectionError("WebSocket not connected")
try:
# OKX expects a simple "ping" string, not JSON
await self._websocket.send("ping")
self._stats['pings_sent'] += 1
self.logger.debug("Sent ping to OKX")
except ConnectionClosed as e:
self.logger.error(f"Connection closed while sending ping: {e}")
self._connection_state = ConnectionState.DISCONNECTED
raise OKXConnectionError(f"Connection closed: {e}")
except Exception as e:
self.logger.error(f"Failed to send ping: {e}")
raise OKXConnectionError(f"Failed to send ping: {e}")
async def _process_message(self, message: str) -> None:
"""
Process incoming message.
Args:
message: Raw message string
"""
try:
# Update statistics first
self._stats['messages_received'] += 1
self._stats['last_message_time'] = datetime.now(timezone.utc)
# Handle simple pong response (OKX sends "pong" as plain string)
if message.strip() == "pong":
self._last_pong_time = time.time()
self._stats['pongs_received'] += 1
self.logger.debug("Received pong from OKX")
return
# Parse JSON message for all other responses
data = json.loads(message)
# Handle special messages
if data.get('event') == 'pong':
self._last_pong_time = time.time()
self._stats['pongs_received'] += 1
self.logger.debug("Received pong from OKX (JSON format)")
return
# Handle subscription confirmations
if data.get('event') == 'subscribe':
self.logger.info(f"Subscription confirmed: {data}")
return
if data.get('event') == 'unsubscribe':
self.logger.info(f"Unsubscription confirmed: {data}")
return
# Handle error messages
if data.get('event') == 'error':
self.logger.error(f"OKX error: {data}")
return
# Process data messages
if 'data' in data and 'arg' in data:
# Notify callbacks
for callback in self._message_callbacks:
try:
callback(data)
except Exception as e:
self.logger.error(f"Error in message callback {callback.__name__}: {e}")
except json.JSONDecodeError as e:
# Check if it's a simple string response we haven't handled
if message.strip() in ["ping", "pong"]:
self.logger.debug(f"Received simple message: {message.strip()}")
if message.strip() == "pong":
self._last_pong_time = time.time()
self._stats['pongs_received'] += 1
else:
self.logger.error(f"Failed to parse JSON message: {e}, message: {message}")
except Exception as e:
self.logger.error(f"Error processing message: {e}")
def get_stats(self) -> Dict[str, Any]:
"""Get connection statistics."""
return {
**self._stats,
'connection_state': self._connection_state.value,
'is_connected': self.is_connected,
'subscriptions_count': len(self._subscriptions),
'reconnect_attempts': self._reconnect_attempts
}
def get_subscriptions(self) -> List[Dict[str, str]]:
"""Get current subscriptions."""
return [sub.to_dict() for sub in self._subscriptions.values()]
async def reconnect(self) -> bool:
"""
Reconnect to WebSocket with retry logic.
Returns:
True if reconnection successful, False otherwise
"""
self.logger.info("Attempting to reconnect to OKX WebSocket")
self._connection_state = ConnectionState.RECONNECTING
self._stats['reconnections'] += 1
# Disconnect first
await self.disconnect()
# Wait a moment before reconnecting
await asyncio.sleep(1)
# Attempt to reconnect
success = await self.connect()
if success:
# Re-subscribe to previous subscriptions
if self._subscriptions:
subscriptions = list(self._subscriptions.values())
self.logger.info(f"Re-subscribing to {len(subscriptions)} channels")
await self.subscribe(subscriptions)
return success
def __repr__(self) -> str:
return f"<OKXWebSocketClient(state={self._connection_state.value}, subscriptions={len(self._subscriptions)})>"

View File

@@ -0,0 +1,27 @@
"""
Exchange registry for supported exchanges.
This module contains the registry of supported exchanges and their capabilities,
separated to avoid circular import issues.
"""
# Exchange registry for factory pattern
EXCHANGE_REGISTRY = {
'okx': {
'collector': 'data.exchanges.okx.collector.OKXCollector',
'websocket': 'data.exchanges.okx.websocket.OKXWebSocketClient',
'name': 'OKX',
'supported_pairs': ['BTC-USDT', 'ETH-USDT', 'SOL-USDT', 'DOGE-USDT', 'TON-USDT'],
'supported_data_types': ['trade', 'orderbook', 'ticker', 'candles']
}
}
def get_supported_exchanges():
"""Get list of supported exchange names."""
return list(EXCHANGE_REGISTRY.keys())
def get_exchange_info(exchange_name: str):
"""Get information about a specific exchange."""
return EXCHANGE_REGISTRY.get(exchange_name.lower())