TCPDashboard/tests/strategies/test_base_strategy.py
Vasily.onl fd5a59fc39 4.0 - 1.0 Implement strategy engine foundation with modular components
- Introduced a new `strategies` package containing the core structure for trading strategies, including `BaseStrategy`, `StrategyFactory`, and various strategy implementations (EMA, RSI, MACD).
- Added utility functions for signal detection and validation in `strategies/utils.py`, enhancing modularity and maintainability.
- Updated `pyproject.toml` to include the new `strategies` package in the build configuration.
- Implemented comprehensive unit tests for the strategy foundation components, ensuring reliability and adherence to project standards.

These changes establish a solid foundation for the strategy engine, aligning with project goals for modularity, performance, and maintainability.
2025-06-12 14:41:16 +08:00

162 lines
6.7 KiB
Python

import pytest
import pandas as pd
from datetime import datetime
from unittest.mock import MagicMock
from strategies.base import BaseStrategy
from strategies.data_types import StrategyResult, StrategySignal, SignalType
from data.common.data_types import OHLCVCandle
# Mock logger for testing
class MockLogger:
def __init__(self):
self.info_calls = []
self.warning_calls = []
self.error_calls = []
def info(self, message):
self.info_calls.append(message)
def warning(self, message):
self.warning_calls.append(message)
def error(self, message):
self.error_calls.append(message)
# Concrete implementation of BaseStrategy for testing purposes
class ConcreteStrategy(BaseStrategy):
def __init__(self, logger=None):
super().__init__("ConcreteStrategy", logger)
def get_required_indicators(self) -> list[dict]:
return []
def calculate(self, data: pd.DataFrame, **kwargs) -> list[StrategyResult]:
# Simple mock calculation for testing
signals = []
if not data.empty:
first_row = data.iloc[0]
signals.append(StrategyResult(
timestamp=first_row.name,
symbol=first_row['symbol'],
timeframe=first_row['timeframe'],
strategy_name=self.strategy_name,
signals=[StrategySignal(
timestamp=first_row.name,
symbol=first_row['symbol'],
timeframe=first_row['timeframe'],
signal_type=SignalType.BUY,
price=float(first_row['close']),
confidence=1.0
)],
indicators_used={}
))
return signals
@pytest.fixture
def mock_logger():
return MockLogger()
@pytest.fixture
def concrete_strategy(mock_logger):
return ConcreteStrategy(logger=mock_logger)
@pytest.fixture
def sample_ohlcv_data():
return pd.DataFrame({
'open': [100, 101, 102, 103, 104],
'high': [105, 106, 107, 108, 109],
'low': [99, 100, 101, 102, 103],
'close': [102, 103, 104, 105, 106],
'volume': [1000, 1100, 1200, 1300, 1400],
'symbol': ['BTC/USDT'] * 5,
'timeframe': ['1h'] * 5
}, index=pd.to_datetime(['2023-01-01 00:00:00', '2023-01-01 01:00:00', '2023-01-01 02:00:00',
'2023-01-01 03:00:00', '2023-01-01 04:00:00']))
def test_prepare_dataframe_initial_data(concrete_strategy, sample_ohlcv_data):
prepared_df = concrete_strategy.prepare_dataframe(sample_ohlcv_data)
assert 'open' in prepared_df.columns
assert 'high' in prepared_df.columns
assert 'low' in prepared_df.columns
assert 'close' in prepared_df.columns
assert 'volume' in prepared_df.columns
assert 'symbol' in prepared_df.columns
assert 'timeframe' in prepared_df.columns
assert prepared_df.index.name == 'timestamp'
assert prepared_df.index.is_monotonic_increasing
def test_prepare_dataframe_sparse_data(concrete_strategy, sample_ohlcv_data):
# Simulate sparse data by removing the middle row
sparse_df = sample_ohlcv_data.drop(sample_ohlcv_data.index[2])
prepared_df = concrete_strategy.prepare_dataframe(sparse_df)
assert len(prepared_df) == len(sample_ohlcv_data) # Should fill missing row with NaN
assert prepared_df.index[2] == sample_ohlcv_data.index[2] # Ensure timestamp is restored
assert pd.isna(prepared_df.loc[sample_ohlcv_data.index[2], 'open']) # Check for NaN in filled row
def test_validate_dataframe_valid(concrete_strategy, sample_ohlcv_data, mock_logger):
# Ensure no warnings/errors are logged for valid data
concrete_strategy.validate_dataframe(sample_ohlcv_data)
assert not mock_logger.warning_calls
assert not mock_logger.error_calls
def test_validate_dataframe_missing_column(concrete_strategy, sample_ohlcv_data, mock_logger):
invalid_df = sample_ohlcv_data.drop(columns=['open'])
with pytest.raises(ValueError, match="Missing required columns: \['open']"):
concrete_strategy.validate_dataframe(invalid_df)
def test_validate_dataframe_invalid_index(concrete_strategy, sample_ohlcv_data, mock_logger):
invalid_df = sample_ohlcv_data.reset_index()
with pytest.raises(ValueError, match="DataFrame index must be named 'timestamp' and be a DatetimeIndex."):
concrete_strategy.validate_dataframe(invalid_df)
def test_validate_dataframe_non_monotonic_index(concrete_strategy, sample_ohlcv_data, mock_logger):
# Reverse order to make it non-monotonic
invalid_df = sample_ohlcv_data.iloc[::-1]
with pytest.raises(ValueError, match="DataFrame index is not monotonically increasing."):
concrete_strategy.validate_dataframe(invalid_df)
def test_validate_indicators_data_valid(concrete_strategy, sample_ohlcv_data, mock_logger):
indicators_data = {
'ema_fast': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index),
'ema_slow': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index)
}
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
required_indicators = [
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
{'type': 'ema', 'period': 26, 'key': 'ema_slow'}
]
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
assert not mock_logger.warning_calls
assert not mock_logger.error_calls
def test_validate_indicators_data_missing_indicator(concrete_strategy, sample_ohlcv_data, mock_logger):
indicators_data = {
'ema_fast': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index),
}
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
required_indicators = [
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
{'type': 'ema', 'period': 26, 'key': 'ema_slow'} # Missing
]
with pytest.raises(ValueError, match="Missing required indicator data for key: ema_slow"):
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
def test_validate_indicators_data_nan_values(concrete_strategy, sample_ohlcv_data, mock_logger):
indicators_data = {
'ema_fast': pd.Series([101, 102, np.nan, 104, 105], index=sample_ohlcv_data.index),
'ema_slow': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index)
}
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
required_indicators = [
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
{'type': 'ema', 'period': 26, 'key': 'ema_slow'}
]
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
assert "NaN values detected in required indicator data for key: ema_fast" in mock_logger.warning_calls