Refactor BaseDataCollector to integrate ConnectionManager for connection handling

- Extracted connection management logic into a new `ConnectionManager` class, promoting separation of concerns and enhancing modularity.
- Updated `BaseDataCollector` to utilize the `ConnectionManager` for connection, disconnection, and reconnection processes, improving code clarity and maintainability.
- Refactored connection-related methods and attributes, ensuring consistent error handling and logging practices.
- Enhanced the `OKXCollector` to implement the new connection management approach, streamlining its connection logic.
- Added unit tests for the `ConnectionManager` to validate its functionality and ensure robust error handling.

These changes improve the architecture of the data collector, aligning with project standards for maintainability and performance.
This commit is contained in:
Vasily.onl
2025-06-09 17:42:06 +08:00
parent 60434afd5d
commit 41f0e8e6b6
8 changed files with 434 additions and 77 deletions

View File

@@ -7,8 +7,9 @@ processing and validating the data, and storing it in the database.
from .base_collector import (
BaseDataCollector, DataCollectorError, DataValidationError,
DataType, CollectorStatus, MarketDataPoint, OHLCVData
CollectorStatus, OHLCVData
)
from .common.data_types import DataType, MarketDataPoint
from .collector_manager import CollectorManager, ManagerStatus, CollectorConfig
__all__ = [

View File

@@ -15,30 +15,8 @@ from enum import Enum
from utils.logger import get_logger
from .collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
class DataType(Enum):
"""Types of data that can be collected."""
TICKER = "ticker"
TRADE = "trade"
ORDERBOOK = "orderbook"
CANDLE = "candle"
BALANCE = "balance"
@dataclass
class MarketDataPoint:
"""Standardized market data structure."""
exchange: str
symbol: str
timestamp: datetime
data_type: DataType
data: Dict[str, Any]
def __post_init__(self):
"""Validate data after initialization."""
if not self.timestamp.tzinfo:
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
from .collector.collector_connection_manager import ConnectionManager
from .common.data_types import DataType, MarketDataPoint
@dataclass
@@ -149,6 +127,16 @@ class BaseDataCollector(ABC):
)
self.component_name = component # Keep for external access
# Initialize connection manager
self._connection_manager = ConnectionManager(
exchange_name=self.exchange_name,
component_name=component,
max_reconnect_attempts=5, # Default, can be made configurable later
reconnect_delay=5.0, # Default, can be made configurable later
logger=self.logger,
state_telemetry=self._state_telemetry
)
# Collector state (now managed by _state_telemetry)
self._tasks: Set[asyncio.Task] = set()
@@ -157,12 +145,6 @@ class BaseDataCollector(ABC):
data_type: [] for data_type in DataType
}
# Connection management
self._connection = None
self._reconnect_attempts = 0
self._max_reconnect_attempts = 5
self._reconnect_delay = 5.0 # seconds
# Log initialization if logger is available
if self._state_telemetry.logger:
if not self._state_telemetry.log_errors_only:
@@ -262,7 +244,7 @@ class BaseDataCollector(ABC):
try:
# Connect to data source
if not await self.connect():
if not await self._connection_manager.connect(self._actual_connect):
self._log_error("Failed to connect to data source")
self._state_telemetry.update_status(CollectorStatus.ERROR)
return False
@@ -271,13 +253,12 @@ class BaseDataCollector(ABC):
if not await self.subscribe_to_data(list(self.symbols), self.data_types):
self._log_error("Failed to subscribe to data streams")
self._state_telemetry.update_status(CollectorStatus.ERROR)
await self.disconnect()
await self._connection_manager.disconnect(self._actual_disconnect)
return False
# Start background tasks
self._state_telemetry.set_running_state(True)
self._state_telemetry.update_status(CollectorStatus.RUNNING)
self._state_telemetry.set_connection_uptime_start() # Record connection uptime start
# Start message processing task
message_task = asyncio.create_task(self._message_loop())
@@ -332,7 +313,7 @@ class BaseDataCollector(ABC):
# Unsubscribe and disconnect
await self.unsubscribe_from_data(list(self.symbols), self.data_types)
await self.disconnect()
await self._connection_manager.disconnect(self._actual_disconnect)
self._state_telemetry.update_status(CollectorStatus.STOPPED)
self._log_info(f"{self.exchange_name} data collector stopped")
@@ -355,7 +336,7 @@ class BaseDataCollector(ABC):
await self.stop()
# Wait a bit before restarting
await asyncio.sleep(self._reconnect_delay)
await asyncio.sleep(self._connection_manager._reconnect_delay)
# Start again
return await self.start()
@@ -447,34 +428,26 @@ class BaseDataCollector(ABC):
Returns:
True if reconnection successful, False if max attempts exceeded
"""
self._reconnect_attempts += 1
if self._reconnect_attempts > self._max_reconnect_attempts:
self._log_error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded")
self._state_telemetry.update_status(CollectorStatus.ERROR)
self._state_telemetry.set_should_be_running(False)
return False
self._state_telemetry.update_status(CollectorStatus.RECONNECTING)
self._log_warning(f"Connection lost. Attempting reconnection {self._reconnect_attempts}/{self._max_reconnect_attempts}")
# Disconnect and wait before retrying
await self.disconnect()
await asyncio.sleep(self._reconnect_delay)
# Attempt to reconnect
try:
if await self.connect():
if await self.subscribe_to_data(list(self.symbols), self.data_types):
self._log_info("Reconnection successful")
self._state_telemetry.update_status(CollectorStatus.RUNNING)
self._reconnect_attempts = 0
return True
except Exception as e:
self._log_error(f"Reconnection attempt failed: {e}")
return False
return await self._connection_manager.handle_connection_error(
connect_logic=self._actual_connect,
subscribe_logic=self.subscribe_to_data,
symbols=list(self.symbols),
data_types=self.data_types
)
@abstractmethod
async def _actual_connect(self) -> bool:
"""
Abstract method for subclasses to implement actual connection logic.
"""
pass
@abstractmethod
async def _actual_disconnect(self) -> None:
"""
Abstract method for subclasses to implement actual disconnection logic.
"""
pass
def add_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None:
"""

