""" Abstract base class for data collectors. This module provides a common interface for all data collection implementations, ensuring consistency across different exchange connectors and data sources. """ import asyncio from abc import ABC, abstractmethod from datetime import datetime, timezone, timedelta from decimal import Decimal from typing import Dict, List, Optional, Any, Callable, Set from dataclasses import dataclass from enum import Enum from utils.logger import get_logger class DataType(Enum): """Types of data that can be collected.""" TICKER = "ticker" TRADE = "trade" ORDERBOOK = "orderbook" CANDLE = "candle" BALANCE = "balance" class CollectorStatus(Enum): """Status of the data collector.""" STOPPED = "stopped" STARTING = "starting" RUNNING = "running" STOPPING = "stopping" ERROR = "error" RECONNECTING = "reconnecting" UNHEALTHY = "unhealthy" # Added for health monitoring @dataclass class MarketDataPoint: """Standardized market data structure.""" exchange: str symbol: str timestamp: datetime data_type: DataType data: Dict[str, Any] def __post_init__(self): """Validate data after initialization.""" if not self.timestamp.tzinfo: self.timestamp = self.timestamp.replace(tzinfo=timezone.utc) @dataclass class OHLCVData: """OHLCV (Open, High, Low, Close, Volume) data structure.""" symbol: str timeframe: str timestamp: datetime open: Decimal high: Decimal low: Decimal close: Decimal volume: Decimal trades_count: Optional[int] = None def __post_init__(self): """Validate OHLCV data after initialization.""" if not self.timestamp.tzinfo: self.timestamp = self.timestamp.replace(tzinfo=timezone.utc) # Validate price data if not all(isinstance(price, (Decimal, float, int)) for price in [self.open, self.high, self.low, self.close]): raise DataValidationError("All OHLCV prices must be numeric") if not isinstance(self.volume, (Decimal, float, int)): raise DataValidationError("Volume must be numeric") # Convert to Decimal for precision self.open = Decimal(str(self.open)) self.high = Decimal(str(self.high)) self.low = Decimal(str(self.low)) self.close = Decimal(str(self.close)) self.volume = Decimal(str(self.volume)) # Validate price relationships if not (self.low <= self.open <= self.high and self.low <= self.close <= self.high): raise DataValidationError(f"Invalid OHLCV data: prices don't match expected relationships for {self.symbol}") class DataCollectorError(Exception): """Base exception for data collector errors.""" pass class DataValidationError(DataCollectorError): """Exception raised when data validation fails.""" pass class ConnectionError(DataCollectorError): """Exception raised when connection to data source fails.""" pass class BaseDataCollector(ABC): """ Abstract base class for all data collectors. This class defines the interface that all data collection implementations must follow, providing consistency across different exchanges and data sources. """ def __init__(self, exchange_name: str, symbols: List[str], data_types: Optional[List[DataType]] = None, component_name: Optional[str] = None, auto_restart: bool = True, health_check_interval: float = 30.0): """ Initialize the base data collector. Args: exchange_name: Name of the exchange (e.g., 'okx', 'binance') symbols: List of trading symbols to collect data for data_types: Types of data to collect (default: [DataType.CANDLE]) component_name: Name for logging (default: based on exchange_name) auto_restart: Enable automatic restart on failures (default: True) health_check_interval: Seconds between health checks (default: 30.0) """ self.exchange_name = exchange_name.lower() self.symbols = set(symbols) self.data_types = data_types or [DataType.CANDLE] self.auto_restart = auto_restart self.health_check_interval = health_check_interval # Initialize logger component = component_name or f"{self.exchange_name}_collector" self.logger = get_logger(component, verbose=True) # Collector state self.status = CollectorStatus.STOPPED self._running = False self._should_be_running = False # Track desired state self._tasks: Set[asyncio.Task] = set() # Data callbacks self._data_callbacks: Dict[DataType, List[Callable]] = { data_type: [] for data_type in DataType } # Connection management self._connection = None self._reconnect_attempts = 0 self._max_reconnect_attempts = 5 self._reconnect_delay = 5.0 # seconds # Health monitoring self._last_heartbeat = datetime.now(timezone.utc) self._last_data_received = None self._health_check_task = None self._max_silence_duration = timedelta(minutes=5) # Max time without data before unhealthy # Statistics self._stats = { 'messages_received': 0, 'messages_processed': 0, 'errors': 0, 'restarts': 0, 'last_message_time': None, 'connection_uptime': None, 'last_error': None, 'last_restart_time': None } self.logger.info(f"Initialized {self.exchange_name} data collector for symbols: {', '.join(symbols)}") @abstractmethod async def connect(self) -> bool: """ Establish connection to the data source. Returns: True if connection successful, False otherwise """ pass @abstractmethod async def disconnect(self) -> None: """Disconnect from the data source.""" pass @abstractmethod async def subscribe_to_data(self, symbols: List[str], data_types: List[DataType]) -> bool: """ Subscribe to data streams for specified symbols and data types. Args: symbols: Trading symbols to subscribe to data_types: Types of data to subscribe to Returns: True if subscription successful, False otherwise """ pass @abstractmethod async def unsubscribe_from_data(self, symbols: List[str], data_types: List[DataType]) -> bool: """ Unsubscribe from data streams. Args: symbols: Trading symbols to unsubscribe from data_types: Types of data to unsubscribe from Returns: True if unsubscription successful, False otherwise """ pass @abstractmethod async def _process_message(self, message: Any) -> Optional[MarketDataPoint]: """ Process incoming message from the data source. Args: message: Raw message from the data source Returns: Processed MarketDataPoint or None if message should be ignored """ pass async def start(self) -> bool: """ Start the data collector. Returns: True if started successfully, False otherwise """ if self.status in [CollectorStatus.RUNNING, CollectorStatus.STARTING]: self.logger.warning("Data collector is already running or starting") return True self.logger.info(f"Starting {self.exchange_name} data collector") self.status = CollectorStatus.STARTING self._should_be_running = True try: # Connect to data source if not await self.connect(): self.status = CollectorStatus.ERROR self.logger.error("Failed to connect to data source") return False # Subscribe to data streams if not await self.subscribe_to_data(list(self.symbols), self.data_types): self.status = CollectorStatus.ERROR self.logger.error("Failed to subscribe to data streams") await self.disconnect() return False # Start message processing self._running = True self.status = CollectorStatus.RUNNING self._stats['connection_uptime'] = datetime.now(timezone.utc) self._last_heartbeat = datetime.now(timezone.utc) # Create background task for message processing message_task = asyncio.create_task(self._message_loop()) self._tasks.add(message_task) message_task.add_done_callback(self._tasks.discard) # Start health monitoring if self.auto_restart: health_task = asyncio.create_task(self._health_monitor()) self._tasks.add(health_task) health_task.add_done_callback(self._tasks.discard) self.logger.info(f"{self.exchange_name} data collector started successfully") return True except Exception as e: self.status = CollectorStatus.ERROR self._stats['last_error'] = str(e) self.logger.error(f"Failed to start data collector: {e}") await self.disconnect() return False async def stop(self, force: bool = False) -> None: """ Stop the data collector. Args: force: If True, don't restart automatically even if auto_restart is enabled """ if self.status == CollectorStatus.STOPPED: self.logger.warning("Data collector is already stopped") return self.logger.info(f"Stopping {self.exchange_name} data collector") self.status = CollectorStatus.STOPPING self._running = False if force: self._should_be_running = False try: # Cancel all tasks for task in list(self._tasks): task.cancel() # Wait for tasks to complete if self._tasks: await asyncio.gather(*self._tasks, return_exceptions=True) # Unsubscribe and disconnect await self.unsubscribe_from_data(list(self.symbols), self.data_types) await self.disconnect() self.status = CollectorStatus.STOPPED self.logger.info(f"{self.exchange_name} data collector stopped") except Exception as e: self.status = CollectorStatus.ERROR self._stats['last_error'] = str(e) self.logger.error(f"Error stopping data collector: {e}") async def restart(self) -> bool: """ Restart the data collector. Returns: True if restart successful, False otherwise """ self.logger.info(f"Restarting {self.exchange_name} data collector") self._stats['restarts'] += 1 self._stats['last_restart_time'] = datetime.now(timezone.utc) # Stop without disabling auto-restart await self.stop(force=False) # Wait a bit before restart await asyncio.sleep(2.0) # Reset reconnection attempts self._reconnect_attempts = 0 # Start again return await self.start() async def _message_loop(self) -> None: """Main message processing loop.""" self.logger.debug("Starting message processing loop") while self._running: try: # This should be implemented by subclasses to handle their specific message loop await self._handle_messages() # Update heartbeat self._last_heartbeat = datetime.now(timezone.utc) except asyncio.CancelledError: self.logger.debug("Message loop cancelled") break except Exception as e: self._stats['errors'] += 1 self._stats['last_error'] = str(e) self.logger.error(f"Error in message loop: {e}") # Attempt reconnection if connection lost if not await self._handle_connection_error(): break await asyncio.sleep(1) # Brief pause before retrying async def _health_monitor(self) -> None: """Monitor collector health and restart if needed.""" self.logger.debug("Starting health monitor") while self._running and self.auto_restart: try: await asyncio.sleep(self.health_check_interval) # Check if we should be running but aren't if self._should_be_running and not self._running: self.logger.warning("Collector should be running but isn't - restarting") await self.restart() continue # Check heartbeat freshness time_since_heartbeat = datetime.now(timezone.utc) - self._last_heartbeat if time_since_heartbeat > timedelta(seconds=self.health_check_interval * 2): self.logger.warning(f"No heartbeat for {time_since_heartbeat.total_seconds():.1f}s - restarting") self.status = CollectorStatus.UNHEALTHY await self.restart() continue # Check data freshness (if we've received data before) if self._last_data_received: time_since_data = datetime.now(timezone.utc) - self._last_data_received if time_since_data > self._max_silence_duration: self.logger.warning(f"No data received for {time_since_data.total_seconds():.1f}s - restarting") self.status = CollectorStatus.UNHEALTHY await self.restart() continue # Check if status indicates failure if self.status in [CollectorStatus.ERROR, CollectorStatus.UNHEALTHY]: self.logger.warning(f"Collector in {self.status.value} status - restarting") await self.restart() continue except asyncio.CancelledError: self.logger.debug("Health monitor cancelled") break except Exception as e: self.logger.error(f"Error in health monitor: {e}") await asyncio.sleep(self.health_check_interval) @abstractmethod async def _handle_messages(self) -> None: """ Handle incoming messages from the data source. This method should be implemented by subclasses to handle their specific message format. """ pass async def _handle_connection_error(self) -> bool: """ Handle connection errors and attempt reconnection. Returns: True if reconnection successful, False if max attempts exceeded """ if self._reconnect_attempts >= self._max_reconnect_attempts: self.logger.error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded") self.status = CollectorStatus.ERROR return False self._reconnect_attempts += 1 self.status = CollectorStatus.RECONNECTING self.logger.warning(f"Connection lost. Attempting reconnection {self._reconnect_attempts}/{self._max_reconnect_attempts}") await asyncio.sleep(self._reconnect_delay) try: if await self.connect(): if await self.subscribe_to_data(list(self.symbols), self.data_types): self.status = CollectorStatus.RUNNING self._reconnect_attempts = 0 self._stats['connection_uptime'] = datetime.now(timezone.utc) self.logger.info("Reconnection successful") return True return False except Exception as e: self._stats['last_error'] = str(e) self.logger.error(f"Reconnection attempt failed: {e}") return False def add_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None: """ Add a callback function to be called when data of specified type is received. Args: data_type: Type of data to register callback for callback: Function to call with MarketDataPoint data """ self._data_callbacks[data_type].append(callback) self.logger.debug(f"Added callback for {data_type.value} data") def remove_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None: """ Remove a data callback. Args: data_type: Type of data to remove callback for callback: Callback function to remove """ if callback in self._data_callbacks[data_type]: self._data_callbacks[data_type].remove(callback) self.logger.debug(f"Removed callback for {data_type.value} data") async def _notify_callbacks(self, data_point: MarketDataPoint) -> None: """ Notify all registered callbacks for the data type. Args: data_point: Market data to send to callbacks """ # Update data received timestamp self._last_data_received = datetime.now(timezone.utc) self._stats['last_message_time'] = self._last_data_received callbacks = self._data_callbacks.get(data_point.data_type, []) for callback in callbacks: try: if asyncio.iscoroutinefunction(callback): await callback(data_point) else: callback(data_point) except Exception as e: self.logger.error(f"Error in data callback: {e}") def get_status(self) -> Dict[str, Any]: """ Get current collector status and statistics. Returns: Dictionary containing status information """ uptime_seconds = None if self._stats['connection_uptime']: uptime_seconds = (datetime.now(timezone.utc) - self._stats['connection_uptime']).total_seconds() time_since_heartbeat = None if self._last_heartbeat: time_since_heartbeat = (datetime.now(timezone.utc) - self._last_heartbeat).total_seconds() time_since_data = None if self._last_data_received: time_since_data = (datetime.now(timezone.utc) - self._last_data_received).total_seconds() return { 'exchange': self.exchange_name, 'status': self.status.value, 'should_be_running': self._should_be_running, 'symbols': list(self.symbols), 'data_types': [dt.value for dt in self.data_types], 'auto_restart': self.auto_restart, 'health': { 'time_since_heartbeat': time_since_heartbeat, 'time_since_data': time_since_data, 'max_silence_duration': self._max_silence_duration.total_seconds() }, 'statistics': { **self._stats, 'uptime_seconds': uptime_seconds, 'reconnect_attempts': self._reconnect_attempts } } def get_health_status(self) -> Dict[str, Any]: """ Get detailed health status for monitoring. Returns: Dictionary containing health information """ now = datetime.now(timezone.utc) is_healthy = True health_issues = [] # Check if should be running but isn't if self._should_be_running and not self._running: is_healthy = False health_issues.append("Should be running but is stopped") # Check heartbeat if self._last_heartbeat: time_since_heartbeat = now - self._last_heartbeat if time_since_heartbeat > timedelta(seconds=self.health_check_interval * 2): is_healthy = False health_issues.append(f"No heartbeat for {time_since_heartbeat.total_seconds():.1f}s") # Check data freshness if self._last_data_received: time_since_data = now - self._last_data_received if time_since_data > self._max_silence_duration: is_healthy = False health_issues.append(f"No data for {time_since_data.total_seconds():.1f}s") # Check status if self.status in [CollectorStatus.ERROR, CollectorStatus.UNHEALTHY]: is_healthy = False health_issues.append(f"Status: {self.status.value}") return { 'is_healthy': is_healthy, 'issues': health_issues, 'status': self.status.value, 'last_heartbeat': self._last_heartbeat.isoformat() if self._last_heartbeat else None, 'last_data_received': self._last_data_received.isoformat() if self._last_data_received else None, 'should_be_running': self._should_be_running, 'is_running': self._running } def add_symbol(self, symbol: str) -> None: """ Add a new symbol to collect data for. Args: symbol: Trading symbol to add """ if symbol not in self.symbols: self.symbols.add(symbol) self.logger.info(f"Added symbol: {symbol}") def remove_symbol(self, symbol: str) -> None: """ Remove a symbol from data collection. Args: symbol: Trading symbol to remove """ if symbol in self.symbols: self.symbols.remove(symbol) self.logger.info(f"Removed symbol: {symbol}") def validate_ohlcv_data(self, data: Dict[str, Any], symbol: str, timeframe: str) -> OHLCVData: """ Validate and convert raw OHLCV data to standardized format. Args: data: Raw OHLCV data dictionary symbol: Trading symbol timeframe: Timeframe (e.g., '1m', '5m', '1h') Returns: Validated OHLCVData object Raises: DataValidationError: If data validation fails """ required_fields = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] # Check required fields for field in required_fields: if field not in data: raise DataValidationError(f"Missing required field: {field}") try: # Parse timestamp timestamp = data['timestamp'] if isinstance(timestamp, (int, float)): # Assume Unix timestamp in milliseconds timestamp = datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc) elif isinstance(timestamp, str): timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) elif not isinstance(timestamp, datetime): raise DataValidationError(f"Invalid timestamp format: {type(timestamp)}") return OHLCVData( symbol=symbol, timeframe=timeframe, timestamp=timestamp, open=Decimal(str(data['open'])), high=Decimal(str(data['high'])), low=Decimal(str(data['low'])), close=Decimal(str(data['close'])), volume=Decimal(str(data['volume'])), trades_count=data.get('trades_count') ) except (ValueError, TypeError, KeyError) as e: raise DataValidationError(f"Invalid OHLCV data for {symbol}: {e}") def __repr__(self) -> str: """String representation of the collector.""" return f"<{self.__class__.__name__}({self.exchange_name}, {len(self.symbols)} symbols, {self.status.value})>"