4.0 - 1.0 Implement strategy engine foundation with modular components
- Introduced a new `strategies` package containing the core structure for trading strategies, including `BaseStrategy`, `StrategyFactory`, and various strategy implementations (EMA, RSI, MACD). - Added utility functions for signal detection and validation in `strategies/utils.py`, enhancing modularity and maintainability. - Updated `pyproject.toml` to include the new `strategies` package in the build configuration. - Implemented comprehensive unit tests for the strategy foundation components, ensuring reliability and adherence to project standards. These changes establish a solid foundation for the strategy engine, aligning with project goals for modularity, performance, and maintainability.
This commit is contained in:
26
strategies/__init__.py
Normal file
26
strategies/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Strategy Engine Package
|
||||
|
||||
This package provides strategy calculation and signal generation capabilities
|
||||
optimized for the TCP Trading Platform's market data and technical indicators.
|
||||
|
||||
IMPORTANT: Mirrors Indicator Patterns
|
||||
- Follows the same architecture as data/common/indicators/
|
||||
- Uses BaseStrategy abstract class pattern
|
||||
- Implements factory pattern for dynamic strategy loading
|
||||
- Supports JSON-based configuration management
|
||||
- Integrates with existing technical indicators system
|
||||
"""
|
||||
|
||||
from .base import BaseStrategy
|
||||
from .factory import StrategyFactory
|
||||
from .data_types import StrategySignal, SignalType, StrategyResult
|
||||
|
||||
# Note: Strategy implementations and manager will be added in next iterations
|
||||
__all__ = [
|
||||
'BaseStrategy',
|
||||
'StrategyFactory',
|
||||
'StrategySignal',
|
||||
'SignalType',
|
||||
'StrategyResult'
|
||||
]
|
||||
150
strategies/base.py
Normal file
150
strategies/base.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Base classes and interfaces for trading strategies.
|
||||
|
||||
This module provides the foundation for all trading strategies
|
||||
with common functionality and type definitions.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
import pandas as pd
|
||||
from utils.logger import get_logger
|
||||
|
||||
from .data_types import StrategyResult
|
||||
from data.common.data_types import OHLCVCandle
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""
|
||||
Abstract base class for all trading strategies.
|
||||
|
||||
Provides common functionality and enforces consistent interface
|
||||
across all strategy implementations.
|
||||
"""
|
||||
|
||||
def __init__(self, logger=None):
|
||||
"""
|
||||
Initialize base strategy.
|
||||
|
||||
Args:
|
||||
logger: Optional logger instance
|
||||
"""
|
||||
if logger is None:
|
||||
self.logger = get_logger(__name__)
|
||||
self.logger = logger
|
||||
|
||||
def prepare_dataframe(self, candles: List[OHLCVCandle]) -> pd.DataFrame:
|
||||
"""
|
||||
Convert OHLCV candles to pandas DataFrame for calculations.
|
||||
|
||||
Args:
|
||||
candles: List of OHLCV candles (can be sparse)
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data, sorted by timestamp
|
||||
"""
|
||||
if not candles:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Convert to DataFrame
|
||||
data = []
|
||||
for candle in candles:
|
||||
data.append({
|
||||
'timestamp': candle.end_time, # Right-aligned timestamp
|
||||
'symbol': candle.symbol,
|
||||
'timeframe': candle.timeframe,
|
||||
'open': float(candle.open),
|
||||
'high': float(candle.high),
|
||||
'low': float(candle.low),
|
||||
'close': float(candle.close),
|
||||
'volume': float(candle.volume),
|
||||
'trade_count': candle.trade_count
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Sort by timestamp to ensure proper order
|
||||
df = df.sort_values('timestamp').reset_index(drop=True)
|
||||
|
||||
# Set timestamp as index for time-series operations
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
||||
|
||||
# Set as index, but keep as column
|
||||
df.set_index('timestamp', inplace=True)
|
||||
|
||||
# Ensure it's datetime
|
||||
df['timestamp'] = df.index
|
||||
|
||||
return df
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, df: pd.DataFrame, indicators_data: Dict[str, pd.DataFrame], **kwargs) -> List[StrategyResult]:
|
||||
"""
|
||||
Calculate the strategy signals.
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV data
|
||||
indicators_data: Dictionary of pre-calculated indicator DataFrames
|
||||
**kwargs: Additional parameters specific to each strategy
|
||||
|
||||
Returns:
|
||||
List of strategy results with signals
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_required_indicators(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get list of indicators required by this strategy.
|
||||
|
||||
Returns:
|
||||
List of indicator configurations needed for strategy calculation
|
||||
Format: [{'type': 'sma', 'period': 20}, {'type': 'ema', 'period': 12}]
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_dataframe(self, df: pd.DataFrame, min_periods: int) -> bool:
|
||||
"""
|
||||
Validate that DataFrame has sufficient data for calculation.
|
||||
|
||||
Args:
|
||||
df: DataFrame to validate
|
||||
min_periods: Minimum number of periods required
|
||||
|
||||
Returns:
|
||||
True if DataFrame is valid, False otherwise
|
||||
"""
|
||||
if df.empty or len(df) < min_periods:
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"Insufficient data: got {len(df)} periods, need {min_periods}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_indicators_data(self, indicators_data: Dict[str, pd.DataFrame],
|
||||
required_indicators: List[Dict[str, Any]]) -> bool:
|
||||
"""
|
||||
Validate that all required indicators are present and have sufficient data.
|
||||
|
||||
Args:
|
||||
indicators_data: Dictionary of indicator DataFrames
|
||||
required_indicators: List of required indicator configurations
|
||||
|
||||
Returns:
|
||||
True if all required indicators are available, False otherwise
|
||||
"""
|
||||
for indicator_config in required_indicators:
|
||||
indicator_key = f"{indicator_config['type']}_{indicator_config.get('period', 'default')}"
|
||||
|
||||
if indicator_key not in indicators_data:
|
||||
if self.logger:
|
||||
self.logger.warning(f"Missing required indicator: {indicator_key}")
|
||||
return False
|
||||
|
||||
if indicators_data[indicator_key].empty:
|
||||
if self.logger:
|
||||
self.logger.warning(f"Empty data for indicator: {indicator_key}")
|
||||
return False
|
||||
|
||||
return True
|
||||
72
strategies/data_types.py
Normal file
72
strategies/data_types.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Strategy Data Types and Signal Definitions
|
||||
|
||||
This module provides data types for strategy calculations, signals,
|
||||
and results in a standardized format.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Any, List
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SignalType(str, Enum):
|
||||
"""
|
||||
Types of trading signals that strategies can generate.
|
||||
"""
|
||||
BUY = "buy"
|
||||
SELL = "sell"
|
||||
HOLD = "hold"
|
||||
ENTRY_LONG = "entry_long"
|
||||
EXIT_LONG = "exit_long"
|
||||
ENTRY_SHORT = "entry_short"
|
||||
EXIT_SHORT = "exit_short"
|
||||
STOP_LOSS = "stop_loss"
|
||||
TAKE_PROFIT = "take_profit"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategySignal:
|
||||
"""
|
||||
Container for individual strategy signals.
|
||||
|
||||
Attributes:
|
||||
timestamp: Signal timestamp (right-aligned with candle)
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
signal_type: Type of signal (buy/sell/hold etc.)
|
||||
price: Price at which signal was generated
|
||||
confidence: Signal confidence score (0.0 to 1.0)
|
||||
metadata: Additional signal metadata (e.g., indicator values)
|
||||
"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
timeframe: str
|
||||
signal_type: SignalType
|
||||
price: float
|
||||
confidence: float = 1.0
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyResult:
|
||||
"""
|
||||
Container for strategy calculation results.
|
||||
|
||||
Attributes:
|
||||
timestamp: Candle timestamp (right-aligned)
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
strategy_name: Name of the strategy that generated this result
|
||||
signals: List of signals generated for this timestamp
|
||||
indicators_used: Dictionary of indicator values used in calculation
|
||||
metadata: Additional calculation metadata
|
||||
"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
timeframe: str
|
||||
strategy_name: str
|
||||
signals: List[StrategySignal]
|
||||
indicators_used: Dict[str, float]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
247
strategies/factory.py
Normal file
247
strategies/factory.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Strategy Factory Module for Strategy Management
|
||||
|
||||
This module provides strategy calculation and signal generation capabilities
|
||||
designed to work with the TCP Trading Platform's technical indicators and market data.
|
||||
|
||||
IMPORTANT: Strategy-Indicator Integration
|
||||
- Strategies consume pre-calculated technical indicators
|
||||
- Uses BaseStrategy abstract class pattern for consistency
|
||||
- Supports dynamic strategy loading and registration
|
||||
- Follows right-aligned timestamp convention
|
||||
- Integrates with existing database and configuration systems
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
from data.common.data_types import OHLCVCandle
|
||||
from data.common.indicators import TechnicalIndicators
|
||||
from .base import BaseStrategy
|
||||
from .data_types import StrategyResult
|
||||
from .utils import create_indicator_key
|
||||
from .implementations.ema_crossover import EMAStrategy
|
||||
from .implementations.rsi import RSIStrategy
|
||||
from .implementations.macd import MACDStrategy
|
||||
# Strategy implementations will be imported as they are created
|
||||
# from .implementations import (
|
||||
# EMAStrategy,
|
||||
# RSIStrategy,
|
||||
# MACDStrategy
|
||||
# )
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""
|
||||
Strategy factory for creating and managing trading strategies.
|
||||
|
||||
This class provides strategy instantiation, calculation orchestration,
|
||||
and signal generation. It integrates with the TechnicalIndicators
|
||||
system to provide pre-calculated indicator data to strategies.
|
||||
|
||||
STRATEGY-INDICATOR INTEGRATION:
|
||||
- Strategies declare required indicators via get_required_indicators()
|
||||
- Factory pre-calculates all required indicators
|
||||
- Strategies receive indicator data as dictionary of DataFrames
|
||||
- Results maintain original timestamp alignment
|
||||
"""
|
||||
|
||||
def __init__(self, logger=None):
|
||||
"""
|
||||
Initialize strategy factory.
|
||||
|
||||
Args:
|
||||
logger: Optional logger instance
|
||||
"""
|
||||
self.logger = logger
|
||||
self.technical_indicators = TechnicalIndicators(logger)
|
||||
|
||||
# Registry of available strategies (will be populated as strategies are implemented)
|
||||
self._strategy_registry = {
|
||||
'ema_crossover': EMAStrategy,
|
||||
'rsi': RSIStrategy,
|
||||
'macd': MACDStrategy
|
||||
}
|
||||
|
||||
if self.logger:
|
||||
self.logger.info("StrategyFactory: Initialized strategy factory")
|
||||
|
||||
def register_strategy(self, name: str, strategy_class: type) -> None:
|
||||
"""
|
||||
Register a new strategy class in the factory.
|
||||
|
||||
Args:
|
||||
name: Strategy name identifier
|
||||
strategy_class: Strategy class (must inherit from BaseStrategy)
|
||||
"""
|
||||
if not issubclass(strategy_class, BaseStrategy):
|
||||
raise ValueError(f"Strategy class {strategy_class} must inherit from BaseStrategy")
|
||||
|
||||
self._strategy_registry[name] = strategy_class
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"StrategyFactory: Registered strategy '{name}'")
|
||||
|
||||
def get_available_strategies(self) -> List[str]:
|
||||
"""
|
||||
Get list of available strategy names.
|
||||
|
||||
Returns:
|
||||
List of registered strategy names
|
||||
"""
|
||||
return list(self._strategy_registry.keys())
|
||||
|
||||
def create_strategy(self, strategy_name: str) -> Optional[BaseStrategy]:
|
||||
"""
|
||||
Create a strategy instance by name.
|
||||
|
||||
Args:
|
||||
strategy_name: Name of the strategy to create
|
||||
|
||||
Returns:
|
||||
Strategy instance or None if strategy not found
|
||||
"""
|
||||
strategy_class = self._strategy_registry.get(strategy_name)
|
||||
if not strategy_class:
|
||||
if self.logger:
|
||||
self.logger.error(f"Unknown strategy: {strategy_name}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return strategy_class(logger=self.logger)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error creating strategy {strategy_name}: {e}")
|
||||
return None
|
||||
|
||||
def calculate_strategy_signals(self, strategy_name: str, df: pd.DataFrame,
|
||||
strategy_config: Dict[str, Any]) -> List[StrategyResult]:
|
||||
"""
|
||||
Calculate signals for a specific strategy.
|
||||
|
||||
Args:
|
||||
strategy_name: Name of the strategy to execute
|
||||
df: DataFrame with OHLCV data
|
||||
strategy_config: Strategy-specific configuration parameters
|
||||
|
||||
Returns:
|
||||
List of strategy results with signals
|
||||
"""
|
||||
# Create strategy instance
|
||||
strategy = self.create_strategy(strategy_name)
|
||||
if not strategy:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Get required indicators for this strategy
|
||||
required_indicators = strategy.get_required_indicators()
|
||||
|
||||
# Pre-calculate all required indicators
|
||||
indicators_data = self._calculate_required_indicators(df, required_indicators)
|
||||
|
||||
# Calculate strategy signals
|
||||
results = strategy.calculate(df, indicators_data, **strategy_config)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error calculating strategy {strategy_name}: {e}")
|
||||
return []
|
||||
|
||||
def calculate_multiple_strategies(self, df: pd.DataFrame,
|
||||
strategies_config: Dict[str, Dict[str, Any]]) -> Dict[str, List[StrategyResult]]:
|
||||
"""
|
||||
Calculate signals for multiple strategies efficiently.
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV data
|
||||
strategies_config: Configuration for strategies to calculate
|
||||
Example: {
|
||||
'ema_cross_1': {'strategy': 'ema_crossover', 'fast_period': 12, 'slow_period': 26},
|
||||
'rsi_momentum': {'strategy': 'rsi', 'period': 14, 'oversold': 30, 'overbought': 70}
|
||||
}
|
||||
|
||||
Returns:
|
||||
Dictionary mapping strategy instance names to their results
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for strategy_instance_name, config in strategies_config.items():
|
||||
strategy_name = config.get('strategy')
|
||||
if not strategy_name:
|
||||
if self.logger:
|
||||
self.logger.warning(f"No strategy specified for {strategy_instance_name}")
|
||||
results[strategy_instance_name] = []
|
||||
continue
|
||||
|
||||
# Extract strategy parameters (exclude 'strategy' key)
|
||||
strategy_params = {k: v for k, v in config.items() if k != 'strategy'}
|
||||
|
||||
try:
|
||||
strategy_results = self.calculate_strategy_signals(
|
||||
strategy_name, df, strategy_params
|
||||
)
|
||||
results[strategy_instance_name] = strategy_results
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error calculating strategy {strategy_instance_name}: {e}")
|
||||
results[strategy_instance_name] = []
|
||||
|
||||
return results
|
||||
|
||||
def _calculate_required_indicators(self, df: pd.DataFrame,
|
||||
required_indicators: List[Dict[str, Any]]) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Pre-calculate all indicators required by a strategy.
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV data
|
||||
required_indicators: List of indicator configurations
|
||||
|
||||
Returns:
|
||||
Dictionary of indicator DataFrames keyed by indicator name
|
||||
"""
|
||||
indicators_data = {}
|
||||
|
||||
for indicator_config in required_indicators:
|
||||
indicator_type = indicator_config.get('type')
|
||||
if not indicator_type:
|
||||
continue
|
||||
|
||||
# Create a unique key for this indicator configuration
|
||||
indicator_key = self._create_indicator_key(indicator_config)
|
||||
|
||||
try:
|
||||
# Calculate the indicator using TechnicalIndicators
|
||||
indicator_result = self.technical_indicators.calculate(
|
||||
indicator_type, df, **{k: v for k, v in indicator_config.items() if k != 'type'}
|
||||
)
|
||||
|
||||
if indicator_result is not None and not indicator_result.empty:
|
||||
indicators_data[indicator_key] = indicator_result
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.warning(f"Empty result for indicator: {indicator_key}")
|
||||
indicators_data[indicator_key] = pd.DataFrame()
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error calculating indicator {indicator_key}: {e}")
|
||||
indicators_data[indicator_key] = pd.DataFrame()
|
||||
|
||||
return indicators_data
|
||||
|
||||
def _create_indicator_key(self, indicator_config: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create a unique key for an indicator configuration.
|
||||
Delegates to shared utility function to ensure consistency.
|
||||
|
||||
Args:
|
||||
indicator_config: Indicator configuration
|
||||
|
||||
Returns:
|
||||
Unique string key for this indicator configuration
|
||||
"""
|
||||
return create_indicator_key(indicator_config)
|
||||
18
strategies/implementations/__init__.py
Normal file
18
strategies/implementations/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Strategy implementations package.
|
||||
|
||||
This package contains individual implementations of trading strategies,
|
||||
each in its own module for better maintainability and separation of concerns.
|
||||
"""
|
||||
|
||||
from .ema_crossover import EMAStrategy
|
||||
from .rsi import RSIStrategy
|
||||
from .macd import MACDStrategy
|
||||
# from .macd import MACDIndicator
|
||||
|
||||
__all__ = [
|
||||
'EMAStrategy',
|
||||
'RSIStrategy',
|
||||
'MACDStrategy',
|
||||
# 'MACDIndicator'
|
||||
]
|
||||
185
strategies/implementations/ema_crossover.py
Normal file
185
strategies/implementations/ema_crossover.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
EMA Crossover Strategy Implementation
|
||||
|
||||
This module implements an Exponential Moving Average (EMA) Crossover trading strategy.
|
||||
It extends the BaseStrategy and generates buy/sell signals based on the crossover
|
||||
of a fast EMA and a slow EMA.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ..base import BaseStrategy
|
||||
from ..data_types import StrategyResult, StrategySignal, SignalType
|
||||
from ..utils import create_indicator_key, detect_crossover_signals_vectorized
|
||||
|
||||
|
||||
class EMAStrategy(BaseStrategy):
|
||||
"""
|
||||
EMA Crossover Strategy.
|
||||
|
||||
Generates buy/sell signals when a fast EMA crosses above or below a slow EMA.
|
||||
"""
|
||||
|
||||
def __init__(self, logger=None):
|
||||
super().__init__(logger)
|
||||
self.strategy_name = "ema_crossover"
|
||||
|
||||
def get_required_indicators(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Defines the indicators required by the EMA Crossover strategy.
|
||||
It needs two EMA indicators: a fast one and a slow one.
|
||||
"""
|
||||
# Default periods for EMA crossover, can be overridden by strategy config
|
||||
return [
|
||||
{'type': 'ema', 'period': 12, 'price_column': 'close'},
|
||||
{'type': 'ema', 'period': 26, 'price_column': 'close'}
|
||||
]
|
||||
|
||||
def calculate(self, df: pd.DataFrame, indicators_data: Dict[str, pd.DataFrame], **kwargs) -> List[StrategyResult]:
|
||||
"""
|
||||
Calculate EMA Crossover strategy signals.
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV data.
|
||||
indicators_data: Dictionary of pre-calculated indicator DataFrames.
|
||||
Expected keys: 'ema_period_12', 'ema_period_26'.
|
||||
**kwargs: Additional strategy parameters (e.g., fast_period, slow_period, price_column).
|
||||
|
||||
Returns:
|
||||
List of StrategyResult objects, each containing generated signals.
|
||||
"""
|
||||
# Extract EMA periods from kwargs or use defaults
|
||||
fast_period = kwargs.get('fast_period', 12)
|
||||
slow_period = kwargs.get('slow_period', 26)
|
||||
price_column = kwargs.get('price_column', 'close')
|
||||
|
||||
# Generate indicator keys using shared utility function
|
||||
fast_ema_key = create_indicator_key({'type': 'ema', 'period': fast_period})
|
||||
slow_ema_key = create_indicator_key({'type': 'ema', 'period': slow_period})
|
||||
|
||||
# Validate that the main DataFrame has enough data for strategy calculation (not just indicators)
|
||||
if not self.validate_dataframe(df, max(fast_period, slow_period)):
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: Insufficient main DataFrame for calculation.")
|
||||
return []
|
||||
|
||||
# Validate that the required indicators are present and have sufficient data
|
||||
required_indicators = [
|
||||
{'type': 'ema', 'period': fast_period},
|
||||
{'type': 'ema', 'period': slow_period}
|
||||
]
|
||||
if not self.validate_indicators_data(indicators_data, required_indicators):
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: Missing or insufficient indicator data.")
|
||||
return []
|
||||
|
||||
fast_ema_df = indicators_data.get(fast_ema_key)
|
||||
slow_ema_df = indicators_data.get(slow_ema_key)
|
||||
|
||||
if fast_ema_df is None or slow_ema_df is None or fast_ema_df.empty or slow_ema_df.empty:
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: EMA indicator DataFrames are not found or empty.")
|
||||
return []
|
||||
|
||||
# Merge all necessary data into a single DataFrame for easier processing
|
||||
# Ensure alignment by index (timestamp)
|
||||
merged_df = pd.merge(df[[price_column, 'symbol', 'timeframe']],
|
||||
fast_ema_df[['ema']],
|
||||
left_index=True, right_index=True, how='inner',
|
||||
suffixes= ('', '_fast'))
|
||||
merged_df = pd.merge(merged_df,
|
||||
slow_ema_df[['ema']],
|
||||
left_index=True, right_index=True, how='inner',
|
||||
suffixes= ('', '_slow'))
|
||||
|
||||
# Rename columns to their logical names after merge
|
||||
merged_df.rename(columns={'ema': 'ema_fast', 'ema_slow': 'ema_slow'}, inplace=True)
|
||||
|
||||
if merged_df.empty:
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: Merged DataFrame is empty after indicator alignment. Check data ranges.")
|
||||
return []
|
||||
|
||||
# Use vectorized signal detection for better performance
|
||||
bullish_crossover, bearish_crossover = detect_crossover_signals_vectorized(
|
||||
merged_df, 'ema_fast', 'ema_slow'
|
||||
)
|
||||
|
||||
results: List[StrategyResult] = []
|
||||
strategy_metadata = {
|
||||
'fast_period': fast_period,
|
||||
'slow_period': slow_period
|
||||
}
|
||||
|
||||
# Process bullish crossover signals
|
||||
bullish_indices = merged_df[bullish_crossover].index
|
||||
for timestamp in bullish_indices:
|
||||
row = merged_df.loc[timestamp]
|
||||
|
||||
# Skip if any EMA values are NaN
|
||||
if pd.isna(row['ema_fast']) or pd.isna(row['ema_slow']):
|
||||
continue
|
||||
|
||||
signal = StrategySignal(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
signal_type=SignalType.BUY,
|
||||
price=float(row[price_column]),
|
||||
confidence=0.8,
|
||||
metadata={'crossover_type': 'bullish', **strategy_metadata}
|
||||
)
|
||||
|
||||
results.append(StrategyResult(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
strategy_name=self.strategy_name,
|
||||
signals=[signal],
|
||||
indicators_used={
|
||||
'ema_fast': float(row['ema_fast']),
|
||||
'ema_slow': float(row['ema_slow'])
|
||||
},
|
||||
metadata=strategy_metadata
|
||||
))
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.strategy_name}: BUY signal at {timestamp} for {row['symbol']}")
|
||||
|
||||
# Process bearish crossover signals
|
||||
bearish_indices = merged_df[bearish_crossover].index
|
||||
for timestamp in bearish_indices:
|
||||
row = merged_df.loc[timestamp]
|
||||
|
||||
# Skip if any EMA values are NaN
|
||||
if pd.isna(row['ema_fast']) or pd.isna(row['ema_slow']):
|
||||
continue
|
||||
|
||||
signal = StrategySignal(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
signal_type=SignalType.SELL,
|
||||
price=float(row[price_column]),
|
||||
confidence=0.8,
|
||||
metadata={'crossover_type': 'bearish', **strategy_metadata}
|
||||
)
|
||||
|
||||
results.append(StrategyResult(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
strategy_name=self.strategy_name,
|
||||
signals=[signal],
|
||||
indicators_used={
|
||||
'ema_fast': float(row['ema_fast']),
|
||||
'ema_slow': float(row['ema_slow'])
|
||||
},
|
||||
metadata=strategy_metadata
|
||||
))
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.strategy_name}: SELL signal at {timestamp} for {row['symbol']}")
|
||||
|
||||
return results
|
||||
180
strategies/implementations/macd.py
Normal file
180
strategies/implementations/macd.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
MACD Strategy Implementation
|
||||
|
||||
This module implements a Moving Average Convergence Divergence (MACD) trading strategy.
|
||||
It extends the BaseStrategy and generates buy/sell signals based on the crossover
|
||||
of the MACD line and its signal line.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ..base import BaseStrategy
|
||||
from ..data_types import StrategyResult, StrategySignal, SignalType
|
||||
from ..utils import create_indicator_key, detect_crossover_signals_vectorized
|
||||
|
||||
|
||||
class MACDStrategy(BaseStrategy):
|
||||
"""
|
||||
MACD Strategy.
|
||||
|
||||
Generates buy/sell signals when the MACD line crosses above or below its signal line.
|
||||
"""
|
||||
|
||||
def __init__(self, logger=None):
|
||||
super().__init__(logger)
|
||||
self.strategy_name = "macd"
|
||||
|
||||
def get_required_indicators(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Defines the indicators required by the MACD strategy.
|
||||
It needs one MACD indicator.
|
||||
"""
|
||||
# Default periods for MACD, can be overridden by strategy config
|
||||
return [
|
||||
{'type': 'macd', 'fast_period': 12, 'slow_period': 26, 'signal_period': 9, 'price_column': 'close'}
|
||||
]
|
||||
|
||||
def calculate(self, df: pd.DataFrame, indicators_data: Dict[str, pd.DataFrame], **kwargs) -> List[StrategyResult]:
|
||||
"""
|
||||
Calculate MACD strategy signals.
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV data.
|
||||
indicators_data: Dictionary of pre-calculated indicator DataFrames.
|
||||
Expected key: 'macd_fast_period_12_slow_period_26_signal_period_9'.
|
||||
**kwargs: Additional strategy parameters (e.g., fast_period, slow_period, signal_period, price_column).
|
||||
|
||||
Returns:
|
||||
List of StrategyResult objects, each containing generated signals.
|
||||
"""
|
||||
# Extract parameters from kwargs or use defaults
|
||||
fast_period = kwargs.get('fast_period', 12)
|
||||
slow_period = kwargs.get('slow_period', 26)
|
||||
signal_period = kwargs.get('signal_period', 9)
|
||||
price_column = kwargs.get('price_column', 'close')
|
||||
|
||||
# Generate indicator key using shared utility function
|
||||
macd_key = create_indicator_key({
|
||||
'type': 'macd',
|
||||
'fast_period': fast_period,
|
||||
'slow_period': slow_period,
|
||||
'signal_period': signal_period
|
||||
})
|
||||
|
||||
# Validate that the main DataFrame has enough data for strategy calculation
|
||||
min_periods = max(slow_period, signal_period)
|
||||
if not self.validate_dataframe(df, min_periods):
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: Insufficient main DataFrame for calculation.")
|
||||
return []
|
||||
|
||||
# Validate that the required MACD indicator data is present and sufficient
|
||||
required_indicators = [
|
||||
{'type': 'macd', 'fast_period': fast_period, 'slow_period': slow_period, 'signal_period': signal_period}
|
||||
]
|
||||
if not self.validate_indicators_data(indicators_data, required_indicators):
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: Missing or insufficient MACD indicator data.")
|
||||
return []
|
||||
|
||||
macd_df = indicators_data.get(macd_key)
|
||||
|
||||
if macd_df is None or macd_df.empty:
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: MACD indicator DataFrame is not found or empty.")
|
||||
return []
|
||||
|
||||
# Merge all necessary data into a single DataFrame for easier processing
|
||||
merged_df = pd.merge(df[[price_column, 'symbol', 'timeframe']],
|
||||
macd_df[['macd', 'signal', 'histogram']],
|
||||
left_index=True, right_index=True, how='inner')
|
||||
|
||||
if merged_df.empty:
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: Merged DataFrame is empty after indicator alignment. Check data ranges.")
|
||||
return []
|
||||
|
||||
# Use vectorized signal detection for better performance
|
||||
bullish_crossover, bearish_crossover = detect_crossover_signals_vectorized(
|
||||
merged_df, 'macd', 'signal'
|
||||
)
|
||||
|
||||
results: List[StrategyResult] = []
|
||||
strategy_metadata = {
|
||||
'fast_period': fast_period,
|
||||
'slow_period': slow_period,
|
||||
'signal_period': signal_period
|
||||
}
|
||||
|
||||
# Process bullish crossover signals (MACD crosses above Signal)
|
||||
bullish_indices = merged_df[bullish_crossover].index
|
||||
for timestamp in bullish_indices:
|
||||
row = merged_df.loc[timestamp]
|
||||
|
||||
# Skip if any MACD values are NaN
|
||||
if pd.isna(row['macd']) or pd.isna(row['signal']):
|
||||
continue
|
||||
|
||||
signal = StrategySignal(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
signal_type=SignalType.BUY,
|
||||
price=float(row[price_column]),
|
||||
confidence=0.9,
|
||||
metadata={'macd_cross': 'bullish', **strategy_metadata}
|
||||
)
|
||||
|
||||
results.append(StrategyResult(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
strategy_name=self.strategy_name,
|
||||
signals=[signal],
|
||||
indicators_used={
|
||||
'macd': float(row['macd']),
|
||||
'signal': float(row['signal'])
|
||||
},
|
||||
metadata=strategy_metadata
|
||||
))
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.strategy_name}: BUY signal at {timestamp} for {row['symbol']} (MACD: {row['macd']:.2f}, Signal: {row['signal']:.2f})")
|
||||
|
||||
# Process bearish crossover signals (MACD crosses below Signal)
|
||||
bearish_indices = merged_df[bearish_crossover].index
|
||||
for timestamp in bearish_indices:
|
||||
row = merged_df.loc[timestamp]
|
||||
|
||||
# Skip if any MACD values are NaN
|
||||
if pd.isna(row['macd']) or pd.isna(row['signal']):
|
||||
continue
|
||||
|
||||
signal = StrategySignal(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
signal_type=SignalType.SELL,
|
||||
price=float(row[price_column]),
|
||||
confidence=0.9,
|
||||
metadata={'macd_cross': 'bearish', **strategy_metadata}
|
||||
)
|
||||
|
||||
results.append(StrategyResult(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
strategy_name=self.strategy_name,
|
||||
signals=[signal],
|
||||
indicators_used={
|
||||
'macd': float(row['macd']),
|
||||
'signal': float(row['signal'])
|
||||
},
|
||||
metadata=strategy_metadata
|
||||
))
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.strategy_name}: SELL signal at {timestamp} for {row['symbol']} (MACD: {row['macd']:.2f}, Signal: {row['signal']:.2f})")
|
||||
|
||||
return results
|
||||
168
strategies/implementations/rsi.py
Normal file
168
strategies/implementations/rsi.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Relative Strength Index (RSI) Strategy Implementation
|
||||
|
||||
This module implements an RSI-based momentum trading strategy.
|
||||
It extends the BaseStrategy and generates buy/sell signals based on
|
||||
RSI crossing overbought/oversold thresholds.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ..base import BaseStrategy
|
||||
from ..data_types import StrategyResult, StrategySignal, SignalType
|
||||
from ..utils import create_indicator_key, detect_threshold_signals_vectorized
|
||||
|
||||
|
||||
class RSIStrategy(BaseStrategy):
|
||||
"""
|
||||
RSI Strategy.
|
||||
|
||||
Generates buy/sell signals when RSI crosses overbought/oversold thresholds.
|
||||
"""
|
||||
|
||||
def __init__(self, logger=None):
|
||||
super().__init__(logger)
|
||||
self.strategy_name = "rsi"
|
||||
|
||||
def get_required_indicators(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Defines the indicators required by the RSI strategy.
|
||||
It needs one RSI indicator.
|
||||
"""
|
||||
# Default period for RSI, can be overridden by strategy config
|
||||
return [
|
||||
{'type': 'rsi', 'period': 14, 'price_column': 'close'}
|
||||
]
|
||||
|
||||
def calculate(self, df: pd.DataFrame, indicators_data: Dict[str, pd.DataFrame], **kwargs) -> List[StrategyResult]:
|
||||
"""
|
||||
Calculate RSI strategy signals.
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV data.
|
||||
indicators_data: Dictionary of pre-calculated indicator DataFrames.
|
||||
Expected key: 'rsi_period_14'.
|
||||
**kwargs: Additional strategy parameters (e.g., period, overbought, oversold, price_column).
|
||||
|
||||
Returns:
|
||||
List of StrategyResult objects, each containing generated signals.
|
||||
"""
|
||||
# Extract parameters from kwargs or use defaults
|
||||
period = kwargs.get('period', 14)
|
||||
overbought = kwargs.get('overbought', 70)
|
||||
oversold = kwargs.get('oversold', 30)
|
||||
price_column = kwargs.get('price_column', 'close')
|
||||
|
||||
# Generate indicator key using shared utility function
|
||||
rsi_key = create_indicator_key({'type': 'rsi', 'period': period})
|
||||
|
||||
# Validate that the main DataFrame has enough data for strategy calculation
|
||||
if not self.validate_dataframe(df, period):
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: Insufficient main DataFrame for calculation.")
|
||||
return []
|
||||
|
||||
# Validate that the required RSI indicator data is present and sufficient
|
||||
required_indicators = [
|
||||
{'type': 'rsi', 'period': period}
|
||||
]
|
||||
if not self.validate_indicators_data(indicators_data, required_indicators):
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: Missing or insufficient RSI indicator data.")
|
||||
return []
|
||||
|
||||
rsi_df = indicators_data.get(rsi_key)
|
||||
|
||||
if rsi_df is None or rsi_df.empty:
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: RSI indicator DataFrame is not found or empty.")
|
||||
return []
|
||||
|
||||
# Merge all necessary data into a single DataFrame for easier processing
|
||||
merged_df = pd.merge(df[[price_column, 'symbol', 'timeframe']],
|
||||
rsi_df[['rsi']],
|
||||
left_index=True, right_index=True, how='inner')
|
||||
|
||||
if merged_df.empty:
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.strategy_name}: Merged DataFrame is empty after indicator alignment. Check data ranges.")
|
||||
return []
|
||||
|
||||
# Use vectorized signal detection for better performance
|
||||
buy_signals, sell_signals = detect_threshold_signals_vectorized(
|
||||
merged_df, 'rsi', overbought, oversold
|
||||
)
|
||||
|
||||
results: List[StrategyResult] = []
|
||||
strategy_metadata = {
|
||||
'period': period,
|
||||
'overbought': overbought,
|
||||
'oversold': oversold
|
||||
}
|
||||
|
||||
# Process buy signals (RSI crosses above oversold threshold)
|
||||
buy_indices = merged_df[buy_signals].index
|
||||
for timestamp in buy_indices:
|
||||
row = merged_df.loc[timestamp]
|
||||
|
||||
# Skip if RSI value is NaN
|
||||
if pd.isna(row['rsi']):
|
||||
continue
|
||||
|
||||
signal = StrategySignal(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
signal_type=SignalType.BUY,
|
||||
price=float(row[price_column]),
|
||||
confidence=0.7,
|
||||
metadata={'rsi_cross': 'oversold_to_buy', **strategy_metadata}
|
||||
)
|
||||
|
||||
results.append(StrategyResult(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
strategy_name=self.strategy_name,
|
||||
signals=[signal],
|
||||
indicators_used={'rsi': float(row['rsi'])},
|
||||
metadata=strategy_metadata
|
||||
))
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.strategy_name}: BUY signal at {timestamp} for {row['symbol']} (RSI: {row['rsi']:.2f})")
|
||||
|
||||
# Process sell signals (RSI crosses below overbought threshold)
|
||||
sell_indices = merged_df[sell_signals].index
|
||||
for timestamp in sell_indices:
|
||||
row = merged_df.loc[timestamp]
|
||||
|
||||
# Skip if RSI value is NaN
|
||||
if pd.isna(row['rsi']):
|
||||
continue
|
||||
|
||||
signal = StrategySignal(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
signal_type=SignalType.SELL,
|
||||
price=float(row[price_column]),
|
||||
confidence=0.7,
|
||||
metadata={'rsi_cross': 'overbought_to_sell', **strategy_metadata}
|
||||
)
|
||||
|
||||
results.append(StrategyResult(
|
||||
timestamp=timestamp,
|
||||
symbol=row['symbol'],
|
||||
timeframe=row['timeframe'],
|
||||
strategy_name=self.strategy_name,
|
||||
signals=[signal],
|
||||
indicators_used={'rsi': float(row['rsi'])},
|
||||
metadata=strategy_metadata
|
||||
))
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.strategy_name}: SELL signal at {timestamp} for {row['symbol']} (RSI: {row['rsi']:.2f})")
|
||||
|
||||
return results
|
||||
166
strategies/utils.py
Normal file
166
strategies/utils.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Trading Strategy Utilities
|
||||
|
||||
This module provides utility functions for managing trading strategy
|
||||
configurations and validation.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_default_strategies_config() -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Create default configuration for common trading strategies.
|
||||
|
||||
Returns:
|
||||
Dictionary with default strategy configurations
|
||||
"""
|
||||
return {
|
||||
'ema_crossover_default': {'strategy': 'ema_crossover', 'fast_period': 12, 'slow_period': 26},
|
||||
'rsi_default': {'strategy': 'rsi', 'period': 14, 'overbought': 70, 'oversold': 30},
|
||||
'macd_default': {'strategy': 'macd', 'fast_period': 12, 'slow_period': 26, 'signal_period': 9}
|
||||
}
|
||||
|
||||
|
||||
def create_indicator_key(indicator_config: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create a unique key for an indicator configuration.
|
||||
This function is shared between StrategyFactory and individual strategies
|
||||
to ensure consistency and reduce duplication.
|
||||
|
||||
Args:
|
||||
indicator_config: Indicator configuration
|
||||
|
||||
Returns:
|
||||
Unique string key for this indicator configuration
|
||||
"""
|
||||
indicator_type = indicator_config.get('type', 'unknown')
|
||||
|
||||
# Common parameters that should be part of the key
|
||||
key_params = []
|
||||
for param in ['period', 'fast_period', 'slow_period', 'signal_period', 'std_dev']:
|
||||
if param in indicator_config:
|
||||
key_params.append(f"{param}_{indicator_config[param]}")
|
||||
|
||||
if key_params:
|
||||
return f"{indicator_type}_{'_'.join(key_params)}"
|
||||
else:
|
||||
return f"{indicator_type}_default"
|
||||
|
||||
|
||||
def detect_crossover_signals_vectorized(df: pd.DataFrame,
|
||||
fast_col: str,
|
||||
slow_col: str) -> Tuple[pd.Series, pd.Series]:
|
||||
"""
|
||||
Detect crossover signals using vectorized operations for better performance.
|
||||
|
||||
Args:
|
||||
df: DataFrame with indicator data
|
||||
fast_col: Column name for fast line (e.g., 'ema_fast')
|
||||
slow_col: Column name for slow line (e.g., 'ema_slow')
|
||||
|
||||
Returns:
|
||||
Tuple of (bullish_crossover_mask, bearish_crossover_mask) as boolean Series
|
||||
"""
|
||||
# Get previous values using shift
|
||||
fast_prev = df[fast_col].shift(1)
|
||||
slow_prev = df[slow_col].shift(1)
|
||||
fast_curr = df[fast_col]
|
||||
slow_curr = df[slow_col]
|
||||
|
||||
# Bullish crossover: fast was <= slow, now fast > slow
|
||||
bullish_crossover = (fast_prev <= slow_prev) & (fast_curr > slow_curr)
|
||||
|
||||
# Bearish crossover: fast was >= slow, now fast < slow
|
||||
bearish_crossover = (fast_prev >= slow_prev) & (fast_curr < slow_curr)
|
||||
|
||||
# Ensure no signals on first row (index 0) where there's no previous value
|
||||
bullish_crossover.iloc[0] = False
|
||||
bearish_crossover.iloc[0] = False
|
||||
|
||||
return bullish_crossover, bearish_crossover
|
||||
|
||||
|
||||
def detect_threshold_signals_vectorized(df: pd.DataFrame,
|
||||
indicator_col: str,
|
||||
upper_threshold: float,
|
||||
lower_threshold: float) -> Tuple[pd.Series, pd.Series]:
|
||||
"""
|
||||
Detect threshold crossing signals using vectorized operations (for RSI, etc.).
|
||||
|
||||
Args:
|
||||
df: DataFrame with indicator data
|
||||
indicator_col: Column name for indicator (e.g., 'rsi')
|
||||
upper_threshold: Upper threshold (e.g., 70 for RSI overbought)
|
||||
lower_threshold: Lower threshold (e.g., 30 for RSI oversold)
|
||||
|
||||
Returns:
|
||||
Tuple of (buy_signal_mask, sell_signal_mask) as boolean Series
|
||||
"""
|
||||
# Get previous values using shift
|
||||
indicator_prev = df[indicator_col].shift(1)
|
||||
indicator_curr = df[indicator_col]
|
||||
|
||||
# Buy signal: was <= lower_threshold, now > lower_threshold (oversold to normal)
|
||||
buy_signal = (indicator_prev <= lower_threshold) & (indicator_curr > lower_threshold)
|
||||
|
||||
# Sell signal: was >= upper_threshold, now < upper_threshold (overbought to normal)
|
||||
sell_signal = (indicator_prev >= upper_threshold) & (indicator_curr < upper_threshold)
|
||||
|
||||
# Ensure no signals on first row (index 0) where there's no previous value
|
||||
buy_signal.iloc[0] = False
|
||||
sell_signal.iloc[0] = False
|
||||
|
||||
return buy_signal, sell_signal
|
||||
|
||||
|
||||
def validate_strategy_config(config: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate trading strategy configuration.
|
||||
|
||||
Args:
|
||||
config: Strategy configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration is valid, False otherwise
|
||||
"""
|
||||
required_fields = ['strategy']
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
return False
|
||||
|
||||
# Validate strategy type
|
||||
valid_types = ['ema_crossover', 'rsi', 'macd']
|
||||
if config['strategy'] not in valid_types:
|
||||
return False
|
||||
|
||||
# Basic validation for common parameters based on strategy type
|
||||
strategy_type = config['strategy']
|
||||
if strategy_type == 'ema_crossover':
|
||||
if not all(k in config for k in ['fast_period', 'slow_period']):
|
||||
return False
|
||||
if not (isinstance(config['fast_period'], int) and config['fast_period'] > 0 and
|
||||
isinstance(config['slow_period'], int) and config['slow_period'] > 0 and
|
||||
config['fast_period'] < config['slow_period']):
|
||||
return False
|
||||
elif strategy_type == 'rsi':
|
||||
if not all(k in config for k in ['period', 'overbought', 'oversold']):
|
||||
return False
|
||||
if not (isinstance(config['period'], int) and config['period'] > 0 and
|
||||
isinstance(config['overbought'], (int, float)) and isinstance(config['oversold'], (int, float)) and
|
||||
0 <= config['oversold'] < config['overbought'] <= 100):
|
||||
return False
|
||||
elif strategy_type == 'macd':
|
||||
if not all(k in config for k in ['fast_period', 'slow_period', 'signal_period']):
|
||||
return False
|
||||
if not (isinstance(config['fast_period'], int) and config['fast_period'] > 0 and
|
||||
isinstance(config['slow_period'], int) and config['slow_period'] > 0 and
|
||||
isinstance(config['signal_period'], int) and config['signal_period'] > 0 and
|
||||
config['fast_period'] < config['slow_period']):
|
||||
return False
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user