Refactor BaseDataCollector to integrate ConnectionManager for connection handling

- Extracted connection management logic into a new `ConnectionManager` class, promoting separation of concerns and enhancing modularity.
- Updated `BaseDataCollector` to utilize the `ConnectionManager` for connection, disconnection, and reconnection processes, improving code clarity and maintainability.
- Refactored connection-related methods and attributes, ensuring consistent error handling and logging practices.
- Enhanced the `OKXCollector` to implement the new connection management approach, streamlining its connection logic.
- Added unit tests for the `ConnectionManager` to validate its functionality and ensure robust error handling.

These changes improve the architecture of the data collector, aligning with project standards for maintainability and performance.
This commit is contained in:
Vasily.onl 2025-06-09 17:42:06 +08:00
parent 60434afd5d
commit 41f0e8e6b6
8 changed files with 434 additions and 77 deletions

View File

@ -7,8 +7,9 @@ processing and validating the data, and storing it in the database.
from .base_collector import ( from .base_collector import (
BaseDataCollector, DataCollectorError, DataValidationError, BaseDataCollector, DataCollectorError, DataValidationError,
DataType, CollectorStatus, MarketDataPoint, OHLCVData CollectorStatus, OHLCVData
) )
from .common.data_types import DataType, MarketDataPoint
from .collector_manager import CollectorManager, ManagerStatus, CollectorConfig from .collector_manager import CollectorManager, ManagerStatus, CollectorConfig
__all__ = [ __all__ = [

View File

@ -15,30 +15,8 @@ from enum import Enum
from utils.logger import get_logger from utils.logger import get_logger
from .collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry from .collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
from .collector.collector_connection_manager import ConnectionManager
from .common.data_types import DataType, MarketDataPoint
class DataType(Enum):
"""Types of data that can be collected."""
TICKER = "ticker"
TRADE = "trade"
ORDERBOOK = "orderbook"
CANDLE = "candle"
BALANCE = "balance"
@dataclass
class MarketDataPoint:
"""Standardized market data structure."""
exchange: str
symbol: str
timestamp: datetime
data_type: DataType
data: Dict[str, Any]
def __post_init__(self):
"""Validate data after initialization."""
if not self.timestamp.tzinfo:
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
@dataclass @dataclass
@ -149,6 +127,16 @@ class BaseDataCollector(ABC):
) )
self.component_name = component # Keep for external access self.component_name = component # Keep for external access
# Initialize connection manager
self._connection_manager = ConnectionManager(
exchange_name=self.exchange_name,
component_name=component,
max_reconnect_attempts=5, # Default, can be made configurable later
reconnect_delay=5.0, # Default, can be made configurable later
logger=self.logger,
state_telemetry=self._state_telemetry
)
# Collector state (now managed by _state_telemetry) # Collector state (now managed by _state_telemetry)
self._tasks: Set[asyncio.Task] = set() self._tasks: Set[asyncio.Task] = set()
@ -157,12 +145,6 @@ class BaseDataCollector(ABC):
data_type: [] for data_type in DataType data_type: [] for data_type in DataType
} }
# Connection management
self._connection = None
self._reconnect_attempts = 0
self._max_reconnect_attempts = 5
self._reconnect_delay = 5.0 # seconds
# Log initialization if logger is available # Log initialization if logger is available
if self._state_telemetry.logger: if self._state_telemetry.logger:
if not self._state_telemetry.log_errors_only: if not self._state_telemetry.log_errors_only:
@ -262,7 +244,7 @@ class BaseDataCollector(ABC):
try: try:
# Connect to data source # Connect to data source
if not await self.connect(): if not await self._connection_manager.connect(self._actual_connect):
self._log_error("Failed to connect to data source") self._log_error("Failed to connect to data source")
self._state_telemetry.update_status(CollectorStatus.ERROR) self._state_telemetry.update_status(CollectorStatus.ERROR)
return False return False
@ -271,13 +253,12 @@ class BaseDataCollector(ABC):
if not await self.subscribe_to_data(list(self.symbols), self.data_types): if not await self.subscribe_to_data(list(self.symbols), self.data_types):
self._log_error("Failed to subscribe to data streams") self._log_error("Failed to subscribe to data streams")
self._state_telemetry.update_status(CollectorStatus.ERROR) self._state_telemetry.update_status(CollectorStatus.ERROR)
await self.disconnect() await self._connection_manager.disconnect(self._actual_disconnect)
return False return False
# Start background tasks # Start background tasks
self._state_telemetry.set_running_state(True) self._state_telemetry.set_running_state(True)
self._state_telemetry.update_status(CollectorStatus.RUNNING) self._state_telemetry.update_status(CollectorStatus.RUNNING)
self._state_telemetry.set_connection_uptime_start() # Record connection uptime start
# Start message processing task # Start message processing task
message_task = asyncio.create_task(self._message_loop()) message_task = asyncio.create_task(self._message_loop())
@ -332,7 +313,7 @@ class BaseDataCollector(ABC):
# Unsubscribe and disconnect # Unsubscribe and disconnect
await self.unsubscribe_from_data(list(self.symbols), self.data_types) await self.unsubscribe_from_data(list(self.symbols), self.data_types)
await self.disconnect() await self._connection_manager.disconnect(self._actual_disconnect)
self._state_telemetry.update_status(CollectorStatus.STOPPED) self._state_telemetry.update_status(CollectorStatus.STOPPED)
self._log_info(f"{self.exchange_name} data collector stopped") self._log_info(f"{self.exchange_name} data collector stopped")
@ -355,7 +336,7 @@ class BaseDataCollector(ABC):
await self.stop() await self.stop()
# Wait a bit before restarting # Wait a bit before restarting
await asyncio.sleep(self._reconnect_delay) await asyncio.sleep(self._connection_manager._reconnect_delay)
# Start again # Start again
return await self.start() return await self.start()
@ -447,34 +428,26 @@ class BaseDataCollector(ABC):
Returns: Returns:
True if reconnection successful, False if max attempts exceeded True if reconnection successful, False if max attempts exceeded
""" """
self._reconnect_attempts += 1 return await self._connection_manager.handle_connection_error(
connect_logic=self._actual_connect,
if self._reconnect_attempts > self._max_reconnect_attempts: subscribe_logic=self.subscribe_to_data,
self._log_error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded") symbols=list(self.symbols),
self._state_telemetry.update_status(CollectorStatus.ERROR) data_types=self.data_types
self._state_telemetry.set_should_be_running(False) )
return False
@abstractmethod
self._state_telemetry.update_status(CollectorStatus.RECONNECTING) async def _actual_connect(self) -> bool:
self._log_warning(f"Connection lost. Attempting reconnection {self._reconnect_attempts}/{self._max_reconnect_attempts}") """
Abstract method for subclasses to implement actual connection logic.
# Disconnect and wait before retrying """
await self.disconnect() pass
await asyncio.sleep(self._reconnect_delay)
@abstractmethod
# Attempt to reconnect async def _actual_disconnect(self) -> None:
try: """
if await self.connect(): Abstract method for subclasses to implement actual disconnection logic.
if await self.subscribe_to_data(list(self.symbols), self.data_types): """
self._log_info("Reconnection successful") pass
self._state_telemetry.update_status(CollectorStatus.RUNNING)
self._reconnect_attempts = 0
return True
except Exception as e:
self._log_error(f"Reconnection attempt failed: {e}")
return False
def add_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None: def add_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None:
""" """

