193 lines
9.7 KiB
Python
193 lines
9.7 KiB
Python
|
|
import asyncio
|
||
|
|
import unittest
|
||
|
|
from unittest.mock import AsyncMock, Mock, patch
|
||
|
|
from datetime import datetime, timedelta, timezone
|
||
|
|
|
||
|
|
from data.collector.collector_connection_manager import ConnectionManager
|
||
|
|
from data.collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||
|
|
from data.common.data_types import DataType # Import for DataType enum
|
||
|
|
|
||
|
|
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
||
|
|
|
||
|
|
def setUp(self):
|
||
|
|
self.mock_logger = Mock()
|
||
|
|
self.mock_state_telemetry = AsyncMock(spec=CollectorStateAndTelemetry)
|
||
|
|
self.mock_state_telemetry.logger = self.mock_logger
|
||
|
|
self.mock_state_telemetry.update_status = Mock()
|
||
|
|
self.mock_state_telemetry.set_should_be_running = Mock()
|
||
|
|
self.mock_state_telemetry.set_connection_uptime_start = Mock()
|
||
|
|
self.mock_state_telemetry._log_info = Mock()
|
||
|
|
self.mock_state_telemetry._log_debug = Mock()
|
||
|
|
self.mock_state_telemetry._log_warning = Mock()
|
||
|
|
self.mock_state_telemetry._log_error = Mock()
|
||
|
|
self.mock_state_telemetry._log_critical = Mock()
|
||
|
|
self.manager = ConnectionManager(
|
||
|
|
exchange_name="test_exchange",
|
||
|
|
component_name="test_collector",
|
||
|
|
logger=self.mock_logger,
|
||
|
|
state_telemetry=self.mock_state_telemetry
|
||
|
|
)
|
||
|
|
|
||
|
|
async def test_init(self):
|
||
|
|
self.assertEqual(self.manager.exchange_name, "test_exchange")
|
||
|
|
self.assertEqual(self.manager.component_name, "test_collector")
|
||
|
|
self.assertEqual(self.manager._max_reconnect_attempts, 5)
|
||
|
|
self.assertEqual(self.manager._reconnect_delay, 5.0)
|
||
|
|
self.assertEqual(self.manager.logger, self.mock_logger)
|
||
|
|
self.assertEqual(self.manager._state_telemetry, self.mock_state_telemetry)
|
||
|
|
self.assertIsNone(self.manager._connection)
|
||
|
|
self.assertEqual(self.manager._reconnect_attempts, 0)
|
||
|
|
|
||
|
|
async def test_connect_success(self):
|
||
|
|
mock_connect_logic = AsyncMock(return_value=True)
|
||
|
|
result = await self.manager.connect(mock_connect_logic)
|
||
|
|
self.assertTrue(result)
|
||
|
|
mock_connect_logic.assert_called_once()
|
||
|
|
self.assertTrue(self.manager._connection)
|
||
|
|
self.mock_state_telemetry.set_connection_uptime_start.assert_called_once()
|
||
|
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Successfully connected to test_exchange")
|
||
|
|
|
||
|
|
async def test_connect_failure(self):
|
||
|
|
mock_connect_logic = AsyncMock(return_value=False)
|
||
|
|
result = await self.manager.connect(mock_connect_logic)
|
||
|
|
self.assertFalse(result)
|
||
|
|
mock_connect_logic.assert_called_once()
|
||
|
|
self.assertIsNone(self.manager._connection)
|
||
|
|
self.mock_state_telemetry._log_error.assert_any_call("test_collector: Failed to connect to test_exchange", exc_info=False)
|
||
|
|
|
||
|
|
async def test_connect_exception(self):
|
||
|
|
mock_connect_logic = AsyncMock(side_effect=Exception("Connection error"))
|
||
|
|
result = await self.manager.connect(mock_connect_logic)
|
||
|
|
self.assertFalse(result)
|
||
|
|
mock_connect_logic.assert_called_once()
|
||
|
|
self.assertIsNone(self.manager._connection)
|
||
|
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Connection error", exc_info=True)
|
||
|
|
|
||
|
|
async def test_disconnect_success(self):
|
||
|
|
self.manager._connection = True # Simulate active connection
|
||
|
|
mock_disconnect_logic = AsyncMock()
|
||
|
|
await self.manager.disconnect(mock_disconnect_logic)
|
||
|
|
mock_disconnect_logic.assert_called_once()
|
||
|
|
self.assertIsNone(self.manager._connection)
|
||
|
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnected from test_exchange")
|
||
|
|
|
||
|
|
async def test_disconnect_no_active_connection(self):
|
||
|
|
mock_disconnect_logic = AsyncMock()
|
||
|
|
await self.manager.disconnect(mock_disconnect_logic)
|
||
|
|
mock_disconnect_logic.assert_not_called()
|
||
|
|
self.assertIsNone(self.manager._connection)
|
||
|
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Disconnecting from test_exchange data source")
|
||
|
|
|
||
|
|
async def test_disconnect_exception(self):
|
||
|
|
self.manager._connection = True
|
||
|
|
mock_disconnect_logic = AsyncMock(side_effect=Exception("Disconnect error"))
|
||
|
|
await self.manager.disconnect(mock_disconnect_logic)
|
||
|
|
mock_disconnect_logic.assert_called_once()
|
||
|
|
self.assertTrue(self.manager._connection) # Connection state remains unchanged on exception
|
||
|
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during disconnection from test_exchange: Disconnect error", exc_info=True)
|
||
|
|
|
||
|
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||
|
|
async def test_handle_connection_error_reconnect_success(self, mock_sleep):
|
||
|
|
mock_connect_logic = AsyncMock(return_value=True)
|
||
|
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
||
|
|
|
||
|
|
self.manager._reconnect_attempts = 0
|
||
|
|
result = await self.manager.handle_connection_error(
|
||
|
|
mock_connect_logic,
|
||
|
|
mock_subscribe_logic,
|
||
|
|
symbols=["BTC/USDT"],
|
||
|
|
data_types=[DataType.CANDLE]
|
||
|
|
)
|
||
|
|
|
||
|
|
self.assertTrue(result)
|
||
|
|
self.assertEqual(self.manager._reconnect_attempts, 0)
|
||
|
|
mock_connect_logic.assert_called_once()
|
||
|
|
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
|
||
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RUNNING)
|
||
|
|
self.mock_state_telemetry._log_info.assert_any_call("test_collector: Reconnection successful for test_exchange")
|
||
|
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||
|
|
|
||
|
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||
|
|
async def test_handle_connection_error_max_attempts_exceeded(self, mock_sleep):
|
||
|
|
mock_connect_logic = AsyncMock(return_value=True)
|
||
|
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
||
|
|
|
||
|
|
self.manager._reconnect_attempts = self.manager._max_reconnect_attempts
|
||
|
|
result = await self.manager.handle_connection_error(
|
||
|
|
mock_connect_logic,
|
||
|
|
mock_subscribe_logic,
|
||
|
|
symbols=["BTC/USDT"],
|
||
|
|
data_types=[DataType.CANDLE]
|
||
|
|
)
|
||
|
|
|
||
|
|
self.assertFalse(result)
|
||
|
|
self.assertEqual(self.manager._reconnect_attempts, self.manager._max_reconnect_attempts + 1)
|
||
|
|
mock_connect_logic.assert_not_called()
|
||
|
|
mock_subscribe_logic.assert_not_called()
|
||
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.ERROR)
|
||
|
|
self.mock_state_telemetry.set_should_be_running.assert_called_once_with(False)
|
||
|
|
self.mock_state_telemetry._log_error.assert_any_call(f"test_collector: Max reconnection attempts ({self.manager._max_reconnect_attempts}) exceeded for test_exchange", exc_info=False)
|
||
|
|
mock_sleep.assert_not_called()
|
||
|
|
|
||
|
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||
|
|
async def test_handle_connection_error_connect_fails(self, mock_sleep):
|
||
|
|
mock_connect_logic = AsyncMock(return_value=False)
|
||
|
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
||
|
|
|
||
|
|
self.manager._reconnect_attempts = 0
|
||
|
|
result = await self.manager.handle_connection_error(
|
||
|
|
mock_connect_logic,
|
||
|
|
mock_subscribe_logic,
|
||
|
|
symbols=["BTC/USDT"],
|
||
|
|
data_types=[DataType.CANDLE]
|
||
|
|
)
|
||
|
|
|
||
|
|
self.assertFalse(result)
|
||
|
|
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||
|
|
mock_connect_logic.assert_called_once()
|
||
|
|
mock_subscribe_logic.assert_not_called()
|
||
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||
|
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Failed to connect to test_exchange", exc_info=False)
|
||
|
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||
|
|
|
||
|
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||
|
|
async def test_handle_connection_error_subscribe_fails(self, mock_sleep):
|
||
|
|
mock_connect_logic = AsyncMock(return_value=True)
|
||
|
|
mock_subscribe_logic = AsyncMock(return_value=False)
|
||
|
|
|
||
|
|
self.manager._reconnect_attempts = 0
|
||
|
|
result = await self.manager.handle_connection_error(
|
||
|
|
mock_connect_logic,
|
||
|
|
mock_subscribe_logic,
|
||
|
|
symbols=["BTC/USDT"],
|
||
|
|
data_types=[DataType.CANDLE]
|
||
|
|
)
|
||
|
|
|
||
|
|
self.assertFalse(result)
|
||
|
|
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||
|
|
mock_connect_logic.assert_called_once()
|
||
|
|
mock_subscribe_logic.assert_called_once_with(["BTC/USDT"], [DataType.CANDLE])
|
||
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||
|
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|
||
|
|
|
||
|
|
@patch('asyncio.sleep', new_callable=AsyncMock)
|
||
|
|
async def test_handle_connection_error_exception_during_reconnect(self, mock_sleep):
|
||
|
|
mock_connect_logic = AsyncMock(side_effect=Exception("Reconnect error"))
|
||
|
|
mock_subscribe_logic = AsyncMock(return_value=True)
|
||
|
|
|
||
|
|
self.manager._reconnect_attempts = 0
|
||
|
|
result = await self.manager.handle_connection_error(
|
||
|
|
mock_connect_logic,
|
||
|
|
mock_subscribe_logic,
|
||
|
|
symbols=["BTC/USDT"],
|
||
|
|
data_types=[DataType.CANDLE]
|
||
|
|
)
|
||
|
|
|
||
|
|
self.assertFalse(result)
|
||
|
|
self.assertEqual(self.manager._reconnect_attempts, 1)
|
||
|
|
mock_connect_logic.assert_called_once()
|
||
|
|
mock_subscribe_logic.assert_not_called()
|
||
|
|
self.mock_state_telemetry.update_status.assert_any_call(CollectorStatus.RECONNECTING)
|
||
|
|
self.mock_state_telemetry._log_error.assert_called_once_with("test_collector: Error during connection to test_exchange: Reconnect error", exc_info=True)
|
||
|
|
mock_sleep.assert_called_once_with(self.manager._reconnect_delay)
|