Cycles/test/indicators/test_supertrend_indicators.py

374 lines
17 KiB
Python
Raw Normal View History

"""
Supertrend Indicators Comparison Test
Focused testing for Supertrend implementations.
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# Import original indicators
from cycles.IncStrategies.indicators import (
SupertrendState as OriginalSupertrend
)
# Import new indicators
from IncrementalTrader.strategies.indicators import (
SupertrendState as NewSupertrend
)
class SupertrendComparisonTest:
"""Test framework for comparing Supertrend implementations."""
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
self.data_file = data_file
self.sample_size = sample_size
self.data = None
self.results = {}
# Create results directory
self.results_dir = Path("test/results/supertrend_indicators")
self.results_dir.mkdir(parents=True, exist_ok=True)
def load_data(self):
"""Load and prepare the data for testing."""
print(f"Loading data from {self.data_file}...")
df = pd.read_csv(self.data_file)
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
if self.sample_size and len(df) > self.sample_size:
df = df.tail(self.sample_size).reset_index(drop=True)
self.data = df
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
def test_supertrend(self, periods=[7, 10, 14, 21], multipliers=[2.0, 3.0, 4.0]):
"""Test Supertrend implementations."""
print("\n=== Testing Supertrend ===")
for period in periods:
for multiplier in multipliers:
print(f"Testing Supertrend({period}, {multiplier})...")
# Initialize indicators
original_st = OriginalSupertrend(period, multiplier)
new_st = NewSupertrend(period, multiplier)
original_values = []
new_values = []
original_trends = []
new_trends = []
original_signals = []
new_signals = []
# Process data
for _, row in self.data.iterrows():
high, low, close = row['High'], row['Low'], row['Close']
# Create OHLC dictionary for both indicators
ohlc_data = {
'open': row['Open'],
'high': high,
'low': low,
'close': close
}
original_st.update(ohlc_data)
new_st.update(ohlc_data)
original_values.append(original_st.get_current_value()['supertrend'] if original_st.is_warmed_up() else np.nan)
new_values.append(new_st.get_current_value()['supertrend'] if new_st.is_warmed_up() else np.nan)
original_trends.append(original_st.get_current_value()['trend'] if original_st.is_warmed_up() else 0)
new_trends.append(new_st.get_current_value()['trend'] if new_st.is_warmed_up() else 0)
# Check for trend changes (signals)
if len(original_trends) > 1:
original_signals.append(1 if original_trends[-1] != original_trends[-2] else 0)
new_signals.append(1 if new_trends[-1] != new_trends[-2] else 0)
else:
original_signals.append(0)
new_signals.append(0)
# Store results
key = f'Supertrend_{period}_{multiplier}'
self.results[key] = {
'original': original_values,
'new': new_values,
'original_trend': original_trends,
'new_trend': new_trends,
'original_signals': original_signals,
'new_signals': new_signals,
'highs': self.data['High'].tolist(),
'lows': self.data['Low'].tolist(),
'closes': self.data['Close'].tolist(),
'dates': self.data['datetime'].tolist(),
'period': period,
'multiplier': multiplier
}
# Calculate differences
diff = np.array(new_values) - np.array(original_values)
valid_diff = diff[~np.isnan(diff)]
# Trend comparison
trend_diff = np.array(new_trends) - np.array(original_trends)
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
# Signal comparison
signal_diff = np.array(new_signals) - np.array(original_signals)
signal_matches = np.sum(signal_diff == 0) / len(signal_diff) * 100
if len(valid_diff) > 0:
max_diff = np.max(np.abs(valid_diff))
mean_diff = np.mean(np.abs(valid_diff))
std_diff = np.std(valid_diff)
print(f" Max difference: {max_diff:.12f}")
print(f" Mean difference: {mean_diff:.12f}")
print(f" Std difference: {std_diff:.12f}")
print(f" Trend match: {trend_matches:.2f}%")
print(f" Signal match: {signal_matches:.2f}%")
# Status check
if max_diff < 1e-10 and trend_matches == 100:
print(f" ✅ PASSED: Mathematically equivalent")
elif max_diff < 1e-6 and trend_matches >= 99:
print(f" ⚠️ WARNING: Small differences (floating point precision)")
else:
print(f" ❌ FAILED: Significant differences detected")
else:
print(f" ❌ ERROR: No valid data points")
def plot_comparison(self, indicator_name: str):
"""Plot detailed comparison for a specific indicator."""
if indicator_name not in self.results:
print(f"No results found for {indicator_name}")
return
result = self.results[indicator_name]
dates = pd.to_datetime(result['dates'])
# Create figure with subplots
fig, axes = plt.subplots(5, 1, figsize=(15, 20))
fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16)
# Plot 1: Price and Supertrend
ax1 = axes[0]
ax1.plot(dates, result['closes'], label='Close Price', alpha=0.7, color='black', linewidth=1)
ax1.plot(dates, result['original'], label='Original Supertrend', alpha=0.8, linewidth=2, color='blue')
ax1.plot(dates, result['new'], label='New Supertrend', alpha=0.8, linewidth=2, linestyle='--', color='red')
ax1.set_title(f'{indicator_name} vs Price')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot 2: Trend comparison
ax2 = axes[1]
ax2.plot(dates, result['original_trend'], label='Original Trend', alpha=0.8, linewidth=2, color='blue')
ax2.plot(dates, result['new_trend'], label='New Trend', alpha=0.8, linewidth=2, linestyle='--', color='red')
ax2.set_title(f'{indicator_name} Trend Direction (1=Up, -1=Down)')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim(-1.5, 1.5)
# Plot 3: Supertrend values comparison
ax3 = axes[2]
ax3.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2)
ax3.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--')
ax3.set_title(f'{indicator_name} Values Comparison')
ax3.legend()
ax3.grid(True, alpha=0.3)
# Plot 4: Difference analysis
ax4 = axes[3]
diff = np.array(result['new']) - np.array(result['original'])
ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1)
ax4.set_title(f'{indicator_name} Difference (New - Original)')
ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5)
ax4.grid(True, alpha=0.3)
# Plot 5: Signal comparison
ax5 = axes[4]
signal_dates = dates[1:] # Signals start from second data point
ax5.scatter(signal_dates, np.array(result['original_signals'][1:]),
label='Original Signals', alpha=0.7, color='blue', s=30)
ax5.scatter(signal_dates, np.array(result['new_signals'][1:]) + 0.1,
label='New Signals', alpha=0.7, color='red', s=30, marker='^')
ax5.set_title(f'{indicator_name} Trend Change Signals')
ax5.legend()
ax5.grid(True, alpha=0.3)
ax5.set_ylim(-0.2, 1.3)
# Add statistics text
valid_diff = diff[~np.isnan(diff)]
if len(valid_diff) > 0:
trend_diff = np.array(result['new_trend']) - np.array(result['original_trend'])
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n'
stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n'
stats_text += f'Trend Match: {trend_matches:.1f}%'
ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
# Format x-axis
for ax in axes:
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
plt.tight_layout()
# Save plot
plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"Plot saved to {plot_path}")
plt.show()
def plot_all_comparisons(self):
"""Plot comparisons for all tested indicators."""
print("\n=== Generating Detailed Comparison Plots ===")
for indicator_name in self.results.keys():
print(f"Plotting {indicator_name}...")
self.plot_comparison(indicator_name)
plt.close('all')
def generate_report(self):
"""Generate detailed report for Supertrend indicators."""
print("\n=== Generating Supertrend Report ===")
report_lines = []
report_lines.append("# Supertrend Indicators Comparison Report")
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
report_lines.append(f"Data file: {self.data_file}")
report_lines.append(f"Sample size: {len(self.data)} data points")
report_lines.append("")
# Summary table
report_lines.append("## Summary Table")
report_lines.append("| Indicator | Period | Multiplier | Max Diff | Mean Diff | Trend Match | Status |")
report_lines.append("|-----------|--------|------------|----------|-----------|-------------|--------|")
for indicator_name, result in self.results.items():
diff = np.array(result['new']) - np.array(result['original'])
valid_diff = diff[~np.isnan(diff)]
trend_diff = np.array(result['new_trend']) - np.array(result['original_trend'])
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
if len(valid_diff) > 0:
max_diff = np.max(np.abs(valid_diff))
mean_diff = np.mean(np.abs(valid_diff))
if max_diff < 1e-10 and trend_matches == 100:
status = "✅ PASSED"
elif max_diff < 1e-6 and trend_matches >= 99:
status = "⚠️ WARNING"
else:
status = "❌ FAILED"
report_lines.append(f"| {indicator_name} | {result['period']} | {result['multiplier']} | "
f"{max_diff:.2e} | {mean_diff:.2e} | {trend_matches:.1f}% | {status} |")
else:
report_lines.append(f"| {indicator_name} | {result['period']} | {result['multiplier']} | "
f"N/A | N/A | N/A | ❌ ERROR |")
report_lines.append("")
# Methodology explanation
report_lines.append("## Methodology")
report_lines.append("### Supertrend Calculation")
report_lines.append("1. **Basic Upper Band**: (High + Low) / 2 + (Multiplier × ATR)")
report_lines.append("2. **Basic Lower Band**: (High + Low) / 2 - (Multiplier × ATR)")
report_lines.append("3. **Final Upper Band**: min(Basic Upper Band, Previous Final Upper Band if Close[1] <= Previous Final Upper Band)")
report_lines.append("4. **Final Lower Band**: max(Basic Lower Band, Previous Final Lower Band if Close[1] >= Previous Final Lower Band)")
report_lines.append("5. **Supertrend**: Final Lower Band if trend is up, Final Upper Band if trend is down")
report_lines.append("6. **Trend**: Up if Close > Previous Supertrend, Down if Close <= Previous Supertrend")
report_lines.append("")
# Detailed analysis
report_lines.append("## Detailed Analysis")
for indicator_name, result in self.results.items():
report_lines.append(f"### {indicator_name}")
diff = np.array(result['new']) - np.array(result['original'])
valid_diff = diff[~np.isnan(diff)]
trend_diff = np.array(result['new_trend']) - np.array(result['original_trend'])
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
signal_diff = np.array(result['new_signals']) - np.array(result['original_signals'])
signal_matches = np.sum(signal_diff == 0) / len(signal_diff) * 100
if len(valid_diff) > 0:
report_lines.append(f"- **Period**: {result['period']}")
report_lines.append(f"- **Multiplier**: {result['multiplier']}")
report_lines.append(f"- **Valid data points**: {len(valid_diff)}")
report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}")
report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}")
report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}")
report_lines.append(f"- **Trend direction match**: {trend_matches:.2f}%")
report_lines.append(f"- **Signal timing match**: {signal_matches:.2f}%")
# Supertrend-specific metrics
valid_original = np.array(result['original'])[~np.isnan(result['original'])]
if len(valid_original) > 0:
mean_st = np.mean(valid_original)
relative_error = np.mean(np.abs(valid_diff)) / mean_st * 100
report_lines.append(f"- **Mean Supertrend value**: {mean_st:.6f}")
report_lines.append(f"- **Relative error**: {relative_error:.2e}%")
# Count trend changes
original_changes = np.sum(np.array(result['original_signals']))
new_changes = np.sum(np.array(result['new_signals']))
report_lines.append(f"- **Original trend changes**: {original_changes}")
report_lines.append(f"- **New trend changes**: {new_changes}")
# Percentile analysis
percentiles = [1, 5, 25, 50, 75, 95, 99]
perc_values = np.percentile(np.abs(valid_diff), percentiles)
perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)])
report_lines.append(f"- **Percentiles**: {perc_str}")
report_lines.append("")
# Save report
report_path = self.results_dir / "supertrend_indicators_report.md"
with open(report_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(report_lines))
print(f"Report saved to {report_path}")
def run_tests(self):
"""Run all Supertrend tests."""
print("Starting Supertrend Comparison Tests...")
# Load data
self.load_data()
# Run tests
self.test_supertrend()
# Generate outputs
self.plot_all_comparisons()
self.generate_report()
print("\n✅ Supertrend tests completed!")
if __name__ == "__main__":
tester = SupertrendComparisonTest(sample_size=3000)
tester.run_tests()