- Extracted `OHLCVData` and validation logic into a new `common/ohlcv_data.py` module, promoting better organization and reusability. - Updated `BaseDataCollector` to utilize the new `validate_ohlcv_data` function for improved data validation, enhancing code clarity and maintainability. - Refactored imports in `data/__init__.py` to reflect the new structure, ensuring consistent access to common data types and exceptions. - Removed redundant data validation logic from `BaseDataCollector`, streamlining its responsibilities. - Added unit tests for `OHLCVData` and validation functions to ensure correctness and reliability. These changes improve the architecture of the data module, aligning with project standards for maintainability and performance.
342 lines
12 KiB
Python
342 lines
12 KiB
Python
"""
|
|
Unit tests for the BaseDataCollector abstract class.
|
|
"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from decimal import Decimal
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from data.base_collector import (
|
|
BaseDataCollector, DataType, CollectorStatus, MarketDataPoint,
|
|
OHLCVData, DataValidationError, DataCollectorError
|
|
)
|
|
|
|
|
|
class TestDataCollector(BaseDataCollector):
|
|
"""Test implementation of BaseDataCollector for testing."""
|
|
|
|
def __init__(self, exchange_name: str, symbols: list, data_types=None):
|
|
super().__init__(exchange_name, symbols, data_types)
|
|
self.connected = False
|
|
self.subscribed = False
|
|
self.messages = []
|
|
|
|
async def _actual_connect(self) -> bool:
|
|
"""Implementation of actual connection logic for testing."""
|
|
await asyncio.sleep(0.01) # Simulate connection delay
|
|
self.connected = True
|
|
return True
|
|
|
|
async def _actual_disconnect(self) -> None:
|
|
"""Implementation of actual disconnection logic for testing."""
|
|
await asyncio.sleep(0.01) # Simulate disconnection delay
|
|
self.connected = False
|
|
self.subscribed = False
|
|
|
|
async def connect(self) -> bool:
|
|
"""Connect using the connection manager."""
|
|
return await self._connection_manager.connect(self._actual_connect)
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect using the connection manager."""
|
|
await self._connection_manager.disconnect(self._actual_disconnect)
|
|
|
|
async def subscribe_to_data(self, symbols: list, data_types: list) -> bool:
|
|
if not self.connected:
|
|
return False
|
|
self.subscribed = True
|
|
return True
|
|
|
|
async def unsubscribe_from_data(self, symbols: list, data_types: list) -> bool:
|
|
self.subscribed = False
|
|
return True
|
|
|
|
async def _process_message(self, message) -> MarketDataPoint:
|
|
self._state_telemetry.increment_messages_received()
|
|
return MarketDataPoint(
|
|
exchange=self.exchange_name,
|
|
symbol=message.get('symbol', 'BTC-USDT'),
|
|
timestamp=datetime.now(timezone.utc),
|
|
data_type=DataType.TICKER,
|
|
data=message
|
|
)
|
|
|
|
async def _handle_messages(self) -> None:
|
|
# Simulate receiving messages
|
|
if self.messages:
|
|
message = self.messages.pop(0)
|
|
data_point = await self._process_message(message)
|
|
# Note: increment_messages_processed() is called in _notify_callbacks()
|
|
await self._notify_callbacks(data_point)
|
|
else:
|
|
await asyncio.sleep(0.1) # Wait for messages
|
|
|
|
def add_test_message(self, message: dict):
|
|
"""Add a test message to be processed."""
|
|
self.messages.append(message)
|
|
|
|
|
|
class TestBaseDataCollector:
|
|
"""Test cases for BaseDataCollector."""
|
|
|
|
@pytest.fixture
|
|
def collector(self):
|
|
"""Create a test collector instance."""
|
|
return TestDataCollector("okx", ["BTC-USDT", "ETH-USDT"], [DataType.TICKER])
|
|
|
|
def test_initialization(self, collector):
|
|
"""Test collector initialization."""
|
|
assert collector.exchange_name == "okx"
|
|
assert collector.symbols == {"BTC-USDT", "ETH-USDT"}
|
|
assert collector.data_types == [DataType.TICKER]
|
|
assert collector.status == CollectorStatus.STOPPED
|
|
assert not collector._state_telemetry._running
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_stop_cycle(self, collector):
|
|
"""Test starting and stopping the collector."""
|
|
# Test start
|
|
success = await collector.start()
|
|
assert success
|
|
assert collector.status == CollectorStatus.RUNNING
|
|
assert collector.connected
|
|
assert collector.subscribed
|
|
assert collector._state_telemetry._running
|
|
|
|
# Wait a bit for the message loop to start
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Test stop
|
|
await collector.stop()
|
|
assert collector.status == CollectorStatus.STOPPED
|
|
assert not collector._state_telemetry._running
|
|
assert not collector.connected
|
|
assert not collector.subscribed
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_processing(self, collector):
|
|
"""Test message processing and callbacks."""
|
|
received_data = []
|
|
|
|
def callback(data_point: MarketDataPoint):
|
|
received_data.append(data_point)
|
|
|
|
collector.add_data_callback(DataType.TICKER, callback)
|
|
|
|
await collector.start()
|
|
|
|
# Add test message
|
|
test_message = {"symbol": "BTC-USDT", "price": "50000"}
|
|
collector.add_test_message(test_message)
|
|
|
|
# Wait for message processing
|
|
await asyncio.sleep(0.2)
|
|
|
|
await collector.stop()
|
|
|
|
# Verify message was processed
|
|
assert len(received_data) == 1
|
|
assert received_data[0].symbol == "BTC-USDT"
|
|
assert received_data[0].data_type == DataType.TICKER
|
|
assert collector._state_telemetry._stats['messages_received'] == 1
|
|
assert collector._state_telemetry._stats['messages_processed'] == 1
|
|
|
|
def test_symbol_management(self, collector):
|
|
"""Test adding and removing symbols."""
|
|
initial_count = len(collector.symbols)
|
|
|
|
# Add new symbol
|
|
collector.add_symbol("LTC-USDT")
|
|
assert "LTC-USDT" in collector.symbols
|
|
assert len(collector.symbols) == initial_count + 1
|
|
|
|
# Remove symbol
|
|
collector.remove_symbol("BTC-USDT")
|
|
assert "BTC-USDT" not in collector.symbols
|
|
assert len(collector.symbols) == initial_count
|
|
|
|
# Try to add existing symbol (should not duplicate)
|
|
collector.add_symbol("ETH-USDT")
|
|
assert len(collector.symbols) == initial_count
|
|
|
|
def test_callback_management(self, collector):
|
|
"""Test adding and removing callbacks."""
|
|
def callback1(data): pass
|
|
def callback2(data): pass
|
|
|
|
# Add callbacks
|
|
collector.add_data_callback(DataType.TICKER, callback1)
|
|
collector.add_data_callback(DataType.TICKER, callback2)
|
|
assert len(collector._callback_dispatcher._data_callbacks[DataType.TICKER]) == 2
|
|
|
|
# Remove callback
|
|
collector.remove_data_callback(DataType.TICKER, callback1)
|
|
assert len(collector._callback_dispatcher._data_callbacks[DataType.TICKER]) == 1
|
|
assert callback2 in collector._callback_dispatcher._data_callbacks[DataType.TICKER]
|
|
|
|
def test_get_status(self, collector):
|
|
"""Test status reporting."""
|
|
status = collector.get_status()
|
|
|
|
assert status['exchange'] == 'okx'
|
|
assert status['status'] == 'stopped'
|
|
assert set(status['symbols']) == {"BTC-USDT", "ETH-USDT"}
|
|
assert status['data_types'] == ['ticker']
|
|
assert 'statistics' in status
|
|
assert status['statistics']['messages_received'] == 0
|
|
|
|
|
|
class TestOHLCVData:
|
|
"""Test cases for OHLCVData validation."""
|
|
|
|
def test_valid_ohlcv_data(self):
|
|
"""Test creating valid OHLCV data."""
|
|
ohlcv = OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open=Decimal("50000"),
|
|
high=Decimal("50100"),
|
|
low=Decimal("49900"),
|
|
close=Decimal("50050"),
|
|
volume=Decimal("1.5"),
|
|
trades_count=100
|
|
)
|
|
|
|
assert ohlcv.symbol == "BTC-USDT"
|
|
assert ohlcv.timeframe == "1m"
|
|
assert isinstance(ohlcv.open, Decimal)
|
|
assert ohlcv.trades_count == 100
|
|
|
|
def test_invalid_ohlcv_relationships(self):
|
|
"""Test OHLCV validation for invalid price relationships."""
|
|
with pytest.raises(DataValidationError):
|
|
OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open=Decimal("50000"),
|
|
high=Decimal("49000"), # High is less than open
|
|
low=Decimal("49900"),
|
|
close=Decimal("50050"),
|
|
volume=Decimal("1.5")
|
|
)
|
|
|
|
def test_ohlcv_decimal_conversion(self):
|
|
"""Test automatic conversion to Decimal."""
|
|
ohlcv = OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open=50000.0, # float
|
|
high=50100, # int
|
|
low=49900, # int (changed from string to test proper conversion)
|
|
close=50050.0, # float
|
|
volume=1.5 # float
|
|
)
|
|
|
|
assert isinstance(ohlcv.open, Decimal)
|
|
assert isinstance(ohlcv.high, Decimal)
|
|
assert isinstance(ohlcv.low, Decimal)
|
|
assert isinstance(ohlcv.close, Decimal)
|
|
assert isinstance(ohlcv.volume, Decimal)
|
|
|
|
|
|
class TestDataValidation:
|
|
"""Test cases for data validation methods."""
|
|
|
|
def test_validate_ohlcv_data_success(self):
|
|
"""Test successful OHLCV data validation."""
|
|
collector = TestDataCollector("test", ["BTC-USDT"])
|
|
|
|
raw_data = {
|
|
"timestamp": 1609459200000, # Unix timestamp in ms
|
|
"open": "50000",
|
|
"high": "50100",
|
|
"low": "49900",
|
|
"close": "50050",
|
|
"volume": "1.5",
|
|
"trades_count": 100
|
|
}
|
|
|
|
ohlcv = collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
|
|
assert ohlcv.symbol == "BTC-USDT"
|
|
assert ohlcv.timeframe == "1m"
|
|
assert ohlcv.trades_count == 100
|
|
assert isinstance(ohlcv.open, Decimal)
|
|
|
|
def test_validate_ohlcv_data_missing_field(self):
|
|
"""Test OHLCV validation with missing required field."""
|
|
collector = TestDataCollector("test", ["BTC-USDT"])
|
|
|
|
raw_data = {
|
|
"timestamp": 1609459200000,
|
|
"open": "50000",
|
|
"high": "50100",
|
|
# Missing 'low' field
|
|
"close": "50050",
|
|
"volume": "1.5"
|
|
}
|
|
|
|
with pytest.raises(DataValidationError, match="Missing required field: low"):
|
|
collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
|
|
def test_validate_ohlcv_data_invalid_timestamp(self):
|
|
"""Test OHLCV validation with invalid timestamp."""
|
|
collector = TestDataCollector("test", ["BTC-USDT"])
|
|
|
|
raw_data = {
|
|
"timestamp": "invalid_timestamp",
|
|
"open": "50000",
|
|
"high": "50100",
|
|
"low": "49900",
|
|
"close": "50050",
|
|
"volume": "1.5"
|
|
}
|
|
|
|
with pytest.raises(DataValidationError):
|
|
collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connection_error_handling():
|
|
"""Test connection error handling and reconnection."""
|
|
|
|
class FailingCollector(TestDataCollector):
|
|
def __init__(self):
|
|
super().__init__("test", ["BTC-USDT"])
|
|
self.connect_attempts = 0
|
|
self.should_fail = True
|
|
|
|
async def _actual_connect(self) -> bool:
|
|
self.connect_attempts += 1
|
|
if self.should_fail and self.connect_attempts < 3:
|
|
return False # Fail first 2 attempts
|
|
return await super()._actual_connect()
|
|
|
|
collector = FailingCollector()
|
|
|
|
# First start should fail
|
|
success = await collector.start()
|
|
assert not success
|
|
assert collector.status == CollectorStatus.ERROR
|
|
|
|
# Reset for retry and allow success
|
|
collector._connection_manager._reconnect_attempts = 0
|
|
collector._state_telemetry.update_status(CollectorStatus.STOPPED)
|
|
collector.connect_attempts = 0 # Reset connection attempts
|
|
collector.should_fail = False # Allow connection to succeed
|
|
|
|
# This attempt should succeed
|
|
success = await collector.start()
|
|
assert success
|
|
assert collector.status == CollectorStatus.RUNNING
|
|
|
|
await collector.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"]) |