Refactor data module to enhance modularity and maintainability

- Extracted `OHLCVData` and validation logic into a new `common/ohlcv_data.py` module, promoting better organization and reusability.
- Updated `BaseDataCollector` to utilize the new `validate_ohlcv_data` function for improved data validation, enhancing code clarity and maintainability.
- Refactored imports in `data/__init__.py` to reflect the new structure, ensuring consistent access to common data types and exceptions.
- Removed redundant data validation logic from `BaseDataCollector`, streamlining its responsibilities.
- Added unit tests for `OHLCVData` and validation functions to ensure correctness and reliability.

These changes improve the architecture of the data module, aligning with project standards for maintainability and performance.
This commit is contained in:
Vasily.onl
2025-06-10 12:04:58 +08:00
parent 3db8fb1c41
commit 33f2110f19
15 changed files with 511 additions and 1009 deletions

View File

@@ -23,16 +23,26 @@ class TestDataCollector(BaseDataCollector):
self.subscribed = False
self.messages = []
async def connect(self) -> bool:
async def _actual_connect(self) -> bool:
"""Implementation of actual connection logic for testing."""
await asyncio.sleep(0.01) # Simulate connection delay
self.connected = True
return True
async def disconnect(self) -> None:
async def _actual_disconnect(self) -> None:
"""Implementation of actual disconnection logic for testing."""
await asyncio.sleep(0.01) # Simulate disconnection delay
self.connected = False
self.subscribed = False
async def connect(self) -> bool:
"""Connect using the connection manager."""
return await self._connection_manager.connect(self._actual_connect)
async def disconnect(self) -> None:
"""Disconnect using the connection manager."""
await self._connection_manager.disconnect(self._actual_disconnect)
async def subscribe_to_data(self, symbols: list, data_types: list) -> bool:
if not self.connected:
return False
@@ -44,7 +54,7 @@ class TestDataCollector(BaseDataCollector):
return True
async def _process_message(self, message) -> MarketDataPoint:
self._stats['messages_received'] += 1
self._state_telemetry.increment_messages_received()
return MarketDataPoint(
exchange=self.exchange_name,
symbol=message.get('symbol', 'BTC-USDT'),
@@ -58,8 +68,7 @@ class TestDataCollector(BaseDataCollector):
if self.messages:
message = self.messages.pop(0)
data_point = await self._process_message(message)
self._stats['messages_processed'] += 1
self._stats['last_message_time'] = datetime.now(timezone.utc)
# Note: increment_messages_processed() is called in _notify_callbacks()
await self._notify_callbacks(data_point)
else:
await asyncio.sleep(0.1) # Wait for messages
@@ -83,7 +92,7 @@ class TestBaseDataCollector:
assert collector.symbols == {"BTC-USDT", "ETH-USDT"}
assert collector.data_types == [DataType.TICKER]
assert collector.status == CollectorStatus.STOPPED
assert not collector._running
assert not collector._state_telemetry._running
@pytest.mark.asyncio
async def test_start_stop_cycle(self, collector):
@@ -94,7 +103,7 @@ class TestBaseDataCollector:
assert collector.status == CollectorStatus.RUNNING
assert collector.connected
assert collector.subscribed
assert collector._running
assert collector._state_telemetry._running
# Wait a bit for the message loop to start
await asyncio.sleep(0.1)
@@ -102,7 +111,7 @@ class TestBaseDataCollector:
# Test stop
await collector.stop()
assert collector.status == CollectorStatus.STOPPED
assert not collector._running
assert not collector._state_telemetry._running
assert not collector.connected
assert not collector.subscribed
@@ -131,8 +140,8 @@ class TestBaseDataCollector:
assert len(received_data) == 1
assert received_data[0].symbol == "BTC-USDT"
assert received_data[0].data_type == DataType.TICKER
assert collector._stats['messages_received'] == 1
assert collector._stats['messages_processed'] == 1
assert collector._state_telemetry._stats['messages_received'] == 1
assert collector._state_telemetry._stats['messages_processed'] == 1
def test_symbol_management(self, collector):
"""Test adding and removing symbols."""
@@ -160,12 +169,12 @@ class TestBaseDataCollector:
# Add callbacks
collector.add_data_callback(DataType.TICKER, callback1)
collector.add_data_callback(DataType.TICKER, callback2)
assert len(collector._data_callbacks[DataType.TICKER]) == 2
assert len(collector._callback_dispatcher._data_callbacks[DataType.TICKER]) == 2
# Remove callback
collector.remove_data_callback(DataType.TICKER, callback1)
assert len(collector._data_callbacks[DataType.TICKER]) == 1
assert callback2 in collector._data_callbacks[DataType.TICKER]
assert len(collector._callback_dispatcher._data_callbacks[DataType.TICKER]) == 1
assert callback2 in collector._callback_dispatcher._data_callbacks[DataType.TICKER]
def test_get_status(self, collector):
"""Test status reporting."""
@@ -302,11 +311,11 @@ async def test_connection_error_handling():
self.connect_attempts = 0
self.should_fail = True
async def connect(self) -> bool:
async def _actual_connect(self) -> bool:
self.connect_attempts += 1
if self.should_fail and self.connect_attempts < 3:
return False # Fail first 2 attempts
return await super().connect()
return await super()._actual_connect()
collector = FailingCollector()
@@ -316,8 +325,8 @@ async def test_connection_error_handling():
assert collector.status == CollectorStatus.ERROR
# Reset for retry and allow success
collector._reconnect_attempts = 0
collector.status = CollectorStatus.STOPPED
collector._connection_manager._reconnect_attempts = 0
collector._state_telemetry.update_status(CollectorStatus.STOPPED)
collector.connect_attempts = 0 # Reset connection attempts
collector.should_fail = False # Allow connection to succeed

View File

@@ -7,6 +7,7 @@ import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
from utils.logger import get_logger
from data.collector_manager import CollectorManager, ManagerStatus, CollectorConfig
from data.base_collector import BaseDataCollector, DataType, CollectorStatus
@@ -22,7 +23,8 @@ class MockDataCollector(BaseDataCollector):
self.should_fail_subscribe = False
self.fail_count = 0
async def connect(self) -> bool:
async def _actual_connect(self) -> bool:
"""Implementation of actual connection logic for testing."""
if self.should_fail_connect and self.fail_count < 2:
self.fail_count += 1
return False
@@ -30,10 +32,19 @@ class MockDataCollector(BaseDataCollector):
self.connected = True
return True
async def disconnect(self) -> None:
async def _actual_disconnect(self) -> None:
"""Implementation of actual disconnection logic for testing."""
await asyncio.sleep(0.01)
self.connected = False
self.subscribed = False
async def connect(self) -> bool:
"""Connect using the connection manager."""
return await self._connection_manager.connect(self._actual_connect)
async def disconnect(self) -> None:
"""Disconnect using the connection manager."""
await self._connection_manager.disconnect(self._actual_disconnect)
async def subscribe_to_data(self, symbols: list, data_types: list) -> bool:
if self.should_fail_subscribe:
@@ -62,7 +73,8 @@ class TestCollectorManager:
@pytest.fixture
def manager(self):
"""Create a test manager instance."""
return CollectorManager("test_manager", global_health_check_interval=1.0)
test_logger = get_logger("test_manager_logger")
return CollectorManager("test_manager", global_health_check_interval=1.0, logger=test_logger)
@pytest.fixture
def mock_collector(self):

View File

@@ -0,0 +1,230 @@
"""
Unit tests for the OHLCVData module.
"""
import pytest
from datetime import datetime, timezone
from decimal import Decimal
from data.common.ohlcv_data import OHLCVData, DataValidationError, validate_ohlcv_data
class TestOHLCVData:
"""Test cases for OHLCVData validation."""
def test_valid_ohlcv_data(self):
"""Test creating valid OHLCV data."""
ohlcv = OHLCVData(
symbol="BTC-USDT",
timeframe="1m",
timestamp=datetime.now(timezone.utc),
open=Decimal("50000"),
high=Decimal("50100"),
low=Decimal("49900"),
close=Decimal("50050"),
volume=Decimal("1.5"),
trades_count=100
)
assert ohlcv.symbol == "BTC-USDT"
assert ohlcv.timeframe == "1m"
assert isinstance(ohlcv.open, Decimal)
assert ohlcv.trades_count == 100
def test_invalid_ohlcv_relationships(self):
"""Test OHLCV validation for invalid price relationships."""
with pytest.raises(DataValidationError):
OHLCVData(
symbol="BTC-USDT",
timeframe="1m",
timestamp=datetime.now(timezone.utc),
open=Decimal("50000"),
high=Decimal("49000"), # High is less than open
low=Decimal("49900"),
close=Decimal("50050"),
volume=Decimal("1.5")
)
def test_ohlcv_decimal_conversion(self):
"""Test automatic conversion to Decimal."""
ohlcv = OHLCVData(
symbol="BTC-USDT",
timeframe="1m",
timestamp=datetime.now(timezone.utc),
open=50000.0, # float
high=50100, # int
low=49900, # int
close=50050.0, # float
volume=1.5 # float
)
assert isinstance(ohlcv.open, Decimal)
assert isinstance(ohlcv.high, Decimal)
assert isinstance(ohlcv.low, Decimal)
assert isinstance(ohlcv.close, Decimal)
assert isinstance(ohlcv.volume, Decimal)
def test_timezone_handling(self):
"""Test that naive datetimes get UTC timezone."""
naive_timestamp = datetime(2023, 1, 1, 12, 0, 0)
ohlcv = OHLCVData(
symbol="BTC-USDT",
timeframe="1m",
timestamp=naive_timestamp,
open=50000,
high=50100,
low=49900,
close=50050,
volume=1.5
)
assert ohlcv.timestamp.tzinfo == timezone.utc
def test_invalid_price_types(self):
"""Test validation fails for invalid price types."""
with pytest.raises(DataValidationError, match="All OHLCV prices must be numeric"):
OHLCVData(
symbol="BTC-USDT",
timeframe="1m",
timestamp=datetime.now(timezone.utc),
open="invalid", # Invalid type
high=50100,
low=49900,
close=50050,
volume=1.5
)
def test_invalid_volume_type(self):
"""Test validation fails for invalid volume type."""
with pytest.raises(DataValidationError, match="Volume must be numeric"):
OHLCVData(
symbol="BTC-USDT",
timeframe="1m",
timestamp=datetime.now(timezone.utc),
open=50000,
high=50100,
low=49900,
close=50050,
volume="invalid" # Invalid type
)
class TestValidateOhlcvData:
"""Test cases for validate_ohlcv_data function."""
def test_validate_success(self):
"""Test successful OHLCV data validation."""
raw_data = {
"timestamp": 1609459200000, # Unix timestamp in ms
"open": "50000",
"high": "50100",
"low": "49900",
"close": "50050",
"volume": "1.5",
"trades_count": 100
}
ohlcv = validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
assert ohlcv.symbol == "BTC-USDT"
assert ohlcv.timeframe == "1m"
assert ohlcv.trades_count == 100
assert isinstance(ohlcv.open, Decimal)
assert ohlcv.open == Decimal("50000")
def test_validate_missing_field(self):
"""Test validation with missing required field."""
raw_data = {
"timestamp": 1609459200000,
"open": "50000",
"high": "50100",
# Missing 'low' field
"close": "50050",
"volume": "1.5"
}
with pytest.raises(DataValidationError, match="Missing required field: low"):
validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
def test_validate_invalid_timestamp_string(self):
"""Test validation with invalid timestamp string."""
raw_data = {
"timestamp": "invalid_timestamp",
"open": "50000",
"high": "50100",
"low": "49900",
"close": "50050",
"volume": "1.5"
}
with pytest.raises(DataValidationError):
validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
def test_validate_timestamp_formats(self):
"""Test validation with different timestamp formats."""
base_data = {
"open": "50000",
"high": "50100",
"low": "49900",
"close": "50050",
"volume": "1.5"
}
# Unix timestamp in milliseconds
data1 = {**base_data, "timestamp": 1609459200000}
ohlcv1 = validate_ohlcv_data(data1, "BTC-USDT", "1m")
assert isinstance(ohlcv1.timestamp, datetime)
# Unix timestamp in seconds (float)
data2 = {**base_data, "timestamp": 1609459200.5}
ohlcv2 = validate_ohlcv_data(data2, "BTC-USDT", "1m")
assert isinstance(ohlcv2.timestamp, datetime)
# ISO format string
data3 = {**base_data, "timestamp": "2021-01-01T00:00:00Z"}
ohlcv3 = validate_ohlcv_data(data3, "BTC-USDT", "1m")
assert isinstance(ohlcv3.timestamp, datetime)
# Already a datetime object
data4 = {**base_data, "timestamp": datetime.now(timezone.utc)}
ohlcv4 = validate_ohlcv_data(data4, "BTC-USDT", "1m")
assert isinstance(ohlcv4.timestamp, datetime)
def test_validate_invalid_numeric_data(self):
"""Test validation with invalid numeric price data."""
raw_data = {
"timestamp": 1609459200000,
"open": "invalid_number",
"high": "50100",
"low": "49900",
"close": "50050",
"volume": "1.5"
}
with pytest.raises(DataValidationError, match="Invalid OHLCV data for BTC-USDT"):
validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
def test_validate_with_optional_fields(self):
"""Test validation works correctly with optional fields."""
raw_data = {
"timestamp": 1609459200000,
"open": "50000",
"high": "50100",
"low": "49900",
"close": "50050",
"volume": "1.5"
# No trades_count
}
ohlcv = validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
assert ohlcv.trades_count is None
# With trades_count
raw_data["trades_count"] = 250
ohlcv = validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
assert ohlcv.trades_count == 250
if __name__ == "__main__":
pytest.main([__file__, "-v"])