Refactor data module to enhance modularity and maintainability
- Extracted `OHLCVData` and validation logic into a new `common/ohlcv_data.py` module, promoting better organization and reusability. - Updated `BaseDataCollector` to utilize the new `validate_ohlcv_data` function for improved data validation, enhancing code clarity and maintainability. - Refactored imports in `data/__init__.py` to reflect the new structure, ensuring consistent access to common data types and exceptions. - Removed redundant data validation logic from `BaseDataCollector`, streamlining its responsibilities. - Added unit tests for `OHLCVData` and validation functions to ensure correctness and reliability. These changes improve the architecture of the data module, aligning with project standards for maintainability and performance.
This commit is contained in:
342
tests/data/collector/test_base_collector.py
Normal file
342
tests/data/collector/test_base_collector.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Unit tests for the BaseDataCollector abstract class.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from data.base_collector import (
|
||||
BaseDataCollector, DataType, CollectorStatus, MarketDataPoint,
|
||||
OHLCVData, DataValidationError, DataCollectorError
|
||||
)
|
||||
|
||||
|
||||
class TestDataCollector(BaseDataCollector):
|
||||
"""Test implementation of BaseDataCollector for testing."""
|
||||
|
||||
def __init__(self, exchange_name: str, symbols: list, data_types=None):
|
||||
super().__init__(exchange_name, symbols, data_types)
|
||||
self.connected = False
|
||||
self.subscribed = False
|
||||
self.messages = []
|
||||
|
||||
async def _actual_connect(self) -> bool:
|
||||
"""Implementation of actual connection logic for testing."""
|
||||
await asyncio.sleep(0.01) # Simulate connection delay
|
||||
self.connected = True
|
||||
return True
|
||||
|
||||
async def _actual_disconnect(self) -> None:
|
||||
"""Implementation of actual disconnection logic for testing."""
|
||||
await asyncio.sleep(0.01) # Simulate disconnection delay
|
||||
self.connected = False
|
||||
self.subscribed = False
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect using the connection manager."""
|
||||
return await self._connection_manager.connect(self._actual_connect)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect using the connection manager."""
|
||||
await self._connection_manager.disconnect(self._actual_disconnect)
|
||||
|
||||
async def subscribe_to_data(self, symbols: list, data_types: list) -> bool:
|
||||
if not self.connected:
|
||||
return False
|
||||
self.subscribed = True
|
||||
return True
|
||||
|
||||
async def unsubscribe_from_data(self, symbols: list, data_types: list) -> bool:
|
||||
self.subscribed = False
|
||||
return True
|
||||
|
||||
async def _process_message(self, message) -> MarketDataPoint:
|
||||
self._state_telemetry.increment_messages_received()
|
||||
return MarketDataPoint(
|
||||
exchange=self.exchange_name,
|
||||
symbol=message.get('symbol', 'BTC-USDT'),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
data_type=DataType.TICKER,
|
||||
data=message
|
||||
)
|
||||
|
||||
async def _handle_messages(self) -> None:
|
||||
# Simulate receiving messages
|
||||
if self.messages:
|
||||
message = self.messages.pop(0)
|
||||
data_point = await self._process_message(message)
|
||||
# Note: increment_messages_processed() is called in _notify_callbacks()
|
||||
await self._notify_callbacks(data_point)
|
||||
else:
|
||||
await asyncio.sleep(0.1) # Wait for messages
|
||||
|
||||
def add_test_message(self, message: dict):
|
||||
"""Add a test message to be processed."""
|
||||
self.messages.append(message)
|
||||
|
||||
|
||||
class TestBaseDataCollector:
|
||||
"""Test cases for BaseDataCollector."""
|
||||
|
||||
@pytest.fixture
|
||||
def collector(self):
|
||||
"""Create a test collector instance."""
|
||||
return TestDataCollector("okx", ["BTC-USDT", "ETH-USDT"], [DataType.TICKER])
|
||||
|
||||
def test_initialization(self, collector):
|
||||
"""Test collector initialization."""
|
||||
assert collector.exchange_name == "okx"
|
||||
assert collector.symbols == {"BTC-USDT", "ETH-USDT"}
|
||||
assert collector.data_types == [DataType.TICKER]
|
||||
assert collector.status == CollectorStatus.STOPPED
|
||||
assert not collector._state_telemetry._running
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop_cycle(self, collector):
|
||||
"""Test starting and stopping the collector."""
|
||||
# Test start
|
||||
success = await collector.start()
|
||||
assert success
|
||||
assert collector.status == CollectorStatus.RUNNING
|
||||
assert collector.connected
|
||||
assert collector.subscribed
|
||||
assert collector._state_telemetry._running
|
||||
|
||||
# Wait a bit for the message loop to start
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Test stop
|
||||
await collector.stop()
|
||||
assert collector.status == CollectorStatus.STOPPED
|
||||
assert not collector._state_telemetry._running
|
||||
assert not collector.connected
|
||||
assert not collector.subscribed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_processing(self, collector):
|
||||
"""Test message processing and callbacks."""
|
||||
received_data = []
|
||||
|
||||
def callback(data_point: MarketDataPoint):
|
||||
received_data.append(data_point)
|
||||
|
||||
collector.add_data_callback(DataType.TICKER, callback)
|
||||
|
||||
await collector.start()
|
||||
|
||||
# Add test message
|
||||
test_message = {"symbol": "BTC-USDT", "price": "50000"}
|
||||
collector.add_test_message(test_message)
|
||||
|
||||
# Wait for message processing
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
await collector.stop()
|
||||
|
||||
# Verify message was processed
|
||||
assert len(received_data) == 1
|
||||
assert received_data[0].symbol == "BTC-USDT"
|
||||
assert received_data[0].data_type == DataType.TICKER
|
||||
assert collector._state_telemetry._stats['messages_received'] == 1
|
||||
assert collector._state_telemetry._stats['messages_processed'] == 1
|
||||
|
||||
def test_symbol_management(self, collector):
|
||||
"""Test adding and removing symbols."""
|
||||
initial_count = len(collector.symbols)
|
||||
|
||||
# Add new symbol
|
||||
collector.add_symbol("LTC-USDT")
|
||||
assert "LTC-USDT" in collector.symbols
|
||||
assert len(collector.symbols) == initial_count + 1
|
||||
|
||||
# Remove symbol
|
||||
collector.remove_symbol("BTC-USDT")
|
||||
assert "BTC-USDT" not in collector.symbols
|
||||
assert len(collector.symbols) == initial_count
|
||||
|
||||
# Try to add existing symbol (should not duplicate)
|
||||
collector.add_symbol("ETH-USDT")
|
||||
assert len(collector.symbols) == initial_count
|
||||
|
||||
def test_callback_management(self, collector):
|
||||
"""Test adding and removing callbacks."""
|
||||
def callback1(data): pass
|
||||
def callback2(data): pass
|
||||
|
||||
# Add callbacks
|
||||
collector.add_data_callback(DataType.TICKER, callback1)
|
||||
collector.add_data_callback(DataType.TICKER, callback2)
|
||||
assert len(collector._callback_dispatcher._data_callbacks[DataType.TICKER]) == 2
|
||||
|
||||
# Remove callback
|
||||
collector.remove_data_callback(DataType.TICKER, callback1)
|
||||
assert len(collector._callback_dispatcher._data_callbacks[DataType.TICKER]) == 1
|
||||
assert callback2 in collector._callback_dispatcher._data_callbacks[DataType.TICKER]
|
||||
|
||||
def test_get_status(self, collector):
|
||||
"""Test status reporting."""
|
||||
status = collector.get_status()
|
||||
|
||||
assert status['exchange'] == 'okx'
|
||||
assert status['status'] == 'stopped'
|
||||
assert set(status['symbols']) == {"BTC-USDT", "ETH-USDT"}
|
||||
assert status['data_types'] == ['ticker']
|
||||
assert 'statistics' in status
|
||||
assert status['statistics']['messages_received'] == 0
|
||||
|
||||
|
||||
class TestOHLCVData:
|
||||
"""Test cases for OHLCVData validation."""
|
||||
|
||||
def test_valid_ohlcv_data(self):
|
||||
"""Test creating valid OHLCV data."""
|
||||
ohlcv = OHLCVData(
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1m",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("50100"),
|
||||
low=Decimal("49900"),
|
||||
close=Decimal("50050"),
|
||||
volume=Decimal("1.5"),
|
||||
trades_count=100
|
||||
)
|
||||
|
||||
assert ohlcv.symbol == "BTC-USDT"
|
||||
assert ohlcv.timeframe == "1m"
|
||||
assert isinstance(ohlcv.open, Decimal)
|
||||
assert ohlcv.trades_count == 100
|
||||
|
||||
def test_invalid_ohlcv_relationships(self):
|
||||
"""Test OHLCV validation for invalid price relationships."""
|
||||
with pytest.raises(DataValidationError):
|
||||
OHLCVData(
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1m",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("49000"), # High is less than open
|
||||
low=Decimal("49900"),
|
||||
close=Decimal("50050"),
|
||||
volume=Decimal("1.5")
|
||||
)
|
||||
|
||||
def test_ohlcv_decimal_conversion(self):
|
||||
"""Test automatic conversion to Decimal."""
|
||||
ohlcv = OHLCVData(
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1m",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
open=50000.0, # float
|
||||
high=50100, # int
|
||||
low=49900, # int (changed from string to test proper conversion)
|
||||
close=50050.0, # float
|
||||
volume=1.5 # float
|
||||
)
|
||||
|
||||
assert isinstance(ohlcv.open, Decimal)
|
||||
assert isinstance(ohlcv.high, Decimal)
|
||||
assert isinstance(ohlcv.low, Decimal)
|
||||
assert isinstance(ohlcv.close, Decimal)
|
||||
assert isinstance(ohlcv.volume, Decimal)
|
||||
|
||||
|
||||
class TestDataValidation:
|
||||
"""Test cases for data validation methods."""
|
||||
|
||||
def test_validate_ohlcv_data_success(self):
|
||||
"""Test successful OHLCV data validation."""
|
||||
collector = TestDataCollector("test", ["BTC-USDT"])
|
||||
|
||||
raw_data = {
|
||||
"timestamp": 1609459200000, # Unix timestamp in ms
|
||||
"open": "50000",
|
||||
"high": "50100",
|
||||
"low": "49900",
|
||||
"close": "50050",
|
||||
"volume": "1.5",
|
||||
"trades_count": 100
|
||||
}
|
||||
|
||||
ohlcv = collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
||||
|
||||
assert ohlcv.symbol == "BTC-USDT"
|
||||
assert ohlcv.timeframe == "1m"
|
||||
assert ohlcv.trades_count == 100
|
||||
assert isinstance(ohlcv.open, Decimal)
|
||||
|
||||
def test_validate_ohlcv_data_missing_field(self):
|
||||
"""Test OHLCV validation with missing required field."""
|
||||
collector = TestDataCollector("test", ["BTC-USDT"])
|
||||
|
||||
raw_data = {
|
||||
"timestamp": 1609459200000,
|
||||
"open": "50000",
|
||||
"high": "50100",
|
||||
# Missing 'low' field
|
||||
"close": "50050",
|
||||
"volume": "1.5"
|
||||
}
|
||||
|
||||
with pytest.raises(DataValidationError, match="Missing required field: low"):
|
||||
collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
||||
|
||||
def test_validate_ohlcv_data_invalid_timestamp(self):
|
||||
"""Test OHLCV validation with invalid timestamp."""
|
||||
collector = TestDataCollector("test", ["BTC-USDT"])
|
||||
|
||||
raw_data = {
|
||||
"timestamp": "invalid_timestamp",
|
||||
"open": "50000",
|
||||
"high": "50100",
|
||||
"low": "49900",
|
||||
"close": "50050",
|
||||
"volume": "1.5"
|
||||
}
|
||||
|
||||
with pytest.raises(DataValidationError):
|
||||
collector.validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_error_handling():
|
||||
"""Test connection error handling and reconnection."""
|
||||
|
||||
class FailingCollector(TestDataCollector):
|
||||
def __init__(self):
|
||||
super().__init__("test", ["BTC-USDT"])
|
||||
self.connect_attempts = 0
|
||||
self.should_fail = True
|
||||
|
||||
async def _actual_connect(self) -> bool:
|
||||
self.connect_attempts += 1
|
||||
if self.should_fail and self.connect_attempts < 3:
|
||||
return False # Fail first 2 attempts
|
||||
return await super()._actual_connect()
|
||||
|
||||
collector = FailingCollector()
|
||||
|
||||
# First start should fail
|
||||
success = await collector.start()
|
||||
assert not success
|
||||
assert collector.status == CollectorStatus.ERROR
|
||||
|
||||
# Reset for retry and allow success
|
||||
collector._connection_manager._reconnect_attempts = 0
|
||||
collector._state_telemetry.update_status(CollectorStatus.STOPPED)
|
||||
collector.connect_attempts = 0 # Reset connection attempts
|
||||
collector.should_fail = False # Allow connection to succeed
|
||||
|
||||
# This attempt should succeed
|
||||
success = await collector.start()
|
||||
assert success
|
||||
assert collector.status == CollectorStatus.RUNNING
|
||||
|
||||
await collector.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
353
tests/data/collector/test_collector_manager.py
Normal file
353
tests/data/collector/test_collector_manager.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Unit tests for the CollectorManager class.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from utils.logger import get_logger
|
||||
from data.collector_manager import CollectorManager, ManagerStatus, CollectorConfig
|
||||
from data.base_collector import BaseDataCollector, DataType, CollectorStatus
|
||||
|
||||
|
||||
class MockDataCollector(BaseDataCollector):
|
||||
"""Mock implementation of BaseDataCollector for testing."""
|
||||
|
||||
def __init__(self, exchange_name: str, symbols: list, auto_restart: bool = True):
|
||||
super().__init__(exchange_name, symbols, [DataType.TICKER], auto_restart=auto_restart)
|
||||
self.connected = False
|
||||
self.subscribed = False
|
||||
self.should_fail_connect = False
|
||||
self.should_fail_subscribe = False
|
||||
self.fail_count = 0
|
||||
|
||||
async def _actual_connect(self) -> bool:
|
||||
"""Implementation of actual connection logic for testing."""
|
||||
if self.should_fail_connect and self.fail_count < 2:
|
||||
self.fail_count += 1
|
||||
return False
|
||||
await asyncio.sleep(0.01)
|
||||
self.connected = True
|
||||
return True
|
||||
|
||||
async def _actual_disconnect(self) -> None:
|
||||
"""Implementation of actual disconnection logic for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
self.connected = False
|
||||
self.subscribed = False
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect using the connection manager."""
|
||||
return await self._connection_manager.connect(self._actual_connect)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect using the connection manager."""
|
||||
await self._connection_manager.disconnect(self._actual_disconnect)
|
||||
|
||||
async def subscribe_to_data(self, symbols: list, data_types: list) -> bool:
|
||||
if self.should_fail_subscribe:
|
||||
return False
|
||||
if not self.connected:
|
||||
return False
|
||||
self.subscribed = True
|
||||
return True
|
||||
|
||||
async def unsubscribe_from_data(self, symbols: list, data_types: list) -> bool:
|
||||
self.subscribed = False
|
||||
return True
|
||||
|
||||
async def _process_message(self, message) -> None:
|
||||
# No message processing in mock
|
||||
pass
|
||||
|
||||
async def _handle_messages(self) -> None:
|
||||
# Simulate light processing
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
class TestCollectorManager:
|
||||
"""Test cases for CollectorManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self):
|
||||
"""Create a test manager instance."""
|
||||
test_logger = get_logger("test_manager_logger")
|
||||
return CollectorManager("test_manager", global_health_check_interval=1.0, logger=test_logger)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_collector(self):
|
||||
"""Create a mock collector."""
|
||||
return MockDataCollector("okx", ["BTC-USDT", "ETH-USDT"])
|
||||
|
||||
def test_initialization(self, manager):
|
||||
"""Test manager initialization."""
|
||||
assert manager.manager_name == "test_manager"
|
||||
assert manager.status == ManagerStatus.STOPPED
|
||||
assert len(manager._collectors) == 0
|
||||
assert len(manager._enabled_collectors) == 0
|
||||
|
||||
def test_add_collector(self, manager, mock_collector):
|
||||
"""Test adding a collector to the manager."""
|
||||
# Add collector
|
||||
manager.add_collector(mock_collector)
|
||||
|
||||
assert len(manager._collectors) == 1
|
||||
assert len(manager._enabled_collectors) == 1
|
||||
|
||||
# Verify collector is in the collections
|
||||
collector_names = manager.list_collectors()
|
||||
assert len(collector_names) == 1
|
||||
assert collector_names[0].startswith("okx_")
|
||||
|
||||
# Test with custom config using a different collector instance
|
||||
mock_collector2 = MockDataCollector("binance", ["ETH-USDT"])
|
||||
config = CollectorConfig(
|
||||
name="custom_collector",
|
||||
exchange="binance",
|
||||
symbols=["ETH-USDT"],
|
||||
data_types=["ticker"],
|
||||
enabled=False
|
||||
)
|
||||
manager.add_collector(mock_collector2, config)
|
||||
assert len(manager._collectors) == 2
|
||||
assert len(manager._enabled_collectors) == 1 # Still 1 since second is disabled
|
||||
|
||||
def test_remove_collector(self, manager, mock_collector):
|
||||
"""Test removing a collector from the manager."""
|
||||
# Add then remove
|
||||
manager.add_collector(mock_collector)
|
||||
collector_names = manager.list_collectors()
|
||||
collector_name = collector_names[0]
|
||||
|
||||
success = manager.remove_collector(collector_name)
|
||||
assert success
|
||||
assert len(manager._collectors) == 0
|
||||
assert len(manager._enabled_collectors) == 0
|
||||
|
||||
# Test removing non-existent collector
|
||||
success = manager.remove_collector("non_existent")
|
||||
assert not success
|
||||
|
||||
def test_enable_disable_collector(self, manager, mock_collector):
|
||||
"""Test enabling and disabling collectors."""
|
||||
manager.add_collector(mock_collector)
|
||||
collector_name = manager.list_collectors()[0]
|
||||
|
||||
# Initially enabled
|
||||
assert collector_name in manager._enabled_collectors
|
||||
|
||||
# Disable
|
||||
success = manager.disable_collector(collector_name)
|
||||
assert success
|
||||
assert collector_name not in manager._enabled_collectors
|
||||
|
||||
# Enable again
|
||||
success = manager.enable_collector(collector_name)
|
||||
assert success
|
||||
assert collector_name in manager._enabled_collectors
|
||||
|
||||
# Test with non-existent collector
|
||||
success = manager.enable_collector("non_existent")
|
||||
assert not success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop_manager(self, manager, mock_collector):
|
||||
"""Test starting and stopping the manager."""
|
||||
# Add a collector
|
||||
manager.add_collector(mock_collector)
|
||||
|
||||
# Start manager
|
||||
success = await manager.start()
|
||||
assert success
|
||||
assert manager.status == ManagerStatus.RUNNING
|
||||
|
||||
# Wait a bit for collectors to start
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Check collector is running
|
||||
running_collectors = manager.get_running_collectors()
|
||||
assert len(running_collectors) == 1
|
||||
|
||||
# Stop manager
|
||||
await manager.stop()
|
||||
assert manager.status == ManagerStatus.STOPPED
|
||||
|
||||
# Check collector is stopped
|
||||
running_collectors = manager.get_running_collectors()
|
||||
assert len(running_collectors) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_collector(self, manager, mock_collector):
|
||||
"""Test restarting a specific collector."""
|
||||
manager.add_collector(mock_collector)
|
||||
await manager.start()
|
||||
|
||||
collector_name = manager.list_collectors()[0]
|
||||
|
||||
# Wait for collector to start
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Restart the collector
|
||||
success = await manager.restart_collector(collector_name)
|
||||
assert success
|
||||
|
||||
# Check statistics
|
||||
status = manager.get_status()
|
||||
assert status['statistics']['restarts_performed'] >= 1
|
||||
|
||||
await manager.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_monitoring(self, manager):
|
||||
"""Test health monitoring and auto-restart functionality."""
|
||||
# Create a collector that will fail initially
|
||||
failing_collector = MockDataCollector("test", ["BTC-USDT"], auto_restart=True)
|
||||
failing_collector.should_fail_connect = True
|
||||
|
||||
manager.add_collector(failing_collector)
|
||||
await manager.start()
|
||||
|
||||
# Wait for health checks
|
||||
await asyncio.sleep(2.5) # More than health check interval
|
||||
|
||||
# Check that restarts were attempted
|
||||
status = manager.get_status()
|
||||
failed_collectors = manager.get_failed_collectors()
|
||||
|
||||
# The collector should have been marked as failed and restart attempts made
|
||||
assert len(failed_collectors) >= 0 # May have recovered
|
||||
|
||||
await manager.stop()
|
||||
|
||||
def test_get_status(self, manager, mock_collector):
|
||||
"""Test status reporting."""
|
||||
manager.add_collector(mock_collector)
|
||||
|
||||
status = manager.get_status()
|
||||
|
||||
assert status['manager_status'] == 'stopped'
|
||||
assert status['total_collectors'] == 1
|
||||
assert len(status['enabled_collectors']) == 1
|
||||
assert 'statistics' in status
|
||||
assert 'collectors' in status
|
||||
|
||||
def test_get_collector_status(self, manager, mock_collector):
|
||||
"""Test getting individual collector status."""
|
||||
manager.add_collector(mock_collector)
|
||||
collector_name = manager.list_collectors()[0]
|
||||
|
||||
collector_status = manager.get_collector_status(collector_name)
|
||||
|
||||
assert collector_status is not None
|
||||
assert collector_status['name'] == collector_name
|
||||
assert 'config' in collector_status
|
||||
assert 'status' in collector_status
|
||||
assert 'health' in collector_status
|
||||
|
||||
# Test non-existent collector
|
||||
non_existent_status = manager.get_collector_status("non_existent")
|
||||
assert non_existent_status is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_all_collectors(self, manager):
|
||||
"""Test restarting all collectors."""
|
||||
# Add multiple collectors
|
||||
collector1 = MockDataCollector("okx", ["BTC-USDT"])
|
||||
collector2 = MockDataCollector("binance", ["ETH-USDT"])
|
||||
|
||||
manager.add_collector(collector1)
|
||||
manager.add_collector(collector2)
|
||||
|
||||
await manager.start()
|
||||
await asyncio.sleep(0.2) # Let them start
|
||||
|
||||
# Restart all
|
||||
results = await manager.restart_all_collectors()
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(success for success in results.values())
|
||||
|
||||
await manager.stop()
|
||||
|
||||
def test_get_running_and_failed_collectors(self, manager, mock_collector):
|
||||
"""Test getting running and failed collector lists."""
|
||||
manager.add_collector(mock_collector)
|
||||
|
||||
# Initially no running collectors
|
||||
running = manager.get_running_collectors()
|
||||
failed = manager.get_failed_collectors()
|
||||
|
||||
assert len(running) == 0
|
||||
# Note: failed might be empty since collector hasn't started yet
|
||||
|
||||
def test_collector_config(self):
|
||||
"""Test CollectorConfig dataclass."""
|
||||
config = CollectorConfig(
|
||||
name="test_collector",
|
||||
exchange="okx",
|
||||
symbols=["BTC-USDT", "ETH-USDT"],
|
||||
data_types=["ticker", "trade"],
|
||||
auto_restart=True,
|
||||
health_check_interval=30.0,
|
||||
enabled=True
|
||||
)
|
||||
|
||||
assert config.name == "test_collector"
|
||||
assert config.exchange == "okx"
|
||||
assert len(config.symbols) == 2
|
||||
assert len(config.data_types) == 2
|
||||
assert config.auto_restart is True
|
||||
assert config.enabled is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_with_connection_failures():
|
||||
"""Test manager handling collectors with connection failures."""
|
||||
manager = CollectorManager("test_manager", global_health_check_interval=0.5)
|
||||
|
||||
# Create a collector that fails connection initially
|
||||
failing_collector = MockDataCollector("failing_exchange", ["BTC-USDT"])
|
||||
failing_collector.should_fail_connect = True
|
||||
|
||||
manager.add_collector(failing_collector)
|
||||
|
||||
# Start manager
|
||||
success = await manager.start()
|
||||
assert success # Manager should start even if collectors fail
|
||||
|
||||
# Wait for some health checks
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
# Check that the failing collector is detected
|
||||
failed_collectors = manager.get_failed_collectors()
|
||||
status = manager.get_status()
|
||||
|
||||
# The collector should be in failed state or have restart attempts
|
||||
assert status['statistics']['restarts_performed'] >= 0
|
||||
|
||||
await manager.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_graceful_shutdown():
|
||||
"""Test that manager shuts down gracefully even with problematic collectors."""
|
||||
manager = CollectorManager("test_manager")
|
||||
|
||||
# Add multiple collectors
|
||||
for i in range(3):
|
||||
collector = MockDataCollector(f"exchange_{i}", ["BTC-USDT"])
|
||||
manager.add_collector(collector)
|
||||
|
||||
await manager.start()
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Stop should complete even if collectors take time
|
||||
await manager.stop()
|
||||
|
||||
assert manager.status == ManagerStatus.STOPPED
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
230
tests/data/test_ohlcv_data.py
Normal file
230
tests/data/test_ohlcv_data.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Unit tests for the OHLCVData module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
|
||||
from data.common.ohlcv_data import OHLCVData, DataValidationError, validate_ohlcv_data
|
||||
|
||||
|
||||
class TestOHLCVData:
|
||||
"""Test cases for OHLCVData validation."""
|
||||
|
||||
def test_valid_ohlcv_data(self):
|
||||
"""Test creating valid OHLCV data."""
|
||||
ohlcv = OHLCVData(
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1m",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("50100"),
|
||||
low=Decimal("49900"),
|
||||
close=Decimal("50050"),
|
||||
volume=Decimal("1.5"),
|
||||
trades_count=100
|
||||
)
|
||||
|
||||
assert ohlcv.symbol == "BTC-USDT"
|
||||
assert ohlcv.timeframe == "1m"
|
||||
assert isinstance(ohlcv.open, Decimal)
|
||||
assert ohlcv.trades_count == 100
|
||||
|
||||
def test_invalid_ohlcv_relationships(self):
|
||||
"""Test OHLCV validation for invalid price relationships."""
|
||||
with pytest.raises(DataValidationError):
|
||||
OHLCVData(
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1m",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("49000"), # High is less than open
|
||||
low=Decimal("49900"),
|
||||
close=Decimal("50050"),
|
||||
volume=Decimal("1.5")
|
||||
)
|
||||
|
||||
def test_ohlcv_decimal_conversion(self):
|
||||
"""Test automatic conversion to Decimal."""
|
||||
ohlcv = OHLCVData(
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1m",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
open=50000.0, # float
|
||||
high=50100, # int
|
||||
low=49900, # int
|
||||
close=50050.0, # float
|
||||
volume=1.5 # float
|
||||
)
|
||||
|
||||
assert isinstance(ohlcv.open, Decimal)
|
||||
assert isinstance(ohlcv.high, Decimal)
|
||||
assert isinstance(ohlcv.low, Decimal)
|
||||
assert isinstance(ohlcv.close, Decimal)
|
||||
assert isinstance(ohlcv.volume, Decimal)
|
||||
|
||||
def test_timezone_handling(self):
|
||||
"""Test that naive datetimes get UTC timezone."""
|
||||
naive_timestamp = datetime(2023, 1, 1, 12, 0, 0)
|
||||
|
||||
ohlcv = OHLCVData(
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1m",
|
||||
timestamp=naive_timestamp,
|
||||
open=50000,
|
||||
high=50100,
|
||||
low=49900,
|
||||
close=50050,
|
||||
volume=1.5
|
||||
)
|
||||
|
||||
assert ohlcv.timestamp.tzinfo == timezone.utc
|
||||
|
||||
def test_invalid_price_types(self):
|
||||
"""Test validation fails for invalid price types."""
|
||||
with pytest.raises(DataValidationError, match="All OHLCV prices must be numeric"):
|
||||
OHLCVData(
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1m",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
open="invalid", # Invalid type
|
||||
high=50100,
|
||||
low=49900,
|
||||
close=50050,
|
||||
volume=1.5
|
||||
)
|
||||
|
||||
def test_invalid_volume_type(self):
|
||||
"""Test validation fails for invalid volume type."""
|
||||
with pytest.raises(DataValidationError, match="Volume must be numeric"):
|
||||
OHLCVData(
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1m",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
open=50000,
|
||||
high=50100,
|
||||
low=49900,
|
||||
close=50050,
|
||||
volume="invalid" # Invalid type
|
||||
)
|
||||
|
||||
|
||||
class TestValidateOhlcvData:
|
||||
"""Test cases for validate_ohlcv_data function."""
|
||||
|
||||
def test_validate_success(self):
|
||||
"""Test successful OHLCV data validation."""
|
||||
raw_data = {
|
||||
"timestamp": 1609459200000, # Unix timestamp in ms
|
||||
"open": "50000",
|
||||
"high": "50100",
|
||||
"low": "49900",
|
||||
"close": "50050",
|
||||
"volume": "1.5",
|
||||
"trades_count": 100
|
||||
}
|
||||
|
||||
ohlcv = validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
||||
|
||||
assert ohlcv.symbol == "BTC-USDT"
|
||||
assert ohlcv.timeframe == "1m"
|
||||
assert ohlcv.trades_count == 100
|
||||
assert isinstance(ohlcv.open, Decimal)
|
||||
assert ohlcv.open == Decimal("50000")
|
||||
|
||||
def test_validate_missing_field(self):
|
||||
"""Test validation with missing required field."""
|
||||
raw_data = {
|
||||
"timestamp": 1609459200000,
|
||||
"open": "50000",
|
||||
"high": "50100",
|
||||
# Missing 'low' field
|
||||
"close": "50050",
|
||||
"volume": "1.5"
|
||||
}
|
||||
|
||||
with pytest.raises(DataValidationError, match="Missing required field: low"):
|
||||
validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
||||
|
||||
def test_validate_invalid_timestamp_string(self):
|
||||
"""Test validation with invalid timestamp string."""
|
||||
raw_data = {
|
||||
"timestamp": "invalid_timestamp",
|
||||
"open": "50000",
|
||||
"high": "50100",
|
||||
"low": "49900",
|
||||
"close": "50050",
|
||||
"volume": "1.5"
|
||||
}
|
||||
|
||||
with pytest.raises(DataValidationError):
|
||||
validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
||||
|
||||
def test_validate_timestamp_formats(self):
|
||||
"""Test validation with different timestamp formats."""
|
||||
base_data = {
|
||||
"open": "50000",
|
||||
"high": "50100",
|
||||
"low": "49900",
|
||||
"close": "50050",
|
||||
"volume": "1.5"
|
||||
}
|
||||
|
||||
# Unix timestamp in milliseconds
|
||||
data1 = {**base_data, "timestamp": 1609459200000}
|
||||
ohlcv1 = validate_ohlcv_data(data1, "BTC-USDT", "1m")
|
||||
assert isinstance(ohlcv1.timestamp, datetime)
|
||||
|
||||
# Unix timestamp in seconds (float)
|
||||
data2 = {**base_data, "timestamp": 1609459200.5}
|
||||
ohlcv2 = validate_ohlcv_data(data2, "BTC-USDT", "1m")
|
||||
assert isinstance(ohlcv2.timestamp, datetime)
|
||||
|
||||
# ISO format string
|
||||
data3 = {**base_data, "timestamp": "2021-01-01T00:00:00Z"}
|
||||
ohlcv3 = validate_ohlcv_data(data3, "BTC-USDT", "1m")
|
||||
assert isinstance(ohlcv3.timestamp, datetime)
|
||||
|
||||
# Already a datetime object
|
||||
data4 = {**base_data, "timestamp": datetime.now(timezone.utc)}
|
||||
ohlcv4 = validate_ohlcv_data(data4, "BTC-USDT", "1m")
|
||||
assert isinstance(ohlcv4.timestamp, datetime)
|
||||
|
||||
def test_validate_invalid_numeric_data(self):
|
||||
"""Test validation with invalid numeric price data."""
|
||||
raw_data = {
|
||||
"timestamp": 1609459200000,
|
||||
"open": "invalid_number",
|
||||
"high": "50100",
|
||||
"low": "49900",
|
||||
"close": "50050",
|
||||
"volume": "1.5"
|
||||
}
|
||||
|
||||
with pytest.raises(DataValidationError, match="Invalid OHLCV data for BTC-USDT"):
|
||||
validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
||||
|
||||
def test_validate_with_optional_fields(self):
|
||||
"""Test validation works correctly with optional fields."""
|
||||
raw_data = {
|
||||
"timestamp": 1609459200000,
|
||||
"open": "50000",
|
||||
"high": "50100",
|
||||
"low": "49900",
|
||||
"close": "50050",
|
||||
"volume": "1.5"
|
||||
# No trades_count
|
||||
}
|
||||
|
||||
ohlcv = validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
||||
assert ohlcv.trades_count is None
|
||||
|
||||
# With trades_count
|
||||
raw_data["trades_count"] = 250
|
||||
ohlcv = validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
||||
assert ohlcv.trades_count == 250
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user