TCPDashboard/tests/strategies/test_strategy_factory.py

292 lines
13 KiB
Python
Raw Normal View History

import pytest
import pandas as pd
from datetime import datetime
from unittest.mock import MagicMock, patch
from strategies.factory import StrategyFactory
from strategies.base import BaseStrategy
from strategies.data_types import StrategyResult, StrategySignal, SignalType
from data.common.data_types import OHLCVCandle
from data.common.indicators import TechnicalIndicators # For mocking purposes
# Mock logger for testing
class MockLogger:
def __init__(self):
self.info_calls = []
self.warning_calls = []
self.error_calls = []
def info(self, message):
self.info_calls.append(message)
def warning(self, message):
self.warning_calls.append(message)
def error(self, message):
self.error_calls.append(message)
# Mock Concrete Strategy for testing StrategyFactory
class MockEMAStrategy(BaseStrategy):
def __init__(self, logger=None):
super().__init__(strategy_name="ema_crossover", logger=logger)
self.calculate_calls = []
def get_required_indicators(self) -> list[dict]:
return [{'type': 'ema', 'period': 12}, {'type': 'ema', 'period': 26}]
def calculate(self, df: pd.DataFrame, indicators_data: dict, **kwargs) -> list[StrategyResult]:
self.calculate_calls.append((df, indicators_data, kwargs))
# In this mock, if indicators_data is empty or missing expected keys, return empty results
required_ema_12 = indicators_data.get('ema_12')
required_ema_26 = indicators_data.get('ema_26')
if not df.empty and required_ema_12 is not None and not required_ema_12.empty and \
required_ema_26 is not None and not required_ema_26.empty:
first_row = df.iloc[0]
return [StrategyResult(
timestamp=first_row.name,
symbol=first_row['symbol'],
timeframe=first_row['timeframe'],
strategy_name=self.strategy_name,
signals=[StrategySignal(
timestamp=first_row.name,
symbol=first_row['symbol'],
timeframe=first_row['timeframe'],
signal_type=SignalType.BUY,
price=float(first_row['close']),
confidence=1.0
)],
indicators_used=indicators_data
)]
return []
class MockRSIStrategy(BaseStrategy):
def __init__(self, logger=None):
super().__init__(strategy_name="rsi", logger=logger)
self.calculate_calls = []
def get_required_indicators(self) -> list[dict]:
return [{'type': 'rsi', 'period': 14}]
def calculate(self, df: pd.DataFrame, indicators_data: dict, **kwargs) -> list[StrategyResult]:
self.calculate_calls.append((df, indicators_data, kwargs))
required_rsi = indicators_data.get('rsi_14')
if not df.empty and required_rsi is not None and not required_rsi.empty:
first_row = df.iloc[0]
return [StrategyResult(
timestamp=first_row.name,
symbol=first_row['symbol'],
timeframe=first_row['timeframe'],
strategy_name=self.strategy_name,
signals=[StrategySignal(
timestamp=first_row.name,
symbol=first_row['symbol'],
timeframe=first_row['timeframe'],
signal_type=SignalType.SELL,
price=float(first_row['close']),
confidence=0.9
)],
indicators_used=indicators_data
)]
return []
class MockMACDStrategy(BaseStrategy):
def __init__(self, logger=None):
super().__init__(strategy_name="macd", logger=logger)
self.calculate_calls = []
def get_required_indicators(self) -> list[dict]:
return [{'type': 'macd', 'fast_period': 12, 'slow_period': 26, 'signal_period': 9}]
def calculate(self, df: pd.DataFrame, indicators_data: dict, **kwargs) -> list[StrategyResult]:
self.calculate_calls.append((df, indicators_data, kwargs))
required_macd = indicators_data.get('macd_12_26_9')
if not df.empty and required_macd is not None and not required_macd.empty:
first_row = df.iloc[0]
return [StrategyResult(
timestamp=first_row.name,
symbol=first_row['symbol'],
timeframe=first_row['timeframe'],
strategy_name=self.strategy_name,
signals=[StrategySignal(
timestamp=first_row.name,
symbol=first_row['symbol'],
timeframe=first_row['timeframe'],
signal_type=SignalType.BUY,
price=float(first_row['close']),
confidence=1.0
)],
indicators_used=indicators_data
)]
return []
@pytest.fixture
def mock_logger():
return MockLogger()
@pytest.fixture
def mock_technical_indicators():
mock_ti = MagicMock(spec=TechnicalIndicators)
# Configure the mock to return dummy data for indicators
def mock_calculate(indicator_type, df, **kwargs):
if indicator_type == 'ema':
# Simulate EMA data with same index as input df
return pd.DataFrame({
'ema_fast': [100.0, 101.0, 102.0, 103.0, 104.0],
'ema_slow': [98.0, 99.0, 100.0, 101.0, 102.0]
}, index=df.index)
elif indicator_type == 'rsi':
# Simulate RSI data with same index as input df
return pd.DataFrame({
'rsi': [60.0, 65.0, 72.0, 28.0, 35.0]
}, index=df.index)
elif indicator_type == 'macd':
# Simulate MACD data with same index as input df
return pd.DataFrame({
'macd': [1.0, 1.1, 1.2, 1.3, 1.4],
'signal': [0.9, 1.0, 1.1, 1.2, 1.3],
'hist': [0.1, 0.1, 0.1, 0.1, 0.1]
}, index=df.index)
return pd.DataFrame(index=df.index) # Default empty DataFrame for other indicators
mock_ti.calculate.side_effect = mock_calculate
return mock_ti
@pytest.fixture
def strategy_factory(mock_technical_indicators, mock_logger):
# Patch the actual strategy imports to use mock strategies during testing
with (
patch('strategies.factory.EMAStrategy', MockEMAStrategy),
patch('strategies.factory.RSIStrategy', MockRSIStrategy),
patch('strategies.factory.MACDStrategy', MockMACDStrategy)
):
factory = StrategyFactory(logger=mock_logger)
factory.technical_indicators = mock_technical_indicators # Explicitly set the mocked TechnicalIndicators
yield factory
@pytest.fixture
def sample_ohlcv_data():
return pd.DataFrame({
'open': [100, 101, 102, 103, 104],
'high': [105, 106, 107, 108, 109],
'low': [99, 100, 101, 102, 103],
'close': [102, 103, 104, 105, 106],
'volume': [1000, 1100, 1200, 1300, 1400],
'symbol': ['BTC/USDT'] * 5,
'timeframe': ['1h'] * 5
}, index=pd.to_datetime(['2023-01-01 00:00:00', '2023-01-01 01:00:00', '2023-01-01 02:00:00',
'2023-01-01 03:00:00', '2023-01-01 04:00:00']))
def test_get_available_strategies(strategy_factory):
available_strategies = strategy_factory.get_available_strategies()
assert "ema_crossover" in available_strategies
assert "rsi" in available_strategies
assert "macd" in available_strategies # MACD is now mocked and registered
def test_create_strategy_success(strategy_factory):
ema_strategy = strategy_factory.create_strategy("ema_crossover")
assert isinstance(ema_strategy, MockEMAStrategy)
assert ema_strategy.strategy_name == "ema_crossover"
def test_create_strategy_unknown(strategy_factory, mock_logger):
strategy = strategy_factory.create_strategy("unknown_strategy")
assert strategy is None
assert "Unknown strategy: unknown_strategy" in mock_logger.error_calls
# def test_calculate_multiple_strategies_success(strategy_factory, sample_ohlcv_data, mock_technical_indicators):
# strategy_configs = {
# "ema_cross_1": {"strategy": "ema_crossover", "fast_period": 12, "slow_period": 26},
# "rsi_momentum": {"strategy": "rsi", "period": 14, "overbought": 70, "oversold": 30}
# }
# all_strategy_results = strategy_factory.calculate_multiple_strategies(
# sample_ohlcv_data, strategy_configs
# )
# assert len(all_strategy_results) == 2 # Expect results for both strategies
# assert "ema_cross_1" in all_strategy_results
# assert "rsi_momentum" in all_strategy_results
# ema_results = all_strategy_results["ema_cross_1"]
# rsi_results = all_strategy_results["rsi_momentum"]
# assert len(ema_results) > 0
# assert ema_results[0].strategy_name == "ema_crossover"
# assert len(rsi_results) > 0
# assert rsi_results[0].strategy_name == "rsi"
# # Verify that TechnicalIndicators.calculate was called with correct arguments
# # EMA calls
# # Check for calls with 'ema' type and specific periods
# ema_calls_12 = [call for call in mock_technical_indicators.calculate.call_args_list
# if call.args[0] == 'ema' and call.kwargs.get('period') == 12]
# ema_calls_26 = [call for call in mock_technical_indicators.calculate.call_args_list
# if call.args[0] == 'ema' and call.kwargs.get('period') == 26]
# assert len(ema_calls_12) == 1
# assert len(ema_calls_26) == 1
# # RSI calls
# rsi_calls = [call for call in mock_technical_indicators.calculate.call_args_list if call.args[0] == 'rsi']
# assert len(rsi_calls) == 1 # One RSI indicator for rsi strategy
# assert rsi_calls[0].kwargs['period'] == 14
# def test_calculate_multiple_strategies_no_configs(strategy_factory, sample_ohlcv_data):
# results = strategy_factory.calculate_multiple_strategies(sample_ohlcv_data, {})
# assert results == {}
# def test_calculate_multiple_strategies_empty_data(strategy_factory, mock_technical_indicators):
# strategy_configs = {
# "ema_cross_1": {"strategy": "ema_crossover", "fast_period": 12, "slow_period": 26}
# }
# empty_df = pd.DataFrame(columns=['open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe'])
# results = strategy_factory.calculate_multiple_strategies(empty_df, strategy_configs)
# assert results == {"ema_cross_1": []} # Expect empty list for the strategy if data is empty
# def test_calculate_multiple_strategies_missing_indicator_data(strategy_factory, sample_ohlcv_data, mock_logger, mock_technical_indicators):
# # Simulate a scenario where an indicator is requested but not returned by TechnicalIndicators
# def mock_calculate_no_ema(indicator_type, df, **kwargs):
# if indicator_type == 'ema':
# return pd.DataFrame(index=df.index) # Simulate no EMA data returned
# elif indicator_type == 'rsi':
# return pd.DataFrame({'rsi': df['close']}, index=df.index)
# return pd.DataFrame(index=df.index)
# mock_technical_indicators.calculate.side_effect = mock_calculate_no_ema
# strategy_configs = {
# "ema_cross_1": {"strategy": "ema_crossover", "fast_period": 12, "slow_period": 26}
# }
# results = strategy_factory.calculate_multiple_strategies(
# sample_ohlcv_data, strategy_configs
# )
# assert results == {"ema_cross_1": []} # Expect empty results if indicators are missing
# assert "Empty result for indicator: ema_12" in mock_logger.warning_calls or \
# "Empty result for indicator: ema_26" in mock_logger.warning_calls
# def test_calculate_multiple_strategies_exception_in_one(strategy_factory, sample_ohlcv_data, mock_logger, mock_technical_indicators):
# def mock_calculate_indicator_with_error(indicator_type, df, **kwargs):
# if indicator_type == 'ema':
# raise Exception("EMA calculation error")
# elif indicator_type == 'rsi':
# return pd.DataFrame({'rsi': [50, 55, 60, 65, 70]}, index=df.index)
# return pd.DataFrame() # Default empty DataFrame
# mock_technical_indicators.calculate.side_effect = mock_calculate_indicator_with_error
# strategy_configs = {
# "ema_cross_1": {"strategy": "ema_crossover", "fast_period": 12, "slow_period": 26},
# "rsi_momentum": {"strategy": "rsi", "period": 14, "overbought": 70, "oversold": 30}
# }
# all_strategy_results = strategy_factory.calculate_multiple_strategies(
# sample_ohlcv_data, strategy_configs
# )
# assert "ema_cross_1" in all_strategy_results and all_strategy_results["ema_cross_1"] == []
# assert "rsi_momentum" in all_strategy_results and len(all_strategy_results["rsi_momentum"]) > 0
# assert "Error calculating strategy ema_cross_1: EMA calculation error" in mock_logger.error_calls