- Introduced a new `strategies` package containing the core structure for trading strategies, including `BaseStrategy`, `StrategyFactory`, and various strategy implementations (EMA, RSI, MACD). - Added utility functions for signal detection and validation in `strategies/utils.py`, enhancing modularity and maintainability. - Updated `pyproject.toml` to include the new `strategies` package in the build configuration. - Implemented comprehensive unit tests for the strategy foundation components, ensuring reliability and adherence to project standards. These changes establish a solid foundation for the strategy engine, aligning with project goals for modularity, performance, and maintainability.
168 lines
6.3 KiB
Python
168 lines
6.3 KiB
Python
"""
|
|
Relative Strength Index (RSI) Strategy Implementation
|
|
|
|
This module implements an RSI-based momentum trading strategy.
|
|
It extends the BaseStrategy and generates buy/sell signals based on
|
|
RSI crossing overbought/oversold thresholds.
|
|
"""
|
|
|
|
import pandas as pd
|
|
from typing import List, Dict, Any
|
|
|
|
from ..base import BaseStrategy
|
|
from ..data_types import StrategyResult, StrategySignal, SignalType
|
|
from ..utils import create_indicator_key, detect_threshold_signals_vectorized
|
|
|
|
|
|
class RSIStrategy(BaseStrategy):
|
|
"""
|
|
RSI Strategy.
|
|
|
|
Generates buy/sell signals when RSI crosses overbought/oversold thresholds.
|
|
"""
|
|
|
|
def __init__(self, logger=None):
|
|
super().__init__(logger)
|
|
self.strategy_name = "rsi"
|
|
|
|
def get_required_indicators(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
Defines the indicators required by the RSI strategy.
|
|
It needs one RSI indicator.
|
|
"""
|
|
# Default period for RSI, can be overridden by strategy config
|
|
return [
|
|
{'type': 'rsi', 'period': 14, 'price_column': 'close'}
|
|
]
|
|
|
|
def calculate(self, df: pd.DataFrame, indicators_data: Dict[str, pd.DataFrame], **kwargs) -> List[StrategyResult]:
|
|
"""
|
|
Calculate RSI strategy signals.
|
|
|
|
Args:
|
|
df: DataFrame with OHLCV data.
|
|
indicators_data: Dictionary of pre-calculated indicator DataFrames.
|
|
Expected key: 'rsi_period_14'.
|
|
**kwargs: Additional strategy parameters (e.g., period, overbought, oversold, price_column).
|
|
|
|
Returns:
|
|
List of StrategyResult objects, each containing generated signals.
|
|
"""
|
|
# Extract parameters from kwargs or use defaults
|
|
period = kwargs.get('period', 14)
|
|
overbought = kwargs.get('overbought', 70)
|
|
oversold = kwargs.get('oversold', 30)
|
|
price_column = kwargs.get('price_column', 'close')
|
|
|
|
# Generate indicator key using shared utility function
|
|
rsi_key = create_indicator_key({'type': 'rsi', 'period': period})
|
|
|
|
# Validate that the main DataFrame has enough data for strategy calculation
|
|
if not self.validate_dataframe(df, period):
|
|
if self.logger:
|
|
self.logger.warning(f"{self.strategy_name}: Insufficient main DataFrame for calculation.")
|
|
return []
|
|
|
|
# Validate that the required RSI indicator data is present and sufficient
|
|
required_indicators = [
|
|
{'type': 'rsi', 'period': period}
|
|
]
|
|
if not self.validate_indicators_data(indicators_data, required_indicators):
|
|
if self.logger:
|
|
self.logger.warning(f"{self.strategy_name}: Missing or insufficient RSI indicator data.")
|
|
return []
|
|
|
|
rsi_df = indicators_data.get(rsi_key)
|
|
|
|
if rsi_df is None or rsi_df.empty:
|
|
if self.logger:
|
|
self.logger.warning(f"{self.strategy_name}: RSI indicator DataFrame is not found or empty.")
|
|
return []
|
|
|
|
# Merge all necessary data into a single DataFrame for easier processing
|
|
merged_df = pd.merge(df[[price_column, 'symbol', 'timeframe']],
|
|
rsi_df[['rsi']],
|
|
left_index=True, right_index=True, how='inner')
|
|
|
|
if merged_df.empty:
|
|
if self.logger:
|
|
self.logger.warning(f"{self.strategy_name}: Merged DataFrame is empty after indicator alignment. Check data ranges.")
|
|
return []
|
|
|
|
# Use vectorized signal detection for better performance
|
|
buy_signals, sell_signals = detect_threshold_signals_vectorized(
|
|
merged_df, 'rsi', overbought, oversold
|
|
)
|
|
|
|
results: List[StrategyResult] = []
|
|
strategy_metadata = {
|
|
'period': period,
|
|
'overbought': overbought,
|
|
'oversold': oversold
|
|
}
|
|
|
|
# Process buy signals (RSI crosses above oversold threshold)
|
|
buy_indices = merged_df[buy_signals].index
|
|
for timestamp in buy_indices:
|
|
row = merged_df.loc[timestamp]
|
|
|
|
# Skip if RSI value is NaN
|
|
if pd.isna(row['rsi']):
|
|
continue
|
|
|
|
signal = StrategySignal(
|
|
timestamp=timestamp,
|
|
symbol=row['symbol'],
|
|
timeframe=row['timeframe'],
|
|
signal_type=SignalType.BUY,
|
|
price=float(row[price_column]),
|
|
confidence=0.7,
|
|
metadata={'rsi_cross': 'oversold_to_buy', **strategy_metadata}
|
|
)
|
|
|
|
results.append(StrategyResult(
|
|
timestamp=timestamp,
|
|
symbol=row['symbol'],
|
|
timeframe=row['timeframe'],
|
|
strategy_name=self.strategy_name,
|
|
signals=[signal],
|
|
indicators_used={'rsi': float(row['rsi'])},
|
|
metadata=strategy_metadata
|
|
))
|
|
|
|
if self.logger:
|
|
self.logger.info(f"{self.strategy_name}: BUY signal at {timestamp} for {row['symbol']} (RSI: {row['rsi']:.2f})")
|
|
|
|
# Process sell signals (RSI crosses below overbought threshold)
|
|
sell_indices = merged_df[sell_signals].index
|
|
for timestamp in sell_indices:
|
|
row = merged_df.loc[timestamp]
|
|
|
|
# Skip if RSI value is NaN
|
|
if pd.isna(row['rsi']):
|
|
continue
|
|
|
|
signal = StrategySignal(
|
|
timestamp=timestamp,
|
|
symbol=row['symbol'],
|
|
timeframe=row['timeframe'],
|
|
signal_type=SignalType.SELL,
|
|
price=float(row[price_column]),
|
|
confidence=0.7,
|
|
metadata={'rsi_cross': 'overbought_to_sell', **strategy_metadata}
|
|
)
|
|
|
|
results.append(StrategyResult(
|
|
timestamp=timestamp,
|
|
symbol=row['symbol'],
|
|
timeframe=row['timeframe'],
|
|
strategy_name=self.strategy_name,
|
|
signals=[signal],
|
|
indicators_used={'rsi': float(row['rsi'])},
|
|
metadata=strategy_metadata
|
|
))
|
|
|
|
if self.logger:
|
|
self.logger.info(f"{self.strategy_name}: SELL signal at {timestamp} for {row['symbol']} (RSI: {row['rsi']:.2f})")
|
|
|
|
return results |