TCPDashboard/tests/data/collector/test_collector_connection_manager.py
Vasily.onl 41f0e8e6b6 Refactor BaseDataCollector to integrate ConnectionManager for connection handling
- Extracted connection management logic into a new `ConnectionManager` class, promoting separation of concerns and enhancing modularity.
- Updated `BaseDataCollector` to utilize the `ConnectionManager` for connection, disconnection, and reconnection processes, improving code clarity and maintainability.
- Refactored connection-related methods and attributes, ensuring consistent error handling and logging practices.
- Enhanced the `OKXCollector` to implement the new connection management approach, streamlining its connection logic.
- Added unit tests for the `ConnectionManager` to validate its functionality and ensure robust error handling.

These changes improve the architecture of the data collector, aligning with project standards for maintainability and performance.
2025-06-09 17:42:06 +08:00

193 lines
9.7 KiB
Python

import asyncio
import unittest
from unittest.mock import AsyncMock, Mock, patch
from datetime import datetime, timedelta, timezone
from data.collector.collector_connection_manager import ConnectionManager
from data.collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
from data.common.data_types import DataType # Import for DataType enum
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_logger = Mock()
self.mock_state_telemetry = AsyncMock(spec=CollectorStateAndTelemetry)
self.mock_state_telemetry.logger = self.mock_logger
self.mock_state_telemetry.update_status = Mock()
self.mock_state_telemetry.set_should_be_running = Mock()
self.mock_state_telemetry.set_connection_uptime_start = Mock()
self.mock_state_telemetry._log_info = Mock()
self.mock_state_telemetry._log_debug = Mock()
self.mock_state_telemetry._log_warning = Mock()
self.mock_state_telemetry._log_error = Mock()
self.mock_state_telemetry._log_critical = Mock()
self.manager = ConnectionManager(
exchange_name="test_exchange",
component_name="test_collector",
logger=self.mock_logger,
state_telemetry=self.mock_state_telemetry
)
async def test_init(self):
self.assertEqual(self.manager.exchange_name, "test_exchange")
self.assertEqual(self.manager.component_name, "test_collector")
self.assertEqual(self.manager._max_reconnect_attempts, 5)
self.assertEqual(self.manager._reconnect_delay, 5.0)
self.assertEqual(self.manager.logger, self.mock_logger)
self.assertEqual(self.manager._state_telemetry, self.mock_state_telemetry)
self.assertIsNone(self.manager._connection)
self.assertEqual(self.manager._reconnect_attempts, 0)
async def test_connect_success(self):
mock_connect_logic = AsyncMock(return_value=True)
result = await self.manager.connect(mock_connect_logic)
self.assertTrue(result)
mock_connect_logic.assert_called_once()
self.assertTrue(self.manager._connection)
self.mock_state_telemetry.set_connection_uptime_start.assert_called_once()
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Successfully connected to test_exchange")
async def test_connect_failure(self):
mock_connect_logic = AsyncMock(return_value=False)
result = await self.manager.connect(mock_connect_logic)
self.assertFalse(result)
mock_connect_logic.assert_called_once()
self.assertIsNone(self.manager._connection)
self.mock_state_telemetry._log_error.assert_any_call("test_collector: Failed to connect to test_exchange", exc_info=False)
async def test_connect_exception(self):
mock_connect_logic = AsyncMock(side_effect=Exception("Connection error"))
result = await self.manager.connect(mock_connect_logic)
self.assertFalse(result)
mock_connect_logic.assert_called_once()
self.assertIsNone(self.manager._connection)
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Connection error", exc_info=True)
async def test_disconnect_success(self):
self.manager._connection = True # Simulate active connection
mock_disconnect_logic = AsyncMock()
await self.manager.disconnect(mock_disconnect_logic)
mock_disconnect_logic.assert_called_once()
self.assertIsNone(self.manager._connection)
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnected from test_exchange")
async def test_disconnect_no_active_connection(self):
mock_disconnect_logic = AsyncMock()
await self.manager.disconnect(mock_disconnect_logic)
mock_disconnect_logic.assert_not_called()
self.assertIsNone(self.manager._connection)
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnecting from test_exchange data source")
async def test_disconnect_exception(self):
self.manager._connection = True
mock_disconnect_logic = AsyncMock(side_effect=Exception("Disconnect error"))
await self.manager.disconnect(mock_disconnect_logic)
mock_disconnect_logic.assert_called_once()
self.assertTrue(self.manager._connection) # Connection state remains unchanged on exception
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during disconnection from test_exchange: Disconnect error", exc_info=True)
@patch('asyncio.sleep', new_callable=AsyncMock)
async def test_handle_connection_error_reconnect_success(self, mock_sleep):
mock_connect_logic = AsyncMock(return_value=True)
mock_subscribe_logic = AsyncMock(return_value=True)
self.manager._reconnect_attempts = 0
result = await self.manager.handle_connection_error(
mock_connect_logic,
mock_subscribe_logic,
symbols=["BTC/USDT"],
data_types=[DataType.CANDLE]
)
self.assertTrue(result)
self.assertEqual(self.manager._reconnect_attempts, 0)
mock_connect_logic.assert_called_once()
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RUNNING)
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Reconnection successful for test_exchange")
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
@patch('asyncio.sleep', new_callable=AsyncMock)
async def test_handle_connection_error_max_attempts_exceeded(self, mock_sleep):
mock_connect_logic = AsyncMock(return_value=True)
mock_subscribe_logic = AsyncMock(return_value=True)
self.manager._reconnect_attempts = self.manager._max_reconnect_attempts
result = await self.manager.handle_connection_error(
mock_connect_logic,
mock_subscribe_logic,
symbols=["BTC/USDT"],
data_types=[DataType.CANDLE]
)
self.assertFalse(result)
self.assertEqual(self.manager._reconnect_attempts, self.manager._max_reconnect_attempts + 1)
mock_connect_logic.assert_not_called()
mock_subscribe_logic.assert_not_called()
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.ERROR)
self.mock_state_telemetry.set_should_be_running.assert_called_once_with(False)
self.mock_state_telemetry._log_error.assert_any_call(f"test_collector: Max reconnection attempts ({self.manager._max_reconnect_attempts}) exceeded for test_exchange", exc_info=False)
mock_sleep.assert_not_called()
@patch('asyncio.sleep', new_callable=AsyncMock)
async def test_handle_connection_error_connect_fails(self, mock_sleep):
mock_connect_logic = AsyncMock(return_value=False)
mock_subscribe_logic = AsyncMock(return_value=True)
self.manager._reconnect_attempts = 0
result = await self.manager.handle_connection_error(
mock_connect_logic,
mock_subscribe_logic,
symbols=["BTC/USDT"],
data_types=[DataType.CANDLE]
)
self.assertFalse(result)
self.assertEqual(self.manager._reconnect_attempts, 1)
mock_connect_logic.assert_called_once()
mock_subscribe_logic.assert_not_called()
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Failed to connect to test_exchange", exc_info=False)
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
@patch('asyncio.sleep', new_callable=AsyncMock)
async def test_handle_connection_error_subscribe_fails(self, mock_sleep):
mock_connect_logic = AsyncMock(return_value=True)
mock_subscribe_logic = AsyncMock(return_value=False)
self.manager._reconnect_attempts = 0
result = await self.manager.handle_connection_error(
mock_connect_logic,
mock_subscribe_logic,
symbols=["BTC/USDT"],
data_types=[DataType.CANDLE]
)
self.assertFalse(result)
self.assertEqual(self.manager._reconnect_attempts, 1)
mock_connect_logic.assert_called_once()
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
@patch('asyncio.sleep', new_callable=AsyncMock)
async def test_handle_connection_error_exception_during_reconnect(self, mock_sleep):
mock_connect_logic = AsyncMock(side_effect=Exception("Reconnect error"))
mock_subscribe_logic = AsyncMock(return_value=True)
self.manager._reconnect_attempts = 0
result = await self.manager.handle_connection_error(
mock_connect_logic,
mock_subscribe_logic,
symbols=["BTC/USDT"],
data_types=[DataType.CANDLE]
)
self.assertFalse(result)
self.assertEqual(self.manager._reconnect_attempts, 1)
mock_connect_logic.assert_called_once()
mock_subscribe_logic.assert_not_called()
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Reconnect error", exc_info=True)
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)