View File

@ -0,0 +1,148 @@
"""
Module for managing network connection and reconnection logic for data collectors.
This module encapsulates the complexities of connecting, disconnecting,
and handling reconnection attempts to a data source, promoting a clean
separation of concerns within the data collector architecture.
"""
import asyncio
from typing import List, Any
from datetime import datetime
# from ..base_collector import DataType # Import from base_collector for now, will refactor later
from .collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
from data.common.data_types import DataType
class ConnectionManager:
"""
Manages the connection, disconnection, and reconnection logic for a data collector.
"""
def __init__(self,
exchange_name: str,
component_name: str,
max_reconnect_attempts: int = 5,
reconnect_delay: float = 5.0,
logger=None,
state_telemetry: CollectorStateAndTelemetry = None):
self.exchange_name = exchange_name
self.component_name = component_name
self._max_reconnect_attempts = max_reconnect_attempts
self._reconnect_delay = reconnect_delay
self.logger = logger
self._state_telemetry = state_telemetry
self._connection = None # Placeholder for the actual connection object
self._reconnect_attempts = 0
def _log_debug(self, message: str) -> None:
if self._state_telemetry:
self._state_telemetry._log_debug(f"{self.component_name}: {message}")
elif self.logger:
self.logger.debug(f"{self.component_name}: {message}")
def _log_info(self, message: str) -> None:
if self._state_telemetry:
self._state_telemetry._log_info(f"{self.component_name}: {message}")
elif self.logger:
self.logger.info(f"{self.component_name}: {message}")
def _log_warning(self, message: str) -> None:
if self._state_telemetry:
self._state_telemetry._log_warning(f"{self.component_name}: {message}")
elif self.logger:
self.logger.warning(f"{self.component_name}: {message}")
def _log_error(self, message: str, exc_info: bool = False) -> None:
if self._state_telemetry:
self._state_telemetry._log_error(f"{self.component_name}: {message}", exc_info=exc_info)
elif self.logger:
self.logger.error(f"{self.component_name}: {message}", exc_info=exc_info)
async def connect(self, connect_logic: callable) -> bool:
"""
Establish connection to the data source using provided logic.
Args:
connect_logic: A callable (async function) that performs the actual connection.
Returns:
True if connection successful, False otherwise
"""
self._log_info(f"Connecting to {self.exchange_name} data source")
try:
success = await connect_logic()
if success:
self._connection = True # Indicate connection is established
self._state_telemetry.set_connection_uptime_start()
self._log_info(f"Successfully connected to {self.exchange_name}")
return True
else:
self._log_error(f"Failed to connect to {self.exchange_name}")
return False
except Exception as e:
self._log_error(f"Error during connection to {self.exchange_name}: {e}", exc_info=True)
return False
async def disconnect(self, disconnect_logic: callable) -> None:
"""
Disconnect from the data source using provided logic.
Args:
disconnect_logic: A callable (async function) that performs the actual disconnection.
"""
self._log_info(f"Disconnecting from {self.exchange_name} data source")
try:
if self._connection:
await disconnect_logic()
self._connection = None
self._log_info(f"Disconnected from {self.exchange_name}")
except Exception as e:
self._log_error(f"Error during disconnection from {self.exchange_name}: {e}", exc_info=True)
async def handle_connection_error(self, connect_logic: callable, subscribe_logic: callable, symbols: List[str], data_types: List[DataType]) -> bool:
"""
Handle connection errors and attempt reconnection.
Args:
connect_logic: Callable for connecting.
subscribe_logic: Callable for subscribing.
symbols: List of symbols to re-subscribe to.
data_types: List of data types to re-subscribe to.
Returns:
True if reconnection successful, False if max attempts exceeded
"""
self._reconnect_attempts += 1
if self._reconnect_attempts > self._max_reconnect_attempts:
self._log_error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded for {self.exchange_name}")
if self._state_telemetry:
self._state_telemetry.update_status(CollectorStatus.ERROR)
self._state_telemetry.set_should_be_running(False)
return False
if self._state_telemetry:
self._state_telemetry.update_status(CollectorStatus.RECONNECTING)
self._log_warning(f"Connection lost. Attempting reconnection {self._reconnect_attempts}/{self._max_reconnect_attempts} for {self.exchange_name}")
# Disconnect and wait before retrying
await self.disconnect(lambda: None) # Pass a no-op disconnect for internal use, actual disconnect handled by caller
await asyncio.sleep(self._reconnect_delay)
# Attempt to reconnect
try:
if await self.connect(connect_logic):
if await subscribe_logic(symbols, data_types):
self._log_info(f"Reconnection successful for {self.exchange_name}")
if self._state_telemetry:
self._state_telemetry.update_status(CollectorStatus.RUNNING)
self._reconnect_attempts = 0
return True
except Exception as e:
self._log_error(f"Reconnection attempt failed for {self.exchange_name}: {e}", exc_info=True)
return False

