TCPDashboard/data/base_collector.py
Vasily.onl 60434afd5d Refactor BaseDataCollector to utilize CollectorStateAndTelemetry for improved state management
- Introduced a new `CollectorStateAndTelemetry` class to encapsulate the status, health checks, and statistics of the data collector, promoting modularity and separation of concerns.
- Updated `BaseDataCollector` to replace direct status management with calls to the new telemetry class, enhancing maintainability and readability.
- Refactored logging methods to utilize the telemetry class, ensuring consistent logging practices.
- Modified the `OKXCollector` to integrate with the new telemetry system for improved status reporting and error handling.
- Added comprehensive tests for the `CollectorStateAndTelemetry` class to ensure functionality and reliability.

These changes streamline the data collector's architecture, aligning with project standards for maintainability and performance.
2025-06-09 17:27:29 +08:00

631 lines
24 KiB
Python

"""
Abstract base class for data collectors.
This module provides a common interface for all data collection implementations,
ensuring consistency across different exchange connectors and data sources.
"""
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime, timezone, timedelta
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable, Set
from dataclasses import dataclass
from enum import Enum
from utils.logger import get_logger
from .collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
class DataType(Enum):
"""Types of data that can be collected."""
TICKER = "ticker"
TRADE = "trade"
ORDERBOOK = "orderbook"
CANDLE = "candle"
BALANCE = "balance"
@dataclass
class MarketDataPoint:
"""Standardized market data structure."""
exchange: str
symbol: str
timestamp: datetime
data_type: DataType
data: Dict[str, Any]
def __post_init__(self):
"""Validate data after initialization."""
if not self.timestamp.tzinfo:
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
@dataclass
class OHLCVData:
"""OHLCV (Open, High, Low, Close, Volume) data structure."""
symbol: str
timeframe: str
timestamp: datetime
open: Decimal
high: Decimal
low: Decimal
close: Decimal
volume: Decimal
trades_count: Optional[int] = None
def __post_init__(self):
"""Validate OHLCV data after initialization."""
if not self.timestamp.tzinfo:
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
# Validate price data
if not all(isinstance(price, (Decimal, float, int)) for price in [self.open, self.high, self.low, self.close]):
raise DataValidationError("All OHLCV prices must be numeric")
if not isinstance(self.volume, (Decimal, float, int)):
raise DataValidationError("Volume must be numeric")
# Convert to Decimal for precision
self.open = Decimal(str(self.open))
self.high = Decimal(str(self.high))
self.low = Decimal(str(self.low))
self.close = Decimal(str(self.close))
self.volume = Decimal(str(self.volume))
# Validate price relationships
if not (self.low <= self.open <= self.high and self.low <= self.close <= self.high):
raise DataValidationError(f"Invalid OHLCV data: prices don't match expected relationships for {self.symbol}")
class DataCollectorError(Exception):
"""Base exception for data collector errors."""
pass
class DataValidationError(DataCollectorError):
"""Exception raised when data validation fails."""
pass
class ConnectionError(DataCollectorError):
"""Exception raised when connection to data source fails."""
pass
class BaseDataCollector(ABC):
"""
Abstract base class for all data collectors.
This class defines the interface that all data collection implementations
must follow, providing consistency across different exchanges and data sources.
"""
def __init__(self,
exchange_name: str,
symbols: List[str],
data_types: Optional[List[DataType]] = None,
timeframes: Optional[List[str]] = None,
component_name: Optional[str] = None,
auto_restart: bool = True,
health_check_interval: float = 30.0,
logger = None,
log_errors_only: bool = False):
"""
Initialize the base data collector.
Args:
exchange_name: Name of the exchange (e.g., 'okx', 'binance')
symbols: List of trading symbols to collect data for
data_types: Types of data to collect (default: [DataType.CANDLE])
timeframes: List of timeframes to collect (e.g., ['1s', '1m', '5m'])
component_name: Name for logging (default: based on exchange_name)
auto_restart: Enable automatic restart on failures (default: True)
health_check_interval: Seconds between health checks (default: 30.0)
logger: Logger instance. If None, no logging will be performed.
log_errors_only: If True and logger is provided, only log error-level messages
"""
self.exchange_name = exchange_name.lower()
self.symbols = set(symbols)
self.data_types = data_types or [DataType.CANDLE]
self.timeframes = timeframes or ['1m', '5m'] # Default timeframes if none provided
self.auto_restart = auto_restart
# Initialize logger based on parameters
if logger is not None:
self.logger = logger
else:
self.logger = get_logger(self.exchange_name) # Ensure a logger is always available
# Initialize state and telemetry manager
component = component_name or f"{self.exchange_name}_collector"
self._state_telemetry = CollectorStateAndTelemetry(
exchange_name=self.exchange_name,
component_name=component,
health_check_interval=health_check_interval,
logger=self.logger, # Pass the actual logger instance
log_errors_only=log_errors_only
)
self.component_name = component # Keep for external access
# Collector state (now managed by _state_telemetry)
self._tasks: Set[asyncio.Task] = set()
# Data callbacks
self._data_callbacks: Dict[DataType, List[Callable]] = {
data_type: [] for data_type in DataType
}
# Connection management
self._connection = None
self._reconnect_attempts = 0
self._max_reconnect_attempts = 5
self._reconnect_delay = 5.0 # seconds
# Log initialization if logger is available
if self._state_telemetry.logger:
if not self._state_telemetry.log_errors_only:
self._state_telemetry._log_info(f"{self.component_name}: Initialized {self.exchange_name} data collector for symbols: {', '.join(symbols)}")
self._state_telemetry._log_info(f"{self.component_name}: Using timeframes: {', '.join(self.timeframes)}")
@property
def status(self) -> CollectorStatus:
return self._state_telemetry.status
def _log_debug(self, message: str) -> None:
self._state_telemetry._log_debug(message)
def _log_info(self, message: str) -> None:
self._state_telemetry._log_info(message)
def _log_warning(self, message: str) -> None:
self._state_telemetry._log_warning(message)
def _log_error(self, message: str, exc_info: bool = False) -> None:
self._state_telemetry._log_error(message, exc_info=exc_info)
def _log_critical(self, message: str, exc_info: bool = False) -> None:
self._state_telemetry._log_critical(message, exc_info=exc_info)
@abstractmethod
async def connect(self) -> bool:
"""
Establish connection to the data source.
Returns:
True if connection successful, False otherwise
"""
pass
@abstractmethod
async def disconnect(self) -> None:
"""Disconnect from the data source."""
pass
@abstractmethod
async def subscribe_to_data(self, symbols: List[str], data_types: List[DataType]) -> bool:
"""
Subscribe to data streams for specified symbols and data types.
Args:
symbols: Trading symbols to subscribe to
data_types: Types of data to subscribe to
Returns:
True if subscription successful, False otherwise
"""
pass
@abstractmethod
async def unsubscribe_from_data(self, symbols: List[str], data_types: List[DataType]) -> bool:
"""
Unsubscribe from data streams.
Args:
symbols: Trading symbols to unsubscribe from
data_types: Types of data to unsubscribe from
Returns:
True if unsubscription successful, False otherwise
"""
pass
@abstractmethod
async def _process_message(self, message: Any) -> Optional[MarketDataPoint]:
"""
Process incoming message from the data source.
Args:
message: Raw message from the data source
Returns:
Processed MarketDataPoint or None if message should be ignored
"""
pass
async def start(self) -> bool:
"""
Start the data collector.
Returns:
True if started successfully, False otherwise
"""
# Check if already running or starting
if self._state_telemetry.status in [CollectorStatus.RUNNING, CollectorStatus.STARTING]:
self._log_warning("Data collector is already running or starting")
return True
self._log_info(f"Starting {self.exchange_name} data collector")
self._state_telemetry.update_status(CollectorStatus.STARTING)
self._state_telemetry.set_should_be_running(True)
try:
# Connect to data source
if not await self.connect():
self._log_error("Failed to connect to data source")
self._state_telemetry.update_status(CollectorStatus.ERROR)
return False
# Subscribe to data streams
if not await self.subscribe_to_data(list(self.symbols), self.data_types):
self._log_error("Failed to subscribe to data streams")
self._state_telemetry.update_status(CollectorStatus.ERROR)
await self.disconnect()
return False
# Start background tasks
self._state_telemetry.set_running_state(True)
self._state_telemetry.update_status(CollectorStatus.RUNNING)
self._state_telemetry.set_connection_uptime_start() # Record connection uptime start
# Start message processing task
message_task = asyncio.create_task(self._message_loop())
self._tasks.add(message_task)
message_task.add_done_callback(self._tasks.discard)
# Start health monitoring task
if self._state_telemetry.health_check_interval > 0:
health_task = asyncio.create_task(self._health_monitor())
self._tasks.add(health_task)
health_task.add_done_callback(self._tasks.discard)
self._log_info(f"{self.exchange_name} data collector started successfully")
return True
except Exception as e:
self._log_error(f"Failed to start data collector: {e}")
self._state_telemetry.update_status(CollectorStatus.ERROR)
self._state_telemetry.set_should_be_running(False)
return False
async def stop(self, force: bool = False) -> None:
"""
Stop the data collector and cleanup resources.
Args:
force: Force stop even if not graceful
"""
if self._state_telemetry.status == CollectorStatus.STOPPED:
self._log_warning("Data collector is already stopped")
return
self._log_info(f"Stopping {self.exchange_name} data collector")
self._state_telemetry.update_status(CollectorStatus.STOPPING)
self._state_telemetry.set_should_be_running(False)
try:
# Stop background tasks
self._state_telemetry.set_running_state(False)
# Cancel all tasks
for task in list(self._tasks):
if not task.done():
task.cancel()
if not force:
try:
await task
except asyncio.CancelledError:
pass
self._tasks.clear()
# Unsubscribe and disconnect
await self.unsubscribe_from_data(list(self.symbols), self.data_types)
await self.disconnect()
self._state_telemetry.update_status(CollectorStatus.STOPPED)
self._log_info(f"{self.exchange_name} data collector stopped")
except Exception as e:
self._log_error(f"Error stopping data collector: {e}")
self._state_telemetry.update_status(CollectorStatus.ERROR)
async def restart(self) -> bool:
"""
Restart the data collector.
Returns:
True if restarted successfully, False otherwise
"""
self._log_info(f"Restarting {self.exchange_name} data collector")
self._state_telemetry.increment_restarts()
# Stop first
await self.stop()
# Wait a bit before restarting
await asyncio.sleep(self._reconnect_delay)
# Start again
return await self.start()
async def _message_loop(self) -> None:
"""Main message processing loop."""
try:
self._log_debug("Starting message processing loop")
while self._state_telemetry._running:
try:
await self._handle_messages()
except asyncio.CancelledError:
break
except Exception as e:
self._state_telemetry.increment_errors(str(e))
self._log_error(f"Error processing messages: {e}")
# Small delay to prevent tight error loops
await asyncio.sleep(0.1)
except asyncio.CancelledError:
self._log_debug("Message loop cancelled")
raise
except Exception as e:
self._log_error(f"Error in message loop: {e}")
self._state_telemetry.update_status(CollectorStatus.ERROR)
async def _health_monitor(self) -> None:
"""Monitor collector health and restart if needed."""
try:
self._log_debug("Starting health monitor")
while self._state_telemetry._running:
try:
await asyncio.sleep(self._state_telemetry.health_check_interval)
# Check if collector should be running but isn't
if self._state_telemetry._should_be_running and self._state_telemetry.status != CollectorStatus.RUNNING:
self._log_warning("Collector should be running but isn't - restarting")
if self.auto_restart:
asyncio.create_task(self.restart())
continue
# Check heartbeat
time_since_heartbeat = datetime.now(timezone.utc) - self._state_telemetry._last_heartbeat
if time_since_heartbeat > timedelta(seconds=self._state_telemetry.health_check_interval * 2):
self._log_warning(f"No heartbeat for {time_since_heartbeat.total_seconds():.1f}s - restarting")
if self.auto_restart:
asyncio.create_task(self.restart())
continue
# Check data reception
if self._state_telemetry._last_data_received:
time_since_data = datetime.now(timezone.utc) - self._state_telemetry._last_data_received
if time_since_data > self._state_telemetry._max_silence_duration:
self._log_warning(f"No data received for {time_since_data.total_seconds():.1f}s - restarting")
if self.auto_restart:
asyncio.create_task(self.restart())
continue
# Check for error status
if self._state_telemetry.status == CollectorStatus.ERROR:
self._log_warning(f"Collector in {self._state_telemetry.status.value} status - restarting")
if self.auto_restart:
asyncio.create_task(self.restart())
except asyncio.CancelledError:
break
except asyncio.CancelledError:
self._log_debug("Health monitor cancelled")
raise
except Exception as e:
self._log_error(f"Error in health monitor: {e}")
@abstractmethod
async def _handle_messages(self) -> None:
"""
Handle incoming messages from the data source.
This method should be implemented by subclasses to handle their specific message format.
"""
pass
async def _handle_connection_error(self) -> bool:
"""
Handle connection errors and attempt reconnection.
Returns:
True if reconnection successful, False if max attempts exceeded
"""
self._reconnect_attempts += 1
if self._reconnect_attempts > self._max_reconnect_attempts:
self._log_error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded")
self._state_telemetry.update_status(CollectorStatus.ERROR)
self._state_telemetry.set_should_be_running(False)
return False
self._state_telemetry.update_status(CollectorStatus.RECONNECTING)
self._log_warning(f"Connection lost. Attempting reconnection {self._reconnect_attempts}/{self._max_reconnect_attempts}")
# Disconnect and wait before retrying
await self.disconnect()
await asyncio.sleep(self._reconnect_delay)
# Attempt to reconnect
try:
if await self.connect():
if await self.subscribe_to_data(list(self.symbols), self.data_types):
self._log_info("Reconnection successful")
self._state_telemetry.update_status(CollectorStatus.RUNNING)
self._reconnect_attempts = 0
return True
except Exception as e:
self._log_error(f"Reconnection attempt failed: {e}")
return False
def add_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None:
"""
Add a callback function for specific data type.
Args:
data_type: Type of data to monitor
callback: Function to call when data is received
"""
if callback not in self._data_callbacks[data_type]:
self._data_callbacks[data_type].append(callback)
self._log_debug(f"Added callback for {data_type.value} data")
def remove_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None:
"""
Remove a callback function for specific data type.
Args:
data_type: Type of data to stop monitoring
callback: Function to remove
"""
if callback in self._data_callbacks[data_type]:
self._data_callbacks[data_type].remove(callback)
self._log_debug(f"Removed callback for {data_type.value} data")
async def _notify_callbacks(self, data_point: MarketDataPoint) -> None:
"""
Notify all registered callbacks for a data point.
Args:
data_point: Market data to distribute
"""
callbacks = self._data_callbacks.get(data_point.data_type, [])
for callback in callbacks:
try:
# Handle both sync and async callbacks
if asyncio.iscoroutinefunction(callback):
await callback(data_point)
else:
callback(data_point)
except Exception as e:
self._log_error(f"Error in data callback: {e}")
# Update statistics
self._state_telemetry.increment_messages_processed()
self._state_telemetry._stats['last_message_time'] = data_point.timestamp # Direct update for now, will refactor
self._state_telemetry.update_data_received_timestamp()
self._state_telemetry.update_heartbeat()
def get_status(self) -> Dict[str, Any]:
"""
Get current collector status and statistics.
Returns:
Dictionary containing status information
"""
return self._state_telemetry.get_status()
def get_health_status(self) -> Dict[str, Any]:
"""
Get detailed health status for monitoring.
Returns:
Dictionary containing health information
"""
return self._state_telemetry.get_health_status()
def add_symbol(self, symbol: str) -> None:
"""
Add a new symbol to collect data for.
Args:
symbol: Trading symbol to add
"""
if symbol not in self.symbols:
self.symbols.add(symbol)
self._log_info(f"Added symbol: {symbol}")
# If collector is running, subscribe to new symbol
if self._state_telemetry.status == CollectorStatus.RUNNING:
# Note: This needs to be called from an async context
# Users should handle this appropriately
pass
def remove_symbol(self, symbol: str) -> None:
"""
Remove a symbol from data collection.
Args:
symbol: Trading symbol to remove
"""
if symbol in self.symbols:
self.symbols.remove(symbol)
self._log_info(f"Removed symbol: {symbol}")
# If collector is running, unsubscribe from symbol
if self._state_telemetry.status == CollectorStatus.RUNNING:
# Note: This needs to be called from an async context
# Users should handle this appropriately
pass
def validate_ohlcv_data(self, data: Dict[str, Any], symbol: str, timeframe: str) -> OHLCVData:
"""
Validate and convert raw OHLCV data to standardized format.
Args:
data: Raw OHLCV data dictionary
symbol: Trading symbol
timeframe: Timeframe (e.g., '1m', '5m', '1h')
Returns:
Validated OHLCVData object
Raises:
DataValidationError: If data validation fails
"""
required_fields = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
# Check required fields
for field in required_fields:
if field not in data:
raise DataValidationError(f"Missing required field: {field}")
try:
# Parse timestamp
timestamp = data['timestamp']
if isinstance(timestamp, (int, float)):
# Assume Unix timestamp in milliseconds
timestamp = datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc)
elif isinstance(timestamp, str):
timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
elif not isinstance(timestamp, datetime):
raise DataValidationError(f"Invalid timestamp format: {type(timestamp)}")
return OHLCVData(
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp,
open=Decimal(str(data['open'])),
high=Decimal(str(data['high'])),
low=Decimal(str(data['low'])),
close=Decimal(str(data['close'])),
volume=Decimal(str(data['volume'])),
trades_count=data.get('trades_count')
)
except (ValueError, TypeError, KeyError) as e:
raise DataValidationError(f"Invalid OHLCV data for {symbol}: {e}")
def __repr__(self) -> str:
"""String representation of the collector."""
return f"<{self.__class__.__name__}({self.exchange_name}, {len(self.symbols)} symbols, {self.status.value})>"