- Introduced `BaseDataCollector` and `CollectorManager` classes for standardized data collection and centralized management. - Added health monitoring features, including auto-restart capabilities and detailed status reporting for collectors. - Updated `env.template` to include new logging and health check configurations. - Enhanced documentation in `docs/data_collectors.md` to provide comprehensive guidance on the new data collection system. - Added unit tests for `BaseDataCollector` and `CollectorManager` to ensure reliability and functionality.
333 lines
11 KiB
Python
333 lines
11 KiB
Python
"""
|
|
Unit tests for the BaseDataCollector abstract class.
|
|
"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from decimal import Decimal
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from data.base_collector import (
|
|
BaseDataCollector, DataType, CollectorStatus, MarketDataPoint,
|
|
OHLCVData, DataValidationError, DataCollectorError
|
|
)
|
|
|
|
|
|
class TestDataCollector(BaseDataCollector):
|
|
"""Test implementation of BaseDataCollector for testing."""
|
|
|
|
def __init__(self, exchange_name: str, symbols: list, data_types=None):
|
|
super().__init__(exchange_name, symbols, data_types)
|
|
self.connected = False
|
|
self.subscribed = False
|
|
self.messages = []
|
|
|
|
async def connect(self) -> bool:
|
|
await asyncio.sleep(0.01) # Simulate connection delay
|
|
self.connected = True
|
|
return True
|
|
|
|
async def disconnect(self) -> None:
|
|
await asyncio.sleep(0.01) # Simulate disconnection delay
|
|
self.connected = False
|
|
self.subscribed = False
|
|
|
|
async def subscribe_to_data(self, symbols: list, data_types: list) -> bool:
|
|
if not self.connected:
|
|
return False
|
|
self.subscribed = True
|
|
return True
|
|
|
|
async def unsubscribe_from_data(self, symbols: list, data_types: list) -> bool:
|
|
self.subscribed = False
|
|
return True
|
|
|
|
async def _process_message(self, message) -> MarketDataPoint:
|
|
self._stats['messages_received'] += 1
|
|
return MarketDataPoint(
|
|
exchange=self.exchange_name,
|
|
symbol=message.get('symbol', 'BTC-USDT'),
|
|
timestamp=datetime.now(timezone.utc),
|
|
data_type=DataType.TICKER,
|
|
data=message
|
|
)
|
|
|
|
async def _handle_messages(self) -> None:
|
|
# Simulate receiving messages
|
|
if self.messages:
|
|
message = self.messages.pop(0)
|
|
data_point = await self._process_message(message)
|
|
self._stats['messages_processed'] += 1
|
|
self._stats['last_message_time'] = datetime.now(timezone.utc)
|
|
await self._notify_callbacks(data_point)
|
|
else:
|
|
await asyncio.sleep(0.1) # Wait for messages
|
|
|
|
def add_test_message(self, message: dict):
|
|
"""Add a test message to be processed."""
|
|
self.messages.append(message)
|
|
|
|
|
|
class TestBaseDataCollector:
|
|
"""Test cases for BaseDataCollector."""
|
|
|
|
@pytest.fixture
|
|
def collector(self):
|
|
"""Create a test collector instance."""
|
|
return TestDataCollector("okx", ["BTC-USDT", "ETH-USDT"], [DataType.TICKER])
|
|
|
|
def test_initialization(self, collector):
|
|
"""Test collector initialization."""
|
|
assert collector.exchange_name == "okx"
|
|
assert collector.symbols == {"BTC-USDT", "ETH-USDT"}
|
|
assert collector.data_types == [DataType.TICKER]
|
|
assert collector.status == CollectorStatus.STOPPED
|
|
assert not collector._running
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_stop_cycle(self, collector):
|
|
"""Test starting and stopping the collector."""
|
|
# Test start
|
|
success = await collector.start()
|
|
assert success
|
|
assert collector.status == CollectorStatus.RUNNING
|
|
assert collector.connected
|
|
assert collector.subscribed
|
|
assert collector._running
|
|
|
|
# Wait a bit for the message loop to start
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Test stop
|
|
await collector.stop()
|
|
assert collector.status == CollectorStatus.STOPPED
|
|
assert not collector._running
|
|
assert not collector.connected
|
|
assert not collector.subscribed
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_processing(self, collector):
|
|
"""Test message processing and callbacks."""
|
|
received_data = []
|
|
|
|
def callback(data_point: MarketDataPoint):
|
|
received_data.append(data_point)
|
|
|
|
collector.add_data_callback(DataType.TICKER, callback)
|
|
|
|
await collector.start()
|
|
|
|
# Add test message
|
|
test_message = {"symbol": "BTC-USDT", "price": "50000"}
|
|
collector.add_test_message(test_message)
|
|
|
|
# Wait for message processing
|
|
await asyncio.sleep(0.2)
|
|
|
|
await collector.stop()
|
|
|
|
# Verify message was processed
|
|
assert len(received_data) == 1
|
|
assert received_data[0].symbol == "BTC-USDT"
|
|
assert received_data[0].data_type == DataType.TICKER
|
|
assert collector._stats['messages_received'] == 1
|
|
assert collector._stats['messages_processed'] == 1
|
|
|
|
def test_symbol_management(self, collector):
|
|
"""Test adding and removing symbols."""
|
|
initial_count = len(collector.symbols)
|
|
|
|
# Add new symbol
|
|
collector.add_symbol("LTC-USDT")
|
|
assert "LTC-USDT" in collector.symbols
|
|
assert len(collector.symbols) == initial_count + 1
|
|
|
|
# Remove symbol
|
|
collector.remove_symbol("BTC-USDT")
|
|
assert "BTC-USDT" not in collector.symbols
|
|
assert len(collector.symbols) == initial_count
|
|
|
|
# Try to add existing symbol (should not duplicate)
|
|
collector.add_symbol("ETH-USDT")
|
|
assert len(collector.symbols) == initial_count
|
|
|
|
def test_callback_management(self, collector):
|
|
"""Test adding and removing callbacks."""
|
|
def callback1(data): pass
|
|
def callback2(data): pass
|
|
|
|
# Add callbacks
|
|
collector.add_data_callback(DataType.TICKER, callback1)
|
|
collector.add_data_callback(DataType.TICKER, callback2)
|
|
assert len(collector._data_callbacks[DataType.TICKER]) == 2
|
|
|
|
# Remove callback
|
|
collector.remove_data_callback(DataType.TICKER, callback1)
|
|
assert len(collector._data_callbacks[DataType.TICKER]) == 1
|
|
assert callback2 in collector._data_callbacks[DataType.TICKER]
|
|
|
|
def test_get_status(self, collector):
|
|
"""Test status reporting."""
|
|
status = collector.get_status()
|
|
|
|
assert status['exchange'] == 'okx'
|
|
assert status['status'] == 'stopped'
|
|
assert set(status['symbols']) == {"BTC-USDT", "ETH-USDT"}
|
|
assert status['data_types'] == ['ticker']
|
|
assert 'statistics' in status
|
|
assert status['statistics']['messages_received'] == 0
|
|
|
|
|
|
class TestOHLCVData:
|
|
"""Test cases for OHLCVData validation."""
|
|
|
|
def test_valid_ohlcv_data(self):
|
|
"""Test creating valid OHLCV data."""
|
|
ohlcv = OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open=Decimal("50000"),
|
|
high=Decimal("50100"),
|
|
low=Decimal("49900"),
|
|
close=Decimal("50050"),
|
|
volume=Decimal("1.5"),
|
|
trades_count=100
|
|
)
|
|
|
|
assert ohlcv.symbol == "BTC-USDT"
|
|
assert ohlcv.timeframe == "1m"
|
|
assert isinstance(ohlcv.open, Decimal)
|
|
assert ohlcv.trades_count == 100
|
|
|
|
def test_invalid_ohlcv_relationships(self):
|
|
"""Test OHLCV validation for invalid price relationships."""
|
|
with pytest.raises(DataValidationError):
|
|
OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open=Decimal("50000"),
|
|
high=Decimal("49000"), # High is less than open
|
|
low=Decimal("49900"),
|
|
close=Decimal("50050"),
|
|
volume=Decimal("1.5")
|
|
)
|
|
|
|
def test_ohlcv_decimal_conversion(self):
|
|
"""Test automatic conversion to Decimal."""
|
|
ohlcv = OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open=50000.0, # float
|
|
high=50100, # int
|
|
low=49900, # int (changed from string to test proper conversion)
|
|
close=50050.0, # float
|
|
volume=1.5 # float
|
|
)
|
|
|
|
assert isinstance(ohlcv.open, Decimal)
|
|
assert isinstance(ohlcv.high, Decimal)
|
|
assert isinstance(ohlcv.low, Decimal)
|
|
assert isinstance(ohlcv.close, Decimal)
|
|
assert isinstance(ohlcv.volume, Decimal)
|
|
|
|
|
|
class TestDataValidation:
|
|
"""Test cases for data validation methods."""
|
|
|
|
def test_validate_ohlcv_data_success(self):
|
|
"""Test successful OHLCV data validation."""
|
|
collector = TestDataCollector("test", ["BTC-USDT"])
|
|
|
|
raw_data = {
|
|
"timestamp": 1609459200000, # Unix timestamp in ms
|
|
"open": "50000",
|
|
"high": "50100",
|
|
"low": "49900",
|
|
"close": "50050",
|
|
"volume": "1.5",
|
|
"trades_count": 100
|
|
}
|
|
|
|
ohlcv = collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
|
|
assert ohlcv.symbol == "BTC-USDT"
|
|
assert ohlcv.timeframe == "1m"
|
|
assert ohlcv.trades_count == 100
|
|
assert isinstance(ohlcv.open, Decimal)
|
|
|
|
def test_validate_ohlcv_data_missing_field(self):
|
|
"""Test OHLCV validation with missing required field."""
|
|
collector = TestDataCollector("test", ["BTC-USDT"])
|
|
|
|
raw_data = {
|
|
"timestamp": 1609459200000,
|
|
"open": "50000",
|
|
"high": "50100",
|
|
# Missing 'low' field
|
|
"close": "50050",
|
|
"volume": "1.5"
|
|
}
|
|
|
|
with pytest.raises(DataValidationError, match="Missing required field: low"):
|
|
collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
|
|
def test_validate_ohlcv_data_invalid_timestamp(self):
|
|
"""Test OHLCV validation with invalid timestamp."""
|
|
collector = TestDataCollector("test", ["BTC-USDT"])
|
|
|
|
raw_data = {
|
|
"timestamp": "invalid_timestamp",
|
|
"open": "50000",
|
|
"high": "50100",
|
|
"low": "49900",
|
|
"close": "50050",
|
|
"volume": "1.5"
|
|
}
|
|
|
|
with pytest.raises(DataValidationError):
|
|
collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connection_error_handling():
|
|
"""Test connection error handling and reconnection."""
|
|
|
|
class FailingCollector(TestDataCollector):
|
|
def __init__(self):
|
|
super().__init__("test", ["BTC-USDT"])
|
|
self.connect_attempts = 0
|
|
self.should_fail = True
|
|
|
|
async def connect(self) -> bool:
|
|
self.connect_attempts += 1
|
|
if self.should_fail and self.connect_attempts < 3:
|
|
return False # Fail first 2 attempts
|
|
return await super().connect()
|
|
|
|
collector = FailingCollector()
|
|
|
|
# First start should fail
|
|
success = await collector.start()
|
|
assert not success
|
|
assert collector.status == CollectorStatus.ERROR
|
|
|
|
# Reset for retry and allow success
|
|
collector._reconnect_attempts = 0
|
|
collector.status = CollectorStatus.STOPPED
|
|
collector.connect_attempts = 0 # Reset connection attempts
|
|
collector.should_fail = False # Allow connection to succeed
|
|
|
|
# This attempt should succeed
|
|
success = await collector.start()
|
|
assert success
|
|
assert collector.status == CollectorStatus.RUNNING
|
|
|
|
await collector.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"]) |