View File

@ -10,7 +10,8 @@ from .data_types import (
OHLCVCandle, OHLCVCandle,
MarketDataPoint, MarketDataPoint,
DataValidationResult, DataValidationResult,
CandleProcessingConfig CandleProcessingConfig,
DataType
) )
from .transformation.trade import ( from .transformation.trade import (
@ -44,6 +45,7 @@ __all__ = [
'MarketDataPoint', 'MarketDataPoint',
'DataValidationResult', 'DataValidationResult',
'CandleProcessingConfig', 'CandleProcessingConfig',
'DataType',
# Trade transformation # Trade transformation
'TradeTransformer', 'TradeTransformer',

View File

@ -10,8 +10,33 @@ from decimal import Decimal
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
import asyncio
from ..base_collector import DataType, MarketDataPoint # Import from base # from ..base_collector import DataType, MarketDataPoint # Import from base
class DataType(Enum):
"""Types of data that can be collected."""
TICKER = "ticker"
TRADE = "trade"
ORDERBOOK = "orderbook"
CANDLE = "candle"
BALANCE = "balance"
@dataclass
class MarketDataPoint:
"""Standardized market data structure."""
exchange: str
symbol: str
timestamp: datetime
data_type: DataType
data: Dict[str, Any]
def __post_init__(self):
"""Validate data after initialization."""
if not self.timestamp.tzinfo:
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
@dataclass @dataclass

View File

@ -109,7 +109,7 @@ class OKXCollector(BaseDataCollector):
symbol, symbol,
config=candle_config or CandleProcessingConfig(timeframes=self.timeframes), # Use provided config or create new one config=candle_config or CandleProcessingConfig(timeframes=self.timeframes), # Use provided config or create new one
component_name=f"{component_name}_processor", component_name=f"{component_name}_processor",
logger=logger logger=self.logger
) )
# Add callbacks for processed data # Add callbacks for processed data
@ -140,6 +140,21 @@ class OKXCollector(BaseDataCollector):
""" """
Establish connection to OKX WebSocket API. Establish connection to OKX WebSocket API.
Returns:
True if connection successful, False otherwise
"""
return await self._connection_manager.connect(self._actual_connect)
async def disconnect(self) -> None:
"""
Disconnect from OKX WebSocket API.
"""
await self._connection_manager.disconnect(self._actual_disconnect)
async def _actual_connect(self) -> bool:
"""
Implement the actual connection logic for OKX WebSocket API.
Returns: Returns:
True if connection successful, False otherwise True if connection successful, False otherwise
""" """
@ -157,7 +172,7 @@ class OKXCollector(BaseDataCollector):
pong_timeout=10.0, pong_timeout=10.0,
max_reconnect_attempts=5, max_reconnect_attempts=5,
reconnect_delay=5.0, reconnect_delay=5.0,
logger=self.logger # Pass the logger to enable ping/pong logging logger=self.logger
) )
# Add message callback # Add message callback
@ -175,9 +190,9 @@ class OKXCollector(BaseDataCollector):
self._log_error(f"Error connecting OKX collector for {self.symbol}: {e}") self._log_error(f"Error connecting OKX collector for {self.symbol}: {e}")
return False return False
async def disconnect(self) -> None: async def _actual_disconnect(self) -> None:
""" """
Disconnect from OKX WebSocket API. Implement the actual disconnection logic for OKX WebSocket API.
""" """
try: try:
self._log_info(f"Disconnecting OKX collector for {self.symbol}") self._log_info(f"Disconnecting OKX collector for {self.symbol}")

