""" ATR Indicators Comparison Test Focused testing for ATR and Simple ATR implementations. """ import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib.dates as mdates from datetime import datetime import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) # Import original indicators from cycles.IncStrategies.indicators import ( ATRState as OriginalATR, SimpleATRState as OriginalSimpleATR ) # Import new indicators from IncrementalTrader.strategies.indicators import ( ATRState as NewATR, SimpleATRState as NewSimpleATR ) class ATRComparisonTest: """Test framework for comparing ATR implementations.""" def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000): self.data_file = data_file self.sample_size = sample_size self.data = None self.results = {} # Create results directory self.results_dir = Path("test/results/atr_indicators") self.results_dir.mkdir(parents=True, exist_ok=True) def load_data(self): """Load and prepare the data for testing.""" print(f"Loading data from {self.data_file}...") df = pd.read_csv(self.data_file) df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s') if self.sample_size and len(df) > self.sample_size: df = df.tail(self.sample_size).reset_index(drop=True) self.data = df print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}") def test_atr(self, periods=[7, 14, 21, 28]): """Test ATR implementations.""" print("\n=== Testing ATR (Wilder's Smoothing) ===") for period in periods: print(f"Testing ATR({period})...") # Initialize indicators original_atr = OriginalATR(period) new_atr = NewATR(period) original_values = [] new_values = [] true_ranges = [] # Process data for _, row in self.data.iterrows(): high, low, close = row['High'], row['Low'], row['Close'] # Create OHLC dictionary for both indicators ohlc_data = { 'open': row['Open'], 'high': high, 'low': low, 'close': close } original_atr.update(ohlc_data) new_atr.update(ohlc_data) original_values.append(original_atr.get_current_value() if original_atr.is_warmed_up() else np.nan) new_values.append(new_atr.get_current_value() if new_atr.is_warmed_up() else np.nan) # Calculate true range for reference if len(self.data) > 1: prev_close = self.data.iloc[max(0, len(true_ranges)-1)]['Close'] if true_ranges else close tr = max(high - low, abs(high - prev_close), abs(low - prev_close)) true_ranges.append(tr) else: true_ranges.append(high - low) # Store results self.results[f'ATR_{period}'] = { 'original': original_values, 'new': new_values, 'true_ranges': true_ranges, 'highs': self.data['High'].tolist(), 'lows': self.data['Low'].tolist(), 'closes': self.data['Close'].tolist(), 'dates': self.data['datetime'].tolist(), 'period': period } # Calculate differences diff = np.array(new_values) - np.array(original_values) valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: max_diff = np.max(np.abs(valid_diff)) mean_diff = np.mean(np.abs(valid_diff)) std_diff = np.std(valid_diff) print(f" Max difference: {max_diff:.12f}") print(f" Mean difference: {mean_diff:.12f}") print(f" Std difference: {std_diff:.12f}") # Status check if max_diff < 1e-10: print(f" ✅ PASSED: Mathematically equivalent") elif max_diff < 1e-6: print(f" ⚠️ WARNING: Small differences (floating point precision)") else: print(f" ❌ FAILED: Significant differences detected") else: print(f" ❌ ERROR: No valid data points") def test_simple_atr(self, periods=[7, 14, 21, 28]): """Test Simple ATR implementations.""" print("\n=== Testing Simple ATR (Simple Moving Average) ===") for period in periods: print(f"Testing SimpleATR({period})...") # Initialize indicators original_atr = OriginalSimpleATR(period) new_atr = NewSimpleATR(period) original_values = [] new_values = [] true_ranges = [] # Process data for _, row in self.data.iterrows(): high, low, close = row['High'], row['Low'], row['Close'] # Create OHLC dictionary for both indicators ohlc_data = { 'open': row['Open'], 'high': high, 'low': low, 'close': close } original_atr.update(ohlc_data) new_atr.update(ohlc_data) original_values.append(original_atr.get_current_value() if original_atr.is_warmed_up() else np.nan) new_values.append(new_atr.get_current_value() if new_atr.is_warmed_up() else np.nan) # Calculate true range for reference if len(self.data) > 1: prev_close = self.data.iloc[max(0, len(true_ranges)-1)]['Close'] if true_ranges else close tr = max(high - low, abs(high - prev_close), abs(low - prev_close)) true_ranges.append(tr) else: true_ranges.append(high - low) # Store results self.results[f'SimpleATR_{period}'] = { 'original': original_values, 'new': new_values, 'true_ranges': true_ranges, 'highs': self.data['High'].tolist(), 'lows': self.data['Low'].tolist(), 'closes': self.data['Close'].tolist(), 'dates': self.data['datetime'].tolist(), 'period': period } # Calculate differences diff = np.array(new_values) - np.array(original_values) valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: max_diff = np.max(np.abs(valid_diff)) mean_diff = np.mean(np.abs(valid_diff)) std_diff = np.std(valid_diff) print(f" Max difference: {max_diff:.12f}") print(f" Mean difference: {mean_diff:.12f}") print(f" Std difference: {std_diff:.12f}") # Status check if max_diff < 1e-10: print(f" ✅ PASSED: Mathematically equivalent") elif max_diff < 1e-6: print(f" ⚠️ WARNING: Small differences (floating point precision)") else: print(f" ❌ FAILED: Significant differences detected") else: print(f" ❌ ERROR: No valid data points") def plot_comparison(self, indicator_name: str): """Plot detailed comparison for a specific indicator.""" if indicator_name not in self.results: print(f"No results found for {indicator_name}") return result = self.results[indicator_name] dates = pd.to_datetime(result['dates']) # Create figure with subplots fig, axes = plt.subplots(4, 1, figsize=(15, 16)) fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16) # Plot 1: OHLC data ax1 = axes[0] ax1.plot(dates, result['highs'], label='High', alpha=0.6, color='green') ax1.plot(dates, result['lows'], label='Low', alpha=0.6, color='red') ax1.plot(dates, result['closes'], label='Close', alpha=0.8, color='blue') ax1.set_title('OHLC Data') ax1.legend() ax1.grid(True, alpha=0.3) # Plot 2: True Range ax2 = axes[1] ax2.plot(dates, result['true_ranges'], label='True Range', alpha=0.7, color='orange') ax2.set_title('True Range Values') ax2.legend() ax2.grid(True, alpha=0.3) # Plot 3: ATR comparison ax3 = axes[2] ax3.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2) ax3.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--') ax3.set_title(f'{indicator_name} Values Comparison') ax3.legend() ax3.grid(True, alpha=0.3) # Plot 4: Difference analysis ax4 = axes[3] diff = np.array(result['new']) - np.array(result['original']) ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1) ax4.set_title(f'{indicator_name} Difference (New - Original)') ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5) ax4.grid(True, alpha=0.3) # Add statistics text valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n' stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n' stats_text += f'Std: {np.std(valid_diff):.2e}' ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8)) # Format x-axis for ax in axes: ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10))) plt.setp(ax.xaxis.get_majorticklabels(), rotation=45) plt.tight_layout() # Save plot plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png" plt.savefig(plot_path, dpi=300, bbox_inches='tight') print(f"Plot saved to {plot_path}") plt.show() def plot_all_comparisons(self): """Plot comparisons for all tested indicators.""" print("\n=== Generating Detailed Comparison Plots ===") for indicator_name in self.results.keys(): print(f"Plotting {indicator_name}...") self.plot_comparison(indicator_name) plt.close('all') def generate_report(self): """Generate detailed report for ATR indicators.""" print("\n=== Generating ATR Report ===") report_lines = [] report_lines.append("# ATR Indicators Comparison Report") report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") report_lines.append(f"Data file: {self.data_file}") report_lines.append(f"Sample size: {len(self.data)} data points") report_lines.append("") # Summary table report_lines.append("## Summary Table") report_lines.append("| Indicator | Period | Max Diff | Mean Diff | Status |") report_lines.append("|-----------|--------|----------|-----------|--------|") for indicator_name, result in self.results.items(): diff = np.array(result['new']) - np.array(result['original']) valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: max_diff = np.max(np.abs(valid_diff)) mean_diff = np.mean(np.abs(valid_diff)) if max_diff < 1e-10: status = "✅ PASSED" elif max_diff < 1e-6: status = "⚠️ WARNING" else: status = "❌ FAILED" report_lines.append(f"| {indicator_name} | {result['period']} | {max_diff:.2e} | {mean_diff:.2e} | {status} |") else: report_lines.append(f"| {indicator_name} | {result['period']} | N/A | N/A | ❌ ERROR |") report_lines.append("") # Methodology explanation report_lines.append("## Methodology") report_lines.append("### ATR (Average True Range)") report_lines.append("- Uses Wilder's smoothing method: ATR = (Previous ATR * (n-1) + Current TR) / n") report_lines.append("- True Range = max(High-Low, |High-PrevClose|, |Low-PrevClose|)") report_lines.append("") report_lines.append("### Simple ATR") report_lines.append("- Uses simple moving average of True Range values") report_lines.append("- More responsive to recent changes than Wilder's method") report_lines.append("") # Detailed analysis report_lines.append("## Detailed Analysis") for indicator_name, result in self.results.items(): report_lines.append(f"### {indicator_name}") diff = np.array(result['new']) - np.array(result['original']) valid_diff = diff[~np.isnan(diff)] if len(valid_diff) > 0: report_lines.append(f"- **Period**: {result['period']}") report_lines.append(f"- **Valid data points**: {len(valid_diff)}") report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}") report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}") report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}") # ATR-specific metrics valid_original = np.array(result['original'])[~np.isnan(result['original'])] if len(valid_original) > 0: mean_atr = np.mean(valid_original) relative_error = np.mean(np.abs(valid_diff)) / mean_atr * 100 report_lines.append(f"- **Mean ATR value**: {mean_atr:.6f}") report_lines.append(f"- **Relative error**: {relative_error:.2e}%") # Percentile analysis percentiles = [1, 5, 25, 50, 75, 95, 99] perc_values = np.percentile(np.abs(valid_diff), percentiles) perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)]) report_lines.append(f"- **Percentiles**: {perc_str}") report_lines.append("") # Save report report_path = self.results_dir / "atr_indicators_report.md" with open(report_path, 'w', encoding='utf-8') as f: f.write('\n'.join(report_lines)) print(f"Report saved to {report_path}") def run_tests(self): """Run all ATR tests.""" print("Starting ATR Comparison Tests...") # Load data self.load_data() # Run tests self.test_atr() self.test_simple_atr() # Generate outputs self.plot_all_comparisons() self.generate_report() print("\n✅ ATR tests completed!") if __name__ == "__main__": tester = ATRComparisonTest(sample_size=3000) tester.run_tests()