""" Unit tests for the BaseDataCollector abstract class. """ import asyncio import pytest from datetime import datetime, timezone from decimal import Decimal from unittest.mock import AsyncMock, MagicMock from data.base_collector import ( BaseDataCollector, DataType, CollectorStatus, MarketDataPoint, OHLCVData, DataValidationError, DataCollectorError ) class TestDataCollector(BaseDataCollector): """Test implementation of BaseDataCollector for testing.""" def __init__(self, exchange_name: str, symbols: list, data_types=None): super().__init__(exchange_name, symbols, data_types) self.connected = False self.subscribed = False self.messages = [] async def connect(self) -> bool: await asyncio.sleep(0.01) # Simulate connection delay self.connected = True return True async def disconnect(self) -> None: await asyncio.sleep(0.01) # Simulate disconnection delay self.connected = False self.subscribed = False async def subscribe_to_data(self, symbols: list, data_types: list) -> bool: if not self.connected: return False self.subscribed = True return True async def unsubscribe_from_data(self, symbols: list, data_types: list) -> bool: self.subscribed = False return True async def _process_message(self, message) -> MarketDataPoint: self._stats['messages_received'] += 1 return MarketDataPoint( exchange=self.exchange_name, symbol=message.get('symbol', 'BTC-USDT'), timestamp=datetime.now(timezone.utc), data_type=DataType.TICKER, data=message ) async def _handle_messages(self) -> None: # Simulate receiving messages if self.messages: message = self.messages.pop(0) data_point = await self._process_message(message) self._stats['messages_processed'] += 1 self._stats['last_message_time'] = datetime.now(timezone.utc) await self._notify_callbacks(data_point) else: await asyncio.sleep(0.1) # Wait for messages def add_test_message(self, message: dict): """Add a test message to be processed.""" self.messages.append(message) class TestBaseDataCollector: """Test cases for BaseDataCollector.""" @pytest.fixture def collector(self): """Create a test collector instance.""" return TestDataCollector("okx", ["BTC-USDT", "ETH-USDT"], [DataType.TICKER]) def test_initialization(self, collector): """Test collector initialization.""" assert collector.exchange_name == "okx" assert collector.symbols == {"BTC-USDT", "ETH-USDT"} assert collector.data_types == [DataType.TICKER] assert collector.status == CollectorStatus.STOPPED assert not collector._running @pytest.mark.asyncio async def test_start_stop_cycle(self, collector): """Test starting and stopping the collector.""" # Test start success = await collector.start() assert success assert collector.status == CollectorStatus.RUNNING assert collector.connected assert collector.subscribed assert collector._running # Wait a bit for the message loop to start await asyncio.sleep(0.1) # Test stop await collector.stop() assert collector.status == CollectorStatus.STOPPED assert not collector._running assert not collector.connected assert not collector.subscribed @pytest.mark.asyncio async def test_message_processing(self, collector): """Test message processing and callbacks.""" received_data = [] def callback(data_point: MarketDataPoint): received_data.append(data_point) collector.add_data_callback(DataType.TICKER, callback) await collector.start() # Add test message test_message = {"symbol": "BTC-USDT", "price": "50000"} collector.add_test_message(test_message) # Wait for message processing await asyncio.sleep(0.2) await collector.stop() # Verify message was processed assert len(received_data) == 1 assert received_data[0].symbol == "BTC-USDT" assert received_data[0].data_type == DataType.TICKER assert collector._stats['messages_received'] == 1 assert collector._stats['messages_processed'] == 1 def test_symbol_management(self, collector): """Test adding and removing symbols.""" initial_count = len(collector.symbols) # Add new symbol collector.add_symbol("LTC-USDT") assert "LTC-USDT" in collector.symbols assert len(collector.symbols) == initial_count + 1 # Remove symbol collector.remove_symbol("BTC-USDT") assert "BTC-USDT" not in collector.symbols assert len(collector.symbols) == initial_count # Try to add existing symbol (should not duplicate) collector.add_symbol("ETH-USDT") assert len(collector.symbols) == initial_count def test_callback_management(self, collector): """Test adding and removing callbacks.""" def callback1(data): pass def callback2(data): pass # Add callbacks collector.add_data_callback(DataType.TICKER, callback1) collector.add_data_callback(DataType.TICKER, callback2) assert len(collector._data_callbacks[DataType.TICKER]) == 2 # Remove callback collector.remove_data_callback(DataType.TICKER, callback1) assert len(collector._data_callbacks[DataType.TICKER]) == 1 assert callback2 in collector._data_callbacks[DataType.TICKER] def test_get_status(self, collector): """Test status reporting.""" status = collector.get_status() assert status['exchange'] == 'okx' assert status['status'] == 'stopped' assert set(status['symbols']) == {"BTC-USDT", "ETH-USDT"} assert status['data_types'] == ['ticker'] assert 'statistics' in status assert status['statistics']['messages_received'] == 0 class TestOHLCVData: """Test cases for OHLCVData validation.""" def test_valid_ohlcv_data(self): """Test creating valid OHLCV data.""" ohlcv = OHLCVData( symbol="BTC-USDT", timeframe="1m", timestamp=datetime.now(timezone.utc), open=Decimal("50000"), high=Decimal("50100"), low=Decimal("49900"), close=Decimal("50050"), volume=Decimal("1.5"), trades_count=100 ) assert ohlcv.symbol == "BTC-USDT" assert ohlcv.timeframe == "1m" assert isinstance(ohlcv.open, Decimal) assert ohlcv.trades_count == 100 def test_invalid_ohlcv_relationships(self): """Test OHLCV validation for invalid price relationships.""" with pytest.raises(DataValidationError): OHLCVData( symbol="BTC-USDT", timeframe="1m", timestamp=datetime.now(timezone.utc), open=Decimal("50000"), high=Decimal("49000"), # High is less than open low=Decimal("49900"), close=Decimal("50050"), volume=Decimal("1.5") ) def test_ohlcv_decimal_conversion(self): """Test automatic conversion to Decimal.""" ohlcv = OHLCVData( symbol="BTC-USDT", timeframe="1m", timestamp=datetime.now(timezone.utc), open=50000.0, # float high=50100, # int low=49900, # int (changed from string to test proper conversion) close=50050.0, # float volume=1.5 # float ) assert isinstance(ohlcv.open, Decimal) assert isinstance(ohlcv.high, Decimal) assert isinstance(ohlcv.low, Decimal) assert isinstance(ohlcv.close, Decimal) assert isinstance(ohlcv.volume, Decimal) class TestDataValidation: """Test cases for data validation methods.""" def test_validate_ohlcv_data_success(self): """Test successful OHLCV data validation.""" collector = TestDataCollector("test", ["BTC-USDT"]) raw_data = { "timestamp": 1609459200000, # Unix timestamp in ms "open": "50000", "high": "50100", "low": "49900", "close": "50050", "volume": "1.5", "trades_count": 100 } ohlcv = collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m") assert ohlcv.symbol == "BTC-USDT" assert ohlcv.timeframe == "1m" assert ohlcv.trades_count == 100 assert isinstance(ohlcv.open, Decimal) def test_validate_ohlcv_data_missing_field(self): """Test OHLCV validation with missing required field.""" collector = TestDataCollector("test", ["BTC-USDT"]) raw_data = { "timestamp": 1609459200000, "open": "50000", "high": "50100", # Missing 'low' field "close": "50050", "volume": "1.5" } with pytest.raises(DataValidationError, match="Missing required field: low"): collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m") def test_validate_ohlcv_data_invalid_timestamp(self): """Test OHLCV validation with invalid timestamp.""" collector = TestDataCollector("test", ["BTC-USDT"]) raw_data = { "timestamp": "invalid_timestamp", "open": "50000", "high": "50100", "low": "49900", "close": "50050", "volume": "1.5" } with pytest.raises(DataValidationError): collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m") @pytest.mark.asyncio async def test_connection_error_handling(): """Test connection error handling and reconnection.""" class FailingCollector(TestDataCollector): def __init__(self): super().__init__("test", ["BTC-USDT"]) self.connect_attempts = 0 self.should_fail = True async def connect(self) -> bool: self.connect_attempts += 1 if self.should_fail and self.connect_attempts < 3: return False # Fail first 2 attempts return await super().connect() collector = FailingCollector() # First start should fail success = await collector.start() assert not success assert collector.status == CollectorStatus.ERROR # Reset for retry and allow success collector._reconnect_attempts = 0 collector.status = CollectorStatus.STOPPED collector.connect_attempts = 0 # Reset connection attempts collector.should_fail = False # Allow connection to succeed # This attempt should succeed success = await collector.start() assert success assert collector.status == CollectorStatus.RUNNING await collector.stop() if __name__ == "__main__": pytest.main([__file__, "-v"])