Cycles/test/indicators/test_atr_indicators.py

395 lines
16 KiB
Python

"""
ATR Indicators Comparison Test
Focused testing for ATR and Simple ATR implementations.
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# Import original indicators
from cycles.IncStrategies.indicators import (
ATRState as OriginalATR,
SimpleATRState as OriginalSimpleATR
)
# Import new indicators
from IncrementalTrader.strategies.indicators import (
ATRState as NewATR,
SimpleATRState as NewSimpleATR
)
class ATRComparisonTest:
"""Test framework for comparing ATR implementations."""
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
self.data_file = data_file
self.sample_size = sample_size
self.data = None
self.results = {}
# Create results directory
self.results_dir = Path("test/results/atr_indicators")
self.results_dir.mkdir(parents=True, exist_ok=True)
def load_data(self):
"""Load and prepare the data for testing."""
print(f"Loading data from {self.data_file}...")
df = pd.read_csv(self.data_file)
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
if self.sample_size and len(df) > self.sample_size:
df = df.tail(self.sample_size).reset_index(drop=True)
self.data = df
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
def test_atr(self, periods=[7, 14, 21, 28]):
"""Test ATR implementations."""
print("\n=== Testing ATR (Wilder's Smoothing) ===")
for period in periods:
print(f"Testing ATR({period})...")
# Initialize indicators
original_atr = OriginalATR(period)
new_atr = NewATR(period)
original_values = []
new_values = []
true_ranges = []
# Process data
for _, row in self.data.iterrows():
high, low, close = row['High'], row['Low'], row['Close']
# Create OHLC dictionary for both indicators
ohlc_data = {
'open': row['Open'],
'high': high,
'low': low,
'close': close
}
original_atr.update(ohlc_data)
new_atr.update(ohlc_data)
original_values.append(original_atr.get_current_value() if original_atr.is_warmed_up() else np.nan)
new_values.append(new_atr.get_current_value() if new_atr.is_warmed_up() else np.nan)
# Calculate true range for reference
if len(self.data) > 1:
prev_close = self.data.iloc[max(0, len(true_ranges)-1)]['Close'] if true_ranges else close
tr = max(high - low, abs(high - prev_close), abs(low - prev_close))
true_ranges.append(tr)
else:
true_ranges.append(high - low)
# Store results
self.results[f'ATR_{period}'] = {
'original': original_values,
'new': new_values,
'true_ranges': true_ranges,
'highs': self.data['High'].tolist(),
'lows': self.data['Low'].tolist(),
'closes': self.data['Close'].tolist(),
'dates': self.data['datetime'].tolist(),
'period': period
}
# Calculate differences
diff = np.array(new_values) - np.array(original_values)
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
max_diff = np.max(np.abs(valid_diff))
mean_diff = np.mean(np.abs(valid_diff))
std_diff = np.std(valid_diff)
print(f" Max difference: {max_diff:.12f}")
print(f" Mean difference: {mean_diff:.12f}")
print(f" Std difference: {std_diff:.12f}")
# Status check
if max_diff < 1e-10:
print(f" ✅ PASSED: Mathematically equivalent")
elif max_diff < 1e-6:
print(f" ⚠️ WARNING: Small differences (floating point precision)")
else:
print(f" ❌ FAILED: Significant differences detected")
else:
print(f" ❌ ERROR: No valid data points")
def test_simple_atr(self, periods=[7, 14, 21, 28]):
"""Test Simple ATR implementations."""
print("\n=== Testing Simple ATR (Simple Moving Average) ===")
for period in periods:
print(f"Testing SimpleATR({period})...")
# Initialize indicators
original_atr = OriginalSimpleATR(period)
new_atr = NewSimpleATR(period)
original_values = []
new_values = []
true_ranges = []
# Process data
for _, row in self.data.iterrows():
high, low, close = row['High'], row['Low'], row['Close']
# Create OHLC dictionary for both indicators
ohlc_data = {
'open': row['Open'],
'high': high,
'low': low,
'close': close
}
original_atr.update(ohlc_data)
new_atr.update(ohlc_data)
original_values.append(original_atr.get_current_value() if original_atr.is_warmed_up() else np.nan)
new_values.append(new_atr.get_current_value() if new_atr.is_warmed_up() else np.nan)
# Calculate true range for reference
if len(self.data) > 1:
prev_close = self.data.iloc[max(0, len(true_ranges)-1)]['Close'] if true_ranges else close
tr = max(high - low, abs(high - prev_close), abs(low - prev_close))
true_ranges.append(tr)
else:
true_ranges.append(high - low)
# Store results
self.results[f'SimpleATR_{period}'] = {
'original': original_values,
'new': new_values,
'true_ranges': true_ranges,
'highs': self.data['High'].tolist(),
'lows': self.data['Low'].tolist(),
'closes': self.data['Close'].tolist(),
'dates': self.data['datetime'].tolist(),
'period': period
}
# Calculate differences
diff = np.array(new_values) - np.array(original_values)
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
max_diff = np.max(np.abs(valid_diff))
mean_diff = np.mean(np.abs(valid_diff))
std_diff = np.std(valid_diff)
print(f" Max difference: {max_diff:.12f}")
print(f" Mean difference: {mean_diff:.12f}")
print(f" Std difference: {std_diff:.12f}")
# Status check
if max_diff < 1e-10:
print(f" ✅ PASSED: Mathematically equivalent")
elif max_diff < 1e-6:
print(f" ⚠️ WARNING: Small differences (floating point precision)")
else:
print(f" ❌ FAILED: Significant differences detected")
else:
print(f" ❌ ERROR: No valid data points")
def plot_comparison(self, indicator_name: str):
"""Plot detailed comparison for a specific indicator."""
if indicator_name not in self.results:
print(f"No results found for {indicator_name}")
return
result = self.results[indicator_name]
dates = pd.to_datetime(result['dates'])
# Create figure with subplots
fig, axes = plt.subplots(4, 1, figsize=(15, 16))
fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16)
# Plot 1: OHLC data
ax1 = axes[0]
ax1.plot(dates, result['highs'], label='High', alpha=0.6, color='green')
ax1.plot(dates, result['lows'], label='Low', alpha=0.6, color='red')
ax1.plot(dates, result['closes'], label='Close', alpha=0.8, color='blue')
ax1.set_title('OHLC Data')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot 2: True Range
ax2 = axes[1]
ax2.plot(dates, result['true_ranges'], label='True Range', alpha=0.7, color='orange')
ax2.set_title('True Range Values')
ax2.legend()
ax2.grid(True, alpha=0.3)
# Plot 3: ATR comparison
ax3 = axes[2]
ax3.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2)
ax3.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--')
ax3.set_title(f'{indicator_name} Values Comparison')
ax3.legend()
ax3.grid(True, alpha=0.3)
# Plot 4: Difference analysis
ax4 = axes[3]
diff = np.array(result['new']) - np.array(result['original'])
ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1)
ax4.set_title(f'{indicator_name} Difference (New - Original)')
ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5)
ax4.grid(True, alpha=0.3)
# Add statistics text
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n'
stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n'
stats_text += f'Std: {np.std(valid_diff):.2e}'
ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
# Format x-axis
for ax in axes:
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
plt.tight_layout()
# Save plot
plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"Plot saved to {plot_path}")
plt.show()
def plot_all_comparisons(self):
"""Plot comparisons for all tested indicators."""
print("\n=== Generating Detailed Comparison Plots ===")
for indicator_name in self.results.keys():
print(f"Plotting {indicator_name}...")
self.plot_comparison(indicator_name)
plt.close('all')
def generate_report(self):
"""Generate detailed report for ATR indicators."""
print("\n=== Generating ATR Report ===")
report_lines = []
report_lines.append("# ATR Indicators Comparison Report")
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
report_lines.append(f"Data file: {self.data_file}")
report_lines.append(f"Sample size: {len(self.data)} data points")
report_lines.append("")
# Summary table
report_lines.append("## Summary Table")
report_lines.append("| Indicator | Period | Max Diff | Mean Diff | Status |")
report_lines.append("|-----------|--------|----------|-----------|--------|")
for indicator_name, result in self.results.items():
diff = np.array(result['new']) - np.array(result['original'])
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
max_diff = np.max(np.abs(valid_diff))
mean_diff = np.mean(np.abs(valid_diff))
if max_diff < 1e-10:
status = "✅ PASSED"
elif max_diff < 1e-6:
status = "⚠️ WARNING"
else:
status = "❌ FAILED"
report_lines.append(f"| {indicator_name} | {result['period']} | {max_diff:.2e} | {mean_diff:.2e} | {status} |")
else:
report_lines.append(f"| {indicator_name} | {result['period']} | N/A | N/A | ❌ ERROR |")
report_lines.append("")
# Methodology explanation
report_lines.append("## Methodology")
report_lines.append("### ATR (Average True Range)")
report_lines.append("- Uses Wilder's smoothing method: ATR = (Previous ATR * (n-1) + Current TR) / n")
report_lines.append("- True Range = max(High-Low, |High-PrevClose|, |Low-PrevClose|)")
report_lines.append("")
report_lines.append("### Simple ATR")
report_lines.append("- Uses simple moving average of True Range values")
report_lines.append("- More responsive to recent changes than Wilder's method")
report_lines.append("")
# Detailed analysis
report_lines.append("## Detailed Analysis")
for indicator_name, result in self.results.items():
report_lines.append(f"### {indicator_name}")
diff = np.array(result['new']) - np.array(result['original'])
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
report_lines.append(f"- **Period**: {result['period']}")
report_lines.append(f"- **Valid data points**: {len(valid_diff)}")
report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}")
report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}")
report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}")
# ATR-specific metrics
valid_original = np.array(result['original'])[~np.isnan(result['original'])]
if len(valid_original) > 0:
mean_atr = np.mean(valid_original)
relative_error = np.mean(np.abs(valid_diff)) / mean_atr * 100
report_lines.append(f"- **Mean ATR value**: {mean_atr:.6f}")
report_lines.append(f"- **Relative error**: {relative_error:.2e}%")
# Percentile analysis
percentiles = [1, 5, 25, 50, 75, 95, 99]
perc_values = np.percentile(np.abs(valid_diff), percentiles)
perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)])
report_lines.append(f"- **Percentiles**: {perc_str}")
report_lines.append("")
# Save report
report_path = self.results_dir / "atr_indicators_report.md"
with open(report_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(report_lines))
print(f"Report saved to {report_path}")
def run_tests(self):
"""Run all ATR tests."""
print("Starting ATR Comparison Tests...")
# Load data
self.load_data()
# Run tests
self.test_atr()
self.test_simple_atr()
# Generate outputs
self.plot_all_comparisons()
self.generate_report()
print("\n✅ ATR tests completed!")
if __name__ == "__main__":
tester = ATRComparisonTest(sample_size=3000)
tester.run_tests()