TCPDashboard/tests/test_base_collector.py
Vasily.onl 4936e5cd73 Implement enhanced data collection system with health monitoring and management
- Introduced `BaseDataCollector` and `CollectorManager` classes for standardized data collection and centralized management.
- Added health monitoring features, including auto-restart capabilities and detailed status reporting for collectors.
- Updated `env.template` to include new logging and health check configurations.
- Enhanced documentation in `docs/data_collectors.md` to provide comprehensive guidance on the new data collection system.
- Added unit tests for `BaseDataCollector` and `CollectorManager` to ensure reliability and functionality.
2025-05-30 20:33:56 +08:00

333 lines
11 KiB
Python

"""
Unit tests for the BaseDataCollector abstract class.
"""
import asyncio
import pytest
from datetime import datetime, timezone
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
from data.base_collector import (
BaseDataCollector, DataType, CollectorStatus, MarketDataPoint,
OHLCVData, DataValidationError, DataCollectorError
)
class TestDataCollector(BaseDataCollector):
"""Test implementation of BaseDataCollector for testing."""
def __init__(self, exchange_name: str, symbols: list, data_types=None):
super().__init__(exchange_name, symbols, data_types)
self.connected = False
self.subscribed = False
self.messages = []
async def connect(self) -> bool:
await asyncio.sleep(0.01) # Simulate connection delay
self.connected = True
return True
async def disconnect(self) -> None:
await asyncio.sleep(0.01) # Simulate disconnection delay
self.connected = False
self.subscribed = False
async def subscribe_to_data(self, symbols: list, data_types: list) -> bool:
if not self.connected:
return False
self.subscribed = True
return True
async def unsubscribe_from_data(self, symbols: list, data_types: list) -> bool:
self.subscribed = False
return True
async def _process_message(self, message) -> MarketDataPoint:
self._stats['messages_received'] += 1
return MarketDataPoint(
exchange=self.exchange_name,
symbol=message.get('symbol', 'BTC-USDT'),
timestamp=datetime.now(timezone.utc),
data_type=DataType.TICKER,
data=message
)
async def _handle_messages(self) -> None:
# Simulate receiving messages
if self.messages:
message = self.messages.pop(0)
data_point = await self._process_message(message)
self._stats['messages_processed'] += 1
self._stats['last_message_time'] = datetime.now(timezone.utc)
await self._notify_callbacks(data_point)
else:
await asyncio.sleep(0.1) # Wait for messages
def add_test_message(self, message: dict):
"""Add a test message to be processed."""
self.messages.append(message)
class TestBaseDataCollector:
"""Test cases for BaseDataCollector."""
@pytest.fixture
def collector(self):
"""Create a test collector instance."""
return TestDataCollector("okx", ["BTC-USDT", "ETH-USDT"], [DataType.TICKER])
def test_initialization(self, collector):
"""Test collector initialization."""
assert collector.exchange_name == "okx"
assert collector.symbols == {"BTC-USDT", "ETH-USDT"}
assert collector.data_types == [DataType.TICKER]
assert collector.status == CollectorStatus.STOPPED
assert not collector._running
@pytest.mark.asyncio
async def test_start_stop_cycle(self, collector):
"""Test starting and stopping the collector."""
# Test start
success = await collector.start()
assert success
assert collector.status == CollectorStatus.RUNNING
assert collector.connected
assert collector.subscribed
assert collector._running
# Wait a bit for the message loop to start
await asyncio.sleep(0.1)
# Test stop
await collector.stop()
assert collector.status == CollectorStatus.STOPPED
assert not collector._running
assert not collector.connected
assert not collector.subscribed
@pytest.mark.asyncio
async def test_message_processing(self, collector):
"""Test message processing and callbacks."""
received_data = []
def callback(data_point: MarketDataPoint):
received_data.append(data_point)
collector.add_data_callback(DataType.TICKER, callback)
await collector.start()
# Add test message
test_message = {"symbol": "BTC-USDT", "price": "50000"}
collector.add_test_message(test_message)
# Wait for message processing
await asyncio.sleep(0.2)
await collector.stop()
# Verify message was processed
assert len(received_data) == 1
assert received_data[0].symbol == "BTC-USDT"
assert received_data[0].data_type == DataType.TICKER
assert collector._stats['messages_received'] == 1
assert collector._stats['messages_processed'] == 1
def test_symbol_management(self, collector):
"""Test adding and removing symbols."""
initial_count = len(collector.symbols)
# Add new symbol
collector.add_symbol("LTC-USDT")
assert "LTC-USDT" in collector.symbols
assert len(collector.symbols) == initial_count + 1
# Remove symbol
collector.remove_symbol("BTC-USDT")
assert "BTC-USDT" not in collector.symbols
assert len(collector.symbols) == initial_count
# Try to add existing symbol (should not duplicate)
collector.add_symbol("ETH-USDT")
assert len(collector.symbols) == initial_count
def test_callback_management(self, collector):
"""Test adding and removing callbacks."""
def callback1(data): pass
def callback2(data): pass
# Add callbacks
collector.add_data_callback(DataType.TICKER, callback1)
collector.add_data_callback(DataType.TICKER, callback2)
assert len(collector._data_callbacks[DataType.TICKER]) == 2
# Remove callback
collector.remove_data_callback(DataType.TICKER, callback1)
assert len(collector._data_callbacks[DataType.TICKER]) == 1
assert callback2 in collector._data_callbacks[DataType.TICKER]
def test_get_status(self, collector):
"""Test status reporting."""
status = collector.get_status()
assert status['exchange'] == 'okx'
assert status['status'] == 'stopped'
assert set(status['symbols']) == {"BTC-USDT", "ETH-USDT"}
assert status['data_types'] == ['ticker']
assert 'statistics' in status
assert status['statistics']['messages_received'] == 0
class TestOHLCVData:
"""Test cases for OHLCVData validation."""
def test_valid_ohlcv_data(self):
"""Test creating valid OHLCV data."""
ohlcv = OHLCVData(
symbol="BTC-USDT",
timeframe="1m",
timestamp=datetime.now(timezone.utc),
open=Decimal("50000"),
high=Decimal("50100"),
low=Decimal("49900"),
close=Decimal("50050"),
volume=Decimal("1.5"),
trades_count=100
)
assert ohlcv.symbol == "BTC-USDT"
assert ohlcv.timeframe == "1m"
assert isinstance(ohlcv.open, Decimal)
assert ohlcv.trades_count == 100
def test_invalid_ohlcv_relationships(self):
"""Test OHLCV validation for invalid price relationships."""
with pytest.raises(DataValidationError):
OHLCVData(
symbol="BTC-USDT",
timeframe="1m",
timestamp=datetime.now(timezone.utc),
open=Decimal("50000"),
high=Decimal("49000"), # High is less than open
low=Decimal("49900"),
close=Decimal("50050"),
volume=Decimal("1.5")
)
def test_ohlcv_decimal_conversion(self):
"""Test automatic conversion to Decimal."""
ohlcv = OHLCVData(
symbol="BTC-USDT",
timeframe="1m",
timestamp=datetime.now(timezone.utc),
open=50000.0, # float
high=50100, # int
low=49900, # int (changed from string to test proper conversion)
close=50050.0, # float
volume=1.5 # float
)
assert isinstance(ohlcv.open, Decimal)
assert isinstance(ohlcv.high, Decimal)
assert isinstance(ohlcv.low, Decimal)
assert isinstance(ohlcv.close, Decimal)
assert isinstance(ohlcv.volume, Decimal)
class TestDataValidation:
"""Test cases for data validation methods."""
def test_validate_ohlcv_data_success(self):
"""Test successful OHLCV data validation."""
collector = TestDataCollector("test", ["BTC-USDT"])
raw_data = {
"timestamp": 1609459200000, # Unix timestamp in ms
"open": "50000",
"high": "50100",
"low": "49900",
"close": "50050",
"volume": "1.5",
"trades_count": 100
}
ohlcv = collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
assert ohlcv.symbol == "BTC-USDT"
assert ohlcv.timeframe == "1m"
assert ohlcv.trades_count == 100
assert isinstance(ohlcv.open, Decimal)
def test_validate_ohlcv_data_missing_field(self):
"""Test OHLCV validation with missing required field."""
collector = TestDataCollector("test", ["BTC-USDT"])
raw_data = {
"timestamp": 1609459200000,
"open": "50000",
"high": "50100",
# Missing 'low' field
"close": "50050",
"volume": "1.5"
}
with pytest.raises(DataValidationError, match="Missing required field: low"):
collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
def test_validate_ohlcv_data_invalid_timestamp(self):
"""Test OHLCV validation with invalid timestamp."""
collector = TestDataCollector("test", ["BTC-USDT"])
raw_data = {
"timestamp": "invalid_timestamp",
"open": "50000",
"high": "50100",
"low": "49900",
"close": "50050",
"volume": "1.5"
}
with pytest.raises(DataValidationError):
collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
@pytest.mark.asyncio
async def test_connection_error_handling():
"""Test connection error handling and reconnection."""
class FailingCollector(TestDataCollector):
def __init__(self):
super().__init__("test", ["BTC-USDT"])
self.connect_attempts = 0
self.should_fail = True
async def connect(self) -> bool:
self.connect_attempts += 1
if self.should_fail and self.connect_attempts < 3:
return False # Fail first 2 attempts
return await super().connect()
collector = FailingCollector()
# First start should fail
success = await collector.start()
assert not success
assert collector.status == CollectorStatus.ERROR
# Reset for retry and allow success
collector._reconnect_attempts = 0
collector.status = CollectorStatus.STOPPED
collector.connect_attempts = 0 # Reset connection attempts
collector.should_fail = False # Allow connection to succeed
# This attempt should succeed
success = await collector.start()
assert success
assert collector.status == CollectorStatus.RUNNING
await collector.stop()
if __name__ == "__main__":
pytest.main([__file__, "-v"])