diff --git a/data/base_collector.py b/data/base_collector.py index cf9b949..3c0b18e 100644 --- a/data/base_collector.py +++ b/data/base_collector.py @@ -16,6 +16,7 @@ from enum import Enum from utils.logger import get_logger from .collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry from .collector.collector_connection_manager import ConnectionManager +from .collector.collector_callback_dispatcher import CallbackDispatcher from .common.data_types import DataType, MarketDataPoint @@ -137,14 +138,15 @@ class BaseDataCollector(ABC): state_telemetry=self._state_telemetry ) + # Initialize callback dispatcher + self._callback_dispatcher = CallbackDispatcher( + component_name=component, + logger=self.logger + ) + # Collector state (now managed by _state_telemetry) self._tasks: Set[asyncio.Task] = set() - # Data callbacks - self._data_callbacks: Dict[DataType, List[Callable]] = { - data_type: [] for data_type in DataType - } - # Log initialization if logger is available if self._state_telemetry.logger: if not self._state_telemetry.log_errors_only: @@ -457,9 +459,7 @@ class BaseDataCollector(ABC): data_type: Type of data to monitor callback: Function to call when data is received """ - if callback not in self._data_callbacks[data_type]: - self._data_callbacks[data_type].append(callback) - self._log_debug(f"Added callback for {data_type.value} data") + self._callback_dispatcher.add_data_callback(data_type, callback) def remove_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None: """ @@ -469,9 +469,7 @@ class BaseDataCollector(ABC): data_type: Type of data to stop monitoring callback: Function to remove """ - if callback in self._data_callbacks[data_type]: - self._data_callbacks[data_type].remove(callback) - self._log_debug(f"Removed callback for {data_type.value} data") + self._callback_dispatcher.remove_data_callback(data_type, callback) async def _notify_callbacks(self, data_point: MarketDataPoint) -> None: """ @@ -480,18 +478,7 @@ class BaseDataCollector(ABC): Args: data_point: Market data to distribute """ - callbacks = self._data_callbacks.get(data_point.data_type, []) - - for callback in callbacks: - try: - # Handle both sync and async callbacks - if asyncio.iscoroutinefunction(callback): - await callback(data_point) - else: - callback(data_point) - - except Exception as e: - self._log_error(f"Error in data callback: {e}") + await self._callback_dispatcher.notify_callbacks(data_point) # Update statistics self._state_telemetry.increment_messages_processed() diff --git a/data/collector/collector_callback_dispatcher.py b/data/collector/collector_callback_dispatcher.py new file mode 100644 index 0000000..4dd0e79 --- /dev/null +++ b/data/collector/collector_callback_dispatcher.py @@ -0,0 +1,85 @@ +""" +Module for managing data callbacks and notifications for data collectors. + +This module encapsulates the logic for registering, removing, and notifying +callback functions when new market data points are received, promoting a +clean separation of concerns within the data collector architecture. +""" + +import asyncio +from typing import Dict, List, Optional, Any, Callable + +from data.common.data_types import DataType, MarketDataPoint + + +class CallbackDispatcher: + """ + Manages the dispatching of market data points to registered callbacks. + """ + + def __init__(self, component_name: str, logger=None): + self.component_name = component_name + self.logger = logger + self._data_callbacks: Dict[DataType, List[Callable]] = { + data_type: [] for data_type in DataType + } + + def _log_debug(self, message: str) -> None: + if self.logger: + self.logger.debug(f"{self.component_name}: {message}") + + def _log_info(self, message: str) -> None: + if self.logger: + self.logger.info(f"{self.component_name}: {message}") + + def _log_warning(self, message: str) -> None: + if self.logger: + self.logger.warning(f"{self.component_name}: {message}") + + def _log_error(self, message: str, exc_info: bool = False) -> None: + if self.logger: + self.logger.error(f"{self.component_name}: {message}", exc_info=exc_info) + + def add_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None: + """ + Add a callback function for specific data type. + + Args: + data_type: Type of data to monitor + callback: Function to call when data is received + """ + if callback not in self._data_callbacks[data_type]: + self._data_callbacks[data_type].append(callback) + self._log_debug(f"Added callback for {data_type.value} data") + + def remove_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None: + """ + Remove a callback function for specific data type. + + Args: + data_type: Type of data to stop monitoring + callback: Function to remove + """ + if callback in self._data_callbacks[data_type]: + self._data_callbacks[data_type].remove(callback) + self._log_debug(f"Removed callback for {data_type.value} data") + + async def notify_callbacks(self, data_point: MarketDataPoint) -> None: + """ + Notify all registered callbacks for a data point. + + Args: + data_point: Market data to distribute + """ + callbacks = self._data_callbacks.get(data_point.data_type, []) + + for callback in callbacks: + try: + # Handle both sync and async callbacks + if asyncio.iscoroutinefunction(callback): + await callback(data_point) + else: + callback(data_point) + + except Exception as e: + self._log_error(f"Error in data callback for {data_point.data_type.value} {data_point.symbol}: {e}", exc_info=True) \ No newline at end of file diff --git a/tasks/tasks-base-collector-refactoring.md b/tasks/tasks-base-collector-refactoring.md index d442cd1..523c4af 100644 --- a/tasks/tasks-base-collector-refactoring.md +++ b/tasks/tasks-base-collector-refactoring.md @@ -36,13 +36,13 @@ - [x] 2.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_connection_manager.py`. - [x] 2.6 Create `tests/data/collector/test_collector_connection_manager.py` and add initial tests for the new class. -- [ ] 3.0 Extract `CallbackDispatcher` Class - - [ ] 3.1 Create `data/collector/collector_callback_dispatcher.py`. - - [ ] 3.2 Move `_data_callbacks` attribute to `CallbackDispatcher`. - - [ ] 3.3 Move `add_data_callback`, `remove_data_callback`, `_notify_callbacks` methods to `CallbackDispatcher`. - - [ ] 3.4 Implement a constructor for `CallbackDispatcher` to receive logger. - - [ ] 3.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_callback_dispatcher.py`. - - [ ] 3.6 Create `tests/data/collector/test_collector_callback_dispatcher.py` and add initial tests for the new class. +- [x] 3.0 Extract `CallbackDispatcher` Class + - [x] 3.1 Create `data/collector/collector_callback_dispatcher.py`. + - [x] 3.2 Move `_data_callbacks` attribute to `CallbackDispatcher`. + - [x] 3.3 Move `add_data_callback`, `remove_data_callback`, `_notify_callbacks` methods to `CallbackDispatcher`. + - [x] 3.4 Implement a constructor for `CallbackDispatcher` to receive logger. + - [x] 3.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_callback_dispatcher.py`. + - [x] 3.6 Create `tests/data/collector/test_collector_callback_dispatcher.py` and add initial tests for the new class. - [ ] 4.0 Refactor `BaseDataCollector` to use new components - [ ] 4.1 Update `BaseDataCollector.__init__` to instantiate and use `CollectorStateAndTelemetry`, `ConnectionManager`, and `CallbackDispatcher` instances. diff --git a/tests/data/collector/test_collector_callback_dispatcher.py b/tests/data/collector/test_collector_callback_dispatcher.py new file mode 100644 index 0000000..ed81388 --- /dev/null +++ b/tests/data/collector/test_collector_callback_dispatcher.py @@ -0,0 +1,112 @@ +import asyncio +import unittest +from unittest.mock import AsyncMock, Mock +from datetime import datetime + +from data.collector.collector_callback_dispatcher import CallbackDispatcher +from data.common.data_types import DataType, MarketDataPoint + + +class TestCallbackDispatcher(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + self.mock_logger = Mock() + self.dispatcher = CallbackDispatcher( + component_name="test_dispatcher", + logger=self.mock_logger + ) + + async def test_init(self): + self.assertEqual(self.dispatcher.component_name, "test_dispatcher") + self.assertEqual(self.dispatcher.logger, self.mock_logger) + self.assertIsInstance(self.dispatcher._data_callbacks, dict) + self.assertGreater(len(self.dispatcher._data_callbacks), 0) # Ensure all DataType enums are initialized + + async def test_add_data_callback(self): + mock_callback = Mock() + data_type = DataType.CANDLE + + self.dispatcher.add_data_callback(data_type, mock_callback) + self.assertIn(mock_callback, self.dispatcher._data_callbacks[data_type]) + self.mock_logger.debug.assert_called_with(f"test_dispatcher: Added callback for {data_type.value} data") + + # Test adding same callback twice (should not add) + self.dispatcher.add_data_callback(data_type, mock_callback) + self.assertEqual(self.dispatcher._data_callbacks[data_type].count(mock_callback), 1) + + async def test_remove_data_callback(self): + mock_callback = Mock() + data_type = DataType.TRADE + + self.dispatcher.add_data_callback(data_type, mock_callback) + self.assertIn(mock_callback, self.dispatcher._data_callbacks[data_type]) + + self.dispatcher.remove_data_callback(data_type, mock_callback) + self.assertNotIn(mock_callback, self.dispatcher._data_callbacks[data_type]) + self.mock_logger.debug.assert_called_with(f"test_dispatcher: Removed callback for {data_type.value} data") + + # Test removing non-existent callback (should do nothing) + self.dispatcher.remove_data_callback(data_type, Mock()) + # No error should be raised and log should not be called again for removal + + async def test_notify_callbacks_sync(self): + mock_sync_callback = Mock() + data_type = DataType.TICKER + data_point = MarketDataPoint("exchange", "symbol", datetime.now(), data_type, {"price": 100}) + + self.dispatcher.add_data_callback(data_type, mock_sync_callback) + await self.dispatcher.notify_callbacks(data_point) + + mock_sync_callback.assert_called_once_with(data_point) + + async def test_notify_callbacks_async(self): + mock_async_callback = AsyncMock() + data_type = DataType.ORDERBOOK + data_point = MarketDataPoint("exchange", "symbol", datetime.now(), data_type, {"bids": [], "asks": []}) + + self.dispatcher.add_data_callback(data_type, mock_async_callback) + await self.dispatcher.notify_callbacks(data_point) + + mock_async_callback.assert_called_once_with(data_point) + + async def test_notify_callbacks_mixed(self): + mock_sync_callback = Mock() + mock_async_callback = AsyncMock() + data_type = DataType.BALANCE + data_point = MarketDataPoint("exchange", "symbol", datetime.now(), data_type, {"asset": "BTC", "balance": 0.5}) + + self.dispatcher.add_data_callback(data_type, mock_sync_callback) + self.dispatcher.add_data_callback(data_type, mock_async_callback) + await self.dispatcher.notify_callbacks(data_point) + + mock_sync_callback.assert_called_once_with(data_point) + mock_async_callback.assert_called_once_with(data_point) + + async def test_notify_callbacks_exception_handling(self): + def failing_sync_callback(data): raise ValueError("Sync error") + async def failing_async_callback(data): raise TypeError("Async error") + + mock_successful_callback = Mock() + + data_type = DataType.CANDLE + data_point = MarketDataPoint("exchange", "symbol", datetime.now(), data_type, {}) + + self.dispatcher.add_data_callback(data_type, failing_sync_callback) + self.dispatcher.add_data_callback(data_type, failing_async_callback) + self.dispatcher.add_data_callback(data_type, mock_successful_callback) + + await self.dispatcher.notify_callbacks(data_point) + + mock_successful_callback.assert_called_once_with(data_point) + self.assertEqual(self.mock_logger.error.call_count, 2) + self.mock_logger.error.assert_any_call(f"test_dispatcher: Error in data callback for {data_type.value} {data_point.symbol}: Sync error", exc_info=True) + self.mock_logger.error.assert_any_call(f"test_dispatcher: Error in data callback for {data_type.value} {data_point.symbol}: Async error", exc_info=True) + + async def test_notify_callbacks_no_callbacks(self): + data_type = DataType.TICKER + data_point = MarketDataPoint("exchange", "symbol", datetime.now(), data_type, {}) + + # No callbacks added + await self.dispatcher.notify_callbacks(data_point) + self.mock_logger.error.assert_not_called() # No errors should be logged + self.mock_logger.debug.assert_not_called() # No debug logs from notify \ No newline at end of file