- Introduced a new `strategies` package containing the core structure for trading strategies, including `BaseStrategy`, `StrategyFactory`, and various strategy implementations (EMA, RSI, MACD). - Added utility functions for signal detection and validation in `strategies/utils.py`, enhancing modularity and maintainability. - Updated `pyproject.toml` to include the new `strategies` package in the build configuration. - Implemented comprehensive unit tests for the strategy foundation components, ensuring reliability and adherence to project standards. These changes establish a solid foundation for the strategy engine, aligning with project goals for modularity, performance, and maintainability.
166 lines
6.2 KiB
Python
166 lines
6.2 KiB
Python
"""
|
|
Trading Strategy Utilities
|
|
|
|
This module provides utility functions for managing trading strategy
|
|
configurations and validation.
|
|
"""
|
|
|
|
from typing import Dict, Any, List, Tuple
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
|
|
def create_default_strategies_config() -> Dict[str, Dict[str, Any]]:
|
|
"""
|
|
Create default configuration for common trading strategies.
|
|
|
|
Returns:
|
|
Dictionary with default strategy configurations
|
|
"""
|
|
return {
|
|
'ema_crossover_default': {'strategy': 'ema_crossover', 'fast_period': 12, 'slow_period': 26},
|
|
'rsi_default': {'strategy': 'rsi', 'period': 14, 'overbought': 70, 'oversold': 30},
|
|
'macd_default': {'strategy': 'macd', 'fast_period': 12, 'slow_period': 26, 'signal_period': 9}
|
|
}
|
|
|
|
|
|
def create_indicator_key(indicator_config: Dict[str, Any]) -> str:
|
|
"""
|
|
Create a unique key for an indicator configuration.
|
|
This function is shared between StrategyFactory and individual strategies
|
|
to ensure consistency and reduce duplication.
|
|
|
|
Args:
|
|
indicator_config: Indicator configuration
|
|
|
|
Returns:
|
|
Unique string key for this indicator configuration
|
|
"""
|
|
indicator_type = indicator_config.get('type', 'unknown')
|
|
|
|
# Common parameters that should be part of the key
|
|
key_params = []
|
|
for param in ['period', 'fast_period', 'slow_period', 'signal_period', 'std_dev']:
|
|
if param in indicator_config:
|
|
key_params.append(f"{param}_{indicator_config[param]}")
|
|
|
|
if key_params:
|
|
return f"{indicator_type}_{'_'.join(key_params)}"
|
|
else:
|
|
return f"{indicator_type}_default"
|
|
|
|
|
|
def detect_crossover_signals_vectorized(df: pd.DataFrame,
|
|
fast_col: str,
|
|
slow_col: str) -> Tuple[pd.Series, pd.Series]:
|
|
"""
|
|
Detect crossover signals using vectorized operations for better performance.
|
|
|
|
Args:
|
|
df: DataFrame with indicator data
|
|
fast_col: Column name for fast line (e.g., 'ema_fast')
|
|
slow_col: Column name for slow line (e.g., 'ema_slow')
|
|
|
|
Returns:
|
|
Tuple of (bullish_crossover_mask, bearish_crossover_mask) as boolean Series
|
|
"""
|
|
# Get previous values using shift
|
|
fast_prev = df[fast_col].shift(1)
|
|
slow_prev = df[slow_col].shift(1)
|
|
fast_curr = df[fast_col]
|
|
slow_curr = df[slow_col]
|
|
|
|
# Bullish crossover: fast was <= slow, now fast > slow
|
|
bullish_crossover = (fast_prev <= slow_prev) & (fast_curr > slow_curr)
|
|
|
|
# Bearish crossover: fast was >= slow, now fast < slow
|
|
bearish_crossover = (fast_prev >= slow_prev) & (fast_curr < slow_curr)
|
|
|
|
# Ensure no signals on first row (index 0) where there's no previous value
|
|
bullish_crossover.iloc[0] = False
|
|
bearish_crossover.iloc[0] = False
|
|
|
|
return bullish_crossover, bearish_crossover
|
|
|
|
|
|
def detect_threshold_signals_vectorized(df: pd.DataFrame,
|
|
indicator_col: str,
|
|
upper_threshold: float,
|
|
lower_threshold: float) -> Tuple[pd.Series, pd.Series]:
|
|
"""
|
|
Detect threshold crossing signals using vectorized operations (for RSI, etc.).
|
|
|
|
Args:
|
|
df: DataFrame with indicator data
|
|
indicator_col: Column name for indicator (e.g., 'rsi')
|
|
upper_threshold: Upper threshold (e.g., 70 for RSI overbought)
|
|
lower_threshold: Lower threshold (e.g., 30 for RSI oversold)
|
|
|
|
Returns:
|
|
Tuple of (buy_signal_mask, sell_signal_mask) as boolean Series
|
|
"""
|
|
# Get previous values using shift
|
|
indicator_prev = df[indicator_col].shift(1)
|
|
indicator_curr = df[indicator_col]
|
|
|
|
# Buy signal: was <= lower_threshold, now > lower_threshold (oversold to normal)
|
|
buy_signal = (indicator_prev <= lower_threshold) & (indicator_curr > lower_threshold)
|
|
|
|
# Sell signal: was >= upper_threshold, now < upper_threshold (overbought to normal)
|
|
sell_signal = (indicator_prev >= upper_threshold) & (indicator_curr < upper_threshold)
|
|
|
|
# Ensure no signals on first row (index 0) where there's no previous value
|
|
buy_signal.iloc[0] = False
|
|
sell_signal.iloc[0] = False
|
|
|
|
return buy_signal, sell_signal
|
|
|
|
|
|
def validate_strategy_config(config: Dict[str, Any]) -> bool:
|
|
"""
|
|
Validate trading strategy configuration.
|
|
|
|
Args:
|
|
config: Strategy configuration dictionary
|
|
|
|
Returns:
|
|
True if configuration is valid, False otherwise
|
|
"""
|
|
required_fields = ['strategy']
|
|
|
|
# Check required fields
|
|
for field in required_fields:
|
|
if field not in config:
|
|
return False
|
|
|
|
# Validate strategy type
|
|
valid_types = ['ema_crossover', 'rsi', 'macd']
|
|
if config['strategy'] not in valid_types:
|
|
return False
|
|
|
|
# Basic validation for common parameters based on strategy type
|
|
strategy_type = config['strategy']
|
|
if strategy_type == 'ema_crossover':
|
|
if not all(k in config for k in ['fast_period', 'slow_period']):
|
|
return False
|
|
if not (isinstance(config['fast_period'], int) and config['fast_period'] > 0 and
|
|
isinstance(config['slow_period'], int) and config['slow_period'] > 0 and
|
|
config['fast_period'] < config['slow_period']):
|
|
return False
|
|
elif strategy_type == 'rsi':
|
|
if not all(k in config for k in ['period', 'overbought', 'oversold']):
|
|
return False
|
|
if not (isinstance(config['period'], int) and config['period'] > 0 and
|
|
isinstance(config['overbought'], (int, float)) and isinstance(config['oversold'], (int, float)) and
|
|
0 <= config['oversold'] < config['overbought'] <= 100):
|
|
return False
|
|
elif strategy_type == 'macd':
|
|
if not all(k in config for k in ['fast_period', 'slow_period', 'signal_period']):
|
|
return False
|
|
if not (isinstance(config['fast_period'], int) and config['fast_period'] > 0 and
|
|
isinstance(config['slow_period'], int) and config['slow_period'] > 0 and
|
|
isinstance(config['signal_period'], int) and config['signal_period'] > 0 and
|
|
config['fast_period'] < config['slow_period']):
|
|
return False
|
|
|
|
return True |