View File

@ -28,13 +28,13 @@
- [x] 1.6 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_state_telemetry.py`. - [x] 1.6 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_state_telemetry.py`.
- [x] 1.7 Create `tests/data/collector/test_collector_state_telemetry.py` and add initial tests for the new class. - [x] 1.7 Create `tests/data/collector/test_collector_state_telemetry.py` and add initial tests for the new class.
- [ ] 2.0 Extract `ConnectionManager` Class - [x] 2.0 Extract `ConnectionManager` Class
- [ ] 2.1 Create `data/collector/collector_connection_manager.py`. - [x] 2.1 Create `data/collector/collector_connection_manager.py`.
- [ ] 2.2 Move connection-related attributes (`_connection`, `_reconnect_attempts`, `_max_reconnect_attempts`, `_reconnect_delay`) to `ConnectionManager`. - [x] 2.2 Move connection-related attributes (`_connection`, `_reconnect_attempts`, `_max_reconnect_attempts`, `_reconnect_delay`) to `ConnectionManager`.
- [ ] 2.3 Move `connect`, `disconnect`, `_handle_connection_error` methods to `ConnectionManager`. - [x] 2.3 Move `connect`, `disconnect`, `_handle_connection_error` methods to `ConnectionManager`.
- [ ] 2.4 Implement a constructor for `ConnectionManager` to receive logger and other necessary parameters. - [x] 2.4 Implement a constructor for `ConnectionManager` to receive logger and other necessary parameters.
- [ ] 2.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_connection_manager.py`. - [x] 2.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_connection_manager.py`.
- [ ] 2.6 Create `tests/data/collector/test_collector_connection_manager.py` and add initial tests for the new class. - [x] 2.6 Create `tests/data/collector/test_collector_connection_manager.py` and add initial tests for the new class.
- [ ] 3.0 Extract `CallbackDispatcher` Class - [ ] 3.0 Extract `CallbackDispatcher` Class
- [ ] 3.1 Create `data/collector/collector_callback_dispatcher.py`. - [ ] 3.1 Create `data/collector/collector_callback_dispatcher.py`.