View File

@@ -0,0 +1,148 @@
"""
Module for managing network connection and reconnection logic for data collectors.
This module encapsulates the complexities of connecting, disconnecting,
and handling reconnection attempts to a data source, promoting a clean
separation of concerns within the data collector architecture.
"""
import asyncio
from typing import List, Any
from datetime import datetime
# from ..base_collector import DataType # Import from base_collector for now, will refactor later
from .collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
from data.common.data_types import DataType
class ConnectionManager:
"""
Manages the connection, disconnection, and reconnection logic for a data collector.
"""
def __init__(self,
exchange_name: str,
component_name: str,
max_reconnect_attempts: int = 5,
reconnect_delay: float = 5.0,
logger=None,
state_telemetry: CollectorStateAndTelemetry = None):
self.exchange_name = exchange_name
self.component_name = component_name
self._max_reconnect_attempts = max_reconnect_attempts
self._reconnect_delay = reconnect_delay
self.logger = logger
self._state_telemetry = state_telemetry
self._connection = None # Placeholder for the actual connection object
self._reconnect_attempts = 0
def _log_debug(self, message: str) -> None:
if self._state_telemetry:
self._state_telemetry._log_debug(f"{self.component_name}: {message}")
elif self.logger:
self.logger.debug(f"{self.component_name}: {message}")
def _log_info(self, message: str) -> None:
if self._state_telemetry:
self._state_telemetry._log_info(f"{self.component_name}: {message}")
elif self.logger:
self.logger.info(f"{self.component_name}: {message}")
def _log_warning(self, message: str) -> None:
if self._state_telemetry:
self._state_telemetry._log_warning(f"{self.component_name}: {message}")
elif self.logger:
self.logger.warning(f"{self.component_name}: {message}")
def _log_error(self, message: str, exc_info: bool = False) -> None:
if self._state_telemetry:
self._state_telemetry._log_error(f"{self.component_name}: {message}", exc_info=exc_info)
elif self.logger:
self.logger.error(f"{self.component_name}: {message}", exc_info=exc_info)
async def connect(self, connect_logic: callable) -> bool:
"""
Establish connection to the data source using provided logic.
Args:
connect_logic: A callable (async function) that performs the actual connection.
Returns:
True if connection successful, False otherwise
"""
self._log_info(f"Connecting to {self.exchange_name} data source")
try:
success = await connect_logic()
if success:
self._connection = True # Indicate connection is established
self._state_telemetry.set_connection_uptime_start()
self._log_info(f"Successfully connected to {self.exchange_name}")
return True
else:
self._log_error(f"Failed to connect to {self.exchange_name}")
return False
except Exception as e:
self._log_error(f"Error during connection to {self.exchange_name}: {e}", exc_info=True)
return False
async def disconnect(self, disconnect_logic: callable) -> None:
"""
Disconnect from the data source using provided logic.
Args:
disconnect_logic: A callable (async function) that performs the actual disconnection.
"""
self._log_info(f"Disconnecting from {self.exchange_name} data source")
try:
if self._connection:
await disconnect_logic()
self._connection = None
self._log_info(f"Disconnected from {self.exchange_name}")
except Exception as e:
self._log_error(f"Error during disconnection from {self.exchange_name}: {e}", exc_info=True)
async def handle_connection_error(self, connect_logic: callable, subscribe_logic: callable, symbols: List[str], data_types: List[DataType]) -> bool:
"""
Handle connection errors and attempt reconnection.
Args:
connect_logic: Callable for connecting.
subscribe_logic: Callable for subscribing.
symbols: List of symbols to re-subscribe to.
data_types: List of data types to re-subscribe to.
Returns:
True if reconnection successful, False if max attempts exceeded
"""
self._reconnect_attempts += 1
if self._reconnect_attempts > self._max_reconnect_attempts:
self._log_error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded for {self.exchange_name}")
if self._state_telemetry:
self._state_telemetry.update_status(CollectorStatus.ERROR)
self._state_telemetry.set_should_be_running(False)
return False
if self._state_telemetry:
self._state_telemetry.update_status(CollectorStatus.RECONNECTING)
self._log_warning(f"Connection lost. Attempting reconnection {self._reconnect_attempts}/{self._max_reconnect_attempts} for {self.exchange_name}")
# Disconnect and wait before retrying
await self.disconnect(lambda: None) # Pass a no-op disconnect for internal use, actual disconnect handled by caller
await asyncio.sleep(self._reconnect_delay)
# Attempt to reconnect
try:
if await self.connect(connect_logic):
if await subscribe_logic(symbols, data_types):
self._log_info(f"Reconnection successful for {self.exchange_name}")
if self._state_telemetry:
self._state_telemetry.update_status(CollectorStatus.RUNNING)
self._reconnect_attempts = 0
return True
except Exception as e:
self._log_error(f"Reconnection attempt failed for {self.exchange_name}: {e}", exc_info=True)
return False

