- Extracted connection management logic into a new `ConnectionManager` class, promoting separation of concerns and enhancing modularity. - Updated `BaseDataCollector` to utilize the `ConnectionManager` for connection, disconnection, and reconnection processes, improving code clarity and maintainability. - Refactored connection-related methods and attributes, ensuring consistent error handling and logging practices. - Enhanced the `OKXCollector` to implement the new connection management approach, streamlining its connection logic. - Added unit tests for the `ConnectionManager` to validate its functionality and ensure robust error handling. These changes improve the architecture of the data collector, aligning with project standards for maintainability and performance.
193 lines
9.7 KiB
Python
193 lines
9.7 KiB
Python
import asyncio
|
|
import unittest
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from data.collector.collector_connection_manager import ConnectionManager
|
|
from data.collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
|
from data.common.data_types import DataType # Import for DataType enum
|
|
|
|
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
|
|
|
def setUp(self):
|
|
self.mock_logger = Mock()
|
|
self.mock_state_telemetry = AsyncMock(spec=CollectorStateAndTelemetry)
|
|
self.mock_state_telemetry.logger = self.mock_logger
|
|
self.mock_state_telemetry.update_status = Mock()
|
|
self.mock_state_telemetry.set_should_be_running = Mock()
|
|
self.mock_state_telemetry.set_connection_uptime_start = Mock()
|
|
self.mock_state_telemetry._log_info = Mock()
|
|
self.mock_state_telemetry._log_debug = Mock()
|
|
self.mock_state_telemetry._log_warning = Mock()
|
|
self.mock_state_telemetry._log_error = Mock()
|
|
self.mock_state_telemetry._log_critical = Mock()
|
|
self.manager = ConnectionManager(
|
|
exchange_name="test_exchange",
|
|
component_name="test_collector",
|
|
logger=self.mock_logger,
|
|
state_telemetry=self.mock_state_telemetry
|
|
)
|
|
|
|
async def test_init(self):
|
|
self.assertEqual(self.manager.exchange_name, "test_exchange")
|
|
self.assertEqual(self.manager.component_name, "test_collector")
|
|
self.assertEqual(self.manager._max_reconnect_attempts, 5)
|
|
self.assertEqual(self.manager._reconnect_delay, 5.0)
|
|
self.assertEqual(self.manager.logger, self.mock_logger)
|
|
self.assertEqual(self.manager._state_telemetry, self.mock_state_telemetry)
|
|
self.assertIsNone(self.manager._connection)
|
|
self.assertEqual(self.manager._reconnect_attempts, 0)
|
|
|
|
async def test_connect_success(self):
|
|
mock_connect_logic = AsyncMock(return_value=True)
|
|
result = await self.manager.connect(mock_connect_logic)
|
|
self.assertTrue(result)
|
|
mock_connect_logic.assert_called_once()
|
|
self.assertTrue(self.manager._connection)
|
|
self.mock_state_telemetry.set_connection_uptime_start.assert_called_once()
|
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Successfully connected to test_exchange")
|
|
|
|
async def test_connect_failure(self):
|
|
mock_connect_logic = AsyncMock(return_value=False)
|
|
result = await self.manager.connect(mock_connect_logic)
|
|
self.assertFalse(result)
|
|
mock_connect_logic.assert_called_once()
|
|
self.assertIsNone(self.manager._connection)
|
|
self.mock_state_telemetry._log_error.assert_any_call("test_collector: Failed to connect to test_exchange", exc_info=False)
|
|
|
|
async def test_connect_exception(self):
|
|
mock_connect_logic = AsyncMock(side_effect=Exception("Connection error"))
|
|
result = await self.manager.connect(mock_connect_logic)
|
|
self.assertFalse(result)
|
|
mock_connect_logic.assert_called_once()
|
|
self.assertIsNone(self.manager._connection)
|
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Connection error", exc_info=True)
|
|
|
|
async def test_disconnect_success(self):
|
|
self.manager._connection = True # Simulate active connection
|
|
mock_disconnect_logic = AsyncMock()
|
|
await self.manager.disconnect(mock_disconnect_logic)
|
|
mock_disconnect_logic.assert_called_once()
|
|
self.assertIsNone(self.manager._connection)
|
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnected from test_exchange")
|
|
|
|
async def test_disconnect_no_active_connection(self):
|
|
mock_disconnect_logic = AsyncMock()
|
|
await self.manager.disconnect(mock_disconnect_logic)
|
|
mock_disconnect_logic.assert_not_called()
|
|
self.assertIsNone(self.manager._connection)
|
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnecting from test_exchange data source")
|
|
|
|
async def test_disconnect_exception(self):
|
|
self.manager._connection = True
|
|
mock_disconnect_logic = AsyncMock(side_effect=Exception("Disconnect error"))
|
|
await self.manager.disconnect(mock_disconnect_logic)
|
|
mock_disconnect_logic.assert_called_once()
|
|
self.assertTrue(self.manager._connection) # Connection state remains unchanged on exception
|
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during disconnection from test_exchange: Disconnect error", exc_info=True)
|
|
|
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
|
async def test_handle_connection_error_reconnect_success(self, mock_sleep):
|
|
mock_connect_logic = AsyncMock(return_value=True)
|
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
|
|
|
self.manager._reconnect_attempts = 0
|
|
result = await self.manager.handle_connection_error(
|
|
mock_connect_logic,
|
|
mock_subscribe_logic,
|
|
symbols=["BTC/USDT"],
|
|
data_types=[DataType.CANDLE]
|
|
)
|
|
|
|
self.assertTrue(result)
|
|
self.assertEqual(self.manager._reconnect_attempts, 0)
|
|
mock_connect_logic.assert_called_once()
|
|
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RUNNING)
|
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Reconnection successful for test_exchange")
|
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
|
|
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
|
async def test_handle_connection_error_max_attempts_exceeded(self, mock_sleep):
|
|
mock_connect_logic = AsyncMock(return_value=True)
|
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
|
|
|
self.manager._reconnect_attempts = self.manager._max_reconnect_attempts
|
|
result = await self.manager.handle_connection_error(
|
|
mock_connect_logic,
|
|
mock_subscribe_logic,
|
|
symbols=["BTC/USDT"],
|
|
data_types=[DataType.CANDLE]
|
|
)
|
|
|
|
self.assertFalse(result)
|
|
self.assertEqual(self.manager._reconnect_attempts, self.manager._max_reconnect_attempts + 1)
|
|
mock_connect_logic.assert_not_called()
|
|
mock_subscribe_logic.assert_not_called()
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.ERROR)
|
|
self.mock_state_telemetry.set_should_be_running.assert_called_once_with(False)
|
|
self.mock_state_telemetry._log_error.assert_any_call(f"test_collector: Max reconnection attempts ({self.manager._max_reconnect_attempts}) exceeded for test_exchange", exc_info=False)
|
|
mock_sleep.assert_not_called()
|
|
|
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
|
async def test_handle_connection_error_connect_fails(self, mock_sleep):
|
|
mock_connect_logic = AsyncMock(return_value=False)
|
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
|
|
|
self.manager._reconnect_attempts = 0
|
|
result = await self.manager.handle_connection_error(
|
|
mock_connect_logic,
|
|
mock_subscribe_logic,
|
|
symbols=["BTC/USDT"],
|
|
data_types=[DataType.CANDLE]
|
|
)
|
|
|
|
self.assertFalse(result)
|
|
self.assertEqual(self.manager._reconnect_attempts, 1)
|
|
mock_connect_logic.assert_called_once()
|
|
mock_subscribe_logic.assert_not_called()
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Failed to connect to test_exchange", exc_info=False)
|
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
|
|
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
|
async def test_handle_connection_error_subscribe_fails(self, mock_sleep):
|
|
mock_connect_logic = AsyncMock(return_value=True)
|
|
mock_subscribe_logic = AsyncMock(return_value=False)
|
|
|
|
self.manager._reconnect_attempts = 0
|
|
result = await self.manager.handle_connection_error(
|
|
mock_connect_logic,
|
|
mock_subscribe_logic,
|
|
symbols=["BTC/USDT"],
|
|
data_types=[DataType.CANDLE]
|
|
)
|
|
|
|
self.assertFalse(result)
|
|
self.assertEqual(self.manager._reconnect_attempts, 1)
|
|
mock_connect_logic.assert_called_once()
|
|
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
|
|
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
|
async def test_handle_connection_error_exception_during_reconnect(self, mock_sleep):
|
|
mock_connect_logic = AsyncMock(side_effect=Exception("Reconnect error"))
|
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
|
|
|
self.manager._reconnect_attempts = 0
|
|
result = await self.manager.handle_connection_error(
|
|
mock_connect_logic,
|
|
mock_subscribe_logic,
|
|
symbols=["BTC/USDT"],
|
|
data_types=[DataType.CANDLE]
|
|
)
|
|
|
|
self.assertFalse(result)
|
|
self.assertEqual(self.manager._reconnect_attempts, 1)
|
|
mock_connect_logic.assert_called_once()
|
|
mock_subscribe_logic.assert_not_called()
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Reconnect error", exc_info=True)
|
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay) |