- Extracted `OHLCVData` and validation logic into a new `common/ohlcv_data.py` module, promoting better organization and reusability. - Updated `BaseDataCollector` to utilize the new `validate_ohlcv_data` function for improved data validation, enhancing code clarity and maintainability. - Refactored imports in `data/__init__.py` to reflect the new structure, ensuring consistent access to common data types and exceptions. - Removed redundant data validation logic from `BaseDataCollector`, streamlining its responsibilities. - Added unit tests for `OHLCVData` and validation functions to ensure correctness and reliability. These changes improve the architecture of the data module, aligning with project standards for maintainability and performance.
230 lines
7.5 KiB
Python
230 lines
7.5 KiB
Python
"""
|
|
Unit tests for the OHLCVData module.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from decimal import Decimal
|
|
|
|
from data.common.ohlcv_data import OHLCVData, DataValidationError, validate_ohlcv_data
|
|
|
|
|
|
class TestOHLCVData:
|
|
"""Test cases for OHLCVData validation."""
|
|
|
|
def test_valid_ohlcv_data(self):
|
|
"""Test creating valid OHLCV data."""
|
|
ohlcv = OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open=Decimal("50000"),
|
|
high=Decimal("50100"),
|
|
low=Decimal("49900"),
|
|
close=Decimal("50050"),
|
|
volume=Decimal("1.5"),
|
|
trades_count=100
|
|
)
|
|
|
|
assert ohlcv.symbol == "BTC-USDT"
|
|
assert ohlcv.timeframe == "1m"
|
|
assert isinstance(ohlcv.open, Decimal)
|
|
assert ohlcv.trades_count == 100
|
|
|
|
def test_invalid_ohlcv_relationships(self):
|
|
"""Test OHLCV validation for invalid price relationships."""
|
|
with pytest.raises(DataValidationError):
|
|
OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open=Decimal("50000"),
|
|
high=Decimal("49000"), # High is less than open
|
|
low=Decimal("49900"),
|
|
close=Decimal("50050"),
|
|
volume=Decimal("1.5")
|
|
)
|
|
|
|
def test_ohlcv_decimal_conversion(self):
|
|
"""Test automatic conversion to Decimal."""
|
|
ohlcv = OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open=50000.0, # float
|
|
high=50100, # int
|
|
low=49900, # int
|
|
close=50050.0, # float
|
|
volume=1.5 # float
|
|
)
|
|
|
|
assert isinstance(ohlcv.open, Decimal)
|
|
assert isinstance(ohlcv.high, Decimal)
|
|
assert isinstance(ohlcv.low, Decimal)
|
|
assert isinstance(ohlcv.close, Decimal)
|
|
assert isinstance(ohlcv.volume, Decimal)
|
|
|
|
def test_timezone_handling(self):
|
|
"""Test that naive datetimes get UTC timezone."""
|
|
naive_timestamp = datetime(2023, 1, 1, 12, 0, 0)
|
|
|
|
ohlcv = OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=naive_timestamp,
|
|
open=50000,
|
|
high=50100,
|
|
low=49900,
|
|
close=50050,
|
|
volume=1.5
|
|
)
|
|
|
|
assert ohlcv.timestamp.tzinfo == timezone.utc
|
|
|
|
def test_invalid_price_types(self):
|
|
"""Test validation fails for invalid price types."""
|
|
with pytest.raises(DataValidationError, match="All OHLCV prices must be numeric"):
|
|
OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open="invalid", # Invalid type
|
|
high=50100,
|
|
low=49900,
|
|
close=50050,
|
|
volume=1.5
|
|
)
|
|
|
|
def test_invalid_volume_type(self):
|
|
"""Test validation fails for invalid volume type."""
|
|
with pytest.raises(DataValidationError, match="Volume must be numeric"):
|
|
OHLCVData(
|
|
symbol="BTC-USDT",
|
|
timeframe="1m",
|
|
timestamp=datetime.now(timezone.utc),
|
|
open=50000,
|
|
high=50100,
|
|
low=49900,
|
|
close=50050,
|
|
volume="invalid" # Invalid type
|
|
)
|
|
|
|
|
|
class TestValidateOhlcvData:
|
|
"""Test cases for validate_ohlcv_data function."""
|
|
|
|
def test_validate_success(self):
|
|
"""Test successful OHLCV data validation."""
|
|
raw_data = {
|
|
"timestamp": 1609459200000, # Unix timestamp in ms
|
|
"open": "50000",
|
|
"high": "50100",
|
|
"low": "49900",
|
|
"close": "50050",
|
|
"volume": "1.5",
|
|
"trades_count": 100
|
|
}
|
|
|
|
ohlcv = validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
|
|
assert ohlcv.symbol == "BTC-USDT"
|
|
assert ohlcv.timeframe == "1m"
|
|
assert ohlcv.trades_count == 100
|
|
assert isinstance(ohlcv.open, Decimal)
|
|
assert ohlcv.open == Decimal("50000")
|
|
|
|
def test_validate_missing_field(self):
|
|
"""Test validation with missing required field."""
|
|
raw_data = {
|
|
"timestamp": 1609459200000,
|
|
"open": "50000",
|
|
"high": "50100",
|
|
# Missing 'low' field
|
|
"close": "50050",
|
|
"volume": "1.5"
|
|
}
|
|
|
|
with pytest.raises(DataValidationError, match="Missing required field: low"):
|
|
validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
|
|
def test_validate_invalid_timestamp_string(self):
|
|
"""Test validation with invalid timestamp string."""
|
|
raw_data = {
|
|
"timestamp": "invalid_timestamp",
|
|
"open": "50000",
|
|
"high": "50100",
|
|
"low": "49900",
|
|
"close": "50050",
|
|
"volume": "1.5"
|
|
}
|
|
|
|
with pytest.raises(DataValidationError):
|
|
validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
|
|
def test_validate_timestamp_formats(self):
|
|
"""Test validation with different timestamp formats."""
|
|
base_data = {
|
|
"open": "50000",
|
|
"high": "50100",
|
|
"low": "49900",
|
|
"close": "50050",
|
|
"volume": "1.5"
|
|
}
|
|
|
|
# Unix timestamp in milliseconds
|
|
data1 = {**base_data, "timestamp": 1609459200000}
|
|
ohlcv1 = validate_ohlcv_data(data1, "BTC-USDT", "1m")
|
|
assert isinstance(ohlcv1.timestamp, datetime)
|
|
|
|
# Unix timestamp in seconds (float)
|
|
data2 = {**base_data, "timestamp": 1609459200.5}
|
|
ohlcv2 = validate_ohlcv_data(data2, "BTC-USDT", "1m")
|
|
assert isinstance(ohlcv2.timestamp, datetime)
|
|
|
|
# ISO format string
|
|
data3 = {**base_data, "timestamp": "2021-01-01T00:00:00Z"}
|
|
ohlcv3 = validate_ohlcv_data(data3, "BTC-USDT", "1m")
|
|
assert isinstance(ohlcv3.timestamp, datetime)
|
|
|
|
# Already a datetime object
|
|
data4 = {**base_data, "timestamp": datetime.now(timezone.utc)}
|
|
ohlcv4 = validate_ohlcv_data(data4, "BTC-USDT", "1m")
|
|
assert isinstance(ohlcv4.timestamp, datetime)
|
|
|
|
def test_validate_invalid_numeric_data(self):
|
|
"""Test validation with invalid numeric price data."""
|
|
raw_data = {
|
|
"timestamp": 1609459200000,
|
|
"open": "invalid_number",
|
|
"high": "50100",
|
|
"low": "49900",
|
|
"close": "50050",
|
|
"volume": "1.5"
|
|
}
|
|
|
|
with pytest.raises(DataValidationError, match="Invalid OHLCV data for BTC-USDT"):
|
|
validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
|
|
def test_validate_with_optional_fields(self):
|
|
"""Test validation works correctly with optional fields."""
|
|
raw_data = {
|
|
"timestamp": 1609459200000,
|
|
"open": "50000",
|
|
"high": "50100",
|
|
"low": "49900",
|
|
"close": "50050",
|
|
"volume": "1.5"
|
|
# No trades_count
|
|
}
|
|
|
|
ohlcv = validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
assert ohlcv.trades_count is None
|
|
|
|
# With trades_count
|
|
raw_data["trades_count"] = 250
|
|
ohlcv = validate_ohlcv_data(raw_data, "BTC-USDT", "1m")
|
|
assert ohlcv.trades_count == 250
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"]) |