- Introduced a new `strategies` package containing the core structure for trading strategies, including `BaseStrategy`, `StrategyFactory`, and various strategy implementations (EMA, RSI, MACD). - Added utility functions for signal detection and validation in `strategies/utils.py`, enhancing modularity and maintainability. - Updated `pyproject.toml` to include the new `strategies` package in the build configuration. - Implemented comprehensive unit tests for the strategy foundation components, ensuring reliability and adherence to project standards. These changes establish a solid foundation for the strategy engine, aligning with project goals for modularity, performance, and maintainability.
162 lines
6.7 KiB
Python
162 lines
6.7 KiB
Python
import pytest
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
from unittest.mock import MagicMock
|
|
|
|
from strategies.base import BaseStrategy
|
|
from strategies.data_types import StrategyResult, StrategySignal, SignalType
|
|
from data.common.data_types import OHLCVCandle
|
|
|
|
# Mock logger for testing
|
|
class MockLogger:
|
|
def __init__(self):
|
|
self.info_calls = []
|
|
self.warning_calls = []
|
|
self.error_calls = []
|
|
|
|
def info(self, message):
|
|
self.info_calls.append(message)
|
|
|
|
def warning(self, message):
|
|
self.warning_calls.append(message)
|
|
|
|
def error(self, message):
|
|
self.error_calls.append(message)
|
|
|
|
# Concrete implementation of BaseStrategy for testing purposes
|
|
class ConcreteStrategy(BaseStrategy):
|
|
def __init__(self, logger=None):
|
|
super().__init__("ConcreteStrategy", logger)
|
|
|
|
def get_required_indicators(self) -> list[dict]:
|
|
return []
|
|
|
|
def calculate(self, data: pd.DataFrame, **kwargs) -> list[StrategyResult]:
|
|
# Simple mock calculation for testing
|
|
signals = []
|
|
if not data.empty:
|
|
first_row = data.iloc[0]
|
|
signals.append(StrategyResult(
|
|
timestamp=first_row.name,
|
|
symbol=first_row['symbol'],
|
|
timeframe=first_row['timeframe'],
|
|
strategy_name=self.strategy_name,
|
|
signals=[StrategySignal(
|
|
timestamp=first_row.name,
|
|
symbol=first_row['symbol'],
|
|
timeframe=first_row['timeframe'],
|
|
signal_type=SignalType.BUY,
|
|
price=float(first_row['close']),
|
|
confidence=1.0
|
|
)],
|
|
indicators_used={}
|
|
))
|
|
return signals
|
|
|
|
@pytest.fixture
|
|
def mock_logger():
|
|
return MockLogger()
|
|
|
|
@pytest.fixture
|
|
def concrete_strategy(mock_logger):
|
|
return ConcreteStrategy(logger=mock_logger)
|
|
|
|
@pytest.fixture
|
|
def sample_ohlcv_data():
|
|
return pd.DataFrame({
|
|
'open': [100, 101, 102, 103, 104],
|
|
'high': [105, 106, 107, 108, 109],
|
|
'low': [99, 100, 101, 102, 103],
|
|
'close': [102, 103, 104, 105, 106],
|
|
'volume': [1000, 1100, 1200, 1300, 1400],
|
|
'symbol': ['BTC/USDT'] * 5,
|
|
'timeframe': ['1h'] * 5
|
|
}, index=pd.to_datetime(['2023-01-01 00:00:00', '2023-01-01 01:00:00', '2023-01-01 02:00:00',
|
|
'2023-01-01 03:00:00', '2023-01-01 04:00:00']))
|
|
|
|
def test_prepare_dataframe_initial_data(concrete_strategy, sample_ohlcv_data):
|
|
prepared_df = concrete_strategy.prepare_dataframe(sample_ohlcv_data)
|
|
assert 'open' in prepared_df.columns
|
|
assert 'high' in prepared_df.columns
|
|
assert 'low' in prepared_df.columns
|
|
assert 'close' in prepared_df.columns
|
|
assert 'volume' in prepared_df.columns
|
|
assert 'symbol' in prepared_df.columns
|
|
assert 'timeframe' in prepared_df.columns
|
|
assert prepared_df.index.name == 'timestamp'
|
|
assert prepared_df.index.is_monotonic_increasing
|
|
|
|
def test_prepare_dataframe_sparse_data(concrete_strategy, sample_ohlcv_data):
|
|
# Simulate sparse data by removing the middle row
|
|
sparse_df = sample_ohlcv_data.drop(sample_ohlcv_data.index[2])
|
|
prepared_df = concrete_strategy.prepare_dataframe(sparse_df)
|
|
assert len(prepared_df) == len(sample_ohlcv_data) # Should fill missing row with NaN
|
|
assert prepared_df.index[2] == sample_ohlcv_data.index[2] # Ensure timestamp is restored
|
|
assert pd.isna(prepared_df.loc[sample_ohlcv_data.index[2], 'open']) # Check for NaN in filled row
|
|
|
|
def test_validate_dataframe_valid(concrete_strategy, sample_ohlcv_data, mock_logger):
|
|
# Ensure no warnings/errors are logged for valid data
|
|
concrete_strategy.validate_dataframe(sample_ohlcv_data)
|
|
assert not mock_logger.warning_calls
|
|
assert not mock_logger.error_calls
|
|
|
|
def test_validate_dataframe_missing_column(concrete_strategy, sample_ohlcv_data, mock_logger):
|
|
invalid_df = sample_ohlcv_data.drop(columns=['open'])
|
|
with pytest.raises(ValueError, match="Missing required columns: \['open']"):
|
|
concrete_strategy.validate_dataframe(invalid_df)
|
|
|
|
def test_validate_dataframe_invalid_index(concrete_strategy, sample_ohlcv_data, mock_logger):
|
|
invalid_df = sample_ohlcv_data.reset_index()
|
|
with pytest.raises(ValueError, match="DataFrame index must be named 'timestamp' and be a DatetimeIndex."):
|
|
concrete_strategy.validate_dataframe(invalid_df)
|
|
|
|
def test_validate_dataframe_non_monotonic_index(concrete_strategy, sample_ohlcv_data, mock_logger):
|
|
# Reverse order to make it non-monotonic
|
|
invalid_df = sample_ohlcv_data.iloc[::-1]
|
|
with pytest.raises(ValueError, match="DataFrame index is not monotonically increasing."):
|
|
concrete_strategy.validate_dataframe(invalid_df)
|
|
|
|
def test_validate_indicators_data_valid(concrete_strategy, sample_ohlcv_data, mock_logger):
|
|
indicators_data = {
|
|
'ema_fast': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index),
|
|
'ema_slow': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index)
|
|
}
|
|
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
|
|
|
|
required_indicators = [
|
|
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
|
|
{'type': 'ema', 'period': 26, 'key': 'ema_slow'}
|
|
]
|
|
|
|
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
|
|
assert not mock_logger.warning_calls
|
|
assert not mock_logger.error_calls
|
|
|
|
def test_validate_indicators_data_missing_indicator(concrete_strategy, sample_ohlcv_data, mock_logger):
|
|
indicators_data = {
|
|
'ema_fast': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index),
|
|
}
|
|
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
|
|
|
|
required_indicators = [
|
|
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
|
|
{'type': 'ema', 'period': 26, 'key': 'ema_slow'} # Missing
|
|
]
|
|
|
|
with pytest.raises(ValueError, match="Missing required indicator data for key: ema_slow"):
|
|
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
|
|
|
|
def test_validate_indicators_data_nan_values(concrete_strategy, sample_ohlcv_data, mock_logger):
|
|
indicators_data = {
|
|
'ema_fast': pd.Series([101, 102, np.nan, 104, 105], index=sample_ohlcv_data.index),
|
|
'ema_slow': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index)
|
|
}
|
|
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
|
|
|
|
required_indicators = [
|
|
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
|
|
{'type': 'ema', 'period': 26, 'key': 'ema_slow'}
|
|
]
|
|
|
|
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
|
|
assert "NaN values detected in required indicator data for key: ema_fast" in mock_logger.warning_calls |