- Updated all technical indicators to return pandas DataFrames instead of lists, improving consistency and usability. - Modified the `calculate` method in `TechnicalIndicators` to directly return DataFrames with relevant indicator values. - Enhanced the `data_integration.py` to utilize the new DataFrame outputs for better integration with charting. - Updated documentation to reflect the new DataFrame-centric approach, including usage examples and output structures. - Improved error handling to ensure empty DataFrames are returned when insufficient data is available. These changes streamline the indicator calculations and improve the overall architecture, aligning with project standards for maintainability and performance.
342 lines
14 KiB
Python
342 lines
14 KiB
Python
"""
|
|
Unit tests for technical indicators module.
|
|
|
|
Tests verify that all technical indicators work correctly with sparse OHLCV data
|
|
and handle edge cases appropriately.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone, timedelta
|
|
from decimal import Decimal
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
from data.common.indicators import (
|
|
TechnicalIndicators,
|
|
IndicatorResult,
|
|
create_default_indicators_config,
|
|
validate_indicator_config
|
|
)
|
|
from data.common.data_types import OHLCVCandle
|
|
|
|
|
|
class TestTechnicalIndicators:
|
|
"""Test suite for TechnicalIndicators class."""
|
|
|
|
@pytest.fixture
|
|
def sample_candles(self):
|
|
"""Create sample OHLCV candles for testing."""
|
|
candles = []
|
|
base_time = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
|
|
|
# Create 30 candles with realistic price movement
|
|
prices = [100.0, 101.0, 102.5, 101.8, 103.0, 104.2, 103.8, 105.0, 104.5, 106.0,
|
|
107.5, 108.0, 107.2, 109.0, 108.5, 110.0, 109.8, 111.0, 110.5, 112.0,
|
|
111.8, 113.0, 112.5, 114.0, 113.2, 115.0, 114.8, 116.0, 115.5, 117.0]
|
|
|
|
for i, price in enumerate(prices):
|
|
candle = OHLCVCandle(
|
|
symbol='BTC-USDT',
|
|
timeframe='1m',
|
|
start_time=base_time + timedelta(minutes=i),
|
|
end_time=base_time + timedelta(minutes=i+1),
|
|
open=Decimal(str(price - 0.2)),
|
|
high=Decimal(str(price + 0.5)),
|
|
low=Decimal(str(price - 0.5)),
|
|
close=Decimal(str(price)),
|
|
volume=Decimal('1000'),
|
|
trade_count=10,
|
|
exchange='test',
|
|
is_complete=True
|
|
)
|
|
candles.append(candle)
|
|
|
|
return candles
|
|
|
|
@pytest.fixture
|
|
def sparse_candles(self):
|
|
"""Create sparse OHLCV candles (with gaps) for testing."""
|
|
candles = []
|
|
base_time = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
|
|
|
# Create candles with time gaps (sparse data)
|
|
gap_minutes = [0, 1, 3, 5, 8, 10, 15, 18, 22, 25]
|
|
prices = [100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0]
|
|
|
|
for i, (gap, price) in enumerate(zip(gap_minutes, prices)):
|
|
candle = OHLCVCandle(
|
|
symbol='BTC-USDT',
|
|
timeframe='1m',
|
|
start_time=base_time + timedelta(minutes=gap),
|
|
end_time=base_time + timedelta(minutes=gap+1),
|
|
open=Decimal(str(price - 0.2)),
|
|
high=Decimal(str(price + 0.5)),
|
|
low=Decimal(str(price - 0.5)),
|
|
close=Decimal(str(price)),
|
|
volume=Decimal('1000'),
|
|
trade_count=10,
|
|
exchange='test',
|
|
is_complete=True
|
|
)
|
|
candles.append(candle)
|
|
|
|
return candles
|
|
|
|
@pytest.fixture
|
|
def indicators(self):
|
|
"""Create TechnicalIndicators instance."""
|
|
return TechnicalIndicators()
|
|
|
|
def test_initialization(self, indicators):
|
|
"""Test TechnicalIndicators initialization."""
|
|
assert indicators is not None
|
|
assert indicators.logger is None
|
|
|
|
def test_prepare_dataframe(self, indicators, sample_candles):
|
|
"""Test DataFrame preparation from OHLCV candles."""
|
|
df = indicators._prepare_dataframe_from_list(sample_candles)
|
|
assert not df.empty
|
|
assert len(df) == len(sample_candles)
|
|
assert list(df.columns) == ['symbol', 'timeframe', 'open', 'high', 'low', 'close', 'volume', 'trade_count', 'timestamp']
|
|
assert df.index.name == 'timestamp'
|
|
# Check that timestamps are sorted
|
|
assert df.index.is_monotonic_increasing
|
|
|
|
def test_prepare_dataframe_empty(self, indicators):
|
|
"""Test DataFrame preparation with empty candles list."""
|
|
df = indicators._prepare_dataframe_from_list([])
|
|
assert df.empty
|
|
|
|
def test_sma_calculation(self, indicators, sample_candles):
|
|
"""Test Simple Moving Average calculation (now returns DataFrame)."""
|
|
period = 5
|
|
df = indicators._prepare_dataframe_from_list(sample_candles)
|
|
df['timestamp'] = df.index
|
|
result_df = indicators.sma(df, period)
|
|
assert isinstance(result_df, pd.DataFrame)
|
|
assert not result_df.empty
|
|
assert 'sma' in result_df.columns
|
|
# Find the correct rolling window for the first SMA value
|
|
first_ts = result_df.index[0]
|
|
first_idx = [candle.end_time for candle in sample_candles].index(first_ts)
|
|
window_closes = [float(candle.close) for candle in sample_candles[first_idx - period + 1:first_idx + 1]]
|
|
expected_sma = sum(window_closes) / len(window_closes)
|
|
assert abs(result_df.iloc[0]['sma'] - expected_sma) < 0.001
|
|
|
|
def test_sma_insufficient_data(self, indicators, sample_candles):
|
|
"""Test SMA with insufficient data (now returns DataFrame)."""
|
|
period = 50 # More than available candles
|
|
df = indicators._prepare_dataframe_from_list(sample_candles)
|
|
df['timestamp'] = df.index
|
|
result_df = indicators.sma(df, period)
|
|
assert isinstance(result_df, pd.DataFrame)
|
|
assert result_df.empty
|
|
|
|
def test_ema_calculation(self, indicators, sample_candles):
|
|
"""Test Exponential Moving Average calculation (now returns DataFrame)."""
|
|
period = 10
|
|
df = indicators._prepare_dataframe_from_list(sample_candles)
|
|
df['timestamp'] = df.index
|
|
result_df = indicators.ema(df, period)
|
|
# Should have results starting from period 10
|
|
assert isinstance(result_df, pd.DataFrame)
|
|
assert len(result_df) == len(sample_candles) - period + 1
|
|
assert 'ema' in result_df.columns
|
|
# EMA should be between the range of input prices
|
|
min_price = min(float(c.close) for c in sample_candles[:period])
|
|
max_price = max(float(c.close) for c in sample_candles[:period])
|
|
assert min_price <= result_df.iloc[0]['ema'] <= max_price
|
|
|
|
def test_rsi_calculation(self, indicators, sample_candles):
|
|
"""Test Relative Strength Index calculation (now returns DataFrame)."""
|
|
period = 14
|
|
df = indicators._prepare_dataframe_from_list(sample_candles)
|
|
df['timestamp'] = df.index
|
|
result_df = indicators.rsi(df, period)
|
|
assert isinstance(result_df, pd.DataFrame)
|
|
assert not result_df.empty
|
|
assert 'rsi' in result_df.columns
|
|
assert 0 <= result_df.iloc[0]['rsi'] <= 100
|
|
|
|
def test_macd_calculation(self, indicators, sample_candles):
|
|
"""Test MACD calculation (now returns DataFrame)."""
|
|
fast_period = 12
|
|
slow_period = 26
|
|
signal_period = 9
|
|
df = indicators._prepare_dataframe_from_list(sample_candles)
|
|
df['timestamp'] = df.index
|
|
result_df = indicators.macd(df, fast_period, slow_period, signal_period)
|
|
# MACD results start after max(slow_period, signal_period) - 1 rows
|
|
min_required = max(slow_period, signal_period)
|
|
expected_count = max(0, len(sample_candles) - (min_required - 1))
|
|
assert isinstance(result_df, pd.DataFrame)
|
|
assert len(result_df) == expected_count
|
|
assert 'macd' in result_df.columns
|
|
assert 'signal' in result_df.columns
|
|
assert 'histogram' in result_df.columns
|
|
if not result_df.empty:
|
|
# Histogram should equal MACD - Signal
|
|
first_row = result_df.iloc[0]
|
|
expected_histogram = first_row['macd'] - first_row['signal']
|
|
assert abs(first_row['histogram'] - expected_histogram) < 0.001
|
|
|
|
def test_bollinger_bands_calculation(self, indicators, sample_candles):
|
|
"""Test Bollinger Bands calculation (now returns DataFrame)."""
|
|
period = 20
|
|
std_dev = 2.0
|
|
df = indicators._prepare_dataframe_from_list(sample_candles)
|
|
df['timestamp'] = df.index
|
|
result_df = indicators.bollinger_bands(df, period, std_dev)
|
|
# Should have results starting from period 20
|
|
assert isinstance(result_df, pd.DataFrame)
|
|
assert len(result_df) == len(sample_candles) - period + 1
|
|
assert 'upper_band' in result_df.columns
|
|
assert 'middle_band' in result_df.columns
|
|
assert 'lower_band' in result_df.columns
|
|
# Upper band should be greater than middle band, which should be greater than lower band
|
|
first_row = result_df.iloc[0]
|
|
assert first_row['upper_band'] > first_row['middle_band']
|
|
assert first_row['middle_band'] > first_row['lower_band']
|
|
|
|
def test_sparse_data_handling(self, indicators, sparse_candles):
|
|
"""Test indicators with sparse data (time gaps)."""
|
|
period = 5
|
|
df = indicators._prepare_dataframe_from_list(sparse_candles)
|
|
df['timestamp'] = df.index
|
|
sma_df = indicators.sma(df, period)
|
|
# Should handle sparse data without issues
|
|
assert not sma_df.empty
|
|
# Check that timestamps are preserved correctly
|
|
for ts in sma_df.index:
|
|
assert ts is not None
|
|
assert isinstance(ts, datetime)
|
|
|
|
def test_calculate_multiple_indicators(self, indicators, sample_candles):
|
|
"""Test calculating multiple indicators at once."""
|
|
config = {
|
|
'sma_10': {'type': 'sma', 'period': 10},
|
|
'ema_12': {'type': 'ema', 'period': 12},
|
|
'rsi_14': {'type': 'rsi', 'period': 14},
|
|
'macd': {'type': 'macd'},
|
|
'bb_20': {'type': 'bollinger_bands', 'period': 20}
|
|
}
|
|
df = indicators._prepare_dataframe_from_list(sample_candles)
|
|
df['timestamp'] = df.index
|
|
results = indicators.calculate_multiple_indicators(df, config)
|
|
assert len(results) == len(config)
|
|
assert 'sma_10' in results
|
|
assert 'ema_12' in results
|
|
assert 'rsi_14' in results
|
|
assert 'macd' in results
|
|
assert 'bb_20' in results
|
|
# Check that each indicator has appropriate results
|
|
assert len(results['sma_10']) > 0
|
|
assert len(results['ema_12']) > 0
|
|
|
|
def test_invalid_indicator_config(self, indicators, sample_candles):
|
|
"""Test handling of invalid indicator configuration."""
|
|
config = {
|
|
'invalid_indicator': {'type': 'unknown_type', 'period': 10}
|
|
}
|
|
df = indicators._prepare_dataframe_from_list(sample_candles)
|
|
results = indicators.calculate_multiple_indicators(df, config)
|
|
assert 'invalid_indicator' in results
|
|
assert len(results['invalid_indicator']) == 0 # Should return empty list
|
|
|
|
def test_different_price_columns(self, indicators, sample_candles):
|
|
"""Test indicators with different price columns (now returns DataFrame)."""
|
|
df = indicators._prepare_dataframe_from_list(sample_candles)
|
|
# Test SMA with 'high' price column
|
|
sma_high = indicators.sma(df, 5, price_column='high')
|
|
sma_close = indicators.sma(df, 5, price_column='close')
|
|
assert len(sma_high) == len(sma_close)
|
|
# High prices should generally give higher SMA values
|
|
assert sma_high.iloc[0]['sma'] >= sma_close.iloc[0]['sma']
|
|
|
|
|
|
class TestIndicatorHelperFunctions:
|
|
"""Test helper functions for indicators."""
|
|
|
|
def test_create_default_indicators_config(self):
|
|
"""Test default indicators configuration creation."""
|
|
config = create_default_indicators_config()
|
|
|
|
assert isinstance(config, dict)
|
|
assert 'sma_20' in config
|
|
assert 'ema_12' in config
|
|
assert 'rsi_14' in config
|
|
assert 'macd_default' in config
|
|
assert 'bollinger_bands_20' in config
|
|
|
|
# Check structure of configurations
|
|
assert config['sma_20']['type'] == 'sma'
|
|
assert config['sma_20']['period'] == 20
|
|
assert config['macd_default']['type'] == 'macd'
|
|
|
|
def test_validate_indicator_config_valid(self):
|
|
"""Test validation of valid indicator configurations."""
|
|
valid_configs = [
|
|
{'type': 'sma', 'period': 20},
|
|
{'type': 'ema', 'period': 12},
|
|
{'type': 'rsi', 'period': 14},
|
|
{'type': 'macd'},
|
|
{'type': 'bollinger_bands', 'period': 20, 'std_dev': 2.0}
|
|
]
|
|
|
|
for config in valid_configs:
|
|
assert validate_indicator_config(config) == True
|
|
|
|
def test_validate_indicator_config_invalid(self):
|
|
"""Test validation of invalid indicator configurations."""
|
|
invalid_configs = [
|
|
{}, # Missing type
|
|
{'type': 'unknown'}, # Invalid type
|
|
{'type': 'sma', 'period': -5}, # Invalid period
|
|
{'type': 'sma', 'period': 'not_a_number'}, # Invalid period type
|
|
{'type': 'bollinger_bands', 'std_dev': -1.0}, # Invalid std_dev
|
|
]
|
|
|
|
for config in invalid_configs:
|
|
assert validate_indicator_config(config) == False
|
|
|
|
|
|
class TestIndicatorResultDataClass:
|
|
"""Test IndicatorResult dataclass."""
|
|
|
|
def test_indicator_result_creation(self):
|
|
"""Test IndicatorResult creation and attributes."""
|
|
timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
|
values = {'sma': 100.5, 'ema': 101.2}
|
|
metadata = {'period': 20}
|
|
|
|
result = IndicatorResult(
|
|
timestamp=timestamp,
|
|
symbol='BTC-USDT',
|
|
timeframe='1m',
|
|
values=values,
|
|
metadata=metadata
|
|
)
|
|
|
|
assert result.timestamp == timestamp
|
|
assert result.symbol == 'BTC-USDT'
|
|
assert result.timeframe == '1m'
|
|
assert result.values == values
|
|
assert result.metadata == metadata
|
|
|
|
def test_indicator_result_without_metadata(self):
|
|
"""Test IndicatorResult creation without metadata."""
|
|
timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
|
values = {'rsi': 65.5}
|
|
|
|
result = IndicatorResult(
|
|
timestamp=timestamp,
|
|
symbol='ETH-USDT',
|
|
timeframe='5m',
|
|
values=values
|
|
)
|
|
|
|
assert result.metadata is None
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__]) |