Refactor BaseDataCollector to integrate ConnectionManager for connection handling
- Extracted connection management logic into a new `ConnectionManager` class, promoting separation of concerns and enhancing modularity. - Updated `BaseDataCollector` to utilize the `ConnectionManager` for connection, disconnection, and reconnection processes, improving code clarity and maintainability. - Refactored connection-related methods and attributes, ensuring consistent error handling and logging practices. - Enhanced the `OKXCollector` to implement the new connection management approach, streamlining its connection logic. - Added unit tests for the `ConnectionManager` to validate its functionality and ensure robust error handling. These changes improve the architecture of the data collector, aligning with project standards for maintainability and performance.
This commit is contained in:
parent
60434afd5d
commit
41f0e8e6b6
@ -7,8 +7,9 @@ processing and validating the data, and storing it in the database.
|
|||||||
|
|
||||||
from .base_collector import (
|
from .base_collector import (
|
||||||
BaseDataCollector, DataCollectorError, DataValidationError,
|
BaseDataCollector, DataCollectorError, DataValidationError,
|
||||||
DataType, CollectorStatus, MarketDataPoint, OHLCVData
|
CollectorStatus, OHLCVData
|
||||||
)
|
)
|
||||||
|
from .common.data_types import DataType, MarketDataPoint
|
||||||
from .collector_manager import CollectorManager, ManagerStatus, CollectorConfig
|
from .collector_manager import CollectorManager, ManagerStatus, CollectorConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -15,30 +15,8 @@ from enum import Enum
|
|||||||
|
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
from .collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
from .collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||||||
|
from .collector.collector_connection_manager import ConnectionManager
|
||||||
|
from .common.data_types import DataType, MarketDataPoint
|
||||||
class DataType(Enum):
|
|
||||||
"""Types of data that can be collected."""
|
|
||||||
TICKER = "ticker"
|
|
||||||
TRADE = "trade"
|
|
||||||
ORDERBOOK = "orderbook"
|
|
||||||
CANDLE = "candle"
|
|
||||||
BALANCE = "balance"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MarketDataPoint:
|
|
||||||
"""Standardized market data structure."""
|
|
||||||
exchange: str
|
|
||||||
symbol: str
|
|
||||||
timestamp: datetime
|
|
||||||
data_type: DataType
|
|
||||||
data: Dict[str, Any]
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""Validate data after initialization."""
|
|
||||||
if not self.timestamp.tzinfo:
|
|
||||||
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -149,6 +127,16 @@ class BaseDataCollector(ABC):
|
|||||||
)
|
)
|
||||||
self.component_name = component # Keep for external access
|
self.component_name = component # Keep for external access
|
||||||
|
|
||||||
|
# Initialize connection manager
|
||||||
|
self._connection_manager = ConnectionManager(
|
||||||
|
exchange_name=self.exchange_name,
|
||||||
|
component_name=component,
|
||||||
|
max_reconnect_attempts=5, # Default, can be made configurable later
|
||||||
|
reconnect_delay=5.0, # Default, can be made configurable later
|
||||||
|
logger=self.logger,
|
||||||
|
state_telemetry=self._state_telemetry
|
||||||
|
)
|
||||||
|
|
||||||
# Collector state (now managed by _state_telemetry)
|
# Collector state (now managed by _state_telemetry)
|
||||||
self._tasks: Set[asyncio.Task] = set()
|
self._tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
@ -157,12 +145,6 @@ class BaseDataCollector(ABC):
|
|||||||
data_type: [] for data_type in DataType
|
data_type: [] for data_type in DataType
|
||||||
}
|
}
|
||||||
|
|
||||||
# Connection management
|
|
||||||
self._connection = None
|
|
||||||
self._reconnect_attempts = 0
|
|
||||||
self._max_reconnect_attempts = 5
|
|
||||||
self._reconnect_delay = 5.0 # seconds
|
|
||||||
|
|
||||||
# Log initialization if logger is available
|
# Log initialization if logger is available
|
||||||
if self._state_telemetry.logger:
|
if self._state_telemetry.logger:
|
||||||
if not self._state_telemetry.log_errors_only:
|
if not self._state_telemetry.log_errors_only:
|
||||||
@ -262,7 +244,7 @@ class BaseDataCollector(ABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Connect to data source
|
# Connect to data source
|
||||||
if not await self.connect():
|
if not await self._connection_manager.connect(self._actual_connect):
|
||||||
self._log_error("Failed to connect to data source")
|
self._log_error("Failed to connect to data source")
|
||||||
self._state_telemetry.update_status(CollectorStatus.ERROR)
|
self._state_telemetry.update_status(CollectorStatus.ERROR)
|
||||||
return False
|
return False
|
||||||
@ -271,13 +253,12 @@ class BaseDataCollector(ABC):
|
|||||||
if not await self.subscribe_to_data(list(self.symbols), self.data_types):
|
if not await self.subscribe_to_data(list(self.symbols), self.data_types):
|
||||||
self._log_error("Failed to subscribe to data streams")
|
self._log_error("Failed to subscribe to data streams")
|
||||||
self._state_telemetry.update_status(CollectorStatus.ERROR)
|
self._state_telemetry.update_status(CollectorStatus.ERROR)
|
||||||
await self.disconnect()
|
await self._connection_manager.disconnect(self._actual_disconnect)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Start background tasks
|
# Start background tasks
|
||||||
self._state_telemetry.set_running_state(True)
|
self._state_telemetry.set_running_state(True)
|
||||||
self._state_telemetry.update_status(CollectorStatus.RUNNING)
|
self._state_telemetry.update_status(CollectorStatus.RUNNING)
|
||||||
self._state_telemetry.set_connection_uptime_start() # Record connection uptime start
|
|
||||||
|
|
||||||
# Start message processing task
|
# Start message processing task
|
||||||
message_task = asyncio.create_task(self._message_loop())
|
message_task = asyncio.create_task(self._message_loop())
|
||||||
@ -332,7 +313,7 @@ class BaseDataCollector(ABC):
|
|||||||
|
|
||||||
# Unsubscribe and disconnect
|
# Unsubscribe and disconnect
|
||||||
await self.unsubscribe_from_data(list(self.symbols), self.data_types)
|
await self.unsubscribe_from_data(list(self.symbols), self.data_types)
|
||||||
await self.disconnect()
|
await self._connection_manager.disconnect(self._actual_disconnect)
|
||||||
|
|
||||||
self._state_telemetry.update_status(CollectorStatus.STOPPED)
|
self._state_telemetry.update_status(CollectorStatus.STOPPED)
|
||||||
self._log_info(f"{self.exchange_name} data collector stopped")
|
self._log_info(f"{self.exchange_name} data collector stopped")
|
||||||
@ -355,7 +336,7 @@ class BaseDataCollector(ABC):
|
|||||||
await self.stop()
|
await self.stop()
|
||||||
|
|
||||||
# Wait a bit before restarting
|
# Wait a bit before restarting
|
||||||
await asyncio.sleep(self._reconnect_delay)
|
await asyncio.sleep(self._connection_manager._reconnect_delay)
|
||||||
|
|
||||||
# Start again
|
# Start again
|
||||||
return await self.start()
|
return await self.start()
|
||||||
@ -447,34 +428,26 @@ class BaseDataCollector(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
True if reconnection successful, False if max attempts exceeded
|
True if reconnection successful, False if max attempts exceeded
|
||||||
"""
|
"""
|
||||||
self._reconnect_attempts += 1
|
return await self._connection_manager.handle_connection_error(
|
||||||
|
connect_logic=self._actual_connect,
|
||||||
|
subscribe_logic=self.subscribe_to_data,
|
||||||
|
symbols=list(self.symbols),
|
||||||
|
data_types=self.data_types
|
||||||
|
)
|
||||||
|
|
||||||
if self._reconnect_attempts > self._max_reconnect_attempts:
|
@abstractmethod
|
||||||
self._log_error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded")
|
async def _actual_connect(self) -> bool:
|
||||||
self._state_telemetry.update_status(CollectorStatus.ERROR)
|
"""
|
||||||
self._state_telemetry.set_should_be_running(False)
|
Abstract method for subclasses to implement actual connection logic.
|
||||||
return False
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
self._state_telemetry.update_status(CollectorStatus.RECONNECTING)
|
@abstractmethod
|
||||||
self._log_warning(f"Connection lost. Attempting reconnection {self._reconnect_attempts}/{self._max_reconnect_attempts}")
|
async def _actual_disconnect(self) -> None:
|
||||||
|
"""
|
||||||
# Disconnect and wait before retrying
|
Abstract method for subclasses to implement actual disconnection logic.
|
||||||
await self.disconnect()
|
"""
|
||||||
await asyncio.sleep(self._reconnect_delay)
|
pass
|
||||||
|
|
||||||
# Attempt to reconnect
|
|
||||||
try:
|
|
||||||
if await self.connect():
|
|
||||||
if await self.subscribe_to_data(list(self.symbols), self.data_types):
|
|
||||||
self._log_info("Reconnection successful")
|
|
||||||
self._state_telemetry.update_status(CollectorStatus.RUNNING)
|
|
||||||
self._reconnect_attempts = 0
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._log_error(f"Reconnection attempt failed: {e}")
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None:
|
def add_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
148
data/collector/collector_connection_manager.py
Normal file
148
data/collector/collector_connection_manager.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
"""
|
||||||
|
Module for managing network connection and reconnection logic for data collectors.
|
||||||
|
|
||||||
|
This module encapsulates the complexities of connecting, disconnecting,
|
||||||
|
and handling reconnection attempts to a data source, promoting a clean
|
||||||
|
separation of concerns within the data collector architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# from ..base_collector import DataType # Import from base_collector for now, will refactor later
|
||||||
|
from .collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||||||
|
from data.common.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionManager:
|
||||||
|
"""
|
||||||
|
Manages the connection, disconnection, and reconnection logic for a data collector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
exchange_name: str,
|
||||||
|
component_name: str,
|
||||||
|
max_reconnect_attempts: int = 5,
|
||||||
|
reconnect_delay: float = 5.0,
|
||||||
|
logger=None,
|
||||||
|
state_telemetry: CollectorStateAndTelemetry = None):
|
||||||
|
self.exchange_name = exchange_name
|
||||||
|
self.component_name = component_name
|
||||||
|
self._max_reconnect_attempts = max_reconnect_attempts
|
||||||
|
self._reconnect_delay = reconnect_delay
|
||||||
|
self.logger = logger
|
||||||
|
self._state_telemetry = state_telemetry
|
||||||
|
|
||||||
|
self._connection = None # Placeholder for the actual connection object
|
||||||
|
self._reconnect_attempts = 0
|
||||||
|
|
||||||
|
def _log_debug(self, message: str) -> None:
|
||||||
|
if self._state_telemetry:
|
||||||
|
self._state_telemetry._log_debug(f"{self.component_name}: {message}")
|
||||||
|
elif self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: {message}")
|
||||||
|
|
||||||
|
def _log_info(self, message: str) -> None:
|
||||||
|
if self._state_telemetry:
|
||||||
|
self._state_telemetry._log_info(f"{self.component_name}: {message}")
|
||||||
|
elif self.logger:
|
||||||
|
self.logger.info(f"{self.component_name}: {message}")
|
||||||
|
|
||||||
|
def _log_warning(self, message: str) -> None:
|
||||||
|
if self._state_telemetry:
|
||||||
|
self._state_telemetry._log_warning(f"{self.component_name}: {message}")
|
||||||
|
elif self.logger:
|
||||||
|
self.logger.warning(f"{self.component_name}: {message}")
|
||||||
|
|
||||||
|
def _log_error(self, message: str, exc_info: bool = False) -> None:
|
||||||
|
if self._state_telemetry:
|
||||||
|
self._state_telemetry._log_error(f"{self.component_name}: {message}", exc_info=exc_info)
|
||||||
|
elif self.logger:
|
||||||
|
self.logger.error(f"{self.component_name}: {message}", exc_info=exc_info)
|
||||||
|
|
||||||
|
async def connect(self, connect_logic: callable) -> bool:
|
||||||
|
"""
|
||||||
|
Establish connection to the data source using provided logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connect_logic: A callable (async function) that performs the actual connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection successful, False otherwise
|
||||||
|
"""
|
||||||
|
self._log_info(f"Connecting to {self.exchange_name} data source")
|
||||||
|
try:
|
||||||
|
success = await connect_logic()
|
||||||
|
if success:
|
||||||
|
self._connection = True # Indicate connection is established
|
||||||
|
self._state_telemetry.set_connection_uptime_start()
|
||||||
|
self._log_info(f"Successfully connected to {self.exchange_name}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self._log_error(f"Failed to connect to {self.exchange_name}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
self._log_error(f"Error during connection to {self.exchange_name}: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def disconnect(self, disconnect_logic: callable) -> None:
|
||||||
|
"""
|
||||||
|
Disconnect from the data source using provided logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
disconnect_logic: A callable (async function) that performs the actual disconnection.
|
||||||
|
"""
|
||||||
|
self._log_info(f"Disconnecting from {self.exchange_name} data source")
|
||||||
|
try:
|
||||||
|
if self._connection:
|
||||||
|
await disconnect_logic()
|
||||||
|
self._connection = None
|
||||||
|
self._log_info(f"Disconnected from {self.exchange_name}")
|
||||||
|
except Exception as e:
|
||||||
|
self._log_error(f"Error during disconnection from {self.exchange_name}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def handle_connection_error(self, connect_logic: callable, subscribe_logic: callable, symbols: List[str], data_types: List[DataType]) -> bool:
|
||||||
|
"""
|
||||||
|
Handle connection errors and attempt reconnection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connect_logic: Callable for connecting.
|
||||||
|
subscribe_logic: Callable for subscribing.
|
||||||
|
symbols: List of symbols to re-subscribe to.
|
||||||
|
data_types: List of data types to re-subscribe to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if reconnection successful, False if max attempts exceeded
|
||||||
|
"""
|
||||||
|
self._reconnect_attempts += 1
|
||||||
|
|
||||||
|
if self._reconnect_attempts > self._max_reconnect_attempts:
|
||||||
|
self._log_error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded for {self.exchange_name}")
|
||||||
|
if self._state_telemetry:
|
||||||
|
self._state_telemetry.update_status(CollectorStatus.ERROR)
|
||||||
|
self._state_telemetry.set_should_be_running(False)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self._state_telemetry:
|
||||||
|
self._state_telemetry.update_status(CollectorStatus.RECONNECTING)
|
||||||
|
self._log_warning(f"Connection lost. Attempting reconnection {self._reconnect_attempts}/{self._max_reconnect_attempts} for {self.exchange_name}")
|
||||||
|
|
||||||
|
# Disconnect and wait before retrying
|
||||||
|
await self.disconnect(lambda: None) # Pass a no-op disconnect for internal use, actual disconnect handled by caller
|
||||||
|
await asyncio.sleep(self._reconnect_delay)
|
||||||
|
|
||||||
|
# Attempt to reconnect
|
||||||
|
try:
|
||||||
|
if await self.connect(connect_logic):
|
||||||
|
if await subscribe_logic(symbols, data_types):
|
||||||
|
self._log_info(f"Reconnection successful for {self.exchange_name}")
|
||||||
|
if self._state_telemetry:
|
||||||
|
self._state_telemetry.update_status(CollectorStatus.RUNNING)
|
||||||
|
self._reconnect_attempts = 0
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._log_error(f"Reconnection attempt failed for {self.exchange_name}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
return False
|
||||||
@ -10,7 +10,8 @@ from .data_types import (
|
|||||||
OHLCVCandle,
|
OHLCVCandle,
|
||||||
MarketDataPoint,
|
MarketDataPoint,
|
||||||
DataValidationResult,
|
DataValidationResult,
|
||||||
CandleProcessingConfig
|
CandleProcessingConfig,
|
||||||
|
DataType
|
||||||
)
|
)
|
||||||
|
|
||||||
from .transformation.trade import (
|
from .transformation.trade import (
|
||||||
@ -44,6 +45,7 @@ __all__ = [
|
|||||||
'MarketDataPoint',
|
'MarketDataPoint',
|
||||||
'DataValidationResult',
|
'DataValidationResult',
|
||||||
'CandleProcessingConfig',
|
'CandleProcessingConfig',
|
||||||
|
'DataType',
|
||||||
|
|
||||||
# Trade transformation
|
# Trade transformation
|
||||||
'TradeTransformer',
|
'TradeTransformer',
|
||||||
|
|||||||
@ -10,8 +10,33 @@ from decimal import Decimal
|
|||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from ..base_collector import DataType, MarketDataPoint # Import from base
|
# from ..base_collector import DataType, MarketDataPoint # Import from base
|
||||||
|
|
||||||
|
|
||||||
|
class DataType(Enum):
|
||||||
|
"""Types of data that can be collected."""
|
||||||
|
TICKER = "ticker"
|
||||||
|
TRADE = "trade"
|
||||||
|
ORDERBOOK = "orderbook"
|
||||||
|
CANDLE = "candle"
|
||||||
|
BALANCE = "balance"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarketDataPoint:
|
||||||
|
"""Standardized market data structure."""
|
||||||
|
exchange: str
|
||||||
|
symbol: str
|
||||||
|
timestamp: datetime
|
||||||
|
data_type: DataType
|
||||||
|
data: Dict[str, Any]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Validate data after initialization."""
|
||||||
|
if not self.timestamp.tzinfo:
|
||||||
|
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class OKXCollector(BaseDataCollector):
|
|||||||
symbol,
|
symbol,
|
||||||
config=candle_config or CandleProcessingConfig(timeframes=self.timeframes), # Use provided config or create new one
|
config=candle_config or CandleProcessingConfig(timeframes=self.timeframes), # Use provided config or create new one
|
||||||
component_name=f"{component_name}_processor",
|
component_name=f"{component_name}_processor",
|
||||||
logger=logger
|
logger=self.logger
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add callbacks for processed data
|
# Add callbacks for processed data
|
||||||
@ -140,6 +140,21 @@ class OKXCollector(BaseDataCollector):
|
|||||||
"""
|
"""
|
||||||
Establish connection to OKX WebSocket API.
|
Establish connection to OKX WebSocket API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection successful, False otherwise
|
||||||
|
"""
|
||||||
|
return await self._connection_manager.connect(self._actual_connect)
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""
|
||||||
|
Disconnect from OKX WebSocket API.
|
||||||
|
"""
|
||||||
|
await self._connection_manager.disconnect(self._actual_disconnect)
|
||||||
|
|
||||||
|
async def _actual_connect(self) -> bool:
|
||||||
|
"""
|
||||||
|
Implement the actual connection logic for OKX WebSocket API.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if connection successful, False otherwise
|
True if connection successful, False otherwise
|
||||||
"""
|
"""
|
||||||
@ -157,7 +172,7 @@ class OKXCollector(BaseDataCollector):
|
|||||||
pong_timeout=10.0,
|
pong_timeout=10.0,
|
||||||
max_reconnect_attempts=5,
|
max_reconnect_attempts=5,
|
||||||
reconnect_delay=5.0,
|
reconnect_delay=5.0,
|
||||||
logger=self.logger # Pass the logger to enable ping/pong logging
|
logger=self.logger
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add message callback
|
# Add message callback
|
||||||
@ -175,9 +190,9 @@ class OKXCollector(BaseDataCollector):
|
|||||||
self._log_error(f"Error connecting OKX collector for {self.symbol}: {e}")
|
self._log_error(f"Error connecting OKX collector for {self.symbol}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def disconnect(self) -> None:
|
async def _actual_disconnect(self) -> None:
|
||||||
"""
|
"""
|
||||||
Disconnect from OKX WebSocket API.
|
Implement the actual disconnection logic for OKX WebSocket API.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._log_info(f"Disconnecting OKX collector for {self.symbol}")
|
self._log_info(f"Disconnecting OKX collector for {self.symbol}")
|
||||||
|
|||||||
@ -28,13 +28,13 @@
|
|||||||
- [x] 1.6 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_state_telemetry.py`.
|
- [x] 1.6 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_state_telemetry.py`.
|
||||||
- [x] 1.7 Create `tests/data/collector/test_collector_state_telemetry.py` and add initial tests for the new class.
|
- [x] 1.7 Create `tests/data/collector/test_collector_state_telemetry.py` and add initial tests for the new class.
|
||||||
|
|
||||||
- [ ] 2.0 Extract `ConnectionManager` Class
|
- [x] 2.0 Extract `ConnectionManager` Class
|
||||||
- [ ] 2.1 Create `data/collector/collector_connection_manager.py`.
|
- [x] 2.1 Create `data/collector/collector_connection_manager.py`.
|
||||||
- [ ] 2.2 Move connection-related attributes (`_connection`, `_reconnect_attempts`, `_max_reconnect_attempts`, `_reconnect_delay`) to `ConnectionManager`.
|
- [x] 2.2 Move connection-related attributes (`_connection`, `_reconnect_attempts`, `_max_reconnect_attempts`, `_reconnect_delay`) to `ConnectionManager`.
|
||||||
- [ ] 2.3 Move `connect`, `disconnect`, `_handle_connection_error` methods to `ConnectionManager`.
|
- [x] 2.3 Move `connect`, `disconnect`, `_handle_connection_error` methods to `ConnectionManager`.
|
||||||
- [ ] 2.4 Implement a constructor for `ConnectionManager` to receive logger and other necessary parameters.
|
- [x] 2.4 Implement a constructor for `ConnectionManager` to receive logger and other necessary parameters.
|
||||||
- [ ] 2.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_connection_manager.py`.
|
- [x] 2.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_connection_manager.py`.
|
||||||
- [ ] 2.6 Create `tests/data/collector/test_collector_connection_manager.py` and add initial tests for the new class.
|
- [x] 2.6 Create `tests/data/collector/test_collector_connection_manager.py` and add initial tests for the new class.
|
||||||
|
|
||||||
- [ ] 3.0 Extract `CallbackDispatcher` Class
|
- [ ] 3.0 Extract `CallbackDispatcher` Class
|
||||||
- [ ] 3.1 Create `data/collector/collector_callback_dispatcher.py`.
|
- [ ] 3.1 Create `data/collector/collector_callback_dispatcher.py`.
|
||||||
|
|||||||
193
tests/data/collector/test_collector_connection_manager.py
Normal file
193
tests/data/collector/test_collector_connection_manager.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import asyncio
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from data.collector.collector_connection_manager import ConnectionManager
|
||||||
|
from data.collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||||||
|
from data.common.data_types import DataType # Import for DataType enum
|
||||||
|
|
||||||
|
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mock_logger = Mock()
|
||||||
|
self.mock_state_telemetry = AsyncMock(spec=CollectorStateAndTelemetry)
|
||||||
|
self.mock_state_telemetry.logger = self.mock_logger
|
||||||
|
self.mock_state_telemetry.update_status = Mock()
|
||||||
|
self.mock_state_telemetry.set_should_be_running = Mock()
|
||||||
|
self.mock_state_telemetry.set_connection_uptime_start = Mock()
|
||||||
|
self.mock_state_telemetry._log_info = Mock()
|
||||||
|
self.mock_state_telemetry._log_debug = Mock()
|
||||||
|
self.mock_state_telemetry._log_warning = Mock()
|
||||||
|
self.mock_state_telemetry._log_error = Mock()
|
||||||
|
self.mock_state_telemetry._log_critical = Mock()
|
||||||
|
self.manager = ConnectionManager(
|
||||||
|
exchange_name="test_exchange",
|
||||||
|
component_name="test_collector",
|
||||||
|
logger=self.mock_logger,
|
||||||
|
state_telemetry=self.mock_state_telemetry
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_init(self):
|
||||||
|
self.assertEqual(self.manager.exchange_name, "test_exchange")
|
||||||
|
self.assertEqual(self.manager.component_name, "test_collector")
|
||||||
|
self.assertEqual(self.manager._max_reconnect_attempts, 5)
|
||||||
|
self.assertEqual(self.manager._reconnect_delay, 5.0)
|
||||||
|
self.assertEqual(self.manager.logger, self.mock_logger)
|
||||||
|
self.assertEqual(self.manager._state_telemetry, self.mock_state_telemetry)
|
||||||
|
self.assertIsNone(self.manager._connection)
|
||||||
|
self.assertEqual(self.manager._reconnect_attempts, 0)
|
||||||
|
|
||||||
|
async def test_connect_success(self):
|
||||||
|
mock_connect_logic = AsyncMock(return_value=True)
|
||||||
|
result = await self.manager.connect(mock_connect_logic)
|
||||||
|
self.assertTrue(result)
|
||||||
|
mock_connect_logic.assert_called_once()
|
||||||
|
self.assertTrue(self.manager._connection)
|
||||||
|
self.mock_state_telemetry.set_connection_uptime_start.assert_called_once()
|
||||||
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Successfully connected to test_exchange")
|
||||||
|
|
||||||
|
async def test_connect_failure(self):
|
||||||
|
mock_connect_logic = AsyncMock(return_value=False)
|
||||||
|
result = await self.manager.connect(mock_connect_logic)
|
||||||
|
self.assertFalse(result)
|
||||||
|
mock_connect_logic.assert_called_once()
|
||||||
|
self.assertIsNone(self.manager._connection)
|
||||||
|
self.mock_state_telemetry._log_error.assert_any_call("test_collector: Failed to connect to test_exchange", exc_info=False)
|
||||||
|
|
||||||
|
async def test_connect_exception(self):
|
||||||
|
mock_connect_logic = AsyncMock(side_effect=Exception("Connection error"))
|
||||||
|
result = await self.manager.connect(mock_connect_logic)
|
||||||
|
self.assertFalse(result)
|
||||||
|
mock_connect_logic.assert_called_once()
|
||||||
|
self.assertIsNone(self.manager._connection)
|
||||||
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Connection error", exc_info=True)
|
||||||
|
|
||||||
|
async def test_disconnect_success(self):
|
||||||
|
self.manager._connection = True # Simulate active connection
|
||||||
|
mock_disconnect_logic = AsyncMock()
|
||||||
|
await self.manager.disconnect(mock_disconnect_logic)
|
||||||
|
mock_disconnect_logic.assert_called_once()
|
||||||
|
self.assertIsNone(self.manager._connection)
|
||||||
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnected from test_exchange")
|
||||||
|
|
||||||
|
async def test_disconnect_no_active_connection(self):
|
||||||
|
mock_disconnect_logic = AsyncMock()
|
||||||
|
await self.manager.disconnect(mock_disconnect_logic)
|
||||||
|
mock_disconnect_logic.assert_not_called()
|
||||||
|
self.assertIsNone(self.manager._connection)
|
||||||
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnecting from test_exchange data source")
|
||||||
|
|
||||||
|
async def test_disconnect_exception(self):
|
||||||
|
self.manager._connection = True
|
||||||
|
mock_disconnect_logic = AsyncMock(side_effect=Exception("Disconnect error"))
|
||||||
|
await self.manager.disconnect(mock_disconnect_logic)
|
||||||
|
mock_disconnect_logic.assert_called_once()
|
||||||
|
self.assertTrue(self.manager._connection) # Connection state remains unchanged on exception
|
||||||
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during disconnection from test_exchange: Disconnect error", exc_info=True)
|
||||||
|
|
||||||
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||||
|
async def test_handle_connection_error_reconnect_success(self, mock_sleep):
|
||||||
|
mock_connect_logic = AsyncMock(return_value=True)
|
||||||
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
self.manager._reconnect_attempts = 0
|
||||||
|
result = await self.manager.handle_connection_error(
|
||||||
|
mock_connect_logic,
|
||||||
|
mock_subscribe_logic,
|
||||||
|
symbols=["BTC/USDT"],
|
||||||
|
data_types=[DataType.CANDLE]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(result)
|
||||||
|
self.assertEqual(self.manager._reconnect_attempts, 0)
|
||||||
|
mock_connect_logic.assert_called_once()
|
||||||
|
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
|
||||||
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||||
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RUNNING)
|
||||||
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Reconnection successful for test_exchange")
|
||||||
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||||
|
|
||||||
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||||
|
async def test_handle_connection_error_max_attempts_exceeded(self, mock_sleep):
|
||||||
|
mock_connect_logic = AsyncMock(return_value=True)
|
||||||
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
self.manager._reconnect_attempts = self.manager._max_reconnect_attempts
|
||||||
|
result = await self.manager.handle_connection_error(
|
||||||
|
mock_connect_logic,
|
||||||
|
mock_subscribe_logic,
|
||||||
|
symbols=["BTC/USDT"],
|
||||||
|
data_types=[DataType.CANDLE]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(result)
|
||||||
|
self.assertEqual(self.manager._reconnect_attempts, self.manager._max_reconnect_attempts + 1)
|
||||||
|
mock_connect_logic.assert_not_called()
|
||||||
|
mock_subscribe_logic.assert_not_called()
|
||||||
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.ERROR)
|
||||||
|
self.mock_state_telemetry.set_should_be_running.assert_called_once_with(False)
|
||||||
|
self.mock_state_telemetry._log_error.assert_any_call(f"test_collector: Max reconnection attempts ({self.manager._max_reconnect_attempts}) exceeded for test_exchange", exc_info=False)
|
||||||
|
mock_sleep.assert_not_called()
|
||||||
|
|
||||||
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||||
|
async def test_handle_connection_error_connect_fails(self, mock_sleep):
|
||||||
|
mock_connect_logic = AsyncMock(return_value=False)
|
||||||
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
self.manager._reconnect_attempts = 0
|
||||||
|
result = await self.manager.handle_connection_error(
|
||||||
|
mock_connect_logic,
|
||||||
|
mock_subscribe_logic,
|
||||||
|
symbols=["BTC/USDT"],
|
||||||
|
data_types=[DataType.CANDLE]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(result)
|
||||||
|
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||||||
|
mock_connect_logic.assert_called_once()
|
||||||
|
mock_subscribe_logic.assert_not_called()
|
||||||
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||||
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Failed to connect to test_exchange", exc_info=False)
|
||||||
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||||
|
|
||||||
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||||
|
async def test_handle_connection_error_subscribe_fails(self, mock_sleep):
|
||||||
|
mock_connect_logic = AsyncMock(return_value=True)
|
||||||
|
mock_subscribe_logic = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
self.manager._reconnect_attempts = 0
|
||||||
|
result = await self.manager.handle_connection_error(
|
||||||
|
mock_connect_logic,
|
||||||
|
mock_subscribe_logic,
|
||||||
|
symbols=["BTC/USDT"],
|
||||||
|
data_types=[DataType.CANDLE]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(result)
|
||||||
|
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||||||
|
mock_connect_logic.assert_called_once()
|
||||||
|
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
|
||||||
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||||
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||||
|
|
||||||
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||||
|
async def test_handle_connection_error_exception_during_reconnect(self, mock_sleep):
|
||||||
|
mock_connect_logic = AsyncMock(side_effect=Exception("Reconnect error"))
|
||||||
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
self.manager._reconnect_attempts = 0
|
||||||
|
result = await self.manager.handle_connection_error(
|
||||||
|
mock_connect_logic,
|
||||||
|
mock_subscribe_logic,
|
||||||
|
symbols=["BTC/USDT"],
|
||||||
|
data_types=[DataType.CANDLE]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(result)
|
||||||
|
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||||||
|
mock_connect_logic.assert_called_once()
|
||||||
|
mock_subscribe_logic.assert_not_called()
|
||||||
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||||
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Reconnect error", exc_info=True)
|
||||||
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||||
Loading…
x
Reference in New Issue
Block a user