import pandas as pd import numpy as np class RSI: """ A class to calculate the Relative Strength Index (RSI). """ def __init__(self, config): """ Initializes the RSI calculator. Args: period (int): The period for RSI calculation. Default is 14. Must be a positive integer. """ if not isinstance(config['rsi_period'], int) or config['rsi_period'] <= 0: raise ValueError("Period must be a positive integer.") self.period = config['rsi_period'] def calculate(self, data_df: pd.DataFrame, price_column: str = 'close') -> pd.DataFrame: """ Calculates the RSI (using Wilder's smoothing) and adds it as a column to the input DataFrame. Args: data_df (pd.DataFrame): DataFrame with historical price data. Must contain the 'price_column'. price_column (str): The name of the column containing price data. Default is 'close'. Returns: pd.DataFrame: The input DataFrame with an added 'RSI' column. Returns the original DataFrame with no 'RSI' column if the period is larger than the number of data points. """ if price_column not in data_df.columns: raise ValueError(f"Price column '{price_column}' not found in DataFrame.") # Check if data is sufficient for calculation (need period + 1 for one diff calculation) if len(data_df) < self.period + 1: print(f"Warning: Data length ({len(data_df)}) is less than RSI period ({self.period}) + 1. RSI will not be calculated meaningfully.") df_copy = data_df.copy() df_copy['RSI'] = np.nan # Add an RSI column with NaNs return df_copy df = data_df.copy() # Work on a copy price_series = df[price_column] # Call the static custom RSI calculator, defaulting to EMA for Wilder's smoothing rsi_series = self.calculate_custom_rsi(price_series, window=self.period, smoothing='EMA') df['RSI'] = rsi_series return df @staticmethod def calculate_custom_rsi(price_series: pd.Series, window: int = 14, smoothing: str = 'SMA') -> pd.Series: """ Calculates RSI with specified window and smoothing (SMA or EMA). Args: price_series (pd.Series): Series of prices. window (int): The period for RSI calculation. Must be a positive integer. smoothing (str): Smoothing method, 'SMA' or 'EMA'. Defaults to 'SMA'. Returns: pd.Series: Series containing the RSI values. """ if not isinstance(price_series, pd.Series): raise TypeError("price_series must be a pandas Series.") if not isinstance(window, int) or window <= 0: raise ValueError("window must be a positive integer.") if smoothing not in ['SMA', 'EMA']: raise ValueError("smoothing must be either 'SMA' or 'EMA'.") if len(price_series) < window + 1: # Need at least window + 1 prices for one diff # print(f"Warning: Data length ({len(price_series)}) is less than RSI window ({window}) + 1. RSI will be all NaN.") return pd.Series(np.nan, index=price_series.index) delta = price_series.diff() # The first delta is NaN. For gain/loss calculations, it can be treated as 0. # However, subsequent rolling/ewm will handle NaNs appropriately if min_periods is set. gain = delta.where(delta > 0, 0.0) loss = -delta.where(delta < 0, 0.0) # Ensure loss is positive # Ensure gain and loss Series have the same index as price_series for rolling/ewm # This is important if price_series has missing dates/times gain = gain.reindex(price_series.index, fill_value=0.0) loss = loss.reindex(price_series.index, fill_value=0.0) if smoothing == 'EMA': # adjust=False for Wilder's smoothing used in RSI avg_gain = gain.ewm(alpha=1/window, adjust=False, min_periods=window).mean() avg_loss = loss.ewm(alpha=1/window, adjust=False, min_periods=window).mean() else: # SMA avg_gain = gain.rolling(window=window, min_periods=window).mean() avg_loss = loss.rolling(window=window, min_periods=window).mean() # Handle division by zero for RS calculation # If avg_loss is 0, RS can be considered infinite (if avg_gain > 0) or undefined (if avg_gain also 0) rs = avg_gain / avg_loss.replace(0, 1e-9) # Replace 0 with a tiny number to avoid direct division by zero warning rsi = 100 - (100 / (1 + rs)) # Correct RSI values for edge cases where avg_loss was 0 # If avg_loss is 0 and avg_gain is > 0, RSI is 100. # If avg_loss is 0 and avg_gain is 0, RSI is 50 (neutral). rsi[avg_loss == 0] = np.where(avg_gain[avg_loss == 0] > 0, 100, 50) # Ensure RSI is NaN where avg_gain or avg_loss is NaN (due to min_periods) rsi[avg_gain.isna() | avg_loss.isna()] = np.nan return rsi