401 lines
17 KiB
Python
401 lines
17 KiB
Python
|
|
"""
|
|||
|
|
RSI Indicators Comparison Test
|
|||
|
|
|
|||
|
|
Focused testing for RSI and Simple RSI implementations.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pandas as pd
|
|||
|
|
import numpy as np
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
import matplotlib.dates as mdates
|
|||
|
|
from datetime import datetime
|
|||
|
|
import sys
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
# Add project root to path
|
|||
|
|
project_root = Path(__file__).parent.parent
|
|||
|
|
sys.path.insert(0, str(project_root))
|
|||
|
|
|
|||
|
|
# Import original indicators
|
|||
|
|
from cycles.IncStrategies.indicators import (
|
|||
|
|
RSIState as OriginalRSI,
|
|||
|
|
SimpleRSIState as OriginalSimpleRSI
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Import new indicators
|
|||
|
|
from IncrementalTrader.strategies.indicators import (
|
|||
|
|
RSIState as NewRSI,
|
|||
|
|
SimpleRSIState as NewSimpleRSI
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RSIComparisonTest:
|
|||
|
|
"""Test framework for comparing RSI implementations."""
|
|||
|
|
|
|||
|
|
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
|
|||
|
|
self.data_file = data_file
|
|||
|
|
self.sample_size = sample_size
|
|||
|
|
self.data = None
|
|||
|
|
self.results = {}
|
|||
|
|
|
|||
|
|
# Create results directory
|
|||
|
|
self.results_dir = Path("test/results/rsi_indicators")
|
|||
|
|
self.results_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
def load_data(self):
|
|||
|
|
"""Load and prepare the data for testing."""
|
|||
|
|
print(f"Loading data from {self.data_file}...")
|
|||
|
|
|
|||
|
|
df = pd.read_csv(self.data_file)
|
|||
|
|
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
|
|||
|
|
|
|||
|
|
if self.sample_size and len(df) > self.sample_size:
|
|||
|
|
df = df.tail(self.sample_size).reset_index(drop=True)
|
|||
|
|
|
|||
|
|
self.data = df
|
|||
|
|
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
|
|||
|
|
|
|||
|
|
def test_rsi(self, periods=[7, 14, 21, 28]):
|
|||
|
|
"""Test RSI implementations (Wilder's smoothing)."""
|
|||
|
|
print("\n=== Testing RSI (Wilder's Smoothing) ===")
|
|||
|
|
|
|||
|
|
for period in periods:
|
|||
|
|
print(f"Testing RSI({period})...")
|
|||
|
|
|
|||
|
|
# Initialize indicators
|
|||
|
|
original_rsi = OriginalRSI(period)
|
|||
|
|
new_rsi = NewRSI(period)
|
|||
|
|
|
|||
|
|
original_values = []
|
|||
|
|
new_values = []
|
|||
|
|
prices = []
|
|||
|
|
price_changes = []
|
|||
|
|
|
|||
|
|
# Process data
|
|||
|
|
prev_price = None
|
|||
|
|
for _, row in self.data.iterrows():
|
|||
|
|
price = row['Close']
|
|||
|
|
prices.append(price)
|
|||
|
|
|
|||
|
|
if prev_price is not None:
|
|||
|
|
price_changes.append(price - prev_price)
|
|||
|
|
else:
|
|||
|
|
price_changes.append(0)
|
|||
|
|
|
|||
|
|
original_rsi.update(price)
|
|||
|
|
new_rsi.update(price)
|
|||
|
|
|
|||
|
|
original_values.append(original_rsi.get_current_value() if original_rsi.is_warmed_up() else np.nan)
|
|||
|
|
new_values.append(new_rsi.get_current_value() if new_rsi.is_warmed_up() else np.nan)
|
|||
|
|
|
|||
|
|
prev_price = price
|
|||
|
|
|
|||
|
|
# Store results
|
|||
|
|
self.results[f'RSI_{period}'] = {
|
|||
|
|
'original': original_values,
|
|||
|
|
'new': new_values,
|
|||
|
|
'prices': prices,
|
|||
|
|
'price_changes': price_changes,
|
|||
|
|
'dates': self.data['datetime'].tolist(),
|
|||
|
|
'period': period
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# Calculate differences
|
|||
|
|
diff = np.array(new_values) - np.array(original_values)
|
|||
|
|
valid_diff = diff[~np.isnan(diff)]
|
|||
|
|
|
|||
|
|
if len(valid_diff) > 0:
|
|||
|
|
max_diff = np.max(np.abs(valid_diff))
|
|||
|
|
mean_diff = np.mean(np.abs(valid_diff))
|
|||
|
|
std_diff = np.std(valid_diff)
|
|||
|
|
|
|||
|
|
print(f" Max difference: {max_diff:.12f}")
|
|||
|
|
print(f" Mean difference: {mean_diff:.12f}")
|
|||
|
|
print(f" Std difference: {std_diff:.12f}")
|
|||
|
|
|
|||
|
|
# Status check
|
|||
|
|
if max_diff < 1e-10:
|
|||
|
|
print(f" ✅ PASSED: Mathematically equivalent")
|
|||
|
|
elif max_diff < 1e-6:
|
|||
|
|
print(f" ⚠️ WARNING: Small differences (floating point precision)")
|
|||
|
|
else:
|
|||
|
|
print(f" ❌ FAILED: Significant differences detected")
|
|||
|
|
else:
|
|||
|
|
print(f" ❌ ERROR: No valid data points")
|
|||
|
|
|
|||
|
|
def test_simple_rsi(self, periods=[7, 14, 21, 28]):
|
|||
|
|
"""Test Simple RSI implementations (Simple moving average)."""
|
|||
|
|
print("\n=== Testing Simple RSI (Simple Moving Average) ===")
|
|||
|
|
|
|||
|
|
for period in periods:
|
|||
|
|
print(f"Testing SimpleRSI({period})...")
|
|||
|
|
|
|||
|
|
# Initialize indicators
|
|||
|
|
original_rsi = OriginalSimpleRSI(period)
|
|||
|
|
new_rsi = NewSimpleRSI(period)
|
|||
|
|
|
|||
|
|
original_values = []
|
|||
|
|
new_values = []
|
|||
|
|
prices = []
|
|||
|
|
price_changes = []
|
|||
|
|
|
|||
|
|
# Process data
|
|||
|
|
prev_price = None
|
|||
|
|
for _, row in self.data.iterrows():
|
|||
|
|
price = row['Close']
|
|||
|
|
prices.append(price)
|
|||
|
|
|
|||
|
|
if prev_price is not None:
|
|||
|
|
price_changes.append(price - prev_price)
|
|||
|
|
else:
|
|||
|
|
price_changes.append(0)
|
|||
|
|
|
|||
|
|
original_rsi.update(price)
|
|||
|
|
new_rsi.update(price)
|
|||
|
|
|
|||
|
|
original_values.append(original_rsi.get_current_value() if original_rsi.is_warmed_up() else np.nan)
|
|||
|
|
new_values.append(new_rsi.get_current_value() if new_rsi.is_warmed_up() else np.nan)
|
|||
|
|
|
|||
|
|
prev_price = price
|
|||
|
|
|
|||
|
|
# Store results
|
|||
|
|
self.results[f'SimpleRSI_{period}'] = {
|
|||
|
|
'original': original_values,
|
|||
|
|
'new': new_values,
|
|||
|
|
'prices': prices,
|
|||
|
|
'price_changes': price_changes,
|
|||
|
|
'dates': self.data['datetime'].tolist(),
|
|||
|
|
'period': period
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# Calculate differences
|
|||
|
|
diff = np.array(new_values) - np.array(original_values)
|
|||
|
|
valid_diff = diff[~np.isnan(diff)]
|
|||
|
|
|
|||
|
|
if len(valid_diff) > 0:
|
|||
|
|
max_diff = np.max(np.abs(valid_diff))
|
|||
|
|
mean_diff = np.mean(np.abs(valid_diff))
|
|||
|
|
std_diff = np.std(valid_diff)
|
|||
|
|
|
|||
|
|
print(f" Max difference: {max_diff:.12f}")
|
|||
|
|
print(f" Mean difference: {mean_diff:.12f}")
|
|||
|
|
print(f" Std difference: {std_diff:.12f}")
|
|||
|
|
|
|||
|
|
# Status check
|
|||
|
|
if max_diff < 1e-10:
|
|||
|
|
print(f" ✅ PASSED: Mathematically equivalent")
|
|||
|
|
elif max_diff < 1e-6:
|
|||
|
|
print(f" ⚠️ WARNING: Small differences (floating point precision)")
|
|||
|
|
else:
|
|||
|
|
print(f" ❌ FAILED: Significant differences detected")
|
|||
|
|
else:
|
|||
|
|
print(f" ❌ ERROR: No valid data points")
|
|||
|
|
|
|||
|
|
def plot_comparison(self, indicator_name: str):
|
|||
|
|
"""Plot detailed comparison for a specific indicator."""
|
|||
|
|
if indicator_name not in self.results:
|
|||
|
|
print(f"No results found for {indicator_name}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
result = self.results[indicator_name]
|
|||
|
|
dates = pd.to_datetime(result['dates'])
|
|||
|
|
|
|||
|
|
# Create figure with subplots
|
|||
|
|
fig, axes = plt.subplots(4, 1, figsize=(15, 16))
|
|||
|
|
fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16)
|
|||
|
|
|
|||
|
|
# Plot 1: Price data
|
|||
|
|
ax1 = axes[0]
|
|||
|
|
ax1.plot(dates, result['prices'], label='Close Price', alpha=0.8, color='black', linewidth=1)
|
|||
|
|
ax1.set_title('Price Data')
|
|||
|
|
ax1.legend()
|
|||
|
|
ax1.grid(True, alpha=0.3)
|
|||
|
|
|
|||
|
|
# Plot 2: RSI comparison with levels
|
|||
|
|
ax2 = axes[1]
|
|||
|
|
ax2.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2, color='blue')
|
|||
|
|
ax2.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--', color='red')
|
|||
|
|
ax2.axhline(y=70, color='red', linestyle=':', alpha=0.7, label='Overbought (70)')
|
|||
|
|
ax2.axhline(y=30, color='green', linestyle=':', alpha=0.7, label='Oversold (30)')
|
|||
|
|
ax2.axhline(y=50, color='gray', linestyle='-', alpha=0.5, label='Midline (50)')
|
|||
|
|
ax2.set_title(f'{indicator_name} Values Comparison')
|
|||
|
|
ax2.set_ylim(0, 100)
|
|||
|
|
ax2.legend()
|
|||
|
|
ax2.grid(True, alpha=0.3)
|
|||
|
|
|
|||
|
|
# Plot 3: Price changes
|
|||
|
|
ax3 = axes[2]
|
|||
|
|
positive_changes = [max(0, change) for change in result['price_changes']]
|
|||
|
|
negative_changes = [abs(min(0, change)) for change in result['price_changes']]
|
|||
|
|
ax3.plot(dates, positive_changes, label='Positive Changes', alpha=0.7, color='green')
|
|||
|
|
ax3.plot(dates, negative_changes, label='Negative Changes', alpha=0.7, color='red')
|
|||
|
|
ax3.set_title('Price Changes (Gains and Losses)')
|
|||
|
|
ax3.legend()
|
|||
|
|
ax3.grid(True, alpha=0.3)
|
|||
|
|
|
|||
|
|
# Plot 4: Difference analysis
|
|||
|
|
ax4 = axes[3]
|
|||
|
|
diff = np.array(result['new']) - np.array(result['original'])
|
|||
|
|
ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1)
|
|||
|
|
ax4.set_title(f'{indicator_name} Difference (New - Original)')
|
|||
|
|
ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5)
|
|||
|
|
ax4.grid(True, alpha=0.3)
|
|||
|
|
|
|||
|
|
# Add statistics text
|
|||
|
|
valid_diff = diff[~np.isnan(diff)]
|
|||
|
|
if len(valid_diff) > 0:
|
|||
|
|
stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n'
|
|||
|
|
stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n'
|
|||
|
|
stats_text += f'Std: {np.std(valid_diff):.2e}'
|
|||
|
|
ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes,
|
|||
|
|
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
|||
|
|
|
|||
|
|
# Format x-axis
|
|||
|
|
for ax in axes:
|
|||
|
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
|||
|
|
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
|
|||
|
|
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
|||
|
|
|
|||
|
|
plt.tight_layout()
|
|||
|
|
|
|||
|
|
# Save plot
|
|||
|
|
plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png"
|
|||
|
|
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
|||
|
|
print(f"Plot saved to {plot_path}")
|
|||
|
|
|
|||
|
|
plt.show()
|
|||
|
|
|
|||
|
|
def plot_all_comparisons(self):
|
|||
|
|
"""Plot comparisons for all tested indicators."""
|
|||
|
|
print("\n=== Generating Detailed Comparison Plots ===")
|
|||
|
|
|
|||
|
|
for indicator_name in self.results.keys():
|
|||
|
|
print(f"Plotting {indicator_name}...")
|
|||
|
|
self.plot_comparison(indicator_name)
|
|||
|
|
plt.close('all')
|
|||
|
|
|
|||
|
|
def generate_report(self):
|
|||
|
|
"""Generate detailed report for RSI indicators."""
|
|||
|
|
print("\n=== Generating RSI Report ===")
|
|||
|
|
|
|||
|
|
report_lines = []
|
|||
|
|
report_lines.append("# RSI Indicators Comparison Report")
|
|||
|
|
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|||
|
|
report_lines.append(f"Data file: {self.data_file}")
|
|||
|
|
report_lines.append(f"Sample size: {len(self.data)} data points")
|
|||
|
|
report_lines.append("")
|
|||
|
|
|
|||
|
|
# Summary table
|
|||
|
|
report_lines.append("## Summary Table")
|
|||
|
|
report_lines.append("| Indicator | Period | Max Diff | Mean Diff | Status |")
|
|||
|
|
report_lines.append("|-----------|--------|----------|-----------|--------|")
|
|||
|
|
|
|||
|
|
for indicator_name, result in self.results.items():
|
|||
|
|
diff = np.array(result['new']) - np.array(result['original'])
|
|||
|
|
valid_diff = diff[~np.isnan(diff)]
|
|||
|
|
|
|||
|
|
if len(valid_diff) > 0:
|
|||
|
|
max_diff = np.max(np.abs(valid_diff))
|
|||
|
|
mean_diff = np.mean(np.abs(valid_diff))
|
|||
|
|
|
|||
|
|
if max_diff < 1e-10:
|
|||
|
|
status = "✅ PASSED"
|
|||
|
|
elif max_diff < 1e-6:
|
|||
|
|
status = "⚠️ WARNING"
|
|||
|
|
else:
|
|||
|
|
status = "❌ FAILED"
|
|||
|
|
|
|||
|
|
report_lines.append(f"| {indicator_name} | {result['period']} | {max_diff:.2e} | {mean_diff:.2e} | {status} |")
|
|||
|
|
else:
|
|||
|
|
report_lines.append(f"| {indicator_name} | {result['period']} | N/A | N/A | ❌ ERROR |")
|
|||
|
|
|
|||
|
|
report_lines.append("")
|
|||
|
|
|
|||
|
|
# Methodology explanation
|
|||
|
|
report_lines.append("## Methodology")
|
|||
|
|
report_lines.append("### RSI (Relative Strength Index)")
|
|||
|
|
report_lines.append("- Uses Wilder's smoothing for average gains and losses")
|
|||
|
|
report_lines.append("- Average Gain = (Previous Average Gain × (n-1) + Current Gain) / n")
|
|||
|
|
report_lines.append("- Average Loss = (Previous Average Loss × (n-1) + Current Loss) / n")
|
|||
|
|
report_lines.append("- RS = Average Gain / Average Loss")
|
|||
|
|
report_lines.append("- RSI = 100 - (100 / (1 + RS))")
|
|||
|
|
report_lines.append("")
|
|||
|
|
report_lines.append("### Simple RSI")
|
|||
|
|
report_lines.append("- Uses simple moving average for average gains and losses")
|
|||
|
|
report_lines.append("- More responsive to recent price changes than Wilder's method")
|
|||
|
|
report_lines.append("")
|
|||
|
|
|
|||
|
|
# Detailed analysis
|
|||
|
|
report_lines.append("## Detailed Analysis")
|
|||
|
|
|
|||
|
|
for indicator_name, result in self.results.items():
|
|||
|
|
report_lines.append(f"### {indicator_name}")
|
|||
|
|
|
|||
|
|
diff = np.array(result['new']) - np.array(result['original'])
|
|||
|
|
valid_diff = diff[~np.isnan(diff)]
|
|||
|
|
|
|||
|
|
if len(valid_diff) > 0:
|
|||
|
|
report_lines.append(f"- **Period**: {result['period']}")
|
|||
|
|
report_lines.append(f"- **Valid data points**: {len(valid_diff)}")
|
|||
|
|
report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}")
|
|||
|
|
report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}")
|
|||
|
|
report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}")
|
|||
|
|
|
|||
|
|
# RSI-specific metrics
|
|||
|
|
valid_original = np.array(result['original'])[~np.isnan(result['original'])]
|
|||
|
|
if len(valid_original) > 0:
|
|||
|
|
mean_rsi = np.mean(valid_original)
|
|||
|
|
overbought_count = np.sum(valid_original > 70)
|
|||
|
|
oversold_count = np.sum(valid_original < 30)
|
|||
|
|
|
|||
|
|
report_lines.append(f"- **Mean RSI value**: {mean_rsi:.2f}")
|
|||
|
|
report_lines.append(f"- **Overbought periods (>70)**: {overbought_count} ({overbought_count/len(valid_original)*100:.1f}%)")
|
|||
|
|
report_lines.append(f"- **Oversold periods (<30)**: {oversold_count} ({oversold_count/len(valid_original)*100:.1f}%)")
|
|||
|
|
|
|||
|
|
# Price change analysis
|
|||
|
|
positive_changes = [max(0, change) for change in result['price_changes']]
|
|||
|
|
negative_changes = [abs(min(0, change)) for change in result['price_changes']]
|
|||
|
|
avg_gain = np.mean([change for change in positive_changes if change > 0]) if any(change > 0 for change in positive_changes) else 0
|
|||
|
|
avg_loss = np.mean([change for change in negative_changes if change > 0]) if any(change > 0 for change in negative_changes) else 0
|
|||
|
|
|
|||
|
|
report_lines.append(f"- **Average gain**: {avg_gain:.6f}")
|
|||
|
|
report_lines.append(f"- **Average loss**: {avg_loss:.6f}")
|
|||
|
|
if avg_loss > 0:
|
|||
|
|
report_lines.append(f"- **Gain/Loss ratio**: {avg_gain/avg_loss:.3f}")
|
|||
|
|
|
|||
|
|
# Percentile analysis
|
|||
|
|
percentiles = [1, 5, 25, 50, 75, 95, 99]
|
|||
|
|
perc_values = np.percentile(np.abs(valid_diff), percentiles)
|
|||
|
|
perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)])
|
|||
|
|
report_lines.append(f"- **Percentiles**: {perc_str}")
|
|||
|
|
|
|||
|
|
report_lines.append("")
|
|||
|
|
|
|||
|
|
# Save report
|
|||
|
|
report_path = self.results_dir / "rsi_indicators_report.md"
|
|||
|
|
with open(report_path, 'w', encoding='utf-8') as f:
|
|||
|
|
f.write('\n'.join(report_lines))
|
|||
|
|
|
|||
|
|
print(f"Report saved to {report_path}")
|
|||
|
|
|
|||
|
|
def run_tests(self):
|
|||
|
|
"""Run all RSI tests."""
|
|||
|
|
print("Starting RSI Comparison Tests...")
|
|||
|
|
|
|||
|
|
# Load data
|
|||
|
|
self.load_data()
|
|||
|
|
|
|||
|
|
# Run tests
|
|||
|
|
self.test_rsi()
|
|||
|
|
self.test_simple_rsi()
|
|||
|
|
|
|||
|
|
# Generate outputs
|
|||
|
|
self.plot_all_comparisons()
|
|||
|
|
self.generate_report()
|
|||
|
|
|
|||
|
|
print("\n✅ RSI tests completed!")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
tester = RSIComparisonTest(sample_size=3000)
|
|||
|
|
tester.run_tests()
|