75 lines
2.7 KiB
Python
75 lines
2.7 KiB
Python
"""
|
|
Relative Strength Index (RSI) indicator implementation.
|
|
"""
|
|
|
|
from typing import List
|
|
import pandas as pd
|
|
|
|
from ..base import BaseIndicator
|
|
from ..result import IndicatorResult
|
|
|
|
|
|
class RSIIndicator(BaseIndicator):
|
|
"""
|
|
Relative Strength Index (RSI) technical indicator.
|
|
|
|
Measures momentum by comparing the magnitude of recent gains to recent losses.
|
|
Handles sparse data appropriately without interpolation.
|
|
"""
|
|
|
|
def calculate(self, df: pd.DataFrame, period: int = 14,
|
|
price_column: str = 'close') -> List[IndicatorResult]:
|
|
"""
|
|
Calculate Relative Strength Index (RSI).
|
|
|
|
Args:
|
|
df: DataFrame with OHLCV data
|
|
period: Number of periods for RSI calculation (default: 14)
|
|
price_column: Price column to use ('open', 'high', 'low', 'close')
|
|
|
|
Returns:
|
|
List of indicator results with RSI values
|
|
"""
|
|
# Validate input data
|
|
if not self.validate_dataframe(df, period + 1): # Need extra period for diff
|
|
return []
|
|
|
|
try:
|
|
# Calculate price changes
|
|
df['price_change'] = df[price_column].diff()
|
|
|
|
# Separate gains and losses
|
|
df['gain'] = df['price_change'].where(df['price_change'] > 0, 0)
|
|
df['loss'] = (-df['price_change']).where(df['price_change'] < 0, 0)
|
|
|
|
# Calculate average gain and loss using EMA
|
|
df['avg_gain'] = df['gain'].ewm(span=period, adjust=False).mean()
|
|
df['avg_loss'] = df['loss'].ewm(span=period, adjust=False).mean()
|
|
|
|
# Calculate RS and RSI
|
|
df['rs'] = df['avg_gain'] / df['avg_loss']
|
|
df['rsi'] = 100 - (100 / (1 + df['rs']))
|
|
|
|
# Handle division by zero
|
|
df['rsi'] = df['rsi'].fillna(50) # Neutral RSI when no losses
|
|
|
|
# Convert results to IndicatorResult objects
|
|
results = []
|
|
for i, (timestamp, row) in enumerate(df.iterrows()):
|
|
# Only return results after minimum period
|
|
if i >= period and not pd.isna(row['rsi']):
|
|
result = IndicatorResult(
|
|
timestamp=timestamp,
|
|
symbol=row['symbol'],
|
|
timeframe=row['timeframe'],
|
|
values={'rsi': row['rsi']},
|
|
metadata={'period': period, 'price_column': price_column}
|
|
)
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.error(f"Error calculating RSI: {e}")
|
|
return [] |