4.0 - 1.0 Implement strategy engine foundation with modular components
- Introduced a new `strategies` package containing the core structure for trading strategies, including `BaseStrategy`, `StrategyFactory`, and various strategy implementations (EMA, RSI, MACD). - Added utility functions for signal detection and validation in `strategies/utils.py`, enhancing modularity and maintainability. - Updated `pyproject.toml` to include the new `strategies` package in the build configuration. - Implemented comprehensive unit tests for the strategy foundation components, ensuring reliability and adherence to project standards. These changes establish a solid foundation for the strategy engine, aligning with project goals for modularity, performance, and maintainability.
This commit is contained in:
1
tests/strategies/__init__.py
Normal file
1
tests/strategies/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
162
tests/strategies/test_base_strategy.py
Normal file
162
tests/strategies/test_base_strategy.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from strategies.base import BaseStrategy
|
||||
from strategies.data_types import StrategyResult, StrategySignal, SignalType
|
||||
from data.common.data_types import OHLCVCandle
|
||||
|
||||
# Mock logger for testing
|
||||
class MockLogger:
|
||||
def __init__(self):
|
||||
self.info_calls = []
|
||||
self.warning_calls = []
|
||||
self.error_calls = []
|
||||
|
||||
def info(self, message):
|
||||
self.info_calls.append(message)
|
||||
|
||||
def warning(self, message):
|
||||
self.warning_calls.append(message)
|
||||
|
||||
def error(self, message):
|
||||
self.error_calls.append(message)
|
||||
|
||||
# Concrete implementation of BaseStrategy for testing purposes
|
||||
class ConcreteStrategy(BaseStrategy):
|
||||
def __init__(self, logger=None):
|
||||
super().__init__("ConcreteStrategy", logger)
|
||||
|
||||
def get_required_indicators(self) -> list[dict]:
|
||||
return []
|
||||
|
||||
def calculate(self, data: pd.DataFrame, **kwargs) -> list[StrategyResult]:
|
||||
# Simple mock calculation for testing
|
||||
signals = []
|
||||
if not data.empty:
|
||||
first_row = data.iloc[0]
|
||||
signals.append(StrategyResult(
|
||||
timestamp=first_row.name,
|
||||
symbol=first_row['symbol'],
|
||||
timeframe=first_row['timeframe'],
|
||||
strategy_name=self.strategy_name,
|
||||
signals=[StrategySignal(
|
||||
timestamp=first_row.name,
|
||||
symbol=first_row['symbol'],
|
||||
timeframe=first_row['timeframe'],
|
||||
signal_type=SignalType.BUY,
|
||||
price=float(first_row['close']),
|
||||
confidence=1.0
|
||||
)],
|
||||
indicators_used={}
|
||||
))
|
||||
return signals
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger():
|
||||
return MockLogger()
|
||||
|
||||
@pytest.fixture
|
||||
def concrete_strategy(mock_logger):
|
||||
return ConcreteStrategy(logger=mock_logger)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ohlcv_data():
|
||||
return pd.DataFrame({
|
||||
'open': [100, 101, 102, 103, 104],
|
||||
'high': [105, 106, 107, 108, 109],
|
||||
'low': [99, 100, 101, 102, 103],
|
||||
'close': [102, 103, 104, 105, 106],
|
||||
'volume': [1000, 1100, 1200, 1300, 1400],
|
||||
'symbol': ['BTC/USDT'] * 5,
|
||||
'timeframe': ['1h'] * 5
|
||||
}, index=pd.to_datetime(['2023-01-01 00:00:00', '2023-01-01 01:00:00', '2023-01-01 02:00:00',
|
||||
'2023-01-01 03:00:00', '2023-01-01 04:00:00']))
|
||||
|
||||
def test_prepare_dataframe_initial_data(concrete_strategy, sample_ohlcv_data):
|
||||
prepared_df = concrete_strategy.prepare_dataframe(sample_ohlcv_data)
|
||||
assert 'open' in prepared_df.columns
|
||||
assert 'high' in prepared_df.columns
|
||||
assert 'low' in prepared_df.columns
|
||||
assert 'close' in prepared_df.columns
|
||||
assert 'volume' in prepared_df.columns
|
||||
assert 'symbol' in prepared_df.columns
|
||||
assert 'timeframe' in prepared_df.columns
|
||||
assert prepared_df.index.name == 'timestamp'
|
||||
assert prepared_df.index.is_monotonic_increasing
|
||||
|
||||
def test_prepare_dataframe_sparse_data(concrete_strategy, sample_ohlcv_data):
|
||||
# Simulate sparse data by removing the middle row
|
||||
sparse_df = sample_ohlcv_data.drop(sample_ohlcv_data.index[2])
|
||||
prepared_df = concrete_strategy.prepare_dataframe(sparse_df)
|
||||
assert len(prepared_df) == len(sample_ohlcv_data) # Should fill missing row with NaN
|
||||
assert prepared_df.index[2] == sample_ohlcv_data.index[2] # Ensure timestamp is restored
|
||||
assert pd.isna(prepared_df.loc[sample_ohlcv_data.index[2], 'open']) # Check for NaN in filled row
|
||||
|
||||
def test_validate_dataframe_valid(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||
# Ensure no warnings/errors are logged for valid data
|
||||
concrete_strategy.validate_dataframe(sample_ohlcv_data)
|
||||
assert not mock_logger.warning_calls
|
||||
assert not mock_logger.error_calls
|
||||
|
||||
def test_validate_dataframe_missing_column(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||
invalid_df = sample_ohlcv_data.drop(columns=['open'])
|
||||
with pytest.raises(ValueError, match="Missing required columns: \['open']"):
|
||||
concrete_strategy.validate_dataframe(invalid_df)
|
||||
|
||||
def test_validate_dataframe_invalid_index(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||
invalid_df = sample_ohlcv_data.reset_index()
|
||||
with pytest.raises(ValueError, match="DataFrame index must be named 'timestamp' and be a DatetimeIndex."):
|
||||
concrete_strategy.validate_dataframe(invalid_df)
|
||||
|
||||
def test_validate_dataframe_non_monotonic_index(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||
# Reverse order to make it non-monotonic
|
||||
invalid_df = sample_ohlcv_data.iloc[::-1]
|
||||
with pytest.raises(ValueError, match="DataFrame index is not monotonically increasing."):
|
||||
concrete_strategy.validate_dataframe(invalid_df)
|
||||
|
||||
def test_validate_indicators_data_valid(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||
indicators_data = {
|
||||
'ema_fast': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index),
|
||||
'ema_slow': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index)
|
||||
}
|
||||
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
|
||||
|
||||
required_indicators = [
|
||||
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
|
||||
{'type': 'ema', 'period': 26, 'key': 'ema_slow'}
|
||||
]
|
||||
|
||||
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
|
||||
assert not mock_logger.warning_calls
|
||||
assert not mock_logger.error_calls
|
||||
|
||||
def test_validate_indicators_data_missing_indicator(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||
indicators_data = {
|
||||
'ema_fast': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index),
|
||||
}
|
||||
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
|
||||
|
||||
required_indicators = [
|
||||
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
|
||||
{'type': 'ema', 'period': 26, 'key': 'ema_slow'} # Missing
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required indicator data for key: ema_slow"):
|
||||
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
|
||||
|
||||
def test_validate_indicators_data_nan_values(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||
indicators_data = {
|
||||
'ema_fast': pd.Series([101, 102, np.nan, 104, 105], index=sample_ohlcv_data.index),
|
||||
'ema_slow': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index)
|
||||
}
|
||||
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
|
||||
|
||||
required_indicators = [
|
||||
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
|
||||
{'type': 'ema', 'period': 26, 'key': 'ema_slow'}
|
||||
]
|
||||
|
||||
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
|
||||
assert "NaN values detected in required indicator data for key: ema_fast" in mock_logger.warning_calls
|
||||
220
tests/strategies/test_strategy_factory.py
Normal file
220
tests/strategies/test_strategy_factory.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from strategies.factory import StrategyFactory
|
||||
from strategies.base import BaseStrategy
|
||||
from strategies.data_types import StrategyResult, StrategySignal, SignalType
|
||||
from data.common.data_types import OHLCVCandle
|
||||
from data.common.indicators import TechnicalIndicators # For mocking purposes
|
||||
|
||||
# Mock logger for testing
|
||||
class MockLogger:
|
||||
def __init__(self):
|
||||
self.info_calls = []
|
||||
self.warning_calls = []
|
||||
self.error_calls = []
|
||||
|
||||
def info(self, message):
|
||||
self.info_calls.append(message)
|
||||
|
||||
def warning(self, message):
|
||||
self.warning_calls.append(message)
|
||||
|
||||
def error(self, message):
|
||||
self.error_calls.append(message)
|
||||
|
||||
# Mock Concrete Strategy for testing StrategyFactory
|
||||
class MockEMAStrategy(BaseStrategy):
|
||||
def __init__(self, logger=None):
|
||||
super().__init__("ema_crossover", logger)
|
||||
self.calculate_calls = []
|
||||
|
||||
def get_required_indicators(self) -> list[dict]:
|
||||
return [{'type': 'ema', 'period': 12}, {'type': 'ema', 'period': 26}]
|
||||
|
||||
def calculate(self, data: pd.DataFrame, **kwargs) -> list[StrategyResult]:
|
||||
self.calculate_calls.append((data, kwargs))
|
||||
# Simulate a signal for testing
|
||||
if not data.empty:
|
||||
first_row = data.iloc[0]
|
||||
return [StrategyResult(
|
||||
timestamp=first_row.name,
|
||||
symbol=first_row['symbol'],
|
||||
timeframe=first_row['timeframe'],
|
||||
strategy_name=self.strategy_name,
|
||||
signals=[StrategySignal(
|
||||
timestamp=first_row.name,
|
||||
symbol=first_row['symbol'],
|
||||
timeframe=first_row['timeframe'],
|
||||
signal_type=SignalType.BUY,
|
||||
price=float(first_row['close']),
|
||||
confidence=1.0
|
||||
)],
|
||||
indicators_used={}
|
||||
)]
|
||||
return []
|
||||
|
||||
class MockRSIStrategy(BaseStrategy):
|
||||
def __init__(self, logger=None):
|
||||
super().__init__("rsi", logger)
|
||||
self.calculate_calls = []
|
||||
|
||||
def get_required_indicators(self) -> list[dict]:
|
||||
return [{'type': 'rsi', 'period': 14}]
|
||||
|
||||
def calculate(self, data: pd.DataFrame, **kwargs) -> list[StrategyResult]:
|
||||
self.calculate_calls.append((data, kwargs))
|
||||
# Simulate a signal for testing
|
||||
if not data.empty:
|
||||
first_row = data.iloc[0]
|
||||
return [StrategyResult(
|
||||
timestamp=first_row.name,
|
||||
symbol=first_row['symbol'],
|
||||
timeframe=first_row['timeframe'],
|
||||
strategy_name=self.strategy_name,
|
||||
signals=[StrategySignal(
|
||||
timestamp=first_row.name,
|
||||
symbol=first_row['symbol'],
|
||||
timeframe=first_row['timeframe'],
|
||||
signal_type=SignalType.SELL,
|
||||
price=float(first_row['close']),
|
||||
confidence=0.9
|
||||
)],
|
||||
indicators_used={}
|
||||
)]
|
||||
return []
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger():
|
||||
return MockLogger()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_technical_indicators():
|
||||
mock_ti = MagicMock(spec=TechnicalIndicators)
|
||||
# Configure the mock to return dummy data for indicators
|
||||
def mock_calculate(indicator_type, df, **kwargs):
|
||||
if indicator_type == 'ema':
|
||||
# Simulate EMA data
|
||||
return pd.DataFrame({
|
||||
'ema_fast': df['close'] * 1.02,
|
||||
'ema_slow': df['close'] * 0.98
|
||||
}, index=df.index)
|
||||
elif indicator_type == 'rsi':
|
||||
# Simulate RSI data
|
||||
return pd.DataFrame({
|
||||
'rsi': pd.Series([60, 65, 72, 28, 35], index=df.index)
|
||||
}, index=df.index)
|
||||
return pd.DataFrame(index=df.index)
|
||||
|
||||
mock_ti.calculate.side_effect = mock_calculate
|
||||
return mock_ti
|
||||
|
||||
@pytest.fixture
|
||||
def strategy_factory(mock_technical_indicators, mock_logger, monkeypatch):
|
||||
# Patch the strategy factory to use our mock strategies
|
||||
monkeypatch.setattr(
|
||||
"strategies.factory.StrategyFactory._STRATEGIES",
|
||||
{
|
||||
"ema_crossover": MockEMAStrategy,
|
||||
"rsi": MockRSIStrategy,
|
||||
}
|
||||
)
|
||||
return StrategyFactory(mock_technical_indicators, mock_logger)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ohlcv_data():
|
||||
return pd.DataFrame({
|
||||
'open': [100, 101, 102, 103, 104],
|
||||
'high': [105, 106, 107, 108, 109],
|
||||
'low': [99, 100, 101, 102, 103],
|
||||
'close': [102, 103, 104, 105, 106],
|
||||
'volume': [1000, 1100, 1200, 1300, 1400],
|
||||
'symbol': ['BTC/USDT'] * 5,
|
||||
'timeframe': ['1h'] * 5
|
||||
}, index=pd.to_datetime(['2023-01-01 00:00:00', '2023-01-01 01:00:00', '2023-01-01 02:00:00',
|
||||
'2023-01-01 03:00:00', '2023-01-01 04:00:00']))
|
||||
|
||||
def test_get_available_strategies(strategy_factory):
|
||||
available_strategies = strategy_factory.get_available_strategies()
|
||||
assert "ema_crossover" in available_strategies
|
||||
assert "rsi" in available_strategies
|
||||
assert "macd" not in available_strategies # Should not be present if not mocked
|
||||
|
||||
def test_create_strategy_success(strategy_factory):
|
||||
ema_strategy = strategy_factory.create_strategy("ema_crossover")
|
||||
assert isinstance(ema_strategy, MockEMAStrategy)
|
||||
assert ema_strategy.strategy_name == "ema_crossover"
|
||||
|
||||
def test_create_strategy_unknown(strategy_factory):
|
||||
with pytest.raises(ValueError, match="Unknown strategy type: unknown_strategy"):
|
||||
strategy_factory.create_strategy("unknown_strategy")
|
||||
|
||||
def test_calculate_multiple_strategies_success(strategy_factory, sample_ohlcv_data, mock_technical_indicators):
|
||||
strategy_configs = [
|
||||
{"strategy": "ema_crossover", "fast_period": 12, "slow_period": 26},
|
||||
{"strategy": "rsi", "period": 14, "overbought": 70, "oversold": 30}
|
||||
]
|
||||
|
||||
all_strategy_results = strategy_factory.calculate_multiple_strategies(
|
||||
strategy_configs, sample_ohlcv_data
|
||||
)
|
||||
|
||||
assert len(all_strategy_results) == 2 # Expect results for both strategies
|
||||
assert "ema_crossover" in all_strategy_results
|
||||
assert "rsi" in all_strategy_results
|
||||
|
||||
ema_results = all_strategy_results["ema_crossover"]
|
||||
rsi_results = all_strategy_results["rsi"]
|
||||
|
||||
assert len(ema_results) > 0
|
||||
assert ema_results[0].strategy_name == "ema_crossover"
|
||||
assert len(rsi_results) > 0
|
||||
assert rsi_results[0].strategy_name == "rsi"
|
||||
|
||||
# Verify that TechnicalIndicators.calculate was called with correct arguments
|
||||
# EMA calls
|
||||
ema_calls = [call for call in mock_technical_indicators.calculate.call_args_list if call.args[0] == 'ema']
|
||||
assert len(ema_calls) == 2 # Two EMA indicators for ema_crossover strategy
|
||||
assert ema_calls[0].kwargs['period'] == 12 or ema_calls[0].kwargs['period'] == 26
|
||||
assert ema_calls[1].kwargs['period'] == 12 or ema_calls[1].kwargs['period'] == 26
|
||||
|
||||
# RSI calls
|
||||
rsi_calls = [call for call in mock_technical_indicators.calculate.call_args_list if call.args[0] == 'rsi']
|
||||
assert len(rsi_calls) == 1 # One RSI indicator for rsi strategy
|
||||
assert rsi_calls[0].kwargs['period'] == 14
|
||||
|
||||
def test_calculate_multiple_strategies_no_configs(strategy_factory, sample_ohlcv_data):
|
||||
results = strategy_factory.calculate_multiple_strategies([], sample_ohlcv_data)
|
||||
assert not results
|
||||
|
||||
def test_calculate_multiple_strategies_empty_data(strategy_factory, mock_technical_indicators):
|
||||
strategy_configs = [
|
||||
{"strategy": "ema_crossover", "fast_period": 12, "slow_period": 26}
|
||||
]
|
||||
empty_df = pd.DataFrame(columns=['open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe'])
|
||||
results = strategy_factory.calculate_multiple_strategies(strategy_configs, empty_df)
|
||||
assert not results
|
||||
|
||||
def test_calculate_multiple_strategies_missing_indicator_data(strategy_factory, sample_ohlcv_data, mock_logger, mock_technical_indicators):
|
||||
# Simulate a scenario where an indicator is requested but not returned by TechnicalIndicators
|
||||
def mock_calculate_no_ema(indicator_type, df, **kwargs):
|
||||
if indicator_type == 'ema':
|
||||
return pd.DataFrame(index=df.index) # Simulate no EMA data returned
|
||||
elif indicator_type == 'rsi':
|
||||
return pd.DataFrame({'rsi': df['close']}, index=df.index)
|
||||
return pd.DataFrame(index=df.index)
|
||||
|
||||
mock_technical_indicators.calculate.side_effect = mock_calculate_no_ema
|
||||
|
||||
strategy_configs = [
|
||||
{"strategy": "ema_crossover", "fast_period": 12, "slow_period": 26}
|
||||
]
|
||||
|
||||
results = strategy_factory.calculate_multiple_strategies(
|
||||
strategy_configs, sample_ohlcv_data
|
||||
)
|
||||
assert not results # Expect no results if indicators are missing
|
||||
assert "Missing required indicator data for key: ema_period_12" in mock_logger.error_calls or \
|
||||
"Missing required indicator data for key: ema_period_26" in mock_logger.error_calls
|
||||
Reference in New Issue
Block a user