Ajasra 3a9dec543c Refactor test_bbrsi.py and enhance strategy implementations
- Removed unused configuration for daily data and consolidated minute configuration into a single config dictionary.
- Updated plotting logic to dynamically handle different strategies, ensuring appropriate bands and signals are displayed based on the selected strategy.
- Improved error handling and logging for missing data in plots.
- Enhanced the Bollinger Bands and RSI classes to support adaptive parameters based on market regimes, improving flexibility in strategy execution.
- Added new CryptoTradingStrategy with multi-timeframe analysis and volume confirmation for better trading signal accuracy.
- Updated documentation to reflect changes in strategy implementations and configuration requirements.
2025-05-22 17:57:04 +08:00

114 lines
5.1 KiB
Python

import pandas as pd
import numpy as np
class RSI:
"""
A class to calculate the Relative Strength Index (RSI).
"""
def __init__(self, config):
"""
Initializes the RSI calculator.
Args:
period (int): The period for RSI calculation. Default is 14.
Must be a positive integer.
"""
if not isinstance(config['rsi_period'], int) or config['rsi_period'] <= 0:
raise ValueError("Period must be a positive integer.")
self.period = config['rsi_period']
def calculate(self, data_df: pd.DataFrame, price_column: str = 'close') -> pd.DataFrame:
"""
Calculates the RSI (using Wilder's smoothing) and adds it as a column to the input DataFrame.
Args:
data_df (pd.DataFrame): DataFrame with historical price data.
Must contain the 'price_column'.
price_column (str): The name of the column containing price data.
Default is 'close'.
Returns:
pd.DataFrame: The input DataFrame with an added 'RSI' column.
Returns the original DataFrame with no 'RSI' column
if the period is larger than the number of data points.
"""
if price_column not in data_df.columns:
raise ValueError(f"Price column '{price_column}' not found in DataFrame.")
# Check if data is sufficient for calculation (need period + 1 for one diff calculation)
if len(data_df) < self.period + 1:
print(f"Warning: Data length ({len(data_df)}) is less than RSI period ({self.period}) + 1. RSI will not be calculated meaningfully.")
df_copy = data_df.copy()
df_copy['RSI'] = np.nan # Add an RSI column with NaNs
return df_copy
df = data_df.copy() # Work on a copy
price_series = df[price_column]
# Call the static custom RSI calculator, defaulting to EMA for Wilder's smoothing
rsi_series = self.calculate_custom_rsi(price_series, window=self.period, smoothing='EMA')
df['RSI'] = rsi_series
return df
@staticmethod
def calculate_custom_rsi(price_series: pd.Series, window: int = 14, smoothing: str = 'SMA') -> pd.Series:
"""
Calculates RSI with specified window and smoothing (SMA or EMA).
Args:
price_series (pd.Series): Series of prices.
window (int): The period for RSI calculation. Must be a positive integer.
smoothing (str): Smoothing method, 'SMA' or 'EMA'. Defaults to 'SMA'.
Returns:
pd.Series: Series containing the RSI values.
"""
if not isinstance(price_series, pd.Series):
raise TypeError("price_series must be a pandas Series.")
if not isinstance(window, int) or window <= 0:
raise ValueError("window must be a positive integer.")
if smoothing not in ['SMA', 'EMA']:
raise ValueError("smoothing must be either 'SMA' or 'EMA'.")
if len(price_series) < window + 1: # Need at least window + 1 prices for one diff
# print(f"Warning: Data length ({len(price_series)}) is less than RSI window ({window}) + 1. RSI will be all NaN.")
return pd.Series(np.nan, index=price_series.index)
delta = price_series.diff()
# The first delta is NaN. For gain/loss calculations, it can be treated as 0.
# However, subsequent rolling/ewm will handle NaNs appropriately if min_periods is set.
gain = delta.where(delta > 0, 0.0)
loss = -delta.where(delta < 0, 0.0) # Ensure loss is positive
# Ensure gain and loss Series have the same index as price_series for rolling/ewm
# This is important if price_series has missing dates/times
gain = gain.reindex(price_series.index, fill_value=0.0)
loss = loss.reindex(price_series.index, fill_value=0.0)
if smoothing == 'EMA':
# adjust=False for Wilder's smoothing used in RSI
avg_gain = gain.ewm(alpha=1/window, adjust=False, min_periods=window).mean()
avg_loss = loss.ewm(alpha=1/window, adjust=False, min_periods=window).mean()
else: # SMA
avg_gain = gain.rolling(window=window, min_periods=window).mean()
avg_loss = loss.rolling(window=window, min_periods=window).mean()
# Handle division by zero for RS calculation
# If avg_loss is 0, RS can be considered infinite (if avg_gain > 0) or undefined (if avg_gain also 0)
rs = avg_gain / avg_loss.replace(0, 1e-9) # Replace 0 with a tiny number to avoid direct division by zero warning
rsi = 100 - (100 / (1 + rs))
# Correct RSI values for edge cases where avg_loss was 0
# If avg_loss is 0 and avg_gain is > 0, RSI is 100.
# If avg_loss is 0 and avg_gain is 0, RSI is 50 (neutral).
rsi[avg_loss == 0] = np.where(avg_gain[avg_loss == 0] > 0, 100, 50)
# Ensure RSI is NaN where avg_gain or avg_loss is NaN (due to min_periods)
rsi[avg_gain.isna() | avg_loss.isna()] = np.nan
return rsi