Refactor BaseDataCollector to integrate ConnectionManager for connection handling
- Extracted connection management logic into a new `ConnectionManager` class, promoting separation of concerns and enhancing modularity. - Updated `BaseDataCollector` to utilize the `ConnectionManager` for connection, disconnection, and reconnection processes, improving code clarity and maintainability. - Refactored connection-related methods and attributes, ensuring consistent error handling and logging practices. - Enhanced the `OKXCollector` to implement the new connection management approach, streamlining its connection logic. - Added unit tests for the `ConnectionManager` to validate its functionality and ensure robust error handling. These changes improve the architecture of the data collector, aligning with project standards for maintainability and performance.
This commit is contained in:
193
tests/data/collector/test_collector_connection_manager.py
Normal file
193
tests/data/collector/test_collector_connection_manager.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from data.collector.collector_connection_manager import ConnectionManager
|
||||
from data.collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||||
from data.common.data_types import DataType # Import for DataType enum
|
||||
|
||||
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_logger = Mock()
|
||||
self.mock_state_telemetry = AsyncMock(spec=CollectorStateAndTelemetry)
|
||||
self.mock_state_telemetry.logger = self.mock_logger
|
||||
self.mock_state_telemetry.update_status = Mock()
|
||||
self.mock_state_telemetry.set_should_be_running = Mock()
|
||||
self.mock_state_telemetry.set_connection_uptime_start = Mock()
|
||||
self.mock_state_telemetry._log_info = Mock()
|
||||
self.mock_state_telemetry._log_debug = Mock()
|
||||
self.mock_state_telemetry._log_warning = Mock()
|
||||
self.mock_state_telemetry._log_error = Mock()
|
||||
self.mock_state_telemetry._log_critical = Mock()
|
||||
self.manager = ConnectionManager(
|
||||
exchange_name="test_exchange",
|
||||
component_name="test_collector",
|
||||
logger=self.mock_logger,
|
||||
state_telemetry=self.mock_state_telemetry
|
||||
)
|
||||
|
||||
async def test_init(self):
|
||||
self.assertEqual(self.manager.exchange_name, "test_exchange")
|
||||
self.assertEqual(self.manager.component_name, "test_collector")
|
||||
self.assertEqual(self.manager._max_reconnect_attempts, 5)
|
||||
self.assertEqual(self.manager._reconnect_delay, 5.0)
|
||||
self.assertEqual(self.manager.logger, self.mock_logger)
|
||||
self.assertEqual(self.manager._state_telemetry, self.mock_state_telemetry)
|
||||
self.assertIsNone(self.manager._connection)
|
||||
self.assertEqual(self.manager._reconnect_attempts, 0)
|
||||
|
||||
async def test_connect_success(self):
|
||||
mock_connect_logic = AsyncMock(return_value=True)
|
||||
result = await self.manager.connect(mock_connect_logic)
|
||||
self.assertTrue(result)
|
||||
mock_connect_logic.assert_called_once()
|
||||
self.assertTrue(self.manager._connection)
|
||||
self.mock_state_telemetry.set_connection_uptime_start.assert_called_once()
|
||||
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Successfully connected to test_exchange")
|
||||
|
||||
async def test_connect_failure(self):
|
||||
mock_connect_logic = AsyncMock(return_value=False)
|
||||
result = await self.manager.connect(mock_connect_logic)
|
||||
self.assertFalse(result)
|
||||
mock_connect_logic.assert_called_once()
|
||||
self.assertIsNone(self.manager._connection)
|
||||
self.mock_state_telemetry._log_error.assert_any_call("test_collector: Failed to connect to test_exchange", exc_info=False)
|
||||
|
||||
async def test_connect_exception(self):
|
||||
mock_connect_logic = AsyncMock(side_effect=Exception("Connection error"))
|
||||
result = await self.manager.connect(mock_connect_logic)
|
||||
self.assertFalse(result)
|
||||
mock_connect_logic.assert_called_once()
|
||||
self.assertIsNone(self.manager._connection)
|
||||
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Connection error", exc_info=True)
|
||||
|
||||
async def test_disconnect_success(self):
|
||||
self.manager._connection = True # Simulate active connection
|
||||
mock_disconnect_logic = AsyncMock()
|
||||
await self.manager.disconnect(mock_disconnect_logic)
|
||||
mock_disconnect_logic.assert_called_once()
|
||||
self.assertIsNone(self.manager._connection)
|
||||
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnected from test_exchange")
|
||||
|
||||
async def test_disconnect_no_active_connection(self):
|
||||
mock_disconnect_logic = AsyncMock()
|
||||
await self.manager.disconnect(mock_disconnect_logic)
|
||||
mock_disconnect_logic.assert_not_called()
|
||||
self.assertIsNone(self.manager._connection)
|
||||
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnecting from test_exchange data source")
|
||||
|
||||
async def test_disconnect_exception(self):
|
||||
self.manager._connection = True
|
||||
mock_disconnect_logic = AsyncMock(side_effect=Exception("Disconnect error"))
|
||||
await self.manager.disconnect(mock_disconnect_logic)
|
||||
mock_disconnect_logic.assert_called_once()
|
||||
self.assertTrue(self.manager._connection) # Connection state remains unchanged on exception
|
||||
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during disconnection from test_exchange: Disconnect error", exc_info=True)
|
||||
|
||||
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||
async def test_handle_connection_error_reconnect_success(self, mock_sleep):
|
||||
mock_connect_logic = AsyncMock(return_value=True)
|
||||
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||
|
||||
self.manager._reconnect_attempts = 0
|
||||
result = await self.manager.handle_connection_error(
|
||||
mock_connect_logic,
|
||||
mock_subscribe_logic,
|
||||
symbols=["BTC/USDT"],
|
||||
data_types=[DataType.CANDLE]
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(self.manager._reconnect_attempts, 0)
|
||||
mock_connect_logic.assert_called_once()
|
||||
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RUNNING)
|
||||
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Reconnection successful for test_exchange")
|
||||
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||
|
||||
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||
async def test_handle_connection_error_max_attempts_exceeded(self, mock_sleep):
|
||||
mock_connect_logic = AsyncMock(return_value=True)
|
||||
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||
|
||||
self.manager._reconnect_attempts = self.manager._max_reconnect_attempts
|
||||
result = await self.manager.handle_connection_error(
|
||||
mock_connect_logic,
|
||||
mock_subscribe_logic,
|
||||
symbols=["BTC/USDT"],
|
||||
data_types=[DataType.CANDLE]
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(self.manager._reconnect_attempts, self.manager._max_reconnect_attempts + 1)
|
||||
mock_connect_logic.assert_not_called()
|
||||
mock_subscribe_logic.assert_not_called()
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.ERROR)
|
||||
self.mock_state_telemetry.set_should_be_running.assert_called_once_with(False)
|
||||
self.mock_state_telemetry._log_error.assert_any_call(f"test_collector: Max reconnection attempts ({self.manager._max_reconnect_attempts}) exceeded for test_exchange", exc_info=False)
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||
async def test_handle_connection_error_connect_fails(self, mock_sleep):
|
||||
mock_connect_logic = AsyncMock(return_value=False)
|
||||
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||
|
||||
self.manager._reconnect_attempts = 0
|
||||
result = await self.manager.handle_connection_error(
|
||||
mock_connect_logic,
|
||||
mock_subscribe_logic,
|
||||
symbols=["BTC/USDT"],
|
||||
data_types=[DataType.CANDLE]
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||||
mock_connect_logic.assert_called_once()
|
||||
mock_subscribe_logic.assert_not_called()
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Failed to connect to test_exchange", exc_info=False)
|
||||
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||
|
||||
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||
async def test_handle_connection_error_subscribe_fails(self, mock_sleep):
|
||||
mock_connect_logic = AsyncMock(return_value=True)
|
||||
mock_subscribe_logic = AsyncMock(return_value=False)
|
||||
|
||||
self.manager._reconnect_attempts = 0
|
||||
result = await self.manager.handle_connection_error(
|
||||
mock_connect_logic,
|
||||
mock_subscribe_logic,
|
||||
symbols=["BTC/USDT"],
|
||||
data_types=[DataType.CANDLE]
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||||
mock_connect_logic.assert_called_once()
|
||||
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||
|
||||
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||||
async def test_handle_connection_error_exception_during_reconnect(self, mock_sleep):
|
||||
mock_connect_logic = AsyncMock(side_effect=Exception("Reconnect error"))
|
||||
mock_subscribe_logic = AsyncMock(return_value=True)
|
||||
|
||||
self.manager._reconnect_attempts = 0
|
||||
result = await self.manager.handle_connection_error(
|
||||
mock_connect_logic,
|
||||
mock_subscribe_logic,
|
||||
symbols=["BTC/USDT"],
|
||||
data_types=[DataType.CANDLE]
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||||
mock_connect_logic.assert_called_once()
|
||||
mock_subscribe_logic.assert_not_called()
|
||||
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||||
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Reconnect error", exc_info=True)
|
||||
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||||
Reference in New Issue
Block a user