""" RSI Indicators Comparison Test Focused testing for RSI and Simple RSI implementations. """ import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib.dates as mdates from datetime import datetime import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) # Import original indicators from cycles.IncStrategies.indicators import ( RSIState as OriginalRSI, SimpleRSIState as OriginalSimpleRSI ) # Import new indicators from IncrementalTrader.strategies.indicators import ( RSIState as NewRSI, SimpleRSIState as NewSimpleRSI ) class RSIComparisonTest: """Test framework for comparing RSI implementations.""" def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000): self.data_file = data_file self.sample_size = sample_size self.data = None self.results = {} # Create results directory self.results_dir = Path("test/results/rsi_indicators") self.results_dir.mkdir(parents=True, exist_ok=True) def load_data(self): """Load and prepare the data for testing.""" print(f"Loading data from {self.data_file}...") df = pd.read_csv(self.data_file) df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s') if self.sample_size and len(df) > self.sample_size: df = df.tail(self.sample_size).reset_index(drop=True) self.data = df print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}") def test_rsi(self, periods=[7, 14, 21, 28]): """Test RSI implementations (Wilder's smoothing).""" print("\n=== Testing RSI (Wilder's Smoothing) ===") for period in periods: print(f"Testing RSI({period})...") # Initialize indicators original_rsi = OriginalRSI(period) new_rsi = NewRSI(period) original_values = [] new_values = [] prices = [] price_changes = [] # Process data prev_price = None for _, row in self.data.iterrows(): price = row['Close'] prices.append(price) if prev_price is not None: price_changes.append(price - prev_price) else: price_changes.append(0) original_rsi.update(price) new_rsi.update(price) original_values.append(original_rsi.get_current_value() if original_rsi.is_warmed_up() else np.nan) new_values.append(new_rsi.get_current_value() if new_rsi.is_warmed_up() else np.nan) prev_price = price # Store results self.results[f'RSI_{period}'] = { 'original': original_values, 'new': new_values, 'prices': prices, 'price_changes': price_changes, 'dates': self.data['datetime'].tolist(), 'period': period } # Calculate differences diff = np.array(new_values) - np.array(original_values) valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: max_diff = np.max(np.abs(valid_diff)) mean_diff = np.mean(np.abs(valid_diff)) std_diff = np.std(valid_diff) print(f" Max difference: {max_diff:.12f}") print(f" Mean difference: {mean_diff:.12f}") print(f" Std difference: {std_diff:.12f}") # Status check if max_diff < 1e-10: print(f" ✅ PASSED: Mathematically equivalent") elif max_diff < 1e-6: print(f" ⚠️ WARNING: Small differences (floating point precision)") else: print(f" ❌ FAILED: Significant differences detected") else: print(f" ❌ ERROR: No valid data points") def test_simple_rsi(self, periods=[7, 14, 21, 28]): """Test Simple RSI implementations (Simple moving average).""" print("\n=== Testing Simple RSI (Simple Moving Average) ===") for period in periods: print(f"Testing SimpleRSI({period})...") # Initialize indicators original_rsi = OriginalSimpleRSI(period) new_rsi = NewSimpleRSI(period) original_values = [] new_values = [] prices = [] price_changes = [] # Process data prev_price = None for _, row in self.data.iterrows(): price = row['Close'] prices.append(price) if prev_price is not None: price_changes.append(price - prev_price) else: price_changes.append(0) original_rsi.update(price) new_rsi.update(price) original_values.append(original_rsi.get_current_value() if original_rsi.is_warmed_up() else np.nan) new_values.append(new_rsi.get_current_value() if new_rsi.is_warmed_up() else np.nan) prev_price = price # Store results self.results[f'SimpleRSI_{period}'] = { 'original': original_values, 'new': new_values, 'prices': prices, 'price_changes': price_changes, 'dates': self.data['datetime'].tolist(), 'period': period } # Calculate differences diff = np.array(new_values) - np.array(original_values) valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: max_diff = np.max(np.abs(valid_diff)) mean_diff = np.mean(np.abs(valid_diff)) std_diff = np.std(valid_diff) print(f" Max difference: {max_diff:.12f}") print(f" Mean difference: {mean_diff:.12f}") print(f" Std difference: {std_diff:.12f}") # Status check if max_diff < 1e-10: print(f" ✅ PASSED: Mathematically equivalent") elif max_diff < 1e-6: print(f" ⚠️ WARNING: Small differences (floating point precision)") else: print(f" ❌ FAILED: Significant differences detected") else: print(f" ❌ ERROR: No valid data points") def plot_comparison(self, indicator_name: str): """Plot detailed comparison for a specific indicator.""" if indicator_name not in self.results: print(f"No results found for {indicator_name}") return result = self.results[indicator_name] dates = pd.to_datetime(result['dates']) # Create figure with subplots fig, axes = plt.subplots(4, 1, figsize=(15, 16)) fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16) # Plot 1: Price data ax1 = axes[0] ax1.plot(dates, result['prices'], label='Close Price', alpha=0.8, color='black', linewidth=1) ax1.set_title('Price Data') ax1.legend() ax1.grid(True, alpha=0.3) # Plot 2: RSI comparison with levels ax2 = axes[1] ax2.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2, color='blue') ax2.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--', color='red') ax2.axhline(y=70, color='red', linestyle=':', alpha=0.7, label='Overbought (70)') ax2.axhline(y=30, color='green', linestyle=':', alpha=0.7, label='Oversold (30)') ax2.axhline(y=50, color='gray', linestyle='-', alpha=0.5, label='Midline (50)') ax2.set_title(f'{indicator_name} Values Comparison') ax2.set_ylim(0, 100) ax2.legend() ax2.grid(True, alpha=0.3) # Plot 3: Price changes ax3 = axes[2] positive_changes = [max(0, change) for change in result['price_changes']] negative_changes = [abs(min(0, change)) for change in result['price_changes']] ax3.plot(dates, positive_changes, label='Positive Changes', alpha=0.7, color='green') ax3.plot(dates, negative_changes, label='Negative Changes', alpha=0.7, color='red') ax3.set_title('Price Changes (Gains and Losses)') ax3.legend() ax3.grid(True, alpha=0.3) # Plot 4: Difference analysis ax4 = axes[3] diff = np.array(result['new']) - np.array(result['original']) ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1) ax4.set_title(f'{indicator_name} Difference (New - Original)') ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5) ax4.grid(True, alpha=0.3) # Add statistics text valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n' stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n' stats_text += f'Std: {np.std(valid_diff):.2e}' ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8)) # Format x-axis for ax in axes: ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10))) plt.setp(ax.xaxis.get_majorticklabels(), rotation=45) plt.tight_layout() # Save plot plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png" plt.savefig(plot_path, dpi=300, bbox_inches='tight') print(f"Plot saved to {plot_path}") plt.show() def plot_all_comparisons(self): """Plot comparisons for all tested indicators.""" print("\n=== Generating Detailed Comparison Plots ===") for indicator_name in self.results.keys(): print(f"Plotting {indicator_name}...") self.plot_comparison(indicator_name) plt.close('all') def generate_report(self): """Generate detailed report for RSI indicators.""" print("\n=== Generating RSI Report ===") report_lines = [] report_lines.append("# RSI Indicators Comparison Report") report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") report_lines.append(f"Data file: {self.data_file}") report_lines.append(f"Sample size: {len(self.data)} data points") report_lines.append("") # Summary table report_lines.append("## Summary Table") report_lines.append("| Indicator | Period | Max Diff | Mean Diff | Status |") report_lines.append("|-----------|--------|----------|-----------|--------|") for indicator_name, result in self.results.items(): diff = np.array(result['new']) - np.array(result['original']) valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: max_diff = np.max(np.abs(valid_diff)) mean_diff = np.mean(np.abs(valid_diff)) if max_diff < 1e-10: status = "✅ PASSED" elif max_diff < 1e-6: status = "⚠️ WARNING" else: status = "❌ FAILED" report_lines.append(f"| {indicator_name} | {result['period']} | {max_diff:.2e} | {mean_diff:.2e} | {status} |") else: report_lines.append(f"| {indicator_name} | {result['period']} | N/A | N/A | ❌ ERROR |") report_lines.append("") # Methodology explanation report_lines.append("## Methodology") report_lines.append("### RSI (Relative Strength Index)") report_lines.append("- Uses Wilder's smoothing for average gains and losses") report_lines.append("- Average Gain = (Previous Average Gain × (n-1) + Current Gain) / n") report_lines.append("- Average Loss = (Previous Average Loss × (n-1) + Current Loss) / n") report_lines.append("- RS = Average Gain / Average Loss") report_lines.append("- RSI = 100 - (100 / (1 + RS))") report_lines.append("") report_lines.append("### Simple RSI") report_lines.append("- Uses simple moving average for average gains and losses") report_lines.append("- More responsive to recent price changes than Wilder's method") report_lines.append("") # Detailed analysis report_lines.append("## Detailed Analysis") for indicator_name, result in self.results.items(): report_lines.append(f"### {indicator_name}") diff = np.array(result['new']) - np.array(result['original']) valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: report_lines.append(f"- **Period**: {result['period']}") report_lines.append(f"- **Valid data points**: {len(valid_diff)}") report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}") report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}") report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}") # RSI-specific metrics valid_original = np.array(result['original'])[~np.isnan(result['original'])] if len(valid_original) > 0: mean_rsi = np.mean(valid_original) overbought_count = np.sum(valid_original > 70) oversold_count = np.sum(valid_original < 30) report_lines.append(f"- **Mean RSI value**: {mean_rsi:.2f}") report_lines.append(f"- **Overbought periods (>70)**: {overbought_count} ({overbought_count/len(valid_original)*100:.1f}%)") report_lines.append(f"- **Oversold periods (<30)**: {oversold_count} ({oversold_count/len(valid_original)*100:.1f}%)") # Price change analysis positive_changes = [max(0, change) for change in result['price_changes']] negative_changes = [abs(min(0, change)) for change in result['price_changes']] avg_gain = np.mean([change for change in positive_changes if change > 0]) if any(change > 0 for change in positive_changes) else 0 avg_loss = np.mean([change for change in negative_changes if change > 0]) if any(change > 0 for change in negative_changes) else 0 report_lines.append(f"- **Average gain**: {avg_gain:.6f}") report_lines.append(f"- **Average loss**: {avg_loss:.6f}") if avg_loss > 0: report_lines.append(f"- **Gain/Loss ratio**: {avg_gain/avg_loss:.3f}") # Percentile analysis percentiles = [1, 5, 25, 50, 75, 95, 99] perc_values = np.percentile(np.abs(valid_diff), percentiles) perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)]) report_lines.append(f"- **Percentiles**: {perc_str}") report_lines.append("") # Save report report_path = self.results_dir / "rsi_indicators_report.md" with open(report_path, 'w', encoding='utf-8') as f: f.write('\n'.join(report_lines)) print(f"Report saved to {report_path}") def run_tests(self): """Run all RSI tests.""" print("Starting RSI Comparison Tests...") # Load data self.load_data() # Run tests self.test_rsi() self.test_simple_rsi() # Generate outputs self.plot_all_comparisons() self.generate_report() print("\n✅ RSI tests completed!") if __name__ == "__main__": tester = RSIComparisonTest(sample_size=3000) tester.run_tests()