360 lines
14 KiB
Python
360 lines
14 KiB
Python
|
|
"""
|
||
|
|
Unit tests for technical indicators module.
|
||
|
|
|
||
|
|
Tests verify that all technical indicators work correctly with sparse OHLCV data
|
||
|
|
and handle edge cases appropriately.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from datetime import datetime, timezone, timedelta
|
||
|
|
from decimal import Decimal
|
||
|
|
import pandas as pd
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from data.common.indicators import (
|
||
|
|
TechnicalIndicators,
|
||
|
|
IndicatorResult,
|
||
|
|
create_default_indicators_config,
|
||
|
|
validate_indicator_config
|
||
|
|
)
|
||
|
|
from data.common.data_types import OHLCVCandle
|
||
|
|
|
||
|
|
|
||
|
|
class TestTechnicalIndicators:
|
||
|
|
"""Test suite for TechnicalIndicators class."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def sample_candles(self):
|
||
|
|
"""Create sample OHLCV candles for testing."""
|
||
|
|
candles = []
|
||
|
|
base_time = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||
|
|
|
||
|
|
# Create 30 candles with realistic price movement
|
||
|
|
prices = [100.0, 101.0, 102.5, 101.8, 103.0, 104.2, 103.8, 105.0, 104.5, 106.0,
|
||
|
|
107.5, 108.0, 107.2, 109.0, 108.5, 110.0, 109.8, 111.0, 110.5, 112.0,
|
||
|
|
111.8, 113.0, 112.5, 114.0, 113.2, 115.0, 114.8, 116.0, 115.5, 117.0]
|
||
|
|
|
||
|
|
for i, price in enumerate(prices):
|
||
|
|
candle = OHLCVCandle(
|
||
|
|
symbol='BTC-USDT',
|
||
|
|
timeframe='1m',
|
||
|
|
start_time=base_time + timedelta(minutes=i),
|
||
|
|
end_time=base_time + timedelta(minutes=i+1),
|
||
|
|
open=Decimal(str(price - 0.2)),
|
||
|
|
high=Decimal(str(price + 0.5)),
|
||
|
|
low=Decimal(str(price - 0.5)),
|
||
|
|
close=Decimal(str(price)),
|
||
|
|
volume=Decimal('1000'),
|
||
|
|
trade_count=10,
|
||
|
|
exchange='test',
|
||
|
|
is_complete=True
|
||
|
|
)
|
||
|
|
candles.append(candle)
|
||
|
|
|
||
|
|
return candles
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def sparse_candles(self):
|
||
|
|
"""Create sparse OHLCV candles (with gaps) for testing."""
|
||
|
|
candles = []
|
||
|
|
base_time = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||
|
|
|
||
|
|
# Create candles with time gaps (sparse data)
|
||
|
|
gap_minutes = [0, 1, 3, 5, 8, 10, 15, 18, 22, 25]
|
||
|
|
prices = [100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0]
|
||
|
|
|
||
|
|
for i, (gap, price) in enumerate(zip(gap_minutes, prices)):
|
||
|
|
candle = OHLCVCandle(
|
||
|
|
symbol='BTC-USDT',
|
||
|
|
timeframe='1m',
|
||
|
|
start_time=base_time + timedelta(minutes=gap),
|
||
|
|
end_time=base_time + timedelta(minutes=gap+1),
|
||
|
|
open=Decimal(str(price - 0.2)),
|
||
|
|
high=Decimal(str(price + 0.5)),
|
||
|
|
low=Decimal(str(price - 0.5)),
|
||
|
|
close=Decimal(str(price)),
|
||
|
|
volume=Decimal('1000'),
|
||
|
|
trade_count=10,
|
||
|
|
exchange='test',
|
||
|
|
is_complete=True
|
||
|
|
)
|
||
|
|
candles.append(candle)
|
||
|
|
|
||
|
|
return candles
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def indicators(self):
|
||
|
|
"""Create TechnicalIndicators instance."""
|
||
|
|
return TechnicalIndicators()
|
||
|
|
|
||
|
|
def test_initialization(self, indicators):
|
||
|
|
"""Test TechnicalIndicators initialization."""
|
||
|
|
assert indicators is not None
|
||
|
|
assert indicators.logger is None
|
||
|
|
|
||
|
|
def test_prepare_dataframe(self, indicators, sample_candles):
|
||
|
|
"""Test DataFrame preparation from OHLCV candles."""
|
||
|
|
df = indicators.prepare_dataframe(sample_candles)
|
||
|
|
|
||
|
|
assert not df.empty
|
||
|
|
assert len(df) == len(sample_candles)
|
||
|
|
assert list(df.columns) == ['symbol', 'timeframe', 'open', 'high', 'low', 'close', 'volume', 'trade_count']
|
||
|
|
assert df.index.name == 'timestamp'
|
||
|
|
|
||
|
|
# Check that timestamps are sorted
|
||
|
|
assert df.index.is_monotonic_increasing
|
||
|
|
|
||
|
|
def test_prepare_dataframe_empty(self, indicators):
|
||
|
|
"""Test DataFrame preparation with empty candles list."""
|
||
|
|
df = indicators.prepare_dataframe([])
|
||
|
|
assert df.empty
|
||
|
|
|
||
|
|
def test_sma_calculation(self, indicators, sample_candles):
|
||
|
|
"""Test Simple Moving Average calculation."""
|
||
|
|
period = 5
|
||
|
|
results = indicators.sma(sample_candles, period)
|
||
|
|
|
||
|
|
# Should have results starting from period 5
|
||
|
|
assert len(results) == len(sample_candles) - period + 1
|
||
|
|
|
||
|
|
# Check first result
|
||
|
|
first_result = results[0]
|
||
|
|
assert isinstance(first_result, IndicatorResult)
|
||
|
|
assert first_result.symbol == 'BTC-USDT'
|
||
|
|
assert first_result.timeframe == '1m'
|
||
|
|
assert 'sma' in first_result.values
|
||
|
|
assert first_result.metadata['period'] == period
|
||
|
|
|
||
|
|
# Verify SMA calculation manually for first result
|
||
|
|
first_5_closes = [float(candle.close) for candle in sample_candles[:5]]
|
||
|
|
expected_sma = sum(first_5_closes) / len(first_5_closes)
|
||
|
|
assert abs(first_result.values['sma'] - expected_sma) < 0.001
|
||
|
|
|
||
|
|
def test_sma_insufficient_data(self, indicators, sample_candles):
|
||
|
|
"""Test SMA with insufficient data."""
|
||
|
|
period = 50 # More than available candles
|
||
|
|
results = indicators.sma(sample_candles, period)
|
||
|
|
assert len(results) == 0
|
||
|
|
|
||
|
|
def test_ema_calculation(self, indicators, sample_candles):
|
||
|
|
"""Test Exponential Moving Average calculation."""
|
||
|
|
period = 10
|
||
|
|
results = indicators.ema(sample_candles, period)
|
||
|
|
|
||
|
|
# Should have results starting from period 10
|
||
|
|
assert len(results) == len(sample_candles) - period + 1
|
||
|
|
|
||
|
|
# Check first result
|
||
|
|
first_result = results[0]
|
||
|
|
assert isinstance(first_result, IndicatorResult)
|
||
|
|
assert 'ema' in first_result.values
|
||
|
|
assert first_result.metadata['period'] == period
|
||
|
|
|
||
|
|
# EMA should be between the range of input prices
|
||
|
|
min_price = min(float(c.close) for c in sample_candles[:period])
|
||
|
|
max_price = max(float(c.close) for c in sample_candles[:period])
|
||
|
|
assert min_price <= first_result.values['ema'] <= max_price
|
||
|
|
|
||
|
|
def test_rsi_calculation(self, indicators, sample_candles):
|
||
|
|
"""Test Relative Strength Index calculation."""
|
||
|
|
period = 14
|
||
|
|
results = indicators.rsi(sample_candles, period)
|
||
|
|
|
||
|
|
# Should have results starting from period 15 (period + 1 for price change calculation)
|
||
|
|
assert len(results) == len(sample_candles) - period
|
||
|
|
|
||
|
|
# Check first result
|
||
|
|
first_result = results[0]
|
||
|
|
assert isinstance(first_result, IndicatorResult)
|
||
|
|
assert 'rsi' in first_result.values
|
||
|
|
assert 0 <= first_result.values['rsi'] <= 100 # RSI should be between 0 and 100
|
||
|
|
assert first_result.metadata['period'] == period
|
||
|
|
|
||
|
|
def test_macd_calculation(self, indicators, sample_candles):
|
||
|
|
"""Test MACD calculation."""
|
||
|
|
fast_period = 12
|
||
|
|
slow_period = 26
|
||
|
|
signal_period = 9
|
||
|
|
results = indicators.macd(sample_candles, fast_period, slow_period, signal_period)
|
||
|
|
|
||
|
|
# MACD needs slow_period + signal_period data points
|
||
|
|
expected_count = len(sample_candles) - slow_period - signal_period + 1
|
||
|
|
assert len(results) == max(0, expected_count)
|
||
|
|
|
||
|
|
if results: # Only test if we have results
|
||
|
|
first_result = results[0]
|
||
|
|
assert isinstance(first_result, IndicatorResult)
|
||
|
|
assert 'macd' in first_result.values
|
||
|
|
assert 'signal' in first_result.values
|
||
|
|
assert 'histogram' in first_result.values
|
||
|
|
|
||
|
|
# Histogram should equal MACD - Signal
|
||
|
|
expected_histogram = first_result.values['macd'] - first_result.values['signal']
|
||
|
|
assert abs(first_result.values['histogram'] - expected_histogram) < 0.001
|
||
|
|
|
||
|
|
def test_bollinger_bands_calculation(self, indicators, sample_candles):
|
||
|
|
"""Test Bollinger Bands calculation."""
|
||
|
|
period = 20
|
||
|
|
std_dev = 2.0
|
||
|
|
results = indicators.bollinger_bands(sample_candles, period, std_dev)
|
||
|
|
|
||
|
|
# Should have results starting from period 20
|
||
|
|
assert len(results) == len(sample_candles) - period + 1
|
||
|
|
|
||
|
|
# Check first result
|
||
|
|
first_result = results[0]
|
||
|
|
assert isinstance(first_result, IndicatorResult)
|
||
|
|
assert 'upper_band' in first_result.values
|
||
|
|
assert 'middle_band' in first_result.values
|
||
|
|
assert 'lower_band' in first_result.values
|
||
|
|
assert 'bandwidth' in first_result.values
|
||
|
|
assert 'percent_b' in first_result.values
|
||
|
|
|
||
|
|
# Upper band should be greater than middle band, which should be greater than lower band
|
||
|
|
assert first_result.values['upper_band'] > first_result.values['middle_band']
|
||
|
|
assert first_result.values['middle_band'] > first_result.values['lower_band']
|
||
|
|
|
||
|
|
def test_sparse_data_handling(self, indicators, sparse_candles):
|
||
|
|
"""Test indicators with sparse data (time gaps)."""
|
||
|
|
period = 5
|
||
|
|
sma_results = indicators.sma(sparse_candles, period)
|
||
|
|
|
||
|
|
# Should handle sparse data without issues
|
||
|
|
assert len(sma_results) > 0
|
||
|
|
|
||
|
|
# Check that timestamps are preserved correctly
|
||
|
|
for result in sma_results:
|
||
|
|
assert result.timestamp is not None
|
||
|
|
assert isinstance(result.timestamp, datetime)
|
||
|
|
|
||
|
|
def test_calculate_multiple_indicators(self, indicators, sample_candles):
|
||
|
|
"""Test calculating multiple indicators at once."""
|
||
|
|
config = {
|
||
|
|
'sma_10': {'type': 'sma', 'period': 10},
|
||
|
|
'ema_12': {'type': 'ema', 'period': 12},
|
||
|
|
'rsi_14': {'type': 'rsi', 'period': 14},
|
||
|
|
'macd': {'type': 'macd'},
|
||
|
|
'bb_20': {'type': 'bollinger_bands', 'period': 20}
|
||
|
|
}
|
||
|
|
|
||
|
|
results = indicators.calculate_multiple_indicators(sample_candles, config)
|
||
|
|
|
||
|
|
assert len(results) == len(config)
|
||
|
|
assert 'sma_10' in results
|
||
|
|
assert 'ema_12' in results
|
||
|
|
assert 'rsi_14' in results
|
||
|
|
assert 'macd' in results
|
||
|
|
assert 'bb_20' in results
|
||
|
|
|
||
|
|
# Check that each indicator has appropriate results
|
||
|
|
assert len(results['sma_10']) > 0
|
||
|
|
assert len(results['ema_12']) > 0
|
||
|
|
|
||
|
|
def test_invalid_indicator_config(self, indicators, sample_candles):
|
||
|
|
"""Test handling of invalid indicator configuration."""
|
||
|
|
config = {
|
||
|
|
'invalid_indicator': {'type': 'unknown_type', 'period': 10}
|
||
|
|
}
|
||
|
|
|
||
|
|
results = indicators.calculate_multiple_indicators(sample_candles, config)
|
||
|
|
|
||
|
|
assert 'invalid_indicator' in results
|
||
|
|
assert len(results['invalid_indicator']) == 0 # Should return empty list
|
||
|
|
|
||
|
|
def test_different_price_columns(self, indicators, sample_candles):
|
||
|
|
"""Test indicators with different price columns."""
|
||
|
|
# Test SMA with 'high' price column
|
||
|
|
sma_high = indicators.sma(sample_candles, 5, price_column='high')
|
||
|
|
sma_close = indicators.sma(sample_candles, 5, price_column='close')
|
||
|
|
|
||
|
|
assert len(sma_high) == len(sma_close)
|
||
|
|
# High prices should generally give higher SMA values
|
||
|
|
assert sma_high[0].values['sma'] >= sma_close[0].values['sma']
|
||
|
|
|
||
|
|
|
||
|
|
class TestIndicatorHelperFunctions:
|
||
|
|
"""Test helper functions for indicators."""
|
||
|
|
|
||
|
|
def test_create_default_indicators_config(self):
|
||
|
|
"""Test default indicators configuration creation."""
|
||
|
|
config = create_default_indicators_config()
|
||
|
|
|
||
|
|
assert isinstance(config, dict)
|
||
|
|
assert 'sma_20' in config
|
||
|
|
assert 'ema_12' in config
|
||
|
|
assert 'rsi_14' in config
|
||
|
|
assert 'macd_default' in config
|
||
|
|
assert 'bollinger_bands_20' in config
|
||
|
|
|
||
|
|
# Check structure of configurations
|
||
|
|
assert config['sma_20']['type'] == 'sma'
|
||
|
|
assert config['sma_20']['period'] == 20
|
||
|
|
assert config['macd_default']['type'] == 'macd'
|
||
|
|
|
||
|
|
def test_validate_indicator_config_valid(self):
|
||
|
|
"""Test validation of valid indicator configurations."""
|
||
|
|
valid_configs = [
|
||
|
|
{'type': 'sma', 'period': 20},
|
||
|
|
{'type': 'ema', 'period': 12},
|
||
|
|
{'type': 'rsi', 'period': 14},
|
||
|
|
{'type': 'macd'},
|
||
|
|
{'type': 'bollinger_bands', 'period': 20, 'std_dev': 2.0}
|
||
|
|
]
|
||
|
|
|
||
|
|
for config in valid_configs:
|
||
|
|
assert validate_indicator_config(config) == True
|
||
|
|
|
||
|
|
def test_validate_indicator_config_invalid(self):
|
||
|
|
"""Test validation of invalid indicator configurations."""
|
||
|
|
invalid_configs = [
|
||
|
|
{}, # Missing type
|
||
|
|
{'type': 'unknown'}, # Invalid type
|
||
|
|
{'type': 'sma', 'period': -5}, # Invalid period
|
||
|
|
{'type': 'sma', 'period': 'not_a_number'}, # Invalid period type
|
||
|
|
{'type': 'bollinger_bands', 'std_dev': -1.0}, # Invalid std_dev
|
||
|
|
]
|
||
|
|
|
||
|
|
for config in invalid_configs:
|
||
|
|
assert validate_indicator_config(config) == False
|
||
|
|
|
||
|
|
|
||
|
|
class TestIndicatorResultDataClass:
|
||
|
|
"""Test IndicatorResult dataclass."""
|
||
|
|
|
||
|
|
def test_indicator_result_creation(self):
|
||
|
|
"""Test IndicatorResult creation and attributes."""
|
||
|
|
timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||
|
|
values = {'sma': 100.5, 'ema': 101.2}
|
||
|
|
metadata = {'period': 20}
|
||
|
|
|
||
|
|
result = IndicatorResult(
|
||
|
|
timestamp=timestamp,
|
||
|
|
symbol='BTC-USDT',
|
||
|
|
timeframe='1m',
|
||
|
|
values=values,
|
||
|
|
metadata=metadata
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.timestamp == timestamp
|
||
|
|
assert result.symbol == 'BTC-USDT'
|
||
|
|
assert result.timeframe == '1m'
|
||
|
|
assert result.values == values
|
||
|
|
assert result.metadata == metadata
|
||
|
|
|
||
|
|
def test_indicator_result_without_metadata(self):
|
||
|
|
"""Test IndicatorResult creation without metadata."""
|
||
|
|
timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||
|
|
values = {'rsi': 65.5}
|
||
|
|
|
||
|
|
result = IndicatorResult(
|
||
|
|
timestamp=timestamp,
|
||
|
|
symbol='ETH-USDT',
|
||
|
|
timeframe='5m',
|
||
|
|
values=values
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.metadata is None
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
pytest.main([__file__])
|