View File

@ -0,0 +1,193 @@
import asyncio
import unittest
from unittest.mock import AsyncMock, Mock, patch
from datetime import datetime, timedelta, timezone
from data.collector.collector_connection_manager import ConnectionManager
from data.collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
from data.common.data_types import DataType # Import for DataType enum
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_logger = Mock()
self.mock_state_telemetry = AsyncMock(spec=CollectorStateAndTelemetry)
self.mock_state_telemetry.logger = self.mock_logger
self.mock_state_telemetry.update_status = Mock()
self.mock_state_telemetry.set_should_be_running = Mock()
self.mock_state_telemetry.set_connection_uptime_start = Mock()
self.mock_state_telemetry._log_info = Mock()
self.mock_state_telemetry._log_debug = Mock()
self.mock_state_telemetry._log_warning = Mock()
self.mock_state_telemetry._log_error = Mock()
self.mock_state_telemetry._log_critical = Mock()
self.manager = ConnectionManager(
exchange_name="test_exchange",
component_name="test_collector",
logger=self.mock_logger,
state_telemetry=self.mock_state_telemetry
)
async def test_init(self):
self.assertEqual(self.manager.exchange_name, "test_exchange")
self.assertEqual(self.manager.component_name, "test_collector")
self.assertEqual(self.manager._max_reconnect_attempts, 5)
self.assertEqual(self.manager._reconnect_delay, 5.0)
self.assertEqual(self.manager.logger, self.mock_logger)
self.assertEqual(self.manager._state_telemetry, self.mock_state_telemetry)
self.assertIsNone(self.manager._connection)
self.assertEqual(self.manager._reconnect_attempts, 0)
async def test_connect_success(self):
mock_connect_logic = AsyncMock(return_value=True)
result = await self.manager.connect(mock_connect_logic)
self.assertTrue(result)
mock_connect_logic.assert_called_once()
self.assertTrue(self.manager._connection)
self.mock_state_telemetry.set_connection_uptime_start.assert_called_once()
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Successfully connected to test_exchange")
async def test_connect_failure(self):
mock_connect_logic = AsyncMock(return_value=False)
result = await self.manager.connect(mock_connect_logic)
self.assertFalse(result)
mock_connect_logic.assert_called_once()
self.assertIsNone(self.manager._connection)
self.mock_state_telemetry._log_error.assert_any_call("test_collector: Failed to connect to test_exchange", exc_info=False)
async def test_connect_exception(self):
mock_connect_logic = AsyncMock(side_effect=Exception("Connection error"))
result = await self.manager.connect(mock_connect_logic)
self.assertFalse(result)
mock_connect_logic.assert_called_once()
self.assertIsNone(self.manager._connection)
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Connection error", exc_info=True)
async def test_disconnect_success(self):
self.manager._connection = True # Simulate active connection
mock_disconnect_logic = AsyncMock()
await self.manager.disconnect(mock_disconnect_logic)
mock_disconnect_logic.assert_called_once()
self.assertIsNone(self.manager._connection)
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnected from test_exchange")
async def test_disconnect_no_active_connection(self):
mock_disconnect_logic = AsyncMock()
await self.manager.disconnect(mock_disconnect_logic)
mock_disconnect_logic.assert_not_called()
self.assertIsNone(self.manager._connection)
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnecting from test_exchange data source")
async def test_disconnect_exception(self):
self.manager._connection = True
mock_disconnect_logic = AsyncMock(side_effect=Exception("Disconnect error"))
await self.manager.disconnect(mock_disconnect_logic)
mock_disconnect_logic.assert_called_once()
self.assertTrue(self.manager._connection) # Connection state remains unchanged on exception
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during disconnection from test_exchange: Disconnect error", exc_info=True)
@patch('asyncio.sleep', new_callable=AsyncMock)
async def test_handle_connection_error_reconnect_success(self, mock_sleep):
mock_connect_logic = AsyncMock(return_value=True)
mock_subscribe_logic = AsyncMock(return_value=True)
self.manager._reconnect_attempts = 0
result = await self.manager.handle_connection_error(
mock_connect_logic,
mock_subscribe_logic,
symbols=["BTC/USDT"],
data_types=[DataType.CANDLE]
)
self.assertTrue(result)
self.assertEqual(self.manager._reconnect_attempts, 0)
mock_connect_logic.assert_called_once()
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RUNNING)
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Reconnection successful for test_exchange")
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
@patch('asyncio.sleep', new_callable=AsyncMock)
async def test_handle_connection_error_max_attempts_exceeded(self, mock_sleep):
mock_connect_logic = AsyncMock(return_value=True)
mock_subscribe_logic = AsyncMock(return_value=True)
self.manager._reconnect_attempts = self.manager._max_reconnect_attempts
result = await self.manager.handle_connection_error(
mock_connect_logic,
mock_subscribe_logic,
symbols=["BTC/USDT"],
data_types=[DataType.CANDLE]
)
self.assertFalse(result)
self.assertEqual(self.manager._reconnect_attempts, self.manager._max_reconnect_attempts + 1)
mock_connect_logic.assert_not_called()
mock_subscribe_logic.assert_not_called()
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.ERROR)
self.mock_state_telemetry.set_should_be_running.assert_called_once_with(False)
self.mock_state_telemetry._log_error.assert_any_call(f"test_collector: Max reconnection attempts ({self.manager._max_reconnect_attempts}) exceeded for test_exchange", exc_info=False)
mock_sleep.assert_not_called()
@patch('asyncio.sleep', new_callable=AsyncMock)
async def test_handle_connection_error_connect_fails(self, mock_sleep):
mock_connect_logic = AsyncMock(return_value=False)
mock_subscribe_logic = AsyncMock(return_value=True)
self.manager._reconnect_attempts = 0
result = await self.manager.handle_connection_error(
mock_connect_logic,
mock_subscribe_logic,
symbols=["BTC/USDT"],
data_types=[DataType.CANDLE]
)
self.assertFalse(result)
self.assertEqual(self.manager._reconnect_attempts, 1)
mock_connect_logic.assert_called_once()
mock_subscribe_logic.assert_not_called()
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Failed to connect to test_exchange", exc_info=False)
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
@patch('asyncio.sleep', new_callable=AsyncMock)
async def test_handle_connection_error_subscribe_fails(self, mock_sleep):
mock_connect_logic = AsyncMock(return_value=True)
mock_subscribe_logic = AsyncMock(return_value=False)
self.manager._reconnect_attempts = 0
result = await self.manager.handle_connection_error(
mock_connect_logic,
mock_subscribe_logic,
symbols=["BTC/USDT"],
data_types=[DataType.CANDLE]
)
self.assertFalse(result)
self.assertEqual(self.manager._reconnect_attempts, 1)
mock_connect_logic.assert_called_once()
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
@patch('asyncio.sleep', new_callable=AsyncMock)
async def test_handle_connection_error_exception_during_reconnect(self, mock_sleep):
mock_connect_logic = AsyncMock(side_effect=Exception("Reconnect error"))
mock_subscribe_logic = AsyncMock(return_value=True)
self.manager._reconnect_attempts = 0
result = await self.manager.handle_connection_error(
mock_connect_logic,
mock_subscribe_logic,
symbols=["BTC/USDT"],
data_types=[DataType.CANDLE]
)
self.assertFalse(result)
self.assertEqual(self.manager._reconnect_attempts, 1)
mock_connect_logic.assert_called_once()
mock_subscribe_logic.assert_not_called()
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Reconnect error", exc_info=True)
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)