188 lines
6.7 KiB
Python
188 lines
6.7 KiB
Python
|
|
"""
|
||
|
|
Tests for data validation module.
|
||
|
|
|
||
|
|
This module provides comprehensive test coverage for the data validation utilities
|
||
|
|
and base validator class.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from decimal import Decimal
|
||
|
|
from typing import Dict, Any
|
||
|
|
|
||
|
|
from data.common.validation import (
|
||
|
|
ValidationResult,
|
||
|
|
BaseDataValidator,
|
||
|
|
is_valid_decimal,
|
||
|
|
validate_required_fields
|
||
|
|
)
|
||
|
|
from data.common.data_types import DataValidationResult, StandardizedTrade, TradeSide
|
||
|
|
|
||
|
|
|
||
|
|
class TestValidationResult:
|
||
|
|
"""Test ValidationResult class."""
|
||
|
|
|
||
|
|
def test_init_with_defaults(self):
|
||
|
|
"""Test initialization with default values."""
|
||
|
|
result = ValidationResult(is_valid=True)
|
||
|
|
assert result.is_valid
|
||
|
|
assert result.errors == []
|
||
|
|
assert result.warnings == []
|
||
|
|
assert result.sanitized_data is None
|
||
|
|
|
||
|
|
def test_init_with_errors(self):
|
||
|
|
"""Test initialization with errors."""
|
||
|
|
errors = ["Error 1", "Error 2"]
|
||
|
|
result = ValidationResult(is_valid=False, errors=errors)
|
||
|
|
assert not result.is_valid
|
||
|
|
assert result.errors == errors
|
||
|
|
assert result.warnings == []
|
||
|
|
|
||
|
|
def test_init_with_warnings(self):
|
||
|
|
"""Test initialization with warnings."""
|
||
|
|
warnings = ["Warning 1"]
|
||
|
|
result = ValidationResult(is_valid=True, warnings=warnings)
|
||
|
|
assert result.is_valid
|
||
|
|
assert result.warnings == warnings
|
||
|
|
assert result.errors == []
|
||
|
|
|
||
|
|
def test_init_with_sanitized_data(self):
|
||
|
|
"""Test initialization with sanitized data."""
|
||
|
|
data = {"key": "value"}
|
||
|
|
result = ValidationResult(is_valid=True, sanitized_data=data)
|
||
|
|
assert result.sanitized_data == data
|
||
|
|
|
||
|
|
|
||
|
|
class MockDataValidator(BaseDataValidator):
|
||
|
|
"""Mock implementation of BaseDataValidator for testing."""
|
||
|
|
|
||
|
|
def validate_symbol_format(self, symbol: str) -> ValidationResult:
|
||
|
|
"""Mock implementation of validate_symbol_format."""
|
||
|
|
if not symbol or not isinstance(symbol, str):
|
||
|
|
return ValidationResult(False, errors=["Invalid symbol format"])
|
||
|
|
return ValidationResult(True)
|
||
|
|
|
||
|
|
def validate_websocket_message(self, message: Dict[str, Any]) -> DataValidationResult:
|
||
|
|
"""Mock implementation of validate_websocket_message."""
|
||
|
|
if not isinstance(message, dict):
|
||
|
|
return DataValidationResult(False, ["Invalid message format"], [])
|
||
|
|
return DataValidationResult(True, [], [])
|
||
|
|
|
||
|
|
|
||
|
|
class TestBaseDataValidator:
|
||
|
|
"""Test BaseDataValidator class."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def validator(self):
|
||
|
|
"""Create a mock validator instance."""
|
||
|
|
return MockDataValidator("test_exchange")
|
||
|
|
|
||
|
|
def test_validate_price(self, validator):
|
||
|
|
"""Test price validation."""
|
||
|
|
# Test valid price
|
||
|
|
result = validator.validate_price("123.45")
|
||
|
|
assert result.is_valid
|
||
|
|
assert result.sanitized_data == Decimal("123.45")
|
||
|
|
|
||
|
|
# Test invalid price
|
||
|
|
result = validator.validate_price("invalid")
|
||
|
|
assert not result.is_valid
|
||
|
|
assert "Invalid price value" in result.errors[0]
|
||
|
|
|
||
|
|
# Test price bounds
|
||
|
|
result = validator.validate_price("0.000000001") # Below min
|
||
|
|
assert result.is_valid # Still valid but with warning
|
||
|
|
assert "below minimum" in result.warnings[0]
|
||
|
|
|
||
|
|
def test_validate_size(self, validator):
|
||
|
|
"""Test size validation."""
|
||
|
|
# Test valid size
|
||
|
|
result = validator.validate_size("10.5")
|
||
|
|
assert result.is_valid
|
||
|
|
assert result.sanitized_data == Decimal("10.5")
|
||
|
|
|
||
|
|
# Test invalid size
|
||
|
|
result = validator.validate_size("-1")
|
||
|
|
assert not result.is_valid
|
||
|
|
assert "must be positive" in result.errors[0]
|
||
|
|
|
||
|
|
def test_validate_timestamp(self, validator):
|
||
|
|
"""Test timestamp validation."""
|
||
|
|
current_time = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||
|
|
|
||
|
|
# Test valid timestamp
|
||
|
|
result = validator.validate_timestamp(current_time)
|
||
|
|
assert result.is_valid
|
||
|
|
|
||
|
|
# Test invalid timestamp
|
||
|
|
result = validator.validate_timestamp("invalid")
|
||
|
|
assert not result.is_valid
|
||
|
|
assert "Invalid timestamp format" in result.errors[0]
|
||
|
|
|
||
|
|
# Test old timestamp
|
||
|
|
old_timestamp = 999999999999 # Before min_timestamp
|
||
|
|
result = validator.validate_timestamp(old_timestamp)
|
||
|
|
assert not result.is_valid
|
||
|
|
assert "too old" in result.errors[0]
|
||
|
|
|
||
|
|
def test_validate_trade_side(self, validator):
|
||
|
|
"""Test trade side validation."""
|
||
|
|
# Test valid sides
|
||
|
|
assert validator.validate_trade_side("buy").is_valid
|
||
|
|
assert validator.validate_trade_side("sell").is_valid
|
||
|
|
|
||
|
|
# Test invalid sides
|
||
|
|
result = validator.validate_trade_side("invalid")
|
||
|
|
assert not result.is_valid
|
||
|
|
assert "Must be 'buy' or 'sell'" in result.errors[0]
|
||
|
|
|
||
|
|
def test_validate_trade_id(self, validator):
|
||
|
|
"""Test trade ID validation."""
|
||
|
|
# Test valid trade IDs
|
||
|
|
assert validator.validate_trade_id("trade123").is_valid
|
||
|
|
assert validator.validate_trade_id("123").is_valid
|
||
|
|
assert validator.validate_trade_id("trade-123_abc").is_valid
|
||
|
|
|
||
|
|
# Test invalid trade IDs
|
||
|
|
result = validator.validate_trade_id("")
|
||
|
|
assert not result.is_valid
|
||
|
|
assert "cannot be empty" in result.errors[0]
|
||
|
|
|
||
|
|
def test_validate_symbol_match(self, validator):
|
||
|
|
"""Test symbol matching validation."""
|
||
|
|
# Test basic symbol validation
|
||
|
|
assert validator.validate_symbol_match("BTC-USD").is_valid
|
||
|
|
|
||
|
|
# Test symbol mismatch
|
||
|
|
result = validator.validate_symbol_match("BTC-USD", "ETH-USD")
|
||
|
|
assert result.is_valid # Still valid but with warning
|
||
|
|
assert "mismatch" in result.warnings[0]
|
||
|
|
|
||
|
|
# Test invalid symbol type
|
||
|
|
result = validator.validate_symbol_match(123)
|
||
|
|
assert not result.is_valid
|
||
|
|
assert "must be string" in result.errors[0]
|
||
|
|
|
||
|
|
|
||
|
|
def test_is_valid_decimal():
|
||
|
|
"""Test is_valid_decimal utility function."""
|
||
|
|
# Test valid decimals
|
||
|
|
assert is_valid_decimal("123.45")
|
||
|
|
assert is_valid_decimal(123.45)
|
||
|
|
assert is_valid_decimal(Decimal("123.45"))
|
||
|
|
|
||
|
|
# Test invalid decimals
|
||
|
|
assert not is_valid_decimal("invalid")
|
||
|
|
assert not is_valid_decimal(None)
|
||
|
|
assert not is_valid_decimal("")
|
||
|
|
|
||
|
|
|
||
|
|
def test_validate_required_fields():
|
||
|
|
"""Test validate_required_fields utility function."""
|
||
|
|
data = {"field1": "value1", "field2": None, "field3": "value3"}
|
||
|
|
required = ["field1", "field2", "field4"]
|
||
|
|
|
||
|
|
missing = validate_required_fields(data, required)
|
||
|
|
assert "field2" in missing # None value
|
||
|
|
assert "field4" in missing # Missing field
|
||
|
|
assert "field1" not in missing # Present field
|