Refactor BaseDataCollector to integrate ConnectionManager for connection handling
- Extracted connection management logic into a new `ConnectionManager` class, promoting separation of concerns and enhancing modularity. - Updated `BaseDataCollector` to utilize the `ConnectionManager` for connection, disconnection, and reconnection processes, improving code clarity and maintainability. - Refactored connection-related methods and attributes, ensuring consistent error handling and logging practices. - Enhanced the `OKXCollector` to implement the new connection management approach, streamlining its connection logic. - Added unit tests for the `ConnectionManager` to validate its functionality and ensure robust error handling. These changes improve the architecture of the data collector, aligning with project standards for maintainability and performance.
This commit is contained in:
parent
60434afd5d
commit
41f0e8e6b6
@ -7,8 +7,9 @@ processing and validating the data, and storing it in the database.
|
||||
|
||||
from .base_collector import (
|
||||
BaseDataCollector, DataCollectorError, DataValidationError,
|
||||
DataType, CollectorStatus, MarketDataPoint, OHLCVData
|
||||
CollectorStatus, OHLCVData
|
||||
)
|
||||
from .common.data_types import DataType, MarketDataPoint
|
||||
from .collector_manager import CollectorManager, ManagerStatus, CollectorConfig
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -15,30 +15,8 @@ from enum import Enum
|
||||
|
||||
from utils.logger import get_logger
|
||||
from .collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
"""Types of data that can be collected."""
|
||||
TICKER = "ticker"
|
||||
TRADE = "trade"
|
||||
ORDERBOOK = "orderbook"
|
||||
CANDLE = "candle"
|
||||
BALANCE = "balance"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketDataPoint:
|
||||
"""Standardized market data structure."""
|
||||
exchange: str
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
data_type: DataType
|
||||
data: Dict[str, Any]
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate data after initialization."""
|
||||
if not self.timestamp.tzinfo:
|
||||
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
|
||||
from .collector.collector_connection_manager import ConnectionManager
|
||||
from .common.data_types import DataType, MarketDataPoint
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -149,6 +127,16 @@ class BaseDataCollector(ABC):
|
||||
)
|
||||
self.component_name = component # Keep for external access
|
||||
|
||||
# Initialize connection manager
|
||||
self._connection_manager = ConnectionManager(
|
||||
exchange_name=self.exchange_name,
|
||||
component_name=component,
|
||||
max_reconnect_attempts=5, # Default, can be made configurable later
|
||||
reconnect_delay=5.0, # Default, can be made configurable later
|
||||
logger=self.logger,
|
||||
state_telemetry=self._state_telemetry
|
||||
)
|
||||
|
||||
# Collector state (now managed by _state_telemetry)
|
||||
self._tasks: Set[asyncio.Task] = set()
|
||||
|
||||
@ -157,12 +145,6 @@ class BaseDataCollector(ABC):
|
||||
data_type: [] for data_type in DataType
|
||||
}
|
||||
|
||||
# Connection management
|
||||
self._connection = None
|
||||
self._reconnect_attempts = 0
|
||||
self._max_reconnect_attempts = 5
|
||||
self._reconnect_delay = 5.0 # seconds
|
||||
|
||||
# Log initialization if logger is available
|
||||
if self._state_telemetry.logger:
|
||||
if not self._state_telemetry.log_errors_only:
|
||||
@ -262,7 +244,7 @@ class BaseDataCollector(ABC):
|
||||
|
||||
try:
|
||||
# Connect to data source
|
||||
if not await self.connect():
|
||||
if not await self._connection_manager.connect(self._actual_connect):
|
||||
self._log_error("Failed to connect to data source")
|
||||
self._state_telemetry.update_status(CollectorStatus.ERROR)
|
||||
return False
|
||||
@ -271,13 +253,12 @@ class BaseDataCollector(ABC):
|
||||
if not await self.subscribe_to_data(list(self.symbols), self.data_types):
|
||||
self._log_error("Failed to subscribe to data streams")
|
||||
self._state_telemetry.update_status(CollectorStatus.ERROR)
|
||||
await self.disconnect()
|
||||
await self._connection_manager.disconnect(self._actual_disconnect)
|
||||
return False
|
||||
|
||||
# Start background tasks
|
||||
self._state_telemetry.set_running_state(True)
|
||||
self._state_telemetry.update_status(CollectorStatus.RUNNING)
|
||||
self._state_telemetry.set_connection_uptime_start() # Record connection uptime start
|
||||
|
||||
# Start message processing task
|
||||
message_task = asyncio.create_task(self._message_loop())
|
||||
@ -332,7 +313,7 @@ class BaseDataCollector(ABC):
|
||||
|
||||
# Unsubscribe and disconnect
|
||||
await self.unsubscribe_from_data(list(self.symbols), self.data_types)
|
||||
await self.disconnect()
|
||||
await self._connection_manager.disconnect(self._actual_disconnect)
|
||||
|
||||
self._state_telemetry.update_status(CollectorStatus.STOPPED)
|
||||
self._log_info(f"{self.exchange_name} data collector stopped")
|
||||
@ -355,7 +336,7 @@ class BaseDataCollector(ABC):
|
||||
await self.stop()
|
||||
|
||||
# Wait a bit before restarting
|
||||
await asyncio.sleep(self._reconnect_delay)
|
||||
await asyncio.sleep(self._connection_manager._reconnect_delay)
|
||||
|
||||
# Start again
|
||||
return await self.start()
|
||||
@ -447,34 +428,26 @@ class BaseDataCollector(ABC):
|
||||
Returns:
|
||||
True if reconnection successful, False if max attempts exceeded
|
||||
"""
|
||||
self._reconnect_attempts += 1
|
||||
|
||||
if self._reconnect_attempts > self._max_reconnect_attempts:
|
||||
self._log_error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded")
|
||||
self._state_telemetry.update_status(CollectorStatus.ERROR)
|
||||
self._state_telemetry.set_should_be_running(False)
|
||||
return False
|
||||
|
||||
self._state_telemetry.update_status(CollectorStatus.RECONNECTING)
|
||||
self._log_warning(f"Connection lost. Attempting reconnection {self._reconnect_attempts}/{self._max_reconnect_attempts}")
|
||||
|
||||
# Disconnect and wait before retrying
|
||||
await self.disconnect()
|
||||
await asyncio.sleep(self._reconnect_delay)
|
||||
|
||||
# Attempt to reconnect
|
||||
try:
|
||||
if await self.connect():
|
||||
if await self.subscribe_to_data(list(self.symbols), self.data_types):
|
||||
self._log_info("Reconnection successful")
|
||||
self._state_telemetry.update_status(CollectorStatus.RUNNING)
|
||||
self._reconnect_attempts = 0
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._log_error(f"Reconnection attempt failed: {e}")
|
||||
|
||||
return False
|
||||
return await self._connection_manager.handle_connection_error(
|
||||
connect_logic=self._actual_connect,
|
||||
subscribe_logic=self.subscribe_to_data,
|
||||
symbols=list(self.symbols),
|
||||
data_types=self.data_types
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def _actual_connect(self) -> bool:
|
||||
"""
|
||||
Abstract method for subclasses to implement actual connection logic.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _actual_disconnect(self) -> None:
|
||||
"""
|
||||
Abstract method for subclasses to implement actual disconnection logic.
|
||||
"""
|
||||
pass
|
||||
|
||||
def add_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None:
|
||||
"""
|
||||
|
||||
148
data/collector/collector_connection_manager.py
Normal file
148
data/collector/collector_connection_manager.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""
|
||||
Module for managing network connection and reconnection logic for data collectors.
|
||||
|
||||
This module encapsulates the complexities of connecting, disconnecting,
|
||||
and handling reconnection attempts to a data source, promoting a clean
|
||||
separation of concerns within the data collector architecture.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Any
|
||||
from datetime import datetime
|
||||
|
||||
# from ..base_collector import DataType # Import from base_collector for now, will refactor later
|
||||
from .collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||||
from data.common.data_types import DataType
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""
|
||||
Manages the connection, disconnection, and reconnection logic for a data collector.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
exchange_name: str,
|
||||
component_name: str,
|
||||
max_reconnect_attempts: int = 5,
|
||||
reconnect_delay: float = 5.0,
|
||||
logger=None,
|
||||
state_telemetry: CollectorStateAndTelemetry = None):
|
||||
self.exchange_name = exchange_name
|
||||
self.component_name = component_name
|
||||
self._max_reconnect_attempts = max_reconnect_attempts
|
||||
self._reconnect_delay = reconnect_delay
|
||||
self.logger = logger
|
||||
self._state_telemetry = state_telemetry
|
||||
|
||||
self._connection = None # Placeholder for the actual connection object
|
||||
self._reconnect_attempts = 0
|
||||
|
||||
def _log_debug(self, message: str) -> None:
|
||||
if self._state_telemetry:
|
||||
self._state_telemetry._log_debug(f"{self.component_name}: {message}")
|
||||
elif self.logger:
|
||||
self.logger.debug(f"{self.component_name}: {message}")
|
||||
|
||||
def _log_info(self, message: str) -> None:
|
||||
if self._state_telemetry:
|
||||
self._state_telemetry._log_info(f"{self.component_name}: {message}")
|
||||
elif self.logger:
|
||||
self.logger.info(f"{self.component_name}: {message}")
|
||||
|
||||
def _log_warning(self, message: str) -> None:
|
||||
if self._state_telemetry:
|
||||
self._state_telemetry._log_warning(f"{self.component_name}: {message}")
|
||||
elif self.logger:
|
||||
self.logger.warning(f"{self.component_name}: {message}")
|
||||
|
||||
def _log_error(self, message: str, exc_info: bool = False) -> None:
|
||||
if self._state_telemetry:
|
||||
self._state_telemetry._log_error(f"{self.component_name}: {message}", exc_info=exc_info)
|
||||
elif self.logger:
|
||||
self.logger.error(f"{self.component_name}: {message}", exc_info=exc_info)
|
||||
|
||||
async def connect(self, connect_logic: callable) -> bool:
|
||||
"""
|
||||
Establish connection to the data source using provided logic.
|
||||
|
||||
Args:
|
||||
connect_logic: A callable (async function) that performs the actual connection.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
self._log_info(f"Connecting to {self.exchange_name} data source")
|
||||
try:
|
||||
success = await connect_logic()
|
||||
if success:
|
||||
self._connection = True # Indicate connection is established
|
||||
self._state_telemetry.set_connection_uptime_start()
|
||||
self._log_info(f"Successfully connected to {self.exchange_name}")
|
||||
return True
|
||||
else:
|
||||
self._log_error(f"Failed to connect to {self.exchange_name}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self._log_error(f"Error during connection to {self.exchange_name}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def disconnect(self, disconnect_logic: callable) -> None:
|
||||
"""
|
||||
Disconnect from the data source using provided logic.
|
||||
|
||||
Args:
|
||||
disconnect_logic: A callable (async function) that performs the actual disconnection.
|
||||
"""
|
||||
self._log_info(f"Disconnecting from {self.exchange_name} data source")
|
||||
try:
|
||||
if self._connection:
|
||||
await disconnect_logic()
|
||||
self._connection = None
|
||||
self._log_info(f"Disconnected from {self.exchange_name}")
|
||||
except Exception as e:
|
||||
self._log_error(f"Error during disconnection from {self.exchange_name}: {e}", exc_info=True)
|
||||
|
||||
async def handle_connection_error(self, connect_logic: callable, subscribe_logic: callable, symbols: List[str], data_types: List[DataType]) -> bool:
|
||||
"""
|
||||
Handle connection errors and attempt reconnection.
|
||||
|
||||
Args:
|
||||
connect_logic: Callable for connecting.
|
||||
subscribe_logic: Callable for subscribing.
|
||||
symbols: List of symbols to re-subscribe to.
|
||||
data_types: List of data types to re-subscribe to.
|
||||
|
||||
Returns:
|
||||
True if reconnection successful, False if max attempts exceeded
|
||||
"""
|
||||
self._reconnect_attempts += 1
|
||||
|
||||
if self._reconnect_attempts > self._max_reconnect_attempts:
|
||||
self._log_error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded for {self.exchange_name}")
|
||||
if self._state_telemetry:
|
||||
self._state_telemetry.update_status(CollectorStatus.ERROR)
|
||||
self._state_telemetry.set_should_be_running(False)
|
||||
return False
|
||||
|
||||
if self._state_telemetry:
|
||||
self._state_telemetry.update_status(CollectorStatus.RECONNECTING)
|
||||
self._log_warning(f"Connection lost. Attempting reconnection {self._reconnect_attempts}/{self._max_reconnect_attempts} for {self.exchange_name}")
|
||||
|
||||
# Disconnect and wait before retrying
|
||||
await self.disconnect(lambda: None) # Pass a no-op disconnect for internal use, actual disconnect handled by caller
|
||||
await asyncio.sleep(self._reconnect_delay)
|
||||
|
||||
# Attempt to reconnect
|
||||
try:
|
||||
if await self.connect(connect_logic):
|
||||
if await subscribe_logic(symbols, data_types):
|
||||
self._log_info(f"Reconnection successful for {self.exchange_name}")
|
||||
if self._state_telemetry:
|
||||
self._state_telemetry.update_status(CollectorStatus.RUNNING)
|
||||
self._reconnect_attempts = 0
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._log_error(f"Reconnection attempt failed for {self.exchange_name}: {e}", exc_info=True)
|
||||
|
||||
return False
|
||||
@ -10,7 +10,8 @@ from .data_types import (
|
||||
OHLCVCandle,
|
||||
MarketDataPoint,
|
||||
DataValidationResult,
|
||||
CandleProcessingConfig
|
||||
CandleProcessingConfig,
|
||||
DataType
|
||||
)
|
||||
|
||||
from .transformation.trade import (
|
||||
@ -44,6 +45,7 @@ __all__ = [
|
||||
'MarketDataPoint',
|
||||
'DataValidationResult',
|
||||
'CandleProcessingConfig',
|
||||
'DataType',
|
||||
|
||||
# Trade transformation
|
||||
'TradeTransformer',
|
||||
|
||||
@ -10,8 +10,33 @@ from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import asyncio
|
||||
|
||||
from ..base_collector import DataType, MarketDataPoint # Import from base
|
||||
# from ..base_collector import DataType, MarketDataPoint # Import from base
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
"""Types of data that can be collected."""
|
||||
TICKER = "ticker"
|
||||
TRADE = "trade"
|
||||
ORDERBOOK = "orderbook"
|
||||
CANDLE = "candle"
|
||||
BALANCE = "balance"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketDataPoint:
|
||||
"""Standardized market data structure."""
|
||||
exchange: str
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
data_type: DataType
|
||||
data: Dict[str, Any]
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate data after initialization."""
|
||||
if not self.timestamp.tzinfo:
|
||||
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -109,7 +109,7 @@ class OKXCollector(BaseDataCollector):
|
||||
symbol,
|
||||
config=candle_config or CandleProcessingConfig(timeframes=self.timeframes), # Use provided config or create new one
|
||||
component_name=f"{component_name}_processor",
|
||||
logger=logger
|
||||
logger=self.logger
|
||||
)
|
||||
|
||||
# Add callbacks for processed data
|
||||
@ -140,6 +140,21 @@ class OKXCollector(BaseDataCollector):
|
||||
"""
|
||||
Establish connection to OKX WebSocket API.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
return await self._connection_manager.connect(self._actual_connect)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect from OKX WebSocket API.
|
||||
"""
|
||||
await self._connection_manager.disconnect(self._actual_disconnect)
|
||||
|
||||
async def _actual_connect(self) -> bool:
|
||||
"""
|
||||
Implement the actual connection logic for OKX WebSocket API.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
@ -157,7 +172,7 @@ class OKXCollector(BaseDataCollector):
|
||||
pong_timeout=10.0,
|
||||
max_reconnect_attempts=5,
|
||||
reconnect_delay=5.0,
|
||||
logger=self.logger # Pass the logger to enable ping/pong logging
|
||||
logger=self.logger
|
||||
)
|
||||
|
||||
# Add message callback
|
||||
@ -175,9 +190,9 @@ class OKXCollector(BaseDataCollector):
|
||||
self._log_error(f"Error connecting OKX collector for {self.symbol}: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
async def _actual_disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect from OKX WebSocket API.
|
||||
Implement the actual disconnection logic for OKX WebSocket API.
|
||||
"""
|
||||
try:
|
||||
self._log_info(f"Disconnecting OKX collector for {self.symbol}")
|
||||
|
||||
@ -28,13 +28,13 @@
|
||||
- [x] 1.6 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_state_telemetry.py`.
|
||||
- [x] 1.7 Create `tests/data/collector/test_collector_state_telemetry.py` and add initial tests for the new class.
|
||||
|
||||
- [ ] 2.0 Extract `ConnectionManager` Class
|
||||
- [ ] 2.1 Create `data/collector/collector_connection_manager.py`.
|
||||
- [ ] 2.2 Move connection-related attributes (`_connection`, `_reconnect_attempts`, `_max_reconnect_attempts`, `_reconnect_delay`) to `ConnectionManager`.
|
||||
- [ ] 2.3 Move `connect`, `disconnect`, `_handle_connection_error` methods to `ConnectionManager`.
|
||||
- [ ] 2.4 Implement a constructor for `ConnectionManager` to receive logger and other necessary parameters.
|
||||
- [ ] 2.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_connection_manager.py`.
|
||||
- [ ] 2.6 Create `tests/data/collector/test_collector_connection_manager.py` and add initial tests for the new class.
|
||||
- [x] 2.0 Extract `ConnectionManager` Class
|
||||
- [x] 2.1 Create `data/collector/collector_connection_manager.py`.
|
||||
- [x] 2.2 Move connection-related attributes (`_connection`, `_reconnect_attempts`, `_max_reconnect_attempts`, `_reconnect_delay`) to `ConnectionManager`.
|
||||
- [x] 2.3 Move `connect`, `disconnect`, `_handle_connection_error` methods to `ConnectionManager`.
|
||||
- [x] 2.4 Implement a constructor for `ConnectionManager` to receive logger and other necessary parameters.
|
||||
- [x] 2.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_connection_manager.py`.
|
||||
- [x] 2.6 Create `tests/data/collector/test_collector_connection_manager.py` and add initial tests for the new class.
|
||||
|
||||
- [ ] 3.0 Extract `CallbackDispatcher` Class
|
||||
- [ ] 3.1 Create `data/collector/collector_callback_dispatcher.py`.
|
||||
|
||||
193
tests/data/collector/test_collector_connection_manager.py
Normal file
193
tests/data/collector/test_collector_connection_manager.py
Normal file
@ -0,0 +1,193 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from data.collector.collector_connection_manager import ConnectionManager
|
||||
from data.collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||||
from data.common.data_types import DataType # Import for DataType enum
|
||||
|
||||
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_logger = Mock()
|
||||
self.mock_state_telemetry = AsyncMock(spec=CollectorStateAndTelemetry)
|
||||
self.mock_state_telemetry.logger = self.mock_logger
|
||||
self.mock_state_telemetry.update_status = Mock()
|
||||
self.mock_state_telemetry.set_should_be_running = Mock()
|
||||
self.mock_state_telemetry.set_connection_uptime_start = Mock()
|
||||
self.mock_state_telemetry._log_info = Mock()
|
||||
self.mock_state_telemetry._log_debug = Mock()
|
||||
self.mock_state_telemetry._log_warning = Mock()
|
||||
self.mock_state_telemetry._log_error = Mock()
|
||||
self.mock_state_telemetry._log_critical = Mock()
|
||||
self.manager = ConnectionManager(
|
||||
exchange_name="test_exchange",
|
||||
component_name="test_collector",
|
||||
logger=self.mock_logger,
|
||||
state_telemetry=self.mock_state_telemetry
|
||||
)
|
||||
|
||||
async def test_init(self):
|
||||
self.assertEqual(self.manager.exchange_name, "test_exchange")
|
||||
self.assertEqual(self.manager.component_name, "test_collector")
|
||||
self.assertEqual(self.manager._max_reconnect_attempts, 5)
|
||||
self.assertEqual(self.manager._reconnect_delay, 5.0)
|
||||
self.assertEqual(self.manager.logger, self.mock_logger)
|
||||
self.assertEqual(self.manager._state_telemetry, self.mock_state_telemetry)
|
||||
self.assertIsNone(self.manager._connection)
|
||||
self.assertEqual(self.manager._reconnect_attempts, 0)
|
||||
|
||||
async def test_connect_success(self):
|
||||
mock_connect_logic = AsyncMock(return_value=True)
|
||||
result = await self.manager.connect(mock_connect_logic)
|
||||
self.assertTrue(result)
|
||||
mock_connect_logic.assert_called_once()
|
||||
self.assertTrue(self.manager._connection)
|
||||
self.mock_state_telemetry.set_connection_uptime_start.assert_called_once()
|
||||
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Successfully connected to test_exchange")
|
||||
|
||||
async def test_connect_failure(self):
|
||||
mock_connect_logic = AsyncMock(return_value=False)
|
||||
result = await self.manager.connect(mock_connect_logic)
|
||||
self.assertFalse(result)
|
||||
mock_connect_logic.assert_called_once()
|
||||
self.assertIsNone(self.manager._connection)
|
||||
self.mock_state_telemetry._log_error.assert_any_call("test_collector: Failed to connect to test_exchange", exc_info=False)
|
||||
|
||||
async def test_connect_exception(self):
|
||||
mock_connect_logic = AsyncMock(side_effect=Exception("Connection error"))
|
||||
result = await self.manager.connect(mock_connect_logic)
|
||||
self.assertFalse(result)
|
||||
mock_connect_logic.assert_called_once()
|
||||
self.assertIsNone(self.manager._connection)
|
||||
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Connection error", exc_info=True)
|
||||
|
||||
async def test_disconnect_success(self):
|
||||
self.manager._connection = True # Simulate active connection
|
||||
mock_disconnect_logic = AsyncMock()
|
||||
await self.manager.disconnect(mock_disconnect_logic)
|
||||
mock_disconnect_logic.assert_called_once()
|
||||
self.assertIsNone(self.manager._connection)
|
||||
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnected from test_exchange")
|
||||
|
||||
async def test_disconnect_no_active_connection(self):
|
||||
mock_disconnect_logic = AsyncMock()
|
||||
await self.manager.disconnect(mock_disconnect_logic)
|
||||
mock_disconnect_logic.assert_not_called()
|
||||
self.assertIsNone(self.manager._connection)
|
||||
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnecting from test_exchange data source")
|
||||
|
||||
async def test_disconnect_exception(self):
|
||||
self.manager._connection = True
|
||||
mock_disconnect_logic = AsyncMock(side_effect=Exception("Disconnect error"))
|
||||
await self.manager.disconnect(mock_disconnect_logic)
|
||||
mock_disconnect_logic.assert_called_once()
|
||||
self.assertTrue(self.manager._connection) # Connection state remains unchanged on exception
|
||||
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during disconnection from test_exchange: Disconnect error", exc_info=True)
|
||||
|
||||
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||
async def test_handle_connection_error_reconnect_success(self, mock_sleep):
|
||||
mock_connect_logic = AsyncMock(return_value=True)
|
||||
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||
|
||||
self.manager._reconnect_attempts = 0
|
||||
result = await self.manager.handle_connection_error(
|
||||
mock_connect_logic,
|
||||
mock_subscribe_logic,
|
||||
symbols=["BTC/USDT"],
|
||||
data_types=[DataType.CANDLE]
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(self.manager._reconnect_attempts, 0)
|
||||
mock_connect_logic.assert_called_once()
|
||||
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RUNNING)
|
||||
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Reconnection successful for test_exchange")
|
||||
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||
|
||||
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||
async def test_handle_connection_error_max_attempts_exceeded(self, mock_sleep):
|
||||
mock_connect_logic = AsyncMock(return_value=True)
|
||||
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||
|
||||
self.manager._reconnect_attempts = self.manager._max_reconnect_attempts
|
||||
result = await self.manager.handle_connection_error(
|
||||
mock_connect_logic,
|
||||
mock_subscribe_logic,
|
||||
symbols=["BTC/USDT"],
|
||||
data_types=[DataType.CANDLE]
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(self.manager._reconnect_attempts, self.manager._max_reconnect_attempts + 1)
|
||||
mock_connect_logic.assert_not_called()
|
||||
mock_subscribe_logic.assert_not_called()
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.ERROR)
|
||||
self.mock_state_telemetry.set_should_be_running.assert_called_once_with(False)
|
||||
self.mock_state_telemetry._log_error.assert_any_call(f"test_collector: Max reconnection attempts ({self.manager._max_reconnect_attempts}) exceeded for test_exchange", exc_info=False)
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||
async def test_handle_connection_error_connect_fails(self, mock_sleep):
|
||||
mock_connect_logic = AsyncMock(return_value=False)
|
||||
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||
|
||||
self.manager._reconnect_attempts = 0
|
||||
result = await self.manager.handle_connection_error(
|
||||
mock_connect_logic,
|
||||
mock_subscribe_logic,
|
||||
symbols=["BTC/USDT"],
|
||||
data_types=[DataType.CANDLE]
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||||
mock_connect_logic.assert_called_once()
|
||||
mock_subscribe_logic.assert_not_called()
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Failed to connect to test_exchange", exc_info=False)
|
||||
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||
|
||||
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||
async def test_handle_connection_error_subscribe_fails(self, mock_sleep):
|
||||
mock_connect_logic = AsyncMock(return_value=True)
|
||||
mock_subscribe_logic = AsyncMock(return_value=False)
|
||||
|
||||
self.manager._reconnect_attempts = 0
|
||||
result = await self.manager.handle_connection_error(
|
||||
mock_connect_logic,
|
||||
mock_subscribe_logic,
|
||||
symbols=["BTC/USDT"],
|
||||
data_types=[DataType.CANDLE]
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||||
mock_connect_logic.assert_called_once()
|
||||
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||
|
||||
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||
async def test_handle_connection_error_exception_during_reconnect(self, mock_sleep):
|
||||
mock_connect_logic = AsyncMock(side_effect=Exception("Reconnect error"))
|
||||
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||
|
||||
self.manager._reconnect_attempts = 0
|
||||
result = await self.manager.handle_connection_error(
|
||||
mock_connect_logic,
|
||||
mock_subscribe_logic,
|
||||
symbols=["BTC/USDT"],
|
||||
data_types=[DataType.CANDLE]
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||||
mock_connect_logic.assert_called_once()
|
||||
mock_subscribe_logic.assert_not_called()
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Reconnect error", exc_info=True)
|
||||
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||
Loading…
x
Reference in New Issue
Block a user