485 lines
18 KiB
Python
Raw Normal View History

"""
OKX Data Collector implementation.
This module provides the main OKX data collector class that extends BaseDataCollector,
handling real-time market data collection for a single trading pair with robust
error handling, health monitoring, and database integration.
"""
import asyncio
from datetime import datetime, timezone
from decimal import Decimal
from typing import Dict, List, Optional, Any, Set
from dataclasses import dataclass
from ...base_collector import (
BaseDataCollector, DataType, CollectorStatus, MarketDataPoint,
OHLCVData, DataValidationError, ConnectionError
)
from .websocket import (
OKXWebSocketClient, OKXSubscription, OKXChannelType,
ConnectionState, OKXWebSocketError
)
from database.connection import get_db_manager, get_raw_data_manager
from database.models import MarketData, RawTrade
from utils.logger import get_logger
@dataclass
class OKXMarketData:
"""OKX-specific market data structure."""
symbol: str
timestamp: datetime
data_type: str
channel: str
raw_data: Dict[str, Any]
class OKXCollector(BaseDataCollector):
"""
OKX data collector for real-time market data.
This collector handles a single trading pair and collects real-time data
including trades, orderbook, and ticker information from OKX exchange.
"""
def __init__(self,
symbol: str,
data_types: Optional[List[DataType]] = None,
component_name: Optional[str] = None,
auto_restart: bool = True,
health_check_interval: float = 30.0,
store_raw_data: bool = True):
"""
Initialize OKX collector for a single trading pair.
Args:
symbol: Trading symbol (e.g., 'BTC-USDT')
data_types: Types of data to collect (default: [DataType.TRADE, DataType.ORDERBOOK])
component_name: Name for logging (default: f'okx_collector_{symbol}')
auto_restart: Enable automatic restart on failures
health_check_interval: Seconds between health checks
store_raw_data: Whether to store raw data for debugging
"""
# Default data types if not specified
if data_types is None:
data_types = [DataType.TRADE, DataType.ORDERBOOK]
# Component name for logging
if component_name is None:
component_name = f"okx_collector_{symbol.replace('-', '_').lower()}"
# Initialize base collector
super().__init__(
exchange_name="okx",
symbols=[symbol],
data_types=data_types,
component_name=component_name,
auto_restart=auto_restart,
health_check_interval=health_check_interval
)
# OKX-specific settings
self.symbol = symbol
self.store_raw_data = store_raw_data
# WebSocket client
self._ws_client: Optional[OKXWebSocketClient] = None
# Database managers
self._db_manager = None
self._raw_data_manager = None
# Data processing
self._message_buffer: List[Dict[str, Any]] = []
self._last_trade_id: Optional[str] = None
self._last_orderbook_ts: Optional[int] = None
# OKX channel mapping
self._channel_mapping = {
DataType.TRADE: OKXChannelType.TRADES.value,
DataType.ORDERBOOK: OKXChannelType.BOOKS5.value,
DataType.TICKER: OKXChannelType.TICKERS.value
}
self.logger.info(f"Initialized OKX collector for {symbol} with data types: {[dt.value for dt in data_types]}")
async def connect(self) -> bool:
"""
Establish connection to OKX WebSocket API.
Returns:
True if connection successful, False otherwise
"""
try:
self.logger.info(f"Connecting OKX collector for {self.symbol}")
# Initialize database managers
self._db_manager = get_db_manager()
if self.store_raw_data:
self._raw_data_manager = get_raw_data_manager()
# Create WebSocket client
ws_component_name = f"okx_ws_{self.symbol.replace('-', '_').lower()}"
self._ws_client = OKXWebSocketClient(
component_name=ws_component_name,
ping_interval=25.0,
pong_timeout=10.0,
max_reconnect_attempts=5,
reconnect_delay=5.0
)
# Add message callback
self._ws_client.add_message_callback(self._on_message)
# Connect to WebSocket
if not await self._ws_client.connect(use_public=True):
self.logger.error("Failed to connect to OKX WebSocket")
return False
self.logger.info(f"Successfully connected OKX collector for {self.symbol}")
return True
except Exception as e:
self.logger.error(f"Error connecting OKX collector for {self.symbol}: {e}")
return False
async def disconnect(self) -> None:
"""Disconnect from OKX WebSocket API."""
try:
self.logger.info(f"Disconnecting OKX collector for {self.symbol}")
if self._ws_client:
await self._ws_client.disconnect()
self._ws_client = None
self.logger.info(f"Disconnected OKX collector for {self.symbol}")
except Exception as e:
self.logger.error(f"Error disconnecting OKX collector for {self.symbol}: {e}")
async def subscribe_to_data(self, symbols: List[str], data_types: List[DataType]) -> bool:
"""
Subscribe to data streams for specified symbols and data types.
Args:
symbols: Trading symbols to subscribe to (should contain self.symbol)
data_types: Types of data to subscribe to
Returns:
True if subscription successful, False otherwise
"""
if not self._ws_client or not self._ws_client.is_connected:
self.logger.error("WebSocket client not connected")
return False
# Validate symbol
if self.symbol not in symbols:
self.logger.warning(f"Symbol {self.symbol} not in subscription list: {symbols}")
return False
try:
# Build subscriptions
subscriptions = []
for data_type in data_types:
if data_type in self._channel_mapping:
channel = self._channel_mapping[data_type]
subscription = OKXSubscription(
channel=channel,
inst_id=self.symbol,
enabled=True
)
subscriptions.append(subscription)
self.logger.debug(f"Added subscription: {channel} for {self.symbol}")
else:
self.logger.warning(f"Unsupported data type: {data_type}")
if not subscriptions:
self.logger.warning("No valid subscriptions to create")
return False
# Subscribe to channels
success = await self._ws_client.subscribe(subscriptions)
if success:
self.logger.info(f"Successfully subscribed to {len(subscriptions)} channels for {self.symbol}")
else:
self.logger.error(f"Failed to subscribe to channels for {self.symbol}")
return success
except Exception as e:
self.logger.error(f"Error subscribing to data for {self.symbol}: {e}")
return False
async def unsubscribe_from_data(self, symbols: List[str], data_types: List[DataType]) -> bool:
"""
Unsubscribe from data streams for specified symbols and data types.
Args:
symbols: Trading symbols to unsubscribe from
data_types: Types of data to unsubscribe from
Returns:
True if unsubscription successful, False otherwise
"""
if not self._ws_client or not self._ws_client.is_connected:
self.logger.warning("WebSocket client not connected for unsubscription")
return True # Consider it successful if already disconnected
try:
# Build unsubscriptions
subscriptions = []
for data_type in data_types:
if data_type in self._channel_mapping:
channel = self._channel_mapping[data_type]
subscription = OKXSubscription(
channel=channel,
inst_id=self.symbol,
enabled=False
)
subscriptions.append(subscription)
if not subscriptions:
return True
# Unsubscribe from channels
success = await self._ws_client.unsubscribe(subscriptions)
if success:
self.logger.info(f"Successfully unsubscribed from {len(subscriptions)} channels for {self.symbol}")
else:
self.logger.warning(f"Failed to unsubscribe from channels for {self.symbol}")
return success
except Exception as e:
self.logger.error(f"Error unsubscribing from data for {self.symbol}: {e}")
return False
async def _process_message(self, message: Any) -> Optional[MarketDataPoint]:
"""
Process incoming message from OKX WebSocket.
Args:
message: Raw message from WebSocket
Returns:
Processed MarketDataPoint or None if processing failed
"""
try:
if not isinstance(message, dict):
self.logger.warning(f"Unexpected message type: {type(message)}")
return None
# Extract channel and data
arg = message.get('arg', {})
channel = arg.get('channel')
inst_id = arg.get('instId')
data_list = message.get('data', [])
# Validate message structure
if not channel or not inst_id or not data_list:
self.logger.debug(f"Incomplete message structure: {message}")
return None
# Check if this message is for our symbol
if inst_id != self.symbol:
self.logger.debug(f"Message for different symbol: {inst_id} (expected: {self.symbol})")
return None
# Process each data item
market_data_points = []
for data_item in data_list:
data_point = await self._process_data_item(channel, data_item)
if data_point:
market_data_points.append(data_point)
# Store raw data if enabled
if self.store_raw_data and self._raw_data_manager:
await self._store_raw_data(channel, message)
# Return the first processed data point (for the base class interface)
return market_data_points[0] if market_data_points else None
except Exception as e:
self.logger.error(f"Error processing message for {self.symbol}: {e}")
return None
async def _handle_messages(self) -> None:
"""
Handle incoming messages from WebSocket.
This is called by the base class message loop.
"""
# The actual message handling is done through the WebSocket client callback
# This method satisfies the abstract method requirement
if self._ws_client and self._ws_client.is_connected:
# Just sleep briefly to yield control
await asyncio.sleep(0.1)
else:
# If not connected, sleep longer to avoid busy loop
await asyncio.sleep(1.0)
async def _process_data_item(self, channel: str, data_item: Dict[str, Any]) -> Optional[MarketDataPoint]:
"""
Process individual data item from OKX message.
Args:
channel: OKX channel name
data_item: Individual data item
Returns:
Processed MarketDataPoint or None
"""
try:
# Determine data type from channel
data_type = None
for dt, ch in self._channel_mapping.items():
if ch == channel:
data_type = dt
break
if not data_type:
self.logger.warning(f"Unknown channel: {channel}")
return None
# Extract timestamp
timestamp_ms = data_item.get('ts')
if timestamp_ms:
timestamp = datetime.fromtimestamp(int(timestamp_ms) / 1000, tz=timezone.utc)
else:
timestamp = datetime.now(timezone.utc)
# Create MarketDataPoint
market_data_point = MarketDataPoint(
exchange="okx",
symbol=self.symbol,
timestamp=timestamp,
data_type=data_type,
data=data_item
)
# Store processed data to database
await self._store_processed_data(market_data_point)
# Update statistics
self._stats['messages_processed'] += 1
self._stats['last_message_time'] = timestamp
return market_data_point
except Exception as e:
self.logger.error(f"Error processing data item for {self.symbol}: {e}")
self._stats['errors'] += 1
return None
async def _store_processed_data(self, data_point: MarketDataPoint) -> None:
"""
Store processed data to MarketData table.
Args:
data_point: Processed market data point
"""
try:
# For now, we'll focus on trade data storage
# Orderbook and ticker storage can be added later
if data_point.data_type == DataType.TRADE:
await self._store_trade_data(data_point)
except Exception as e:
self.logger.error(f"Error storing processed data for {self.symbol}: {e}")
async def _store_trade_data(self, data_point: MarketDataPoint) -> None:
"""
Store trade data to database.
Args:
data_point: Trade data point
"""
try:
if not self._db_manager:
return
trade_data = data_point.data
# Extract trade information
trade_id = trade_data.get('tradeId')
price = Decimal(str(trade_data.get('px', '0')))
size = Decimal(str(trade_data.get('sz', '0')))
side = trade_data.get('side', 'unknown')
# Skip duplicate trades
if trade_id == self._last_trade_id:
return
self._last_trade_id = trade_id
# For now, we'll log the trade data
# Actual database storage will be implemented in the next phase
self.logger.debug(f"Trade: {self.symbol} - {side} {size} @ {price} (ID: {trade_id})")
except Exception as e:
self.logger.error(f"Error storing trade data for {self.symbol}: {e}")
async def _store_raw_data(self, channel: str, raw_message: Dict[str, Any]) -> None:
"""
Store raw data for debugging and compliance.
Args:
channel: OKX channel name
raw_message: Complete raw message
"""
try:
if not self._raw_data_manager:
return
# Store raw data using the raw data manager
self._raw_data_manager.store_raw_data(
exchange="okx",
symbol=self.symbol,
data_type=channel,
raw_data=raw_message,
timestamp=datetime.now(timezone.utc)
)
except Exception as e:
self.logger.error(f"Error storing raw data for {self.symbol}: {e}")
def _on_message(self, message: Dict[str, Any]) -> None:
"""
Callback function for WebSocket messages.
Args:
message: Message received from WebSocket
"""
try:
# Add message to buffer for processing
self._message_buffer.append(message)
# Process message asynchronously
asyncio.create_task(self._process_message(message))
except Exception as e:
self.logger.error(f"Error in message callback for {self.symbol}: {e}")
def get_status(self) -> Dict[str, Any]:
"""Get collector status including WebSocket client status."""
base_status = super().get_status()
# Add OKX-specific status
okx_status = {
'symbol': self.symbol,
'websocket_connected': self._ws_client.is_connected if self._ws_client else False,
'websocket_state': self._ws_client.connection_state.value if self._ws_client else 'disconnected',
'last_trade_id': self._last_trade_id,
'message_buffer_size': len(self._message_buffer),
'store_raw_data': self.store_raw_data
}
# Add WebSocket stats if available
if self._ws_client:
okx_status['websocket_stats'] = self._ws_client.get_stats()
return {**base_status, **okx_status}
def __repr__(self) -> str:
return f"<OKXCollector(symbol={self.symbol}, status={self.status.value}, data_types={[dt.value for dt in self.data_types]})>"