""" Supertrend Indicators Comparison Test Focused testing for Supertrend implementations. """ import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib.dates as mdates from datetime import datetime import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) # Import original indicators from cycles.IncStrategies.indicators import ( SupertrendState as OriginalSupertrend ) # Import new indicators from IncrementalTrader.strategies.indicators import ( SupertrendState as NewSupertrend ) class SupertrendComparisonTest: """Test framework for comparing Supertrend implementations.""" def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000): self.data_file = data_file self.sample_size = sample_size self.data = None self.results = {} # Create results directory self.results_dir = Path("test/results/supertrend_indicators") self.results_dir.mkdir(parents=True, exist_ok=True) def load_data(self): """Load and prepare the data for testing.""" print(f"Loading data from {self.data_file}...") df = pd.read_csv(self.data_file) df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s') if self.sample_size and len(df) > self.sample_size: df = df.tail(self.sample_size).reset_index(drop=True) self.data = df print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}") def test_supertrend(self, periods=[7, 10, 14, 21], multipliers=[2.0, 3.0, 4.0]): """Test Supertrend implementations.""" print("\n=== Testing Supertrend ===") for period in periods: for multiplier in multipliers: print(f"Testing Supertrend({period}, {multiplier})...") # Initialize indicators original_st = OriginalSupertrend(period, multiplier) new_st = NewSupertrend(period, multiplier) original_values = [] new_values = [] original_trends = [] new_trends = [] original_signals = [] new_signals = [] # Process data for _, row in self.data.iterrows(): high, low, close = row['High'], row['Low'], row['Close'] # Create OHLC dictionary for both indicators ohlc_data = { 'open': row['Open'], 'high': high, 'low': low, 'close': close } original_st.update(ohlc_data) new_st.update(ohlc_data) original_values.append(original_st.get_current_value()['supertrend'] if original_st.is_warmed_up() else np.nan) new_values.append(new_st.get_current_value()['supertrend'] if new_st.is_warmed_up() else np.nan) original_trends.append(original_st.get_current_value()['trend'] if original_st.is_warmed_up() else 0) new_trends.append(new_st.get_current_value()['trend'] if new_st.is_warmed_up() else 0) # Check for trend changes (signals) if len(original_trends) > 1: original_signals.append(1 if original_trends[-1] != original_trends[-2] else 0) new_signals.append(1 if new_trends[-1] != new_trends[-2] else 0) else: original_signals.append(0) new_signals.append(0) # Store results key = f'Supertrend_{period}_{multiplier}' self.results[key] = { 'original': original_values, 'new': new_values, 'original_trend': original_trends, 'new_trend': new_trends, 'original_signals': original_signals, 'new_signals': new_signals, 'highs': self.data['High'].tolist(), 'lows': self.data['Low'].tolist(), 'closes': self.data['Close'].tolist(), 'dates': self.data['datetime'].tolist(), 'period': period, 'multiplier': multiplier } # Calculate differences diff = np.array(new_values) - np.array(original_values) valid_diff = diff[~np.isnan(diff)] # Trend comparison trend_diff = np.array(new_trends) - np.array(original_trends) trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100 # Signal comparison signal_diff = np.array(new_signals) - np.array(original_signals) signal_matches = np.sum(signal_diff == 0) / len(signal_diff) * 100 if len(valid_diff) > 0: max_diff = np.max(np.abs(valid_diff)) mean_diff = np.mean(np.abs(valid_diff)) std_diff = np.std(valid_diff) print(f" Max difference: {max_diff:.12f}") print(f" Mean difference: {mean_diff:.12f}") print(f" Std difference: {std_diff:.12f}") print(f" Trend match: {trend_matches:.2f}%") print(f" Signal match: {signal_matches:.2f}%") # Status check if max_diff < 1e-10 and trend_matches == 100: print(f" ✅ PASSED: Mathematically equivalent") elif max_diff < 1e-6 and trend_matches >= 99: print(f" ⚠️ WARNING: Small differences (floating point precision)") else: print(f" ❌ FAILED: Significant differences detected") else: print(f" ❌ ERROR: No valid data points") def plot_comparison(self, indicator_name: str): """Plot detailed comparison for a specific indicator.""" if indicator_name not in self.results: print(f"No results found for {indicator_name}") return result = self.results[indicator_name] dates = pd.to_datetime(result['dates']) # Create figure with subplots fig, axes = plt.subplots(5, 1, figsize=(15, 20)) fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16) # Plot 1: Price and Supertrend ax1 = axes[0] ax1.plot(dates, result['closes'], label='Close Price', alpha=0.7, color='black', linewidth=1) ax1.plot(dates, result['original'], label='Original Supertrend', alpha=0.8, linewidth=2, color='blue') ax1.plot(dates, result['new'], label='New Supertrend', alpha=0.8, linewidth=2, linestyle='--', color='red') ax1.set_title(f'{indicator_name} vs Price') ax1.legend() ax1.grid(True, alpha=0.3) # Plot 2: Trend comparison ax2 = axes[1] ax2.plot(dates, result['original_trend'], label='Original Trend', alpha=0.8, linewidth=2, color='blue') ax2.plot(dates, result['new_trend'], label='New Trend', alpha=0.8, linewidth=2, linestyle='--', color='red') ax2.set_title(f'{indicator_name} Trend Direction (1=Up, -1=Down)') ax2.legend() ax2.grid(True, alpha=0.3) ax2.set_ylim(-1.5, 1.5) # Plot 3: Supertrend values comparison ax3 = axes[2] ax3.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2) ax3.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--') ax3.set_title(f'{indicator_name} Values Comparison') ax3.legend() ax3.grid(True, alpha=0.3) # Plot 4: Difference analysis ax4 = axes[3] diff = np.array(result['new']) - np.array(result['original']) ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1) ax4.set_title(f'{indicator_name} Difference (New - Original)') ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5) ax4.grid(True, alpha=0.3) # Plot 5: Signal comparison ax5 = axes[4] signal_dates = dates[1:] # Signals start from second data point ax5.scatter(signal_dates, np.array(result['original_signals'][1:]), label='Original Signals', alpha=0.7, color='blue', s=30) ax5.scatter(signal_dates, np.array(result['new_signals'][1:]) + 0.1, label='New Signals', alpha=0.7, color='red', s=30, marker='^') ax5.set_title(f'{indicator_name} Trend Change Signals') ax5.legend() ax5.grid(True, alpha=0.3) ax5.set_ylim(-0.2, 1.3) # Add statistics text valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: trend_diff = np.array(result['new_trend']) - np.array(result['original_trend']) trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100 stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n' stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n' stats_text += f'Trend Match: {trend_matches:.1f}%' ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8)) # Format x-axis for ax in axes: ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10))) plt.setp(ax.xaxis.get_majorticklabels(), rotation=45) plt.tight_layout() # Save plot plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png" plt.savefig(plot_path, dpi=300, bbox_inches='tight') print(f"Plot saved to {plot_path}") plt.show() def plot_all_comparisons(self): """Plot comparisons for all tested indicators.""" print("\n=== Generating Detailed Comparison Plots ===") for indicator_name in self.results.keys(): print(f"Plotting {indicator_name}...") self.plot_comparison(indicator_name) plt.close('all') def generate_report(self): """Generate detailed report for Supertrend indicators.""" print("\n=== Generating Supertrend Report ===") report_lines = [] report_lines.append("# Supertrend Indicators Comparison Report") report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") report_lines.append(f"Data file: {self.data_file}") report_lines.append(f"Sample size: {len(self.data)} data points") report_lines.append("") # Summary table report_lines.append("## Summary Table") report_lines.append("| Indicator | Period | Multiplier | Max Diff | Mean Diff | Trend Match | Status |") report_lines.append("|-----------|--------|------------|----------|-----------|-------------|--------|") for indicator_name, result in self.results.items(): diff = np.array(result['new']) - np.array(result['original']) valid_diff = diff[~np.isnan(diff)] trend_diff = np.array(result['new_trend']) - np.array(result['original_trend']) trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100 if len(valid_diff) > 0: max_diff = np.max(np.abs(valid_diff)) mean_diff = np.mean(np.abs(valid_diff)) if max_diff < 1e-10 and trend_matches == 100: status = "✅ PASSED" elif max_diff < 1e-6 and trend_matches >= 99: status = "⚠️ WARNING" else: status = "❌ FAILED" report_lines.append(f"| {indicator_name} | {result['period']} | {result['multiplier']} | " f"{max_diff:.2e} | {mean_diff:.2e} | {trend_matches:.1f}% | {status} |") else: report_lines.append(f"| {indicator_name} | {result['period']} | {result['multiplier']} | " f"N/A | N/A | N/A | ❌ ERROR |") report_lines.append("") # Methodology explanation report_lines.append("## Methodology") report_lines.append("### Supertrend Calculation") report_lines.append("1. **Basic Upper Band**: (High + Low) / 2 + (Multiplier × ATR)") report_lines.append("2. **Basic Lower Band**: (High + Low) / 2 - (Multiplier × ATR)") report_lines.append("3. **Final Upper Band**: min(Basic Upper Band, Previous Final Upper Band if Close[1] <= Previous Final Upper Band)") report_lines.append("4. **Final Lower Band**: max(Basic Lower Band, Previous Final Lower Band if Close[1] >= Previous Final Lower Band)") report_lines.append("5. **Supertrend**: Final Lower Band if trend is up, Final Upper Band if trend is down") report_lines.append("6. **Trend**: Up if Close > Previous Supertrend, Down if Close <= Previous Supertrend") report_lines.append("") # Detailed analysis report_lines.append("## Detailed Analysis") for indicator_name, result in self.results.items(): report_lines.append(f"### {indicator_name}") diff = np.array(result['new']) - np.array(result['original']) valid_diff = diff[~np.isnan(diff)] trend_diff = np.array(result['new_trend']) - np.array(result['original_trend']) trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100 signal_diff = np.array(result['new_signals']) - np.array(result['original_signals']) signal_matches = np.sum(signal_diff == 0) / len(signal_diff) * 100 if len(valid_diff) > 0: report_lines.append(f"- **Period**: {result['period']}") report_lines.append(f"- **Multiplier**: {result['multiplier']}") report_lines.append(f"- **Valid data points**: {len(valid_diff)}") report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}") report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}") report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}") report_lines.append(f"- **Trend direction match**: {trend_matches:.2f}%") report_lines.append(f"- **Signal timing match**: {signal_matches:.2f}%") # Supertrend-specific metrics valid_original = np.array(result['original'])[~np.isnan(result['original'])] if len(valid_original) > 0: mean_st = np.mean(valid_original) relative_error = np.mean(np.abs(valid_diff)) / mean_st * 100 report_lines.append(f"- **Mean Supertrend value**: {mean_st:.6f}") report_lines.append(f"- **Relative error**: {relative_error:.2e}%") # Count trend changes original_changes = np.sum(np.array(result['original_signals'])) new_changes = np.sum(np.array(result['new_signals'])) report_lines.append(f"- **Original trend changes**: {original_changes}") report_lines.append(f"- **New trend changes**: {new_changes}") # Percentile analysis percentiles = [1, 5, 25, 50, 75, 95, 99] perc_values = np.percentile(np.abs(valid_diff), percentiles) perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)]) report_lines.append(f"- **Percentiles**: {perc_str}") report_lines.append("") # Save report report_path = self.results_dir / "supertrend_indicators_report.md" with open(report_path, 'w', encoding='utf-8') as f: f.write('\n'.join(report_lines)) print(f"Report saved to {report_path}") def run_tests(self): """Run all Supertrend tests.""" print("Starting Supertrend Comparison Tests...") # Load data self.load_data() # Run tests self.test_supertrend() # Generate outputs self.plot_all_comparisons() self.generate_report() print("\n✅ Supertrend tests completed!") if __name__ == "__main__": tester = SupertrendComparisonTest(sample_size=3000) tester.run_tests()