166 lines
6.2 KiB
Python
Raw Normal View History

"""
Trading Strategy Utilities
This module provides utility functions for managing trading strategy
configurations and validation.
"""
from typing import Dict, Any, List, Tuple
import pandas as pd
import numpy as np
def create_default_strategies_config() -> Dict[str, Dict[str, Any]]:
"""
Create default configuration for common trading strategies.
Returns:
Dictionary with default strategy configurations
"""
return {
'ema_crossover_default': {'strategy': 'ema_crossover', 'fast_period': 12, 'slow_period': 26},
'rsi_default': {'strategy': 'rsi', 'period': 14, 'overbought': 70, 'oversold': 30},
'macd_default': {'strategy': 'macd', 'fast_period': 12, 'slow_period': 26, 'signal_period': 9}
}
def create_indicator_key(indicator_config: Dict[str, Any]) -> str:
"""
Create a unique key for an indicator configuration.
This function is shared between StrategyFactory and individual strategies
to ensure consistency and reduce duplication.
Args:
indicator_config: Indicator configuration
Returns:
Unique string key for this indicator configuration
"""
indicator_type = indicator_config.get('type', 'unknown')
# Common parameters that should be part of the key
key_params = []
for param in ['period', 'fast_period', 'slow_period', 'signal_period', 'std_dev']:
if param in indicator_config:
key_params.append(f"{param}_{indicator_config[param]}")
if key_params:
return f"{indicator_type}_{'_'.join(key_params)}"
else:
return f"{indicator_type}_default"
def detect_crossover_signals_vectorized(df: pd.DataFrame,
fast_col: str,
slow_col: str) -> Tuple[pd.Series, pd.Series]:
"""
Detect crossover signals using vectorized operations for better performance.
Args:
df: DataFrame with indicator data
fast_col: Column name for fast line (e.g., 'ema_fast')
slow_col: Column name for slow line (e.g., 'ema_slow')
Returns:
Tuple of (bullish_crossover_mask, bearish_crossover_mask) as boolean Series
"""
# Get previous values using shift
fast_prev = df[fast_col].shift(1)
slow_prev = df[slow_col].shift(1)
fast_curr = df[fast_col]
slow_curr = df[slow_col]
# Bullish crossover: fast was <= slow, now fast > slow
bullish_crossover = (fast_prev <= slow_prev) & (fast_curr > slow_curr)
# Bearish crossover: fast was >= slow, now fast < slow
bearish_crossover = (fast_prev >= slow_prev) & (fast_curr < slow_curr)
# Ensure no signals on first row (index 0) where there's no previous value
bullish_crossover.iloc[0] = False
bearish_crossover.iloc[0] = False
return bullish_crossover, bearish_crossover
def detect_threshold_signals_vectorized(df: pd.DataFrame,
indicator_col: str,
upper_threshold: float,
lower_threshold: float) -> Tuple[pd.Series, pd.Series]:
"""
Detect threshold crossing signals using vectorized operations (for RSI, etc.).
Args:
df: DataFrame with indicator data
indicator_col: Column name for indicator (e.g., 'rsi')
upper_threshold: Upper threshold (e.g., 70 for RSI overbought)
lower_threshold: Lower threshold (e.g., 30 for RSI oversold)
Returns:
Tuple of (buy_signal_mask, sell_signal_mask) as boolean Series
"""
# Get previous values using shift
indicator_prev = df[indicator_col].shift(1)
indicator_curr = df[indicator_col]
# Buy signal: was <= lower_threshold, now > lower_threshold (oversold to normal)
buy_signal = (indicator_prev <= lower_threshold) & (indicator_curr > lower_threshold)
# Sell signal: was >= upper_threshold, now < upper_threshold (overbought to normal)
sell_signal = (indicator_prev >= upper_threshold) & (indicator_curr < upper_threshold)
# Ensure no signals on first row (index 0) where there's no previous value
buy_signal.iloc[0] = False
sell_signal.iloc[0] = False
return buy_signal, sell_signal
def validate_strategy_config(config: Dict[str, Any]) -> bool:
"""
Validate trading strategy configuration.
Args:
config: Strategy configuration dictionary
Returns:
True if configuration is valid, False otherwise
"""
required_fields = ['strategy']
# Check required fields
for field in required_fields:
if field not in config:
return False
# Validate strategy type
valid_types = ['ema_crossover', 'rsi', 'macd']
if config['strategy'] not in valid_types:
return False
# Basic validation for common parameters based on strategy type
strategy_type = config['strategy']
if strategy_type == 'ema_crossover':
if not all(k in config for k in ['fast_period', 'slow_period']):
return False
if not (isinstance(config['fast_period'], int) and config['fast_period'] > 0 and
isinstance(config['slow_period'], int) and config['slow_period'] > 0 and
config['fast_period'] < config['slow_period']):
return False
elif strategy_type == 'rsi':
if not all(k in config for k in ['period', 'overbought', 'oversold']):
return False
if not (isinstance(config['period'], int) and config['period'] > 0 and
isinstance(config['overbought'], (int, float)) and isinstance(config['oversold'], (int, float)) and
0 <= config['oversold'] < config['overbought'] <= 100):
return False
elif strategy_type == 'macd':
if not all(k in config for k in ['fast_period', 'slow_period', 'signal_period']):
return False
if not (isinstance(config['fast_period'], int) and config['fast_period'] > 0 and
isinstance(config['slow_period'], int) and config['slow_period'] > 0 and
isinstance(config['signal_period'], int) and config['signal_period'] > 0 and
config['fast_period'] < config['slow_period']):
return False
return True