View File

@@ -10,7 +10,8 @@ from .data_types import (
OHLCVCandle,
MarketDataPoint,
DataValidationResult,
CandleProcessingConfig
CandleProcessingConfig,
DataType
)
from .transformation.trade import (
@@ -44,6 +45,7 @@ __all__ = [
'MarketDataPoint',
'DataValidationResult',
'CandleProcessingConfig',
'DataType',
# Trade transformation
'TradeTransformer',

View File

@@ -10,8 +10,33 @@ from decimal import Decimal
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
import asyncio
from ..base_collector import DataType, MarketDataPoint # Import from base
# from ..base_collector import DataType, MarketDataPoint # Import from base
class DataType(Enum):
"""Types of data that can be collected."""
TICKER = "ticker"
TRADE = "trade"
ORDERBOOK = "orderbook"
CANDLE = "candle"
BALANCE = "balance"
@dataclass
class MarketDataPoint:
"""Standardized market data structure."""
exchange: str
symbol: str
timestamp: datetime
data_type: DataType
data: Dict[str, Any]
def __post_init__(self):
"""Validate data after initialization."""
if not self.timestamp.tzinfo:
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
@dataclass

View File

@@ -109,7 +109,7 @@ class OKXCollector(BaseDataCollector):
symbol,
config=candle_config or CandleProcessingConfig(timeframes=self.timeframes), # Use provided config or create new one
component_name=f"{component_name}_processor",
logger=logger
logger=self.logger
)
# Add callbacks for processed data
@@ -140,6 +140,21 @@ class OKXCollector(BaseDataCollector):
"""
Establish connection to OKX WebSocket API.
Returns:
True if connection successful, False otherwise
"""
return await self._connection_manager.connect(self._actual_connect)
async def disconnect(self) -> None:
"""
Disconnect from OKX WebSocket API.
"""
await self._connection_manager.disconnect(self._actual_disconnect)
async def _actual_connect(self) -> bool:
"""
Implement the actual connection logic for OKX WebSocket API.
Returns:
True if connection successful, False otherwise
"""
@@ -157,7 +172,7 @@ class OKXCollector(BaseDataCollector):
pong_timeout=10.0,
max_reconnect_attempts=5,
reconnect_delay=5.0,
logger=self.logger # Pass the logger to enable ping/pong logging
logger=self.logger
)
# Add message callback
@@ -175,9 +190,9 @@ class OKXCollector(BaseDataCollector):
self._log_error(f"Error connecting OKX collector for {self.symbol}: {e}")
return False
async def disconnect(self) -> None:
async def _actual_disconnect(self) -> None:
"""
Disconnect from OKX WebSocket API.
Implement the actual disconnection logic for OKX WebSocket API.
"""
try:
self._log_info(f"Disconnecting OKX collector for {self.symbol}")