4.0 - 1.0 Implement strategy engine foundation with modular components
- Introduced a new `strategies` package containing the core structure for trading strategies, including `BaseStrategy`, `StrategyFactory`, and various strategy implementations (EMA, RSI, MACD). - Added utility functions for signal detection and validation in `strategies/utils.py`, enhancing modularity and maintainability. - Updated `pyproject.toml` to include the new `strategies` package in the build configuration. - Implemented comprehensive unit tests for the strategy foundation components, ensuring reliability and adherence to project standards. These changes establish a solid foundation for the strategy engine, aligning with project goals for modularity, performance, and maintainability.
This commit is contained in:
parent
571d583a5b
commit
fd5a59fc39
@ -60,7 +60,7 @@ requires = ["hatchling"]
|
|||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["config", "database", "scripts", "tests", "data"]
|
packages = ["config", "database", "scripts", "tests", "data", "strategies"]
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 88
|
line-length = 88
|
||||||
@ -80,3 +80,7 @@ disallow_untyped_defs = true
|
|||||||
dev = [
|
dev = [
|
||||||
"pytest-asyncio>=1.0.0",
|
"pytest-asyncio>=1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
pythonpath = ["."]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
|||||||
26
strategies/__init__.py
Normal file
26
strategies/__init__.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
Strategy Engine Package
|
||||||
|
|
||||||
|
This package provides strategy calculation and signal generation capabilities
|
||||||
|
optimized for the TCP Trading Platform's market data and technical indicators.
|
||||||
|
|
||||||
|
IMPORTANT: Mirrors Indicator Patterns
|
||||||
|
- Follows the same architecture as data/common/indicators/
|
||||||
|
- Uses BaseStrategy abstract class pattern
|
||||||
|
- Implements factory pattern for dynamic strategy loading
|
||||||
|
- Supports JSON-based configuration management
|
||||||
|
- Integrates with existing technical indicators system
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import BaseStrategy
|
||||||
|
from .factory import StrategyFactory
|
||||||
|
from .data_types import StrategySignal, SignalType, StrategyResult
|
||||||
|
|
||||||
|
# Note: Strategy implementations and manager will be added in next iterations
|
||||||
|
__all__ = [
|
||||||
|
'BaseStrategy',
|
||||||
|
'StrategyFactory',
|
||||||
|
'StrategySignal',
|
||||||
|
'SignalType',
|
||||||
|
'StrategyResult'
|
||||||
|
]
|
||||||
150
strategies/base.py
Normal file
150
strategies/base.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
"""
|
||||||
|
Base classes and interfaces for trading strategies.
|
||||||
|
|
||||||
|
This module provides the foundation for all trading strategies
|
||||||
|
with common functionality and type definitions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import pandas as pd
|
||||||
|
from utils.logger import get_logger
|
||||||
|
|
||||||
|
from .data_types import StrategyResult
|
||||||
|
from data.common.data_types import OHLCVCandle
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStrategy(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all trading strategies.
|
||||||
|
|
||||||
|
Provides common functionality and enforces consistent interface
|
||||||
|
across all strategy implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger=None):
|
||||||
|
"""
|
||||||
|
Initialize base strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger: Optional logger instance
|
||||||
|
"""
|
||||||
|
if logger is None:
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def prepare_dataframe(self, candles: List[OHLCVCandle]) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Convert OHLCV candles to pandas DataFrame for calculations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candles: List of OHLCV candles (can be sparse)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with OHLCV data, sorted by timestamp
|
||||||
|
"""
|
||||||
|
if not candles:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
data = []
|
||||||
|
for candle in candles:
|
||||||
|
data.append({
|
||||||
|
'timestamp': candle.end_time, # Right-aligned timestamp
|
||||||
|
'symbol': candle.symbol,
|
||||||
|
'timeframe': candle.timeframe,
|
||||||
|
'open': float(candle.open),
|
||||||
|
'high': float(candle.high),
|
||||||
|
'low': float(candle.low),
|
||||||
|
'close': float(candle.close),
|
||||||
|
'volume': float(candle.volume),
|
||||||
|
'trade_count': candle.trade_count
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
|
||||||
|
# Sort by timestamp to ensure proper order
|
||||||
|
df = df.sort_values('timestamp').reset_index(drop=True)
|
||||||
|
|
||||||
|
# Set timestamp as index for time-series operations
|
||||||
|
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
||||||
|
|
||||||
|
# Set as index, but keep as column
|
||||||
|
df.set_index('timestamp', inplace=True)
|
||||||
|
|
||||||
|
# Ensure it's datetime
|
||||||
|
df['timestamp'] = df.index
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def calculate(self, df: pd.DataFrame, indicators_data: Dict[str, pd.DataFrame], **kwargs) -> List[StrategyResult]:
|
||||||
|
"""
|
||||||
|
Calculate the strategy signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame with OHLCV data
|
||||||
|
indicators_data: Dictionary of pre-calculated indicator DataFrames
|
||||||
|
**kwargs: Additional parameters specific to each strategy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of strategy results with signals
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_required_indicators(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get list of indicators required by this strategy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of indicator configurations needed for strategy calculation
|
||||||
|
Format: [{'type': 'sma', 'period': 20}, {'type': 'ema', 'period': 12}]
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate_dataframe(self, df: pd.DataFrame, min_periods: int) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that DataFrame has sufficient data for calculation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame to validate
|
||||||
|
min_periods: Minimum number of periods required
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if DataFrame is valid, False otherwise
|
||||||
|
"""
|
||||||
|
if df.empty or len(df) < min_periods:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(
|
||||||
|
f"Insufficient data: got {len(df)} periods, need {min_periods}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_indicators_data(self, indicators_data: Dict[str, pd.DataFrame],
|
||||||
|
required_indicators: List[Dict[str, Any]]) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that all required indicators are present and have sufficient data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicators_data: Dictionary of indicator DataFrames
|
||||||
|
required_indicators: List of required indicator configurations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all required indicators are available, False otherwise
|
||||||
|
"""
|
||||||
|
for indicator_config in required_indicators:
|
||||||
|
indicator_key = f"{indicator_config['type']}_{indicator_config.get('period', 'default')}"
|
||||||
|
|
||||||
|
if indicator_key not in indicators_data:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"Missing required indicator: {indicator_key}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if indicators_data[indicator_key].empty:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"Empty data for indicator: {indicator_key}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
72
strategies/data_types.py
Normal file
72
strategies/data_types.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
"""
|
||||||
|
Strategy Data Types and Signal Definitions
|
||||||
|
|
||||||
|
This module provides data types for strategy calculations, signals,
|
||||||
|
and results in a standardized format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Optional, Any, List
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class SignalType(str, Enum):
|
||||||
|
"""
|
||||||
|
Types of trading signals that strategies can generate.
|
||||||
|
"""
|
||||||
|
BUY = "buy"
|
||||||
|
SELL = "sell"
|
||||||
|
HOLD = "hold"
|
||||||
|
ENTRY_LONG = "entry_long"
|
||||||
|
EXIT_LONG = "exit_long"
|
||||||
|
ENTRY_SHORT = "entry_short"
|
||||||
|
EXIT_SHORT = "exit_short"
|
||||||
|
STOP_LOSS = "stop_loss"
|
||||||
|
TAKE_PROFIT = "take_profit"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StrategySignal:
|
||||||
|
"""
|
||||||
|
Container for individual strategy signals.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
timestamp: Signal timestamp (right-aligned with candle)
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Candle timeframe
|
||||||
|
signal_type: Type of signal (buy/sell/hold etc.)
|
||||||
|
price: Price at which signal was generated
|
||||||
|
confidence: Signal confidence score (0.0 to 1.0)
|
||||||
|
metadata: Additional signal metadata (e.g., indicator values)
|
||||||
|
"""
|
||||||
|
timestamp: datetime
|
||||||
|
symbol: str
|
||||||
|
timeframe: str
|
||||||
|
signal_type: SignalType
|
||||||
|
price: float
|
||||||
|
confidence: float = 1.0
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StrategyResult:
|
||||||
|
"""
|
||||||
|
Container for strategy calculation results.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
timestamp: Candle timestamp (right-aligned)
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Candle timeframe
|
||||||
|
strategy_name: Name of the strategy that generated this result
|
||||||
|
signals: List of signals generated for this timestamp
|
||||||
|
indicators_used: Dictionary of indicator values used in calculation
|
||||||
|
metadata: Additional calculation metadata
|
||||||
|
"""
|
||||||
|
timestamp: datetime
|
||||||
|
symbol: str
|
||||||
|
timeframe: str
|
||||||
|
strategy_name: str
|
||||||
|
signals: List[StrategySignal]
|
||||||
|
indicators_used: Dict[str, float]
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
247
strategies/factory.py
Normal file
247
strategies/factory.py
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
"""
|
||||||
|
Strategy Factory Module for Strategy Management
|
||||||
|
|
||||||
|
This module provides strategy calculation and signal generation capabilities
|
||||||
|
designed to work with the TCP Trading Platform's technical indicators and market data.
|
||||||
|
|
||||||
|
IMPORTANT: Strategy-Indicator Integration
|
||||||
|
- Strategies consume pre-calculated technical indicators
|
||||||
|
- Uses BaseStrategy abstract class pattern for consistency
|
||||||
|
- Supports dynamic strategy loading and registration
|
||||||
|
- Follows right-aligned timestamp convention
|
||||||
|
- Integrates with existing database and configuration systems
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from data.common.data_types import OHLCVCandle
|
||||||
|
from data.common.indicators import TechnicalIndicators
|
||||||
|
from .base import BaseStrategy
|
||||||
|
from .data_types import StrategyResult
|
||||||
|
from .utils import create_indicator_key
|
||||||
|
from .implementations.ema_crossover import EMAStrategy
|
||||||
|
from .implementations.rsi import RSIStrategy
|
||||||
|
from .implementations.macd import MACDStrategy
|
||||||
|
# Strategy implementations will be imported as they are created
|
||||||
|
# from .implementations import (
|
||||||
|
# EMAStrategy,
|
||||||
|
# RSIStrategy,
|
||||||
|
# MACDStrategy
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyFactory:
|
||||||
|
"""
|
||||||
|
Strategy factory for creating and managing trading strategies.
|
||||||
|
|
||||||
|
This class provides strategy instantiation, calculation orchestration,
|
||||||
|
and signal generation. It integrates with the TechnicalIndicators
|
||||||
|
system to provide pre-calculated indicator data to strategies.
|
||||||
|
|
||||||
|
STRATEGY-INDICATOR INTEGRATION:
|
||||||
|
- Strategies declare required indicators via get_required_indicators()
|
||||||
|
- Factory pre-calculates all required indicators
|
||||||
|
- Strategies receive indicator data as dictionary of DataFrames
|
||||||
|
- Results maintain original timestamp alignment
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger=None):
|
||||||
|
"""
|
||||||
|
Initialize strategy factory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger: Optional logger instance
|
||||||
|
"""
|
||||||
|
self.logger = logger
|
||||||
|
self.technical_indicators = TechnicalIndicators(logger)
|
||||||
|
|
||||||
|
# Registry of available strategies (will be populated as strategies are implemented)
|
||||||
|
self._strategy_registry = {
|
||||||
|
'ema_crossover': EMAStrategy,
|
||||||
|
'rsi': RSIStrategy,
|
||||||
|
'macd': MACDStrategy
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info("StrategyFactory: Initialized strategy factory")
|
||||||
|
|
||||||
|
def register_strategy(self, name: str, strategy_class: type) -> None:
|
||||||
|
"""
|
||||||
|
Register a new strategy class in the factory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Strategy name identifier
|
||||||
|
strategy_class: Strategy class (must inherit from BaseStrategy)
|
||||||
|
"""
|
||||||
|
if not issubclass(strategy_class, BaseStrategy):
|
||||||
|
raise ValueError(f"Strategy class {strategy_class} must inherit from BaseStrategy")
|
||||||
|
|
||||||
|
self._strategy_registry[name] = strategy_class
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"StrategyFactory: Registered strategy '{name}'")
|
||||||
|
|
||||||
|
def get_available_strategies(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of available strategy names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of registered strategy names
|
||||||
|
"""
|
||||||
|
return list(self._strategy_registry.keys())
|
||||||
|
|
||||||
|
def create_strategy(self, strategy_name: str) -> Optional[BaseStrategy]:
|
||||||
|
"""
|
||||||
|
Create a strategy instance by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy_name: Name of the strategy to create
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Strategy instance or None if strategy not found
|
||||||
|
"""
|
||||||
|
strategy_class = self._strategy_registry.get(strategy_name)
|
||||||
|
if not strategy_class:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"Unknown strategy: {strategy_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return strategy_class(logger=self.logger)
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"Error creating strategy {strategy_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_strategy_signals(self, strategy_name: str, df: pd.DataFrame,
|
||||||
|
strategy_config: Dict[str, Any]) -> List[StrategyResult]:
|
||||||
|
"""
|
||||||
|
Calculate signals for a specific strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy_name: Name of the strategy to execute
|
||||||
|
df: DataFrame with OHLCV data
|
||||||
|
strategy_config: Strategy-specific configuration parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of strategy results with signals
|
||||||
|
"""
|
||||||
|
# Create strategy instance
|
||||||
|
strategy = self.create_strategy(strategy_name)
|
||||||
|
if not strategy:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get required indicators for this strategy
|
||||||
|
required_indicators = strategy.get_required_indicators()
|
||||||
|
|
||||||
|
# Pre-calculate all required indicators
|
||||||
|
indicators_data = self._calculate_required_indicators(df, required_indicators)
|
||||||
|
|
||||||
|
# Calculate strategy signals
|
||||||
|
results = strategy.calculate(df, indicators_data, **strategy_config)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"Error calculating strategy {strategy_name}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def calculate_multiple_strategies(self, df: pd.DataFrame,
|
||||||
|
strategies_config: Dict[str, Dict[str, Any]]) -> Dict[str, List[StrategyResult]]:
|
||||||
|
"""
|
||||||
|
Calculate signals for multiple strategies efficiently.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame with OHLCV data
|
||||||
|
strategies_config: Configuration for strategies to calculate
|
||||||
|
Example: {
|
||||||
|
'ema_cross_1': {'strategy': 'ema_crossover', 'fast_period': 12, 'slow_period': 26},
|
||||||
|
'rsi_momentum': {'strategy': 'rsi', 'period': 14, 'oversold': 30, 'overbought': 70}
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping strategy instance names to their results
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for strategy_instance_name, config in strategies_config.items():
|
||||||
|
strategy_name = config.get('strategy')
|
||||||
|
if not strategy_name:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"No strategy specified for {strategy_instance_name}")
|
||||||
|
results[strategy_instance_name] = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract strategy parameters (exclude 'strategy' key)
|
||||||
|
strategy_params = {k: v for k, v in config.items() if k != 'strategy'}
|
||||||
|
|
||||||
|
try:
|
||||||
|
strategy_results = self.calculate_strategy_signals(
|
||||||
|
strategy_name, df, strategy_params
|
||||||
|
)
|
||||||
|
results[strategy_instance_name] = strategy_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"Error calculating strategy {strategy_instance_name}: {e}")
|
||||||
|
results[strategy_instance_name] = []
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _calculate_required_indicators(self, df: pd.DataFrame,
|
||||||
|
required_indicators: List[Dict[str, Any]]) -> Dict[str, pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
Pre-calculate all indicators required by a strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame with OHLCV data
|
||||||
|
required_indicators: List of indicator configurations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of indicator DataFrames keyed by indicator name
|
||||||
|
"""
|
||||||
|
indicators_data = {}
|
||||||
|
|
||||||
|
for indicator_config in required_indicators:
|
||||||
|
indicator_type = indicator_config.get('type')
|
||||||
|
if not indicator_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create a unique key for this indicator configuration
|
||||||
|
indicator_key = self._create_indicator_key(indicator_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Calculate the indicator using TechnicalIndicators
|
||||||
|
indicator_result = self.technical_indicators.calculate(
|
||||||
|
indicator_type, df, **{k: v for k, v in indicator_config.items() if k != 'type'}
|
||||||
|
)
|
||||||
|
|
||||||
|
if indicator_result is not None and not indicator_result.empty:
|
||||||
|
indicators_data[indicator_key] = indicator_result
|
||||||
|
else:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"Empty result for indicator: {indicator_key}")
|
||||||
|
indicators_data[indicator_key] = pd.DataFrame()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"Error calculating indicator {indicator_key}: {e}")
|
||||||
|
indicators_data[indicator_key] = pd.DataFrame()
|
||||||
|
|
||||||
|
return indicators_data
|
||||||
|
|
||||||
|
def _create_indicator_key(self, indicator_config: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Create a unique key for an indicator configuration.
|
||||||
|
Delegates to shared utility function to ensure consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicator_config: Indicator configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Unique string key for this indicator configuration
|
||||||
|
"""
|
||||||
|
return create_indicator_key(indicator_config)
|
||||||
18
strategies/implementations/__init__.py
Normal file
18
strategies/implementations/__init__.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
Strategy implementations package.
|
||||||
|
|
||||||
|
This package contains individual implementations of trading strategies,
|
||||||
|
each in its own module for better maintainability and separation of concerns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .ema_crossover import EMAStrategy
|
||||||
|
from .rsi import RSIStrategy
|
||||||
|
from .macd import MACDStrategy
|
||||||
|
# from .macd import MACDIndicator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'EMAStrategy',
|
||||||
|
'RSIStrategy',
|
||||||
|
'MACDStrategy',
|
||||||
|
# 'MACDIndicator'
|
||||||
|
]
|
||||||
185
strategies/implementations/ema_crossover.py
Normal file
185
strategies/implementations/ema_crossover.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
"""
|
||||||
|
EMA Crossover Strategy Implementation
|
||||||
|
|
||||||
|
This module implements an Exponential Moving Average (EMA) Crossover trading strategy.
|
||||||
|
It extends the BaseStrategy and generates buy/sell signals based on the crossover
|
||||||
|
of a fast EMA and a slow EMA.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from ..base import BaseStrategy
|
||||||
|
from ..data_types import StrategyResult, StrategySignal, SignalType
|
||||||
|
from ..utils import create_indicator_key, detect_crossover_signals_vectorized
|
||||||
|
|
||||||
|
|
||||||
|
class EMAStrategy(BaseStrategy):
|
||||||
|
"""
|
||||||
|
EMA Crossover Strategy.
|
||||||
|
|
||||||
|
Generates buy/sell signals when a fast EMA crosses above or below a slow EMA.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger=None):
|
||||||
|
super().__init__(logger)
|
||||||
|
self.strategy_name = "ema_crossover"
|
||||||
|
|
||||||
|
def get_required_indicators(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Defines the indicators required by the EMA Crossover strategy.
|
||||||
|
It needs two EMA indicators: a fast one and a slow one.
|
||||||
|
"""
|
||||||
|
# Default periods for EMA crossover, can be overridden by strategy config
|
||||||
|
return [
|
||||||
|
{'type': 'ema', 'period': 12, 'price_column': 'close'},
|
||||||
|
{'type': 'ema', 'period': 26, 'price_column': 'close'}
|
||||||
|
]
|
||||||
|
|
||||||
|
def calculate(self, df: pd.DataFrame, indicators_data: Dict[str, pd.DataFrame], **kwargs) -> List[StrategyResult]:
|
||||||
|
"""
|
||||||
|
Calculate EMA Crossover strategy signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame with OHLCV data.
|
||||||
|
indicators_data: Dictionary of pre-calculated indicator DataFrames.
|
||||||
|
Expected keys: 'ema_period_12', 'ema_period_26'.
|
||||||
|
**kwargs: Additional strategy parameters (e.g., fast_period, slow_period, price_column).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of StrategyResult objects, each containing generated signals.
|
||||||
|
"""
|
||||||
|
# Extract EMA periods from kwargs or use defaults
|
||||||
|
fast_period = kwargs.get('fast_period', 12)
|
||||||
|
slow_period = kwargs.get('slow_period', 26)
|
||||||
|
price_column = kwargs.get('price_column', 'close')
|
||||||
|
|
||||||
|
# Generate indicator keys using shared utility function
|
||||||
|
fast_ema_key = create_indicator_key({'type': 'ema', 'period': fast_period})
|
||||||
|
slow_ema_key = create_indicator_key({'type': 'ema', 'period': slow_period})
|
||||||
|
|
||||||
|
# Validate that the main DataFrame has enough data for strategy calculation (not just indicators)
|
||||||
|
if not self.validate_dataframe(df, max(fast_period, slow_period)):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: Insufficient main DataFrame for calculation.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Validate that the required indicators are present and have sufficient data
|
||||||
|
required_indicators = [
|
||||||
|
{'type': 'ema', 'period': fast_period},
|
||||||
|
{'type': 'ema', 'period': slow_period}
|
||||||
|
]
|
||||||
|
if not self.validate_indicators_data(indicators_data, required_indicators):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: Missing or insufficient indicator data.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
fast_ema_df = indicators_data.get(fast_ema_key)
|
||||||
|
slow_ema_df = indicators_data.get(slow_ema_key)
|
||||||
|
|
||||||
|
if fast_ema_df is None or slow_ema_df is None or fast_ema_df.empty or slow_ema_df.empty:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: EMA indicator DataFrames are not found or empty.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Merge all necessary data into a single DataFrame for easier processing
|
||||||
|
# Ensure alignment by index (timestamp)
|
||||||
|
merged_df = pd.merge(df[[price_column, 'symbol', 'timeframe']],
|
||||||
|
fast_ema_df[['ema']],
|
||||||
|
left_index=True, right_index=True, how='inner',
|
||||||
|
suffixes= ('', '_fast'))
|
||||||
|
merged_df = pd.merge(merged_df,
|
||||||
|
slow_ema_df[['ema']],
|
||||||
|
left_index=True, right_index=True, how='inner',
|
||||||
|
suffixes= ('', '_slow'))
|
||||||
|
|
||||||
|
# Rename columns to their logical names after merge
|
||||||
|
merged_df.rename(columns={'ema': 'ema_fast', 'ema_slow': 'ema_slow'}, inplace=True)
|
||||||
|
|
||||||
|
if merged_df.empty:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: Merged DataFrame is empty after indicator alignment. Check data ranges.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use vectorized signal detection for better performance
|
||||||
|
bullish_crossover, bearish_crossover = detect_crossover_signals_vectorized(
|
||||||
|
merged_df, 'ema_fast', 'ema_slow'
|
||||||
|
)
|
||||||
|
|
||||||
|
results: List[StrategyResult] = []
|
||||||
|
strategy_metadata = {
|
||||||
|
'fast_period': fast_period,
|
||||||
|
'slow_period': slow_period
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process bullish crossover signals
|
||||||
|
bullish_indices = merged_df[bullish_crossover].index
|
||||||
|
for timestamp in bullish_indices:
|
||||||
|
row = merged_df.loc[timestamp]
|
||||||
|
|
||||||
|
# Skip if any EMA values are NaN
|
||||||
|
if pd.isna(row['ema_fast']) or pd.isna(row['ema_slow']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
signal = StrategySignal(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
signal_type=SignalType.BUY,
|
||||||
|
price=float(row[price_column]),
|
||||||
|
confidence=0.8,
|
||||||
|
metadata={'crossover_type': 'bullish', **strategy_metadata}
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(StrategyResult(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
strategy_name=self.strategy_name,
|
||||||
|
signals=[signal],
|
||||||
|
indicators_used={
|
||||||
|
'ema_fast': float(row['ema_fast']),
|
||||||
|
'ema_slow': float(row['ema_slow'])
|
||||||
|
},
|
||||||
|
metadata=strategy_metadata
|
||||||
|
))
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"{self.strategy_name}: BUY signal at {timestamp} for {row['symbol']}")
|
||||||
|
|
||||||
|
# Process bearish crossover signals
|
||||||
|
bearish_indices = merged_df[bearish_crossover].index
|
||||||
|
for timestamp in bearish_indices:
|
||||||
|
row = merged_df.loc[timestamp]
|
||||||
|
|
||||||
|
# Skip if any EMA values are NaN
|
||||||
|
if pd.isna(row['ema_fast']) or pd.isna(row['ema_slow']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
signal = StrategySignal(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
signal_type=SignalType.SELL,
|
||||||
|
price=float(row[price_column]),
|
||||||
|
confidence=0.8,
|
||||||
|
metadata={'crossover_type': 'bearish', **strategy_metadata}
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(StrategyResult(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
strategy_name=self.strategy_name,
|
||||||
|
signals=[signal],
|
||||||
|
indicators_used={
|
||||||
|
'ema_fast': float(row['ema_fast']),
|
||||||
|
'ema_slow': float(row['ema_slow'])
|
||||||
|
},
|
||||||
|
metadata=strategy_metadata
|
||||||
|
))
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"{self.strategy_name}: SELL signal at {timestamp} for {row['symbol']}")
|
||||||
|
|
||||||
|
return results
|
||||||
180
strategies/implementations/macd.py
Normal file
180
strategies/implementations/macd.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
"""
|
||||||
|
MACD Strategy Implementation
|
||||||
|
|
||||||
|
This module implements a Moving Average Convergence Divergence (MACD) trading strategy.
|
||||||
|
It extends the BaseStrategy and generates buy/sell signals based on the crossover
|
||||||
|
of the MACD line and its signal line.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from ..base import BaseStrategy
|
||||||
|
from ..data_types import StrategyResult, StrategySignal, SignalType
|
||||||
|
from ..utils import create_indicator_key, detect_crossover_signals_vectorized
|
||||||
|
|
||||||
|
|
||||||
|
class MACDStrategy(BaseStrategy):
|
||||||
|
"""
|
||||||
|
MACD Strategy.
|
||||||
|
|
||||||
|
Generates buy/sell signals when the MACD line crosses above or below its signal line.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger=None):
|
||||||
|
super().__init__(logger)
|
||||||
|
self.strategy_name = "macd"
|
||||||
|
|
||||||
|
def get_required_indicators(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Defines the indicators required by the MACD strategy.
|
||||||
|
It needs one MACD indicator.
|
||||||
|
"""
|
||||||
|
# Default periods for MACD, can be overridden by strategy config
|
||||||
|
return [
|
||||||
|
{'type': 'macd', 'fast_period': 12, 'slow_period': 26, 'signal_period': 9, 'price_column': 'close'}
|
||||||
|
]
|
||||||
|
|
||||||
|
def calculate(self, df: pd.DataFrame, indicators_data: Dict[str, pd.DataFrame], **kwargs) -> List[StrategyResult]:
|
||||||
|
"""
|
||||||
|
Calculate MACD strategy signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame with OHLCV data.
|
||||||
|
indicators_data: Dictionary of pre-calculated indicator DataFrames.
|
||||||
|
Expected key: 'macd_fast_period_12_slow_period_26_signal_period_9'.
|
||||||
|
**kwargs: Additional strategy parameters (e.g., fast_period, slow_period, signal_period, price_column).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of StrategyResult objects, each containing generated signals.
|
||||||
|
"""
|
||||||
|
# Extract parameters from kwargs or use defaults
|
||||||
|
fast_period = kwargs.get('fast_period', 12)
|
||||||
|
slow_period = kwargs.get('slow_period', 26)
|
||||||
|
signal_period = kwargs.get('signal_period', 9)
|
||||||
|
price_column = kwargs.get('price_column', 'close')
|
||||||
|
|
||||||
|
# Generate indicator key using shared utility function
|
||||||
|
macd_key = create_indicator_key({
|
||||||
|
'type': 'macd',
|
||||||
|
'fast_period': fast_period,
|
||||||
|
'slow_period': slow_period,
|
||||||
|
'signal_period': signal_period
|
||||||
|
})
|
||||||
|
|
||||||
|
# Validate that the main DataFrame has enough data for strategy calculation
|
||||||
|
min_periods = max(slow_period, signal_period)
|
||||||
|
if not self.validate_dataframe(df, min_periods):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: Insufficient main DataFrame for calculation.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Validate that the required MACD indicator data is present and sufficient
|
||||||
|
required_indicators = [
|
||||||
|
{'type': 'macd', 'fast_period': fast_period, 'slow_period': slow_period, 'signal_period': signal_period}
|
||||||
|
]
|
||||||
|
if not self.validate_indicators_data(indicators_data, required_indicators):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: Missing or insufficient MACD indicator data.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
macd_df = indicators_data.get(macd_key)
|
||||||
|
|
||||||
|
if macd_df is None or macd_df.empty:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: MACD indicator DataFrame is not found or empty.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Merge all necessary data into a single DataFrame for easier processing
|
||||||
|
merged_df = pd.merge(df[[price_column, 'symbol', 'timeframe']],
|
||||||
|
macd_df[['macd', 'signal', 'histogram']],
|
||||||
|
left_index=True, right_index=True, how='inner')
|
||||||
|
|
||||||
|
if merged_df.empty:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: Merged DataFrame is empty after indicator alignment. Check data ranges.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use vectorized signal detection for better performance
|
||||||
|
bullish_crossover, bearish_crossover = detect_crossover_signals_vectorized(
|
||||||
|
merged_df, 'macd', 'signal'
|
||||||
|
)
|
||||||
|
|
||||||
|
results: List[StrategyResult] = []
|
||||||
|
strategy_metadata = {
|
||||||
|
'fast_period': fast_period,
|
||||||
|
'slow_period': slow_period,
|
||||||
|
'signal_period': signal_period
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process bullish crossover signals (MACD crosses above Signal)
|
||||||
|
bullish_indices = merged_df[bullish_crossover].index
|
||||||
|
for timestamp in bullish_indices:
|
||||||
|
row = merged_df.loc[timestamp]
|
||||||
|
|
||||||
|
# Skip if any MACD values are NaN
|
||||||
|
if pd.isna(row['macd']) or pd.isna(row['signal']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
signal = StrategySignal(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
signal_type=SignalType.BUY,
|
||||||
|
price=float(row[price_column]),
|
||||||
|
confidence=0.9,
|
||||||
|
metadata={'macd_cross': 'bullish', **strategy_metadata}
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(StrategyResult(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
strategy_name=self.strategy_name,
|
||||||
|
signals=[signal],
|
||||||
|
indicators_used={
|
||||||
|
'macd': float(row['macd']),
|
||||||
|
'signal': float(row['signal'])
|
||||||
|
},
|
||||||
|
metadata=strategy_metadata
|
||||||
|
))
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"{self.strategy_name}: BUY signal at {timestamp} for {row['symbol']} (MACD: {row['macd']:.2f}, Signal: {row['signal']:.2f})")
|
||||||
|
|
||||||
|
# Process bearish crossover signals (MACD crosses below Signal)
|
||||||
|
bearish_indices = merged_df[bearish_crossover].index
|
||||||
|
for timestamp in bearish_indices:
|
||||||
|
row = merged_df.loc[timestamp]
|
||||||
|
|
||||||
|
# Skip if any MACD values are NaN
|
||||||
|
if pd.isna(row['macd']) or pd.isna(row['signal']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
signal = StrategySignal(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
signal_type=SignalType.SELL,
|
||||||
|
price=float(row[price_column]),
|
||||||
|
confidence=0.9,
|
||||||
|
metadata={'macd_cross': 'bearish', **strategy_metadata}
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(StrategyResult(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
strategy_name=self.strategy_name,
|
||||||
|
signals=[signal],
|
||||||
|
indicators_used={
|
||||||
|
'macd': float(row['macd']),
|
||||||
|
'signal': float(row['signal'])
|
||||||
|
},
|
||||||
|
metadata=strategy_metadata
|
||||||
|
))
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"{self.strategy_name}: SELL signal at {timestamp} for {row['symbol']} (MACD: {row['macd']:.2f}, Signal: {row['signal']:.2f})")
|
||||||
|
|
||||||
|
return results
|
||||||
168
strategies/implementations/rsi.py
Normal file
168
strategies/implementations/rsi.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
Relative Strength Index (RSI) Strategy Implementation
|
||||||
|
|
||||||
|
This module implements an RSI-based momentum trading strategy.
|
||||||
|
It extends the BaseStrategy and generates buy/sell signals based on
|
||||||
|
RSI crossing overbought/oversold thresholds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from ..base import BaseStrategy
|
||||||
|
from ..data_types import StrategyResult, StrategySignal, SignalType
|
||||||
|
from ..utils import create_indicator_key, detect_threshold_signals_vectorized
|
||||||
|
|
||||||
|
|
||||||
|
class RSIStrategy(BaseStrategy):
|
||||||
|
"""
|
||||||
|
RSI Strategy.
|
||||||
|
|
||||||
|
Generates buy/sell signals when RSI crosses overbought/oversold thresholds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger=None):
|
||||||
|
super().__init__(logger)
|
||||||
|
self.strategy_name = "rsi"
|
||||||
|
|
||||||
|
def get_required_indicators(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Defines the indicators required by the RSI strategy.
|
||||||
|
It needs one RSI indicator.
|
||||||
|
"""
|
||||||
|
# Default period for RSI, can be overridden by strategy config
|
||||||
|
return [
|
||||||
|
{'type': 'rsi', 'period': 14, 'price_column': 'close'}
|
||||||
|
]
|
||||||
|
|
||||||
|
def calculate(self, df: pd.DataFrame, indicators_data: Dict[str, pd.DataFrame], **kwargs) -> List[StrategyResult]:
|
||||||
|
"""
|
||||||
|
Calculate RSI strategy signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame with OHLCV data.
|
||||||
|
indicators_data: Dictionary of pre-calculated indicator DataFrames.
|
||||||
|
Expected key: 'rsi_period_14'.
|
||||||
|
**kwargs: Additional strategy parameters (e.g., period, overbought, oversold, price_column).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of StrategyResult objects, each containing generated signals.
|
||||||
|
"""
|
||||||
|
# Extract parameters from kwargs or use defaults
|
||||||
|
period = kwargs.get('period', 14)
|
||||||
|
overbought = kwargs.get('overbought', 70)
|
||||||
|
oversold = kwargs.get('oversold', 30)
|
||||||
|
price_column = kwargs.get('price_column', 'close')
|
||||||
|
|
||||||
|
# Generate indicator key using shared utility function
|
||||||
|
rsi_key = create_indicator_key({'type': 'rsi', 'period': period})
|
||||||
|
|
||||||
|
# Validate that the main DataFrame has enough data for strategy calculation
|
||||||
|
if not self.validate_dataframe(df, period):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: Insufficient main DataFrame for calculation.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Validate that the required RSI indicator data is present and sufficient
|
||||||
|
required_indicators = [
|
||||||
|
{'type': 'rsi', 'period': period}
|
||||||
|
]
|
||||||
|
if not self.validate_indicators_data(indicators_data, required_indicators):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: Missing or insufficient RSI indicator data.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
rsi_df = indicators_data.get(rsi_key)
|
||||||
|
|
||||||
|
if rsi_df is None or rsi_df.empty:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: RSI indicator DataFrame is not found or empty.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Merge all necessary data into a single DataFrame for easier processing
|
||||||
|
merged_df = pd.merge(df[[price_column, 'symbol', 'timeframe']],
|
||||||
|
rsi_df[['rsi']],
|
||||||
|
left_index=True, right_index=True, how='inner')
|
||||||
|
|
||||||
|
if merged_df.empty:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.strategy_name}: Merged DataFrame is empty after indicator alignment. Check data ranges.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use vectorized signal detection for better performance
|
||||||
|
buy_signals, sell_signals = detect_threshold_signals_vectorized(
|
||||||
|
merged_df, 'rsi', overbought, oversold
|
||||||
|
)
|
||||||
|
|
||||||
|
results: List[StrategyResult] = []
|
||||||
|
strategy_metadata = {
|
||||||
|
'period': period,
|
||||||
|
'overbought': overbought,
|
||||||
|
'oversold': oversold
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process buy signals (RSI crosses above oversold threshold)
|
||||||
|
buy_indices = merged_df[buy_signals].index
|
||||||
|
for timestamp in buy_indices:
|
||||||
|
row = merged_df.loc[timestamp]
|
||||||
|
|
||||||
|
# Skip if RSI value is NaN
|
||||||
|
if pd.isna(row['rsi']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
signal = StrategySignal(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
signal_type=SignalType.BUY,
|
||||||
|
price=float(row[price_column]),
|
||||||
|
confidence=0.7,
|
||||||
|
metadata={'rsi_cross': 'oversold_to_buy', **strategy_metadata}
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(StrategyResult(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
strategy_name=self.strategy_name,
|
||||||
|
signals=[signal],
|
||||||
|
indicators_used={'rsi': float(row['rsi'])},
|
||||||
|
metadata=strategy_metadata
|
||||||
|
))
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"{self.strategy_name}: BUY signal at {timestamp} for {row['symbol']} (RSI: {row['rsi']:.2f})")
|
||||||
|
|
||||||
|
# Process sell signals (RSI crosses below overbought threshold)
|
||||||
|
sell_indices = merged_df[sell_signals].index
|
||||||
|
for timestamp in sell_indices:
|
||||||
|
row = merged_df.loc[timestamp]
|
||||||
|
|
||||||
|
# Skip if RSI value is NaN
|
||||||
|
if pd.isna(row['rsi']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
signal = StrategySignal(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
signal_type=SignalType.SELL,
|
||||||
|
price=float(row[price_column]),
|
||||||
|
confidence=0.7,
|
||||||
|
metadata={'rsi_cross': 'overbought_to_sell', **strategy_metadata}
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(StrategyResult(
|
||||||
|
timestamp=timestamp,
|
||||||
|
symbol=row['symbol'],
|
||||||
|
timeframe=row['timeframe'],
|
||||||
|
strategy_name=self.strategy_name,
|
||||||
|
signals=[signal],
|
||||||
|
indicators_used={'rsi': float(row['rsi'])},
|
||||||
|
metadata=strategy_metadata
|
||||||
|
))
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"{self.strategy_name}: SELL signal at {timestamp} for {row['symbol']} (RSI: {row['rsi']:.2f})")
|
||||||
|
|
||||||
|
return results
|
||||||
166
strategies/utils.py
Normal file
166
strategies/utils.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
Trading Strategy Utilities
|
||||||
|
|
||||||
|
This module provides utility functions for managing trading strategy
|
||||||
|
configurations and validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List, Tuple
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def create_default_strategies_config() -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Create default configuration for common trading strategies.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with default strategy configurations
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'ema_crossover_default': {'strategy': 'ema_crossover', 'fast_period': 12, 'slow_period': 26},
|
||||||
|
'rsi_default': {'strategy': 'rsi', 'period': 14, 'overbought': 70, 'oversold': 30},
|
||||||
|
'macd_default': {'strategy': 'macd', 'fast_period': 12, 'slow_period': 26, 'signal_period': 9}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_indicator_key(indicator_config: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Create a unique key for an indicator configuration.
|
||||||
|
This function is shared between StrategyFactory and individual strategies
|
||||||
|
to ensure consistency and reduce duplication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicator_config: Indicator configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Unique string key for this indicator configuration
|
||||||
|
"""
|
||||||
|
indicator_type = indicator_config.get('type', 'unknown')
|
||||||
|
|
||||||
|
# Common parameters that should be part of the key
|
||||||
|
key_params = []
|
||||||
|
for param in ['period', 'fast_period', 'slow_period', 'signal_period', 'std_dev']:
|
||||||
|
if param in indicator_config:
|
||||||
|
key_params.append(f"{param}_{indicator_config[param]}")
|
||||||
|
|
||||||
|
if key_params:
|
||||||
|
return f"{indicator_type}_{'_'.join(key_params)}"
|
||||||
|
else:
|
||||||
|
return f"{indicator_type}_default"
|
||||||
|
|
||||||
|
|
||||||
|
def detect_crossover_signals_vectorized(df: pd.DataFrame,
|
||||||
|
fast_col: str,
|
||||||
|
slow_col: str) -> Tuple[pd.Series, pd.Series]:
|
||||||
|
"""
|
||||||
|
Detect crossover signals using vectorized operations for better performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame with indicator data
|
||||||
|
fast_col: Column name for fast line (e.g., 'ema_fast')
|
||||||
|
slow_col: Column name for slow line (e.g., 'ema_slow')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (bullish_crossover_mask, bearish_crossover_mask) as boolean Series
|
||||||
|
"""
|
||||||
|
# Get previous values using shift
|
||||||
|
fast_prev = df[fast_col].shift(1)
|
||||||
|
slow_prev = df[slow_col].shift(1)
|
||||||
|
fast_curr = df[fast_col]
|
||||||
|
slow_curr = df[slow_col]
|
||||||
|
|
||||||
|
# Bullish crossover: fast was <= slow, now fast > slow
|
||||||
|
bullish_crossover = (fast_prev <= slow_prev) & (fast_curr > slow_curr)
|
||||||
|
|
||||||
|
# Bearish crossover: fast was >= slow, now fast < slow
|
||||||
|
bearish_crossover = (fast_prev >= slow_prev) & (fast_curr < slow_curr)
|
||||||
|
|
||||||
|
# Ensure no signals on first row (index 0) where there's no previous value
|
||||||
|
bullish_crossover.iloc[0] = False
|
||||||
|
bearish_crossover.iloc[0] = False
|
||||||
|
|
||||||
|
return bullish_crossover, bearish_crossover
|
||||||
|
|
||||||
|
|
||||||
|
def detect_threshold_signals_vectorized(df: pd.DataFrame,
|
||||||
|
indicator_col: str,
|
||||||
|
upper_threshold: float,
|
||||||
|
lower_threshold: float) -> Tuple[pd.Series, pd.Series]:
|
||||||
|
"""
|
||||||
|
Detect threshold crossing signals using vectorized operations (for RSI, etc.).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame with indicator data
|
||||||
|
indicator_col: Column name for indicator (e.g., 'rsi')
|
||||||
|
upper_threshold: Upper threshold (e.g., 70 for RSI overbought)
|
||||||
|
lower_threshold: Lower threshold (e.g., 30 for RSI oversold)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (buy_signal_mask, sell_signal_mask) as boolean Series
|
||||||
|
"""
|
||||||
|
# Get previous values using shift
|
||||||
|
indicator_prev = df[indicator_col].shift(1)
|
||||||
|
indicator_curr = df[indicator_col]
|
||||||
|
|
||||||
|
# Buy signal: was <= lower_threshold, now > lower_threshold (oversold to normal)
|
||||||
|
buy_signal = (indicator_prev <= lower_threshold) & (indicator_curr > lower_threshold)
|
||||||
|
|
||||||
|
# Sell signal: was >= upper_threshold, now < upper_threshold (overbought to normal)
|
||||||
|
sell_signal = (indicator_prev >= upper_threshold) & (indicator_curr < upper_threshold)
|
||||||
|
|
||||||
|
# Ensure no signals on first row (index 0) where there's no previous value
|
||||||
|
buy_signal.iloc[0] = False
|
||||||
|
sell_signal.iloc[0] = False
|
||||||
|
|
||||||
|
return buy_signal, sell_signal
|
||||||
|
|
||||||
|
|
||||||
|
def validate_strategy_config(config: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Validate trading strategy configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Strategy configuration dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if configuration is valid, False otherwise
|
||||||
|
"""
|
||||||
|
required_fields = ['strategy']
|
||||||
|
|
||||||
|
# Check required fields
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in config:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate strategy type
|
||||||
|
valid_types = ['ema_crossover', 'rsi', 'macd']
|
||||||
|
if config['strategy'] not in valid_types:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Basic validation for common parameters based on strategy type
|
||||||
|
strategy_type = config['strategy']
|
||||||
|
if strategy_type == 'ema_crossover':
|
||||||
|
if not all(k in config for k in ['fast_period', 'slow_period']):
|
||||||
|
return False
|
||||||
|
if not (isinstance(config['fast_period'], int) and config['fast_period'] > 0 and
|
||||||
|
isinstance(config['slow_period'], int) and config['slow_period'] > 0 and
|
||||||
|
config['fast_period'] < config['slow_period']):
|
||||||
|
return False
|
||||||
|
elif strategy_type == 'rsi':
|
||||||
|
if not all(k in config for k in ['period', 'overbought', 'oversold']):
|
||||||
|
return False
|
||||||
|
if not (isinstance(config['period'], int) and config['period'] > 0 and
|
||||||
|
isinstance(config['overbought'], (int, float)) and isinstance(config['oversold'], (int, float)) and
|
||||||
|
0 <= config['oversold'] < config['overbought'] <= 100):
|
||||||
|
return False
|
||||||
|
elif strategy_type == 'macd':
|
||||||
|
if not all(k in config for k in ['fast_period', 'slow_period', 'signal_period']):
|
||||||
|
return False
|
||||||
|
if not (isinstance(config['fast_period'], int) and config['fast_period'] > 0 and
|
||||||
|
isinstance(config['slow_period'], int) and config['slow_period'] > 0 and
|
||||||
|
isinstance(config['signal_period'], int) and config['signal_period'] > 0 and
|
||||||
|
config['fast_period'] < config['slow_period']):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
@ -38,19 +38,31 @@
|
|||||||
- **Layered Chart Integration**: Strategy signals and performance visualizations will be integrated into the dashboard as a new chart layer, utilizing the existing modular chart system.
|
- **Layered Chart Integration**: Strategy signals and performance visualizations will be integrated into the dashboard as a new chart layer, utilizing the existing modular chart system.
|
||||||
- **Comprehensive Testing**: Ensure that all new classes, functions, and modules within the strategy engine have corresponding unit tests placed in the `tests/strategies/` directory, following established testing conventions.
|
- **Comprehensive Testing**: Ensure that all new classes, functions, and modules within the strategy engine have corresponding unit tests placed in the `tests/strategies/` directory, following established testing conventions.
|
||||||
|
|
||||||
|
## Decisions
|
||||||
|
|
||||||
|
### 1. Vectorized vs. Iterative Calculation
|
||||||
|
- **Decision**: Refactored strategy signal detection to use vectorized Pandas operations (e.g., `shift()`, boolean indexing) instead of iterative Python loops.
|
||||||
|
- **Reasoning**: Significantly improves performance for signal generation, especially with large datasets, while maintaining identical results as verified by dedicated tests.
|
||||||
|
- **Impact**: All core strategy implementations (EMA Crossover, RSI, MACD) now leverage vectorized functions for their primary signal detection logic.
|
||||||
|
|
||||||
|
### 2. Indicator Key Generation Consistency
|
||||||
|
- **Decision**: Centralized the `_create_indicator_key` logic into a shared utility function `create_indicator_key()` in `strategies/utils.py`.
|
||||||
|
- **Reasoning**: Eliminates code duplication in `StrategyFactory` and individual strategy implementations, ensuring consistent key generation and easier maintenance if indicator naming conventions change.
|
||||||
|
- **Impact**: `StrategyFactory` and all strategy implementations now use this shared utility for generating unique indicator keys.
|
||||||
|
|
||||||
## Tasks
|
## Tasks
|
||||||
|
|
||||||
- [ ] 1.0 Core Strategy Foundation Setup
|
- [x] 1.0 Core Strategy Foundation Setup
|
||||||
- [ ] 1.1 Create `strategies/` directory structure following indicators pattern
|
- [x] 1.1 Create `strategies/` directory structure following indicators pattern
|
||||||
- [ ] 1.2 Implement `BaseStrategy` abstract class in `strategies/base.py` with `calculate()` and `get_required_indicators()` methods
|
- [x] 1.2 Implement `BaseStrategy` abstract class in `strategies/base.py` with `calculate()` and `get_required_indicators()` methods
|
||||||
- [ ] 1.3 Create `strategies/data_types.py` with `StrategySignal`, `SignalType`, and `StrategyResult` classes
|
- [x] 1.3 Create `strategies/data_types.py` with `StrategySignal`, `SignalType`, and `StrategyResult` classes
|
||||||
- [ ] 1.4 Implement `StrategyFactory` class in `strategies/factory.py` for dynamic strategy loading and registration
|
- [x] 1.4 Implement `StrategyFactory` class in `strategies/factory.py` for dynamic strategy loading and registration
|
||||||
- [ ] 1.5 Create strategy implementations directory `strategies/implementations/`
|
- [x] 1.5 Create strategy implementations directory `strategies/implementations/`
|
||||||
- [ ] 1.6 Implement `EMAStrategy` in `strategies/implementations/ema_crossover.py` as reference implementation
|
- [x] 1.6 Implement `EMAStrategy` in `strategies/implementations/ema_crossover.py` as reference implementation
|
||||||
- [ ] 1.7 Implement `RSIStrategy` in `strategies/implementations/rsi.py` for momentum-based signals
|
- [x] 1.7 Implement `RSIStrategy` in `strategies/implementations/rsi.py` for momentum-based signals
|
||||||
- [ ] 1.8 Implement `MACDStrategy` in `strategies/implementations/macd.py` for trend-following signals
|
- [x] 1.8 Implement `MACDStrategy` in `strategies/implementations/macd.py` for trend-following signals
|
||||||
- [ ] 1.9 Create `strategies/utils.py` with helper functions for signal validation and processing
|
- [x] 1.9 Create `strategies/utils.py` with helper functions for signal validation and processing
|
||||||
- [ ] 1.10 Create comprehensive unit tests for all strategy foundation components
|
- [x] 1.10 Create comprehensive unit tests for all strategy foundation components
|
||||||
|
|
||||||
- [ ] 2.0 Strategy Configuration System
|
- [ ] 2.0 Strategy Configuration System
|
||||||
- [ ] 2.1 Create `config/strategies/` directory structure mirroring indicators configuration
|
- [ ] 2.1 Create `config/strategies/` directory structure mirroring indicators configuration
|
||||||
@ -77,7 +89,7 @@
|
|||||||
- [ ] 4.0 Strategy Data Integration
|
- [ ] 4.0 Strategy Data Integration
|
||||||
- [ ] 4.1 Create `StrategyDataIntegrator` class in new `strategies/data_integration.py` module
|
- [ ] 4.1 Create `StrategyDataIntegrator` class in new `strategies/data_integration.py` module
|
||||||
- [ ] 4.2 Implement data loading interface that leverages existing `TechnicalIndicators` class for indicator dependencies
|
- [ ] 4.2 Implement data loading interface that leverages existing `TechnicalIndicators` class for indicator dependencies
|
||||||
- [ ] 4.3 Add multi-timeframe data handling for strategies that require indicators from different timeframes
|
- [x] 4.3 Add multi-timeframe data handling for strategies that require indicators from different timeframes
|
||||||
- [ ] 4.4 Implement strategy calculation orchestration with proper indicator dependency resolution
|
- [ ] 4.4 Implement strategy calculation orchestration with proper indicator dependency resolution
|
||||||
- [ ] 4.5 Create caching layer for computed indicator results to avoid recalculation across strategies
|
- [ ] 4.5 Create caching layer for computed indicator results to avoid recalculation across strategies
|
||||||
- [ ] 4.6 Add strategy signal generation and validation pipeline
|
- [ ] 4.6 Add strategy signal generation and validation pipeline
|
||||||
|
|||||||
1
tests/strategies/__init__.py
Normal file
1
tests/strategies/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
162
tests/strategies/test_base_strategy.py
Normal file
162
tests/strategies/test_base_strategy.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from strategies.base import BaseStrategy
|
||||||
|
from strategies.data_types import StrategyResult, StrategySignal, SignalType
|
||||||
|
from data.common.data_types import OHLCVCandle
|
||||||
|
|
||||||
|
# Mock logger for testing
|
||||||
|
class MockLogger:
|
||||||
|
def __init__(self):
|
||||||
|
self.info_calls = []
|
||||||
|
self.warning_calls = []
|
||||||
|
self.error_calls = []
|
||||||
|
|
||||||
|
def info(self, message):
|
||||||
|
self.info_calls.append(message)
|
||||||
|
|
||||||
|
def warning(self, message):
|
||||||
|
self.warning_calls.append(message)
|
||||||
|
|
||||||
|
def error(self, message):
|
||||||
|
self.error_calls.append(message)
|
||||||
|
|
||||||
|
# Concrete implementation of BaseStrategy for testing purposes
|
||||||
|
class ConcreteStrategy(BaseStrategy):
|
||||||
|
def __init__(self, logger=None):
|
||||||
|
super().__init__("ConcreteStrategy", logger)
|
||||||
|
|
||||||
|
def get_required_indicators(self) -> list[dict]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def calculate(self, data: pd.DataFrame, **kwargs) -> list[StrategyResult]:
|
||||||
|
# Simple mock calculation for testing
|
||||||
|
signals = []
|
||||||
|
if not data.empty:
|
||||||
|
first_row = data.iloc[0]
|
||||||
|
signals.append(StrategyResult(
|
||||||
|
timestamp=first_row.name,
|
||||||
|
symbol=first_row['symbol'],
|
||||||
|
timeframe=first_row['timeframe'],
|
||||||
|
strategy_name=self.strategy_name,
|
||||||
|
signals=[StrategySignal(
|
||||||
|
timestamp=first_row.name,
|
||||||
|
symbol=first_row['symbol'],
|
||||||
|
timeframe=first_row['timeframe'],
|
||||||
|
signal_type=SignalType.BUY,
|
||||||
|
price=float(first_row['close']),
|
||||||
|
confidence=1.0
|
||||||
|
)],
|
||||||
|
indicators_used={}
|
||||||
|
))
|
||||||
|
return signals
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_logger():
|
||||||
|
return MockLogger()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def concrete_strategy(mock_logger):
|
||||||
|
return ConcreteStrategy(logger=mock_logger)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_ohlcv_data():
|
||||||
|
return pd.DataFrame({
|
||||||
|
'open': [100, 101, 102, 103, 104],
|
||||||
|
'high': [105, 106, 107, 108, 109],
|
||||||
|
'low': [99, 100, 101, 102, 103],
|
||||||
|
'close': [102, 103, 104, 105, 106],
|
||||||
|
'volume': [1000, 1100, 1200, 1300, 1400],
|
||||||
|
'symbol': ['BTC/USDT'] * 5,
|
||||||
|
'timeframe': ['1h'] * 5
|
||||||
|
}, index=pd.to_datetime(['2023-01-01 00:00:00', '2023-01-01 01:00:00', '2023-01-01 02:00:00',
|
||||||
|
'2023-01-01 03:00:00', '2023-01-01 04:00:00']))
|
||||||
|
|
||||||
|
def test_prepare_dataframe_initial_data(concrete_strategy, sample_ohlcv_data):
|
||||||
|
prepared_df = concrete_strategy.prepare_dataframe(sample_ohlcv_data)
|
||||||
|
assert 'open' in prepared_df.columns
|
||||||
|
assert 'high' in prepared_df.columns
|
||||||
|
assert 'low' in prepared_df.columns
|
||||||
|
assert 'close' in prepared_df.columns
|
||||||
|
assert 'volume' in prepared_df.columns
|
||||||
|
assert 'symbol' in prepared_df.columns
|
||||||
|
assert 'timeframe' in prepared_df.columns
|
||||||
|
assert prepared_df.index.name == 'timestamp'
|
||||||
|
assert prepared_df.index.is_monotonic_increasing
|
||||||
|
|
||||||
|
def test_prepare_dataframe_sparse_data(concrete_strategy, sample_ohlcv_data):
|
||||||
|
# Simulate sparse data by removing the middle row
|
||||||
|
sparse_df = sample_ohlcv_data.drop(sample_ohlcv_data.index[2])
|
||||||
|
prepared_df = concrete_strategy.prepare_dataframe(sparse_df)
|
||||||
|
assert len(prepared_df) == len(sample_ohlcv_data) # Should fill missing row with NaN
|
||||||
|
assert prepared_df.index[2] == sample_ohlcv_data.index[2] # Ensure timestamp is restored
|
||||||
|
assert pd.isna(prepared_df.loc[sample_ohlcv_data.index[2], 'open']) # Check for NaN in filled row
|
||||||
|
|
||||||
|
def test_validate_dataframe_valid(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||||
|
# Ensure no warnings/errors are logged for valid data
|
||||||
|
concrete_strategy.validate_dataframe(sample_ohlcv_data)
|
||||||
|
assert not mock_logger.warning_calls
|
||||||
|
assert not mock_logger.error_calls
|
||||||
|
|
||||||
|
def test_validate_dataframe_missing_column(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||||
|
invalid_df = sample_ohlcv_data.drop(columns=['open'])
|
||||||
|
with pytest.raises(ValueError, match="Missing required columns: \['open']"):
|
||||||
|
concrete_strategy.validate_dataframe(invalid_df)
|
||||||
|
|
||||||
|
def test_validate_dataframe_invalid_index(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||||
|
invalid_df = sample_ohlcv_data.reset_index()
|
||||||
|
with pytest.raises(ValueError, match="DataFrame index must be named 'timestamp' and be a DatetimeIndex."):
|
||||||
|
concrete_strategy.validate_dataframe(invalid_df)
|
||||||
|
|
||||||
|
def test_validate_dataframe_non_monotonic_index(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||||
|
# Reverse order to make it non-monotonic
|
||||||
|
invalid_df = sample_ohlcv_data.iloc[::-1]
|
||||||
|
with pytest.raises(ValueError, match="DataFrame index is not monotonically increasing."):
|
||||||
|
concrete_strategy.validate_dataframe(invalid_df)
|
||||||
|
|
||||||
|
def test_validate_indicators_data_valid(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||||
|
indicators_data = {
|
||||||
|
'ema_fast': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index),
|
||||||
|
'ema_slow': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index)
|
||||||
|
}
|
||||||
|
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
|
||||||
|
|
||||||
|
required_indicators = [
|
||||||
|
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
|
||||||
|
{'type': 'ema', 'period': 26, 'key': 'ema_slow'}
|
||||||
|
]
|
||||||
|
|
||||||
|
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
|
||||||
|
assert not mock_logger.warning_calls
|
||||||
|
assert not mock_logger.error_calls
|
||||||
|
|
||||||
|
def test_validate_indicators_data_missing_indicator(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||||
|
indicators_data = {
|
||||||
|
'ema_fast': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index),
|
||||||
|
}
|
||||||
|
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
|
||||||
|
|
||||||
|
required_indicators = [
|
||||||
|
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
|
||||||
|
{'type': 'ema', 'period': 26, 'key': 'ema_slow'} # Missing
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing required indicator data for key: ema_slow"):
|
||||||
|
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
|
||||||
|
|
||||||
|
def test_validate_indicators_data_nan_values(concrete_strategy, sample_ohlcv_data, mock_logger):
|
||||||
|
indicators_data = {
|
||||||
|
'ema_fast': pd.Series([101, 102, np.nan, 104, 105], index=sample_ohlcv_data.index),
|
||||||
|
'ema_slow': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index)
|
||||||
|
}
|
||||||
|
merged_df = pd.concat([sample_ohlcv_data, pd.DataFrame(indicators_data)], axis=1)
|
||||||
|
|
||||||
|
required_indicators = [
|
||||||
|
{'type': 'ema', 'period': 12, 'key': 'ema_fast'},
|
||||||
|
{'type': 'ema', 'period': 26, 'key': 'ema_slow'}
|
||||||
|
]
|
||||||
|
|
||||||
|
concrete_strategy.validate_indicators_data(merged_df, required_indicators)
|
||||||
|
assert "NaN values detected in required indicator data for key: ema_fast" in mock_logger.warning_calls
|
||||||
220
tests/strategies/test_strategy_factory.py
Normal file
220
tests/strategies/test_strategy_factory.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from strategies.factory import StrategyFactory
|
||||||
|
from strategies.base import BaseStrategy
|
||||||
|
from strategies.data_types import StrategyResult, StrategySignal, SignalType
|
||||||
|
from data.common.data_types import OHLCVCandle
|
||||||
|
from data.common.indicators import TechnicalIndicators # For mocking purposes
|
||||||
|
|
||||||
|
# Mock logger for testing
|
||||||
|
class MockLogger:
|
||||||
|
def __init__(self):
|
||||||
|
self.info_calls = []
|
||||||
|
self.warning_calls = []
|
||||||
|
self.error_calls = []
|
||||||
|
|
||||||
|
def info(self, message):
|
||||||
|
self.info_calls.append(message)
|
||||||
|
|
||||||
|
def warning(self, message):
|
||||||
|
self.warning_calls.append(message)
|
||||||
|
|
||||||
|
def error(self, message):
|
||||||
|
self.error_calls.append(message)
|
||||||
|
|
||||||
|
# Mock Concrete Strategy for testing StrategyFactory
|
||||||
|
class MockEMAStrategy(BaseStrategy):
|
||||||
|
def __init__(self, logger=None):
|
||||||
|
super().__init__("ema_crossover", logger)
|
||||||
|
self.calculate_calls = []
|
||||||
|
|
||||||
|
def get_required_indicators(self) -> list[dict]:
|
||||||
|
return [{'type': 'ema', 'period': 12}, {'type': 'ema', 'period': 26}]
|
||||||
|
|
||||||
|
def calculate(self, data: pd.DataFrame, **kwargs) -> list[StrategyResult]:
|
||||||
|
self.calculate_calls.append((data, kwargs))
|
||||||
|
# Simulate a signal for testing
|
||||||
|
if not data.empty:
|
||||||
|
first_row = data.iloc[0]
|
||||||
|
return [StrategyResult(
|
||||||
|
timestamp=first_row.name,
|
||||||
|
symbol=first_row['symbol'],
|
||||||
|
timeframe=first_row['timeframe'],
|
||||||
|
strategy_name=self.strategy_name,
|
||||||
|
signals=[StrategySignal(
|
||||||
|
timestamp=first_row.name,
|
||||||
|
symbol=first_row['symbol'],
|
||||||
|
timeframe=first_row['timeframe'],
|
||||||
|
signal_type=SignalType.BUY,
|
||||||
|
price=float(first_row['close']),
|
||||||
|
confidence=1.0
|
||||||
|
)],
|
||||||
|
indicators_used={}
|
||||||
|
)]
|
||||||
|
return []
|
||||||
|
|
||||||
|
class MockRSIStrategy(BaseStrategy):
|
||||||
|
def __init__(self, logger=None):
|
||||||
|
super().__init__("rsi", logger)
|
||||||
|
self.calculate_calls = []
|
||||||
|
|
||||||
|
def get_required_indicators(self) -> list[dict]:
|
||||||
|
return [{'type': 'rsi', 'period': 14}]
|
||||||
|
|
||||||
|
def calculate(self, data: pd.DataFrame, **kwargs) -> list[StrategyResult]:
|
||||||
|
self.calculate_calls.append((data, kwargs))
|
||||||
|
# Simulate a signal for testing
|
||||||
|
if not data.empty:
|
||||||
|
first_row = data.iloc[0]
|
||||||
|
return [StrategyResult(
|
||||||
|
timestamp=first_row.name,
|
||||||
|
symbol=first_row['symbol'],
|
||||||
|
timeframe=first_row['timeframe'],
|
||||||
|
strategy_name=self.strategy_name,
|
||||||
|
signals=[StrategySignal(
|
||||||
|
timestamp=first_row.name,
|
||||||
|
symbol=first_row['symbol'],
|
||||||
|
timeframe=first_row['timeframe'],
|
||||||
|
signal_type=SignalType.SELL,
|
||||||
|
price=float(first_row['close']),
|
||||||
|
confidence=0.9
|
||||||
|
)],
|
||||||
|
indicators_used={}
|
||||||
|
)]
|
||||||
|
return []
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_logger():
|
||||||
|
return MockLogger()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_technical_indicators():
|
||||||
|
mock_ti = MagicMock(spec=TechnicalIndicators)
|
||||||
|
# Configure the mock to return dummy data for indicators
|
||||||
|
def mock_calculate(indicator_type, df, **kwargs):
|
||||||
|
if indicator_type == 'ema':
|
||||||
|
# Simulate EMA data
|
||||||
|
return pd.DataFrame({
|
||||||
|
'ema_fast': df['close'] * 1.02,
|
||||||
|
'ema_slow': df['close'] * 0.98
|
||||||
|
}, index=df.index)
|
||||||
|
elif indicator_type == 'rsi':
|
||||||
|
# Simulate RSI data
|
||||||
|
return pd.DataFrame({
|
||||||
|
'rsi': pd.Series([60, 65, 72, 28, 35], index=df.index)
|
||||||
|
}, index=df.index)
|
||||||
|
return pd.DataFrame(index=df.index)
|
||||||
|
|
||||||
|
mock_ti.calculate.side_effect = mock_calculate
|
||||||
|
return mock_ti
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def strategy_factory(mock_technical_indicators, mock_logger, monkeypatch):
|
||||||
|
# Patch the strategy factory to use our mock strategies
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"strategies.factory.StrategyFactory._STRATEGIES",
|
||||||
|
{
|
||||||
|
"ema_crossover": MockEMAStrategy,
|
||||||
|
"rsi": MockRSIStrategy,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return StrategyFactory(mock_technical_indicators, mock_logger)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_ohlcv_data():
|
||||||
|
return pd.DataFrame({
|
||||||
|
'open': [100, 101, 102, 103, 104],
|
||||||
|
'high': [105, 106, 107, 108, 109],
|
||||||
|
'low': [99, 100, 101, 102, 103],
|
||||||
|
'close': [102, 103, 104, 105, 106],
|
||||||
|
'volume': [1000, 1100, 1200, 1300, 1400],
|
||||||
|
'symbol': ['BTC/USDT'] * 5,
|
||||||
|
'timeframe': ['1h'] * 5
|
||||||
|
}, index=pd.to_datetime(['2023-01-01 00:00:00', '2023-01-01 01:00:00', '2023-01-01 02:00:00',
|
||||||
|
'2023-01-01 03:00:00', '2023-01-01 04:00:00']))
|
||||||
|
|
||||||
|
def test_get_available_strategies(strategy_factory):
|
||||||
|
available_strategies = strategy_factory.get_available_strategies()
|
||||||
|
assert "ema_crossover" in available_strategies
|
||||||
|
assert "rsi" in available_strategies
|
||||||
|
assert "macd" not in available_strategies # Should not be present if not mocked
|
||||||
|
|
||||||
|
def test_create_strategy_success(strategy_factory):
|
||||||
|
ema_strategy = strategy_factory.create_strategy("ema_crossover")
|
||||||
|
assert isinstance(ema_strategy, MockEMAStrategy)
|
||||||
|
assert ema_strategy.strategy_name == "ema_crossover"
|
||||||
|
|
||||||
|
def test_create_strategy_unknown(strategy_factory):
|
||||||
|
with pytest.raises(ValueError, match="Unknown strategy type: unknown_strategy"):
|
||||||
|
strategy_factory.create_strategy("unknown_strategy")
|
||||||
|
|
||||||
|
def test_calculate_multiple_strategies_success(strategy_factory, sample_ohlcv_data, mock_technical_indicators):
|
||||||
|
strategy_configs = [
|
||||||
|
{"strategy": "ema_crossover", "fast_period": 12, "slow_period": 26},
|
||||||
|
{"strategy": "rsi", "period": 14, "overbought": 70, "oversold": 30}
|
||||||
|
]
|
||||||
|
|
||||||
|
all_strategy_results = strategy_factory.calculate_multiple_strategies(
|
||||||
|
strategy_configs, sample_ohlcv_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(all_strategy_results) == 2 # Expect results for both strategies
|
||||||
|
assert "ema_crossover" in all_strategy_results
|
||||||
|
assert "rsi" in all_strategy_results
|
||||||
|
|
||||||
|
ema_results = all_strategy_results["ema_crossover"]
|
||||||
|
rsi_results = all_strategy_results["rsi"]
|
||||||
|
|
||||||
|
assert len(ema_results) > 0
|
||||||
|
assert ema_results[0].strategy_name == "ema_crossover"
|
||||||
|
assert len(rsi_results) > 0
|
||||||
|
assert rsi_results[0].strategy_name == "rsi"
|
||||||
|
|
||||||
|
# Verify that TechnicalIndicators.calculate was called with correct arguments
|
||||||
|
# EMA calls
|
||||||
|
ema_calls = [call for call in mock_technical_indicators.calculate.call_args_list if call.args[0] == 'ema']
|
||||||
|
assert len(ema_calls) == 2 # Two EMA indicators for ema_crossover strategy
|
||||||
|
assert ema_calls[0].kwargs['period'] == 12 or ema_calls[0].kwargs['period'] == 26
|
||||||
|
assert ema_calls[1].kwargs['period'] == 12 or ema_calls[1].kwargs['period'] == 26
|
||||||
|
|
||||||
|
# RSI calls
|
||||||
|
rsi_calls = [call for call in mock_technical_indicators.calculate.call_args_list if call.args[0] == 'rsi']
|
||||||
|
assert len(rsi_calls) == 1 # One RSI indicator for rsi strategy
|
||||||
|
assert rsi_calls[0].kwargs['period'] == 14
|
||||||
|
|
||||||
|
def test_calculate_multiple_strategies_no_configs(strategy_factory, sample_ohlcv_data):
|
||||||
|
results = strategy_factory.calculate_multiple_strategies([], sample_ohlcv_data)
|
||||||
|
assert not results
|
||||||
|
|
||||||
|
def test_calculate_multiple_strategies_empty_data(strategy_factory, mock_technical_indicators):
|
||||||
|
strategy_configs = [
|
||||||
|
{"strategy": "ema_crossover", "fast_period": 12, "slow_period": 26}
|
||||||
|
]
|
||||||
|
empty_df = pd.DataFrame(columns=['open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe'])
|
||||||
|
results = strategy_factory.calculate_multiple_strategies(strategy_configs, empty_df)
|
||||||
|
assert not results
|
||||||
|
|
||||||
|
def test_calculate_multiple_strategies_missing_indicator_data(strategy_factory, sample_ohlcv_data, mock_logger, mock_technical_indicators):
|
||||||
|
# Simulate a scenario where an indicator is requested but not returned by TechnicalIndicators
|
||||||
|
def mock_calculate_no_ema(indicator_type, df, **kwargs):
|
||||||
|
if indicator_type == 'ema':
|
||||||
|
return pd.DataFrame(index=df.index) # Simulate no EMA data returned
|
||||||
|
elif indicator_type == 'rsi':
|
||||||
|
return pd.DataFrame({'rsi': df['close']}, index=df.index)
|
||||||
|
return pd.DataFrame(index=df.index)
|
||||||
|
|
||||||
|
mock_technical_indicators.calculate.side_effect = mock_calculate_no_ema
|
||||||
|
|
||||||
|
strategy_configs = [
|
||||||
|
{"strategy": "ema_crossover", "fast_period": 12, "slow_period": 26}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = strategy_factory.calculate_multiple_strategies(
|
||||||
|
strategy_configs, sample_ohlcv_data
|
||||||
|
)
|
||||||
|
assert not results # Expect no results if indicators are missing
|
||||||
|
assert "Missing required indicator data for key: ema_period_12" in mock_logger.error_calls or \
|
||||||
|
"Missing required indicator data for key: ema_period_26" in mock_logger.error_calls
|
||||||
Loading…
x
Reference in New Issue
Block a user