Cycles/test/indicators/test_rsi_indicators.py

401 lines
17 KiB
Python
Raw Normal View History

"""
RSI Indicators Comparison Test
Focused testing for RSI and Simple RSI implementations.
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# Import original indicators
from cycles.IncStrategies.indicators import (
RSIState as OriginalRSI,
SimpleRSIState as OriginalSimpleRSI
)
# Import new indicators
from IncrementalTrader.strategies.indicators import (
RSIState as NewRSI,
SimpleRSIState as NewSimpleRSI
)
class RSIComparisonTest:
"""Test framework for comparing RSI implementations."""
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
self.data_file = data_file
self.sample_size = sample_size
self.data = None
self.results = {}
# Create results directory
self.results_dir = Path("test/results/rsi_indicators")
self.results_dir.mkdir(parents=True, exist_ok=True)
def load_data(self):
"""Load and prepare the data for testing."""
print(f"Loading data from {self.data_file}...")
df = pd.read_csv(self.data_file)
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
if self.sample_size and len(df) > self.sample_size:
df = df.tail(self.sample_size).reset_index(drop=True)
self.data = df
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
def test_rsi(self, periods=[7, 14, 21, 28]):
"""Test RSI implementations (Wilder's smoothing)."""
print("\n=== Testing RSI (Wilder's Smoothing) ===")
for period in periods:
print(f"Testing RSI({period})...")
# Initialize indicators
original_rsi = OriginalRSI(period)
new_rsi = NewRSI(period)
original_values = []
new_values = []
prices = []
price_changes = []
# Process data
prev_price = None
for _, row in self.data.iterrows():
price = row['Close']
prices.append(price)
if prev_price is not None:
price_changes.append(price - prev_price)
else:
price_changes.append(0)
original_rsi.update(price)
new_rsi.update(price)
original_values.append(original_rsi.get_current_value() if original_rsi.is_warmed_up() else np.nan)
new_values.append(new_rsi.get_current_value() if new_rsi.is_warmed_up() else np.nan)
prev_price = price
# Store results
self.results[f'RSI_{period}'] = {
'original': original_values,
'new': new_values,
'prices': prices,
'price_changes': price_changes,
'dates': self.data['datetime'].tolist(),
'period': period
}
# Calculate differences
diff = np.array(new_values) - np.array(original_values)
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
max_diff = np.max(np.abs(valid_diff))
mean_diff = np.mean(np.abs(valid_diff))
std_diff = np.std(valid_diff)
print(f" Max difference: {max_diff:.12f}")
print(f" Mean difference: {mean_diff:.12f}")
print(f" Std difference: {std_diff:.12f}")
# Status check
if max_diff < 1e-10:
print(f" ✅ PASSED: Mathematically equivalent")
elif max_diff < 1e-6:
print(f" ⚠️ WARNING: Small differences (floating point precision)")
else:
print(f" ❌ FAILED: Significant differences detected")
else:
print(f" ❌ ERROR: No valid data points")
def test_simple_rsi(self, periods=[7, 14, 21, 28]):
"""Test Simple RSI implementations (Simple moving average)."""
print("\n=== Testing Simple RSI (Simple Moving Average) ===")
for period in periods:
print(f"Testing SimpleRSI({period})...")
# Initialize indicators
original_rsi = OriginalSimpleRSI(period)
new_rsi = NewSimpleRSI(period)
original_values = []
new_values = []
prices = []
price_changes = []
# Process data
prev_price = None
for _, row in self.data.iterrows():
price = row['Close']
prices.append(price)
if prev_price is not None:
price_changes.append(price - prev_price)
else:
price_changes.append(0)
original_rsi.update(price)
new_rsi.update(price)
original_values.append(original_rsi.get_current_value() if original_rsi.is_warmed_up() else np.nan)
new_values.append(new_rsi.get_current_value() if new_rsi.is_warmed_up() else np.nan)
prev_price = price
# Store results
self.results[f'SimpleRSI_{period}'] = {
'original': original_values,
'new': new_values,
'prices': prices,
'price_changes': price_changes,
'dates': self.data['datetime'].tolist(),
'period': period
}
# Calculate differences
diff = np.array(new_values) - np.array(original_values)
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
max_diff = np.max(np.abs(valid_diff))
mean_diff = np.mean(np.abs(valid_diff))
std_diff = np.std(valid_diff)
print(f" Max difference: {max_diff:.12f}")
print(f" Mean difference: {mean_diff:.12f}")
print(f" Std difference: {std_diff:.12f}")
# Status check
if max_diff < 1e-10:
print(f" ✅ PASSED: Mathematically equivalent")
elif max_diff < 1e-6:
print(f" ⚠️ WARNING: Small differences (floating point precision)")
else:
print(f" ❌ FAILED: Significant differences detected")
else:
print(f" ❌ ERROR: No valid data points")
def plot_comparison(self, indicator_name: str):
"""Plot detailed comparison for a specific indicator."""
if indicator_name not in self.results:
print(f"No results found for {indicator_name}")
return
result = self.results[indicator_name]
dates = pd.to_datetime(result['dates'])
# Create figure with subplots
fig, axes = plt.subplots(4, 1, figsize=(15, 16))
fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16)
# Plot 1: Price data
ax1 = axes[0]
ax1.plot(dates, result['prices'], label='Close Price', alpha=0.8, color='black', linewidth=1)
ax1.set_title('Price Data')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot 2: RSI comparison with levels
ax2 = axes[1]
ax2.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2, color='blue')
ax2.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--', color='red')
ax2.axhline(y=70, color='red', linestyle=':', alpha=0.7, label='Overbought (70)')
ax2.axhline(y=30, color='green', linestyle=':', alpha=0.7, label='Oversold (30)')
ax2.axhline(y=50, color='gray', linestyle='-', alpha=0.5, label='Midline (50)')
ax2.set_title(f'{indicator_name} Values Comparison')
ax2.set_ylim(0, 100)
ax2.legend()
ax2.grid(True, alpha=0.3)
# Plot 3: Price changes
ax3 = axes[2]
positive_changes = [max(0, change) for change in result['price_changes']]
negative_changes = [abs(min(0, change)) for change in result['price_changes']]
ax3.plot(dates, positive_changes, label='Positive Changes', alpha=0.7, color='green')
ax3.plot(dates, negative_changes, label='Negative Changes', alpha=0.7, color='red')
ax3.set_title('Price Changes (Gains and Losses)')
ax3.legend()
ax3.grid(True, alpha=0.3)
# Plot 4: Difference analysis
ax4 = axes[3]
diff = np.array(result['new']) - np.array(result['original'])
ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1)
ax4.set_title(f'{indicator_name} Difference (New - Original)')
ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5)
ax4.grid(True, alpha=0.3)
# Add statistics text
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n'
stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n'
stats_text += f'Std: {np.std(valid_diff):.2e}'
ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
# Format x-axis
for ax in axes:
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
plt.tight_layout()
# Save plot
plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"Plot saved to {plot_path}")
plt.show()
def plot_all_comparisons(self):
"""Plot comparisons for all tested indicators."""
print("\n=== Generating Detailed Comparison Plots ===")
for indicator_name in self.results.keys():
print(f"Plotting {indicator_name}...")
self.plot_comparison(indicator_name)
plt.close('all')
def generate_report(self):
"""Generate detailed report for RSI indicators."""
print("\n=== Generating RSI Report ===")
report_lines = []
report_lines.append("# RSI Indicators Comparison Report")
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
report_lines.append(f"Data file: {self.data_file}")
report_lines.append(f"Sample size: {len(self.data)} data points")
report_lines.append("")
# Summary table
report_lines.append("## Summary Table")
report_lines.append("| Indicator | Period | Max Diff | Mean Diff | Status |")
report_lines.append("|-----------|--------|----------|-----------|--------|")
for indicator_name, result in self.results.items():
diff = np.array(result['new']) - np.array(result['original'])
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
max_diff = np.max(np.abs(valid_diff))
mean_diff = np.mean(np.abs(valid_diff))
if max_diff < 1e-10:
status = "✅ PASSED"
elif max_diff < 1e-6:
status = "⚠️ WARNING"
else:
status = "❌ FAILED"
report_lines.append(f"| {indicator_name} | {result['period']} | {max_diff:.2e} | {mean_diff:.2e} | {status} |")
else:
report_lines.append(f"| {indicator_name} | {result['period']} | N/A | N/A | ❌ ERROR |")
report_lines.append("")
# Methodology explanation
report_lines.append("## Methodology")
report_lines.append("### RSI (Relative Strength Index)")
report_lines.append("- Uses Wilder's smoothing for average gains and losses")
report_lines.append("- Average Gain = (Previous Average Gain × (n-1) + Current Gain) / n")
report_lines.append("- Average Loss = (Previous Average Loss × (n-1) + Current Loss) / n")
report_lines.append("- RS = Average Gain / Average Loss")
report_lines.append("- RSI = 100 - (100 / (1 + RS))")
report_lines.append("")
report_lines.append("### Simple RSI")
report_lines.append("- Uses simple moving average for average gains and losses")
report_lines.append("- More responsive to recent price changes than Wilder's method")
report_lines.append("")
# Detailed analysis
report_lines.append("## Detailed Analysis")
for indicator_name, result in self.results.items():
report_lines.append(f"### {indicator_name}")
diff = np.array(result['new']) - np.array(result['original'])
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
report_lines.append(f"- **Period**: {result['period']}")
report_lines.append(f"- **Valid data points**: {len(valid_diff)}")
report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}")
report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}")
report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}")
# RSI-specific metrics
valid_original = np.array(result['original'])[~np.isnan(result['original'])]
if len(valid_original) > 0:
mean_rsi = np.mean(valid_original)
overbought_count = np.sum(valid_original > 70)
oversold_count = np.sum(valid_original < 30)
report_lines.append(f"- **Mean RSI value**: {mean_rsi:.2f}")
report_lines.append(f"- **Overbought periods (>70)**: {overbought_count} ({overbought_count/len(valid_original)*100:.1f}%)")
report_lines.append(f"- **Oversold periods (<30)**: {oversold_count} ({oversold_count/len(valid_original)*100:.1f}%)")
# Price change analysis
positive_changes = [max(0, change) for change in result['price_changes']]
negative_changes = [abs(min(0, change)) for change in result['price_changes']]
avg_gain = np.mean([change for change in positive_changes if change > 0]) if any(change > 0 for change in positive_changes) else 0
avg_loss = np.mean([change for change in negative_changes if change > 0]) if any(change > 0 for change in negative_changes) else 0
report_lines.append(f"- **Average gain**: {avg_gain:.6f}")
report_lines.append(f"- **Average loss**: {avg_loss:.6f}")
if avg_loss > 0:
report_lines.append(f"- **Gain/Loss ratio**: {avg_gain/avg_loss:.3f}")
# Percentile analysis
percentiles = [1, 5, 25, 50, 75, 95, 99]
perc_values = np.percentile(np.abs(valid_diff), percentiles)
perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)])
report_lines.append(f"- **Percentiles**: {perc_str}")
report_lines.append("")
# Save report
report_path = self.results_dir / "rsi_indicators_report.md"
with open(report_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(report_lines))
print(f"Report saved to {report_path}")
def run_tests(self):
"""Run all RSI tests."""
print("Starting RSI Comparison Tests...")
# Load data
self.load_data()
# Run tests
self.test_rsi()
self.test_simple_rsi()
# Generate outputs
self.plot_all_comparisons()
self.generate_report()
print("\n✅ RSI tests completed!")
if __name__ == "__main__":
tester = RSIComparisonTest(sample_size=3000)
tester.run_tests()