Refactor BaseDataCollector to integrate CallbackDispatcher for improved callback management
- Extracted callback management logic into a new `CallbackDispatcher` class, promoting separation of concerns and enhancing modularity. - Updated `BaseDataCollector` to utilize the `CallbackDispatcher` for adding, removing, and notifying data callbacks, improving code clarity and maintainability. - Refactored related methods to ensure consistent error handling and logging practices. - Added unit tests for the `CallbackDispatcher` to validate its functionality and ensure robust error handling. These changes streamline the callback management architecture, aligning with project standards for maintainability and performance.
This commit is contained in:
parent
41f0e8e6b6
commit
3db8fb1c41
@ -16,6 +16,7 @@ from enum import Enum
|
||||
from utils.logger import get_logger
|
||||
from .collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||||
from .collector.collector_connection_manager import ConnectionManager
|
||||
from .collector.collector_callback_dispatcher import CallbackDispatcher
|
||||
from .common.data_types import DataType, MarketDataPoint
|
||||
|
||||
|
||||
@ -137,14 +138,15 @@ class BaseDataCollector(ABC):
|
||||
state_telemetry=self._state_telemetry
|
||||
)
|
||||
|
||||
# Initialize callback dispatcher
|
||||
self._callback_dispatcher = CallbackDispatcher(
|
||||
component_name=component,
|
||||
logger=self.logger
|
||||
)
|
||||
|
||||
# Collector state (now managed by _state_telemetry)
|
||||
self._tasks: Set[asyncio.Task] = set()
|
||||
|
||||
# Data callbacks
|
||||
self._data_callbacks: Dict[DataType, List[Callable]] = {
|
||||
data_type: [] for data_type in DataType
|
||||
}
|
||||
|
||||
# Log initialization if logger is available
|
||||
if self._state_telemetry.logger:
|
||||
if not self._state_telemetry.log_errors_only:
|
||||
@ -457,9 +459,7 @@ class BaseDataCollector(ABC):
|
||||
data_type: Type of data to monitor
|
||||
callback: Function to call when data is received
|
||||
"""
|
||||
if callback not in self._data_callbacks[data_type]:
|
||||
self._data_callbacks[data_type].append(callback)
|
||||
self._log_debug(f"Added callback for {data_type.value} data")
|
||||
self._callback_dispatcher.add_data_callback(data_type, callback)
|
||||
|
||||
def remove_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None:
|
||||
"""
|
||||
@ -469,9 +469,7 @@ class BaseDataCollector(ABC):
|
||||
data_type: Type of data to stop monitoring
|
||||
callback: Function to remove
|
||||
"""
|
||||
if callback in self._data_callbacks[data_type]:
|
||||
self._data_callbacks[data_type].remove(callback)
|
||||
self._log_debug(f"Removed callback for {data_type.value} data")
|
||||
self._callback_dispatcher.remove_data_callback(data_type, callback)
|
||||
|
||||
async def _notify_callbacks(self, data_point: MarketDataPoint) -> None:
|
||||
"""
|
||||
@ -480,18 +478,7 @@ class BaseDataCollector(ABC):
|
||||
Args:
|
||||
data_point: Market data to distribute
|
||||
"""
|
||||
callbacks = self._data_callbacks.get(data_point.data_type, [])
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
# Handle both sync and async callbacks
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(data_point)
|
||||
else:
|
||||
callback(data_point)
|
||||
|
||||
except Exception as e:
|
||||
self._log_error(f"Error in data callback: {e}")
|
||||
await self._callback_dispatcher.notify_callbacks(data_point)
|
||||
|
||||
# Update statistics
|
||||
self._state_telemetry.increment_messages_processed()
|
||||
|
||||
85
data/collector/collector_callback_dispatcher.py
Normal file
85
data/collector/collector_callback_dispatcher.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""
|
||||
Module for managing data callbacks and notifications for data collectors.
|
||||
|
||||
This module encapsulates the logic for registering, removing, and notifying
|
||||
callback functions when new market data points are received, promoting a
|
||||
clean separation of concerns within the data collector architecture.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
|
||||
from data.common.data_types import DataType, MarketDataPoint
|
||||
|
||||
|
||||
class CallbackDispatcher:
|
||||
"""
|
||||
Manages the dispatching of market data points to registered callbacks.
|
||||
"""
|
||||
|
||||
def __init__(self, component_name: str, logger=None):
|
||||
self.component_name = component_name
|
||||
self.logger = logger
|
||||
self._data_callbacks: Dict[DataType, List[Callable]] = {
|
||||
data_type: [] for data_type in DataType
|
||||
}
|
||||
|
||||
def _log_debug(self, message: str) -> None:
|
||||
if self.logger:
|
||||
self.logger.debug(f"{self.component_name}: {message}")
|
||||
|
||||
def _log_info(self, message: str) -> None:
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.component_name}: {message}")
|
||||
|
||||
def _log_warning(self, message: str) -> None:
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.component_name}: {message}")
|
||||
|
||||
def _log_error(self, message: str, exc_info: bool = False) -> None:
|
||||
if self.logger:
|
||||
self.logger.error(f"{self.component_name}: {message}", exc_info=exc_info)
|
||||
|
||||
def add_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None:
|
||||
"""
|
||||
Add a callback function for specific data type.
|
||||
|
||||
Args:
|
||||
data_type: Type of data to monitor
|
||||
callback: Function to call when data is received
|
||||
"""
|
||||
if callback not in self._data_callbacks[data_type]:
|
||||
self._data_callbacks[data_type].append(callback)
|
||||
self._log_debug(f"Added callback for {data_type.value} data")
|
||||
|
||||
def remove_data_callback(self, data_type: DataType, callback: Callable[[MarketDataPoint], None]) -> None:
|
||||
"""
|
||||
Remove a callback function for specific data type.
|
||||
|
||||
Args:
|
||||
data_type: Type of data to stop monitoring
|
||||
callback: Function to remove
|
||||
"""
|
||||
if callback in self._data_callbacks[data_type]:
|
||||
self._data_callbacks[data_type].remove(callback)
|
||||
self._log_debug(f"Removed callback for {data_type.value} data")
|
||||
|
||||
async def notify_callbacks(self, data_point: MarketDataPoint) -> None:
|
||||
"""
|
||||
Notify all registered callbacks for a data point.
|
||||
|
||||
Args:
|
||||
data_point: Market data to distribute
|
||||
"""
|
||||
callbacks = self._data_callbacks.get(data_point.data_type, [])
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
# Handle both sync and async callbacks
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(data_point)
|
||||
else:
|
||||
callback(data_point)
|
||||
|
||||
except Exception as e:
|
||||
self._log_error(f"Error in data callback for {data_point.data_type.value} {data_point.symbol}: {e}", exc_info=True)
|
||||
@ -36,13 +36,13 @@
|
||||
- [x] 2.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_connection_manager.py`.
|
||||
- [x] 2.6 Create `tests/data/collector/test_collector_connection_manager.py` and add initial tests for the new class.
|
||||
|
||||
- [ ] 3.0 Extract `CallbackDispatcher` Class
|
||||
- [ ] 3.1 Create `data/collector/collector_callback_dispatcher.py`.
|
||||
- [ ] 3.2 Move `_data_callbacks` attribute to `CallbackDispatcher`.
|
||||
- [ ] 3.3 Move `add_data_callback`, `remove_data_callback`, `_notify_callbacks` methods to `CallbackDispatcher`.
|
||||
- [ ] 3.4 Implement a constructor for `CallbackDispatcher` to receive logger.
|
||||
- [ ] 3.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_callback_dispatcher.py`.
|
||||
- [ ] 3.6 Create `tests/data/collector/test_collector_callback_dispatcher.py` and add initial tests for the new class.
|
||||
- [x] 3.0 Extract `CallbackDispatcher` Class
|
||||
- [x] 3.1 Create `data/collector/collector_callback_dispatcher.py`.
|
||||
- [x] 3.2 Move `_data_callbacks` attribute to `CallbackDispatcher`.
|
||||
- [x] 3.3 Move `add_data_callback`, `remove_data_callback`, `_notify_callbacks` methods to `CallbackDispatcher`.
|
||||
- [x] 3.4 Implement a constructor for `CallbackDispatcher` to receive logger.
|
||||
- [x] 3.5 Add necessary imports to both `data/base_collector.py` and `data/collector/collector_callback_dispatcher.py`.
|
||||
- [x] 3.6 Create `tests/data/collector/test_collector_callback_dispatcher.py` and add initial tests for the new class.
|
||||
|
||||
- [ ] 4.0 Refactor `BaseDataCollector` to use new components
|
||||
- [ ] 4.1 Update `BaseDataCollector.__init__` to instantiate and use `CollectorStateAndTelemetry`, `ConnectionManager`, and `CallbackDispatcher` instances.
|
||||
|
||||
112
tests/data/collector/test_collector_callback_dispatcher.py
Normal file
112
tests/data/collector/test_collector_callback_dispatcher.py
Normal file
@ -0,0 +1,112 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from datetime import datetime
|
||||
|
||||
from data.collector.collector_callback_dispatcher import CallbackDispatcher
|
||||
from data.common.data_types import DataType, MarketDataPoint
|
||||
|
||||
|
||||
class TestCallbackDispatcher(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_logger = Mock()
|
||||
self.dispatcher = CallbackDispatcher(
|
||||
component_name="test_dispatcher",
|
||||
logger=self.mock_logger
|
||||
)
|
||||
|
||||
async def test_init(self):
|
||||
self.assertEqual(self.dispatcher.component_name, "test_dispatcher")
|
||||
self.assertEqual(self.dispatcher.logger, self.mock_logger)
|
||||
self.assertIsInstance(self.dispatcher._data_callbacks, dict)
|
||||
self.assertGreater(len(self.dispatcher._data_callbacks), 0) # Ensure all DataType enums are initialized
|
||||
|
||||
async def test_add_data_callback(self):
|
||||
mock_callback = Mock()
|
||||
data_type = DataType.CANDLE
|
||||
|
||||
self.dispatcher.add_data_callback(data_type, mock_callback)
|
||||
self.assertIn(mock_callback, self.dispatcher._data_callbacks[data_type])
|
||||
self.mock_logger.debug.assert_called_with(f"test_dispatcher: Added callback for {data_type.value} data")
|
||||
|
||||
# Test adding same callback twice (should not add)
|
||||
self.dispatcher.add_data_callback(data_type, mock_callback)
|
||||
self.assertEqual(self.dispatcher._data_callbacks[data_type].count(mock_callback), 1)
|
||||
|
||||
async def test_remove_data_callback(self):
|
||||
mock_callback = Mock()
|
||||
data_type = DataType.TRADE
|
||||
|
||||
self.dispatcher.add_data_callback(data_type, mock_callback)
|
||||
self.assertIn(mock_callback, self.dispatcher._data_callbacks[data_type])
|
||||
|
||||
self.dispatcher.remove_data_callback(data_type, mock_callback)
|
||||
self.assertNotIn(mock_callback, self.dispatcher._data_callbacks[data_type])
|
||||
self.mock_logger.debug.assert_called_with(f"test_dispatcher: Removed callback for {data_type.value} data")
|
||||
|
||||
# Test removing non-existent callback (should do nothing)
|
||||
self.dispatcher.remove_data_callback(data_type, Mock())
|
||||
# No error should be raised and log should not be called again for removal
|
||||
|
||||
async def test_notify_callbacks_sync(self):
|
||||
mock_sync_callback = Mock()
|
||||
data_type = DataType.TICKER
|
||||
data_point = MarketDataPoint("exchange", "symbol", datetime.now(), data_type, {"price": 100})
|
||||
|
||||
self.dispatcher.add_data_callback(data_type, mock_sync_callback)
|
||||
await self.dispatcher.notify_callbacks(data_point)
|
||||
|
||||
mock_sync_callback.assert_called_once_with(data_point)
|
||||
|
||||
async def test_notify_callbacks_async(self):
|
||||
mock_async_callback = AsyncMock()
|
||||
data_type = DataType.ORDERBOOK
|
||||
data_point = MarketDataPoint("exchange", "symbol", datetime.now(), data_type, {"bids": [], "asks": []})
|
||||
|
||||
self.dispatcher.add_data_callback(data_type, mock_async_callback)
|
||||
await self.dispatcher.notify_callbacks(data_point)
|
||||
|
||||
mock_async_callback.assert_called_once_with(data_point)
|
||||
|
||||
async def test_notify_callbacks_mixed(self):
|
||||
mock_sync_callback = Mock()
|
||||
mock_async_callback = AsyncMock()
|
||||
data_type = DataType.BALANCE
|
||||
data_point = MarketDataPoint("exchange", "symbol", datetime.now(), data_type, {"asset": "BTC", "balance": 0.5})
|
||||
|
||||
self.dispatcher.add_data_callback(data_type, mock_sync_callback)
|
||||
self.dispatcher.add_data_callback(data_type, mock_async_callback)
|
||||
await self.dispatcher.notify_callbacks(data_point)
|
||||
|
||||
mock_sync_callback.assert_called_once_with(data_point)
|
||||
mock_async_callback.assert_called_once_with(data_point)
|
||||
|
||||
async def test_notify_callbacks_exception_handling(self):
|
||||
def failing_sync_callback(data): raise ValueError("Sync error")
|
||||
async def failing_async_callback(data): raise TypeError("Async error")
|
||||
|
||||
mock_successful_callback = Mock()
|
||||
|
||||
data_type = DataType.CANDLE
|
||||
data_point = MarketDataPoint("exchange", "symbol", datetime.now(), data_type, {})
|
||||
|
||||
self.dispatcher.add_data_callback(data_type, failing_sync_callback)
|
||||
self.dispatcher.add_data_callback(data_type, failing_async_callback)
|
||||
self.dispatcher.add_data_callback(data_type, mock_successful_callback)
|
||||
|
||||
await self.dispatcher.notify_callbacks(data_point)
|
||||
|
||||
mock_successful_callback.assert_called_once_with(data_point)
|
||||
self.assertEqual(self.mock_logger.error.call_count, 2)
|
||||
self.mock_logger.error.assert_any_call(f"test_dispatcher: Error in data callback for {data_type.value} {data_point.symbol}: Sync error", exc_info=True)
|
||||
self.mock_logger.error.assert_any_call(f"test_dispatcher: Error in data callback for {data_type.value} {data_point.symbol}: Async error", exc_info=True)
|
||||
|
||||
async def test_notify_callbacks_no_callbacks(self):
|
||||
data_type = DataType.TICKER
|
||||
data_point = MarketDataPoint("exchange", "symbol", datetime.now(), data_type, {})
|
||||
|
||||
# No callbacks added
|
||||
await self.dispatcher.notify_callbacks(data_point)
|
||||
self.mock_logger.error.assert_not_called() # No errors should be logged
|
||||
self.mock_logger.debug.assert_not_called() # No debug logs from notify
|
||||
Loading…
x
Reference in New Issue
Block a user