374 lines
17 KiB
Python
374 lines
17 KiB
Python
"""
|
||
Supertrend Indicators Comparison Test
|
||
|
||
Focused testing for Supertrend implementations.
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.dates as mdates
|
||
from datetime import datetime
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# Add project root to path
|
||
project_root = Path(__file__).parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
# Import original indicators
|
||
from cycles.IncStrategies.indicators import (
|
||
SupertrendState as OriginalSupertrend
|
||
)
|
||
|
||
# Import new indicators
|
||
from IncrementalTrader.strategies.indicators import (
|
||
SupertrendState as NewSupertrend
|
||
)
|
||
|
||
|
||
class SupertrendComparisonTest:
|
||
"""Test framework for comparing Supertrend implementations."""
|
||
|
||
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
|
||
self.data_file = data_file
|
||
self.sample_size = sample_size
|
||
self.data = None
|
||
self.results = {}
|
||
|
||
# Create results directory
|
||
self.results_dir = Path("test/results/supertrend_indicators")
|
||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
def load_data(self):
|
||
"""Load and prepare the data for testing."""
|
||
print(f"Loading data from {self.data_file}...")
|
||
|
||
df = pd.read_csv(self.data_file)
|
||
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||
|
||
if self.sample_size and len(df) > self.sample_size:
|
||
df = df.tail(self.sample_size).reset_index(drop=True)
|
||
|
||
self.data = df
|
||
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
|
||
|
||
def test_supertrend(self, periods=[7, 10, 14, 21], multipliers=[2.0, 3.0, 4.0]):
|
||
"""Test Supertrend implementations."""
|
||
print("\n=== Testing Supertrend ===")
|
||
|
||
for period in periods:
|
||
for multiplier in multipliers:
|
||
print(f"Testing Supertrend({period}, {multiplier})...")
|
||
|
||
# Initialize indicators
|
||
original_st = OriginalSupertrend(period, multiplier)
|
||
new_st = NewSupertrend(period, multiplier)
|
||
|
||
original_values = []
|
||
new_values = []
|
||
original_trends = []
|
||
new_trends = []
|
||
original_signals = []
|
||
new_signals = []
|
||
|
||
# Process data
|
||
for _, row in self.data.iterrows():
|
||
high, low, close = row['High'], row['Low'], row['Close']
|
||
|
||
# Create OHLC dictionary for both indicators
|
||
ohlc_data = {
|
||
'open': row['Open'],
|
||
'high': high,
|
||
'low': low,
|
||
'close': close
|
||
}
|
||
|
||
original_st.update(ohlc_data)
|
||
new_st.update(ohlc_data)
|
||
|
||
original_values.append(original_st.get_current_value()['supertrend'] if original_st.is_warmed_up() else np.nan)
|
||
new_values.append(new_st.get_current_value()['supertrend'] if new_st.is_warmed_up() else np.nan)
|
||
original_trends.append(original_st.get_current_value()['trend'] if original_st.is_warmed_up() else 0)
|
||
new_trends.append(new_st.get_current_value()['trend'] if new_st.is_warmed_up() else 0)
|
||
|
||
# Check for trend changes (signals)
|
||
if len(original_trends) > 1:
|
||
original_signals.append(1 if original_trends[-1] != original_trends[-2] else 0)
|
||
new_signals.append(1 if new_trends[-1] != new_trends[-2] else 0)
|
||
else:
|
||
original_signals.append(0)
|
||
new_signals.append(0)
|
||
|
||
# Store results
|
||
key = f'Supertrend_{period}_{multiplier}'
|
||
self.results[key] = {
|
||
'original': original_values,
|
||
'new': new_values,
|
||
'original_trend': original_trends,
|
||
'new_trend': new_trends,
|
||
'original_signals': original_signals,
|
||
'new_signals': new_signals,
|
||
'highs': self.data['High'].tolist(),
|
||
'lows': self.data['Low'].tolist(),
|
||
'closes': self.data['Close'].tolist(),
|
||
'dates': self.data['datetime'].tolist(),
|
||
'period': period,
|
||
'multiplier': multiplier
|
||
}
|
||
|
||
# Calculate differences
|
||
diff = np.array(new_values) - np.array(original_values)
|
||
valid_diff = diff[~np.isnan(diff)]
|
||
|
||
# Trend comparison
|
||
trend_diff = np.array(new_trends) - np.array(original_trends)
|
||
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
|
||
|
||
# Signal comparison
|
||
signal_diff = np.array(new_signals) - np.array(original_signals)
|
||
signal_matches = np.sum(signal_diff == 0) / len(signal_diff) * 100
|
||
|
||
if len(valid_diff) > 0:
|
||
max_diff = np.max(np.abs(valid_diff))
|
||
mean_diff = np.mean(np.abs(valid_diff))
|
||
std_diff = np.std(valid_diff)
|
||
|
||
print(f" Max difference: {max_diff:.12f}")
|
||
print(f" Mean difference: {mean_diff:.12f}")
|
||
print(f" Std difference: {std_diff:.12f}")
|
||
print(f" Trend match: {trend_matches:.2f}%")
|
||
print(f" Signal match: {signal_matches:.2f}%")
|
||
|
||
# Status check
|
||
if max_diff < 1e-10 and trend_matches == 100:
|
||
print(f" ✅ PASSED: Mathematically equivalent")
|
||
elif max_diff < 1e-6 and trend_matches >= 99:
|
||
print(f" ⚠️ WARNING: Small differences (floating point precision)")
|
||
else:
|
||
print(f" ❌ FAILED: Significant differences detected")
|
||
else:
|
||
print(f" ❌ ERROR: No valid data points")
|
||
|
||
def plot_comparison(self, indicator_name: str):
|
||
"""Plot detailed comparison for a specific indicator."""
|
||
if indicator_name not in self.results:
|
||
print(f"No results found for {indicator_name}")
|
||
return
|
||
|
||
result = self.results[indicator_name]
|
||
dates = pd.to_datetime(result['dates'])
|
||
|
||
# Create figure with subplots
|
||
fig, axes = plt.subplots(5, 1, figsize=(15, 20))
|
||
fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16)
|
||
|
||
# Plot 1: Price and Supertrend
|
||
ax1 = axes[0]
|
||
ax1.plot(dates, result['closes'], label='Close Price', alpha=0.7, color='black', linewidth=1)
|
||
ax1.plot(dates, result['original'], label='Original Supertrend', alpha=0.8, linewidth=2, color='blue')
|
||
ax1.plot(dates, result['new'], label='New Supertrend', alpha=0.8, linewidth=2, linestyle='--', color='red')
|
||
ax1.set_title(f'{indicator_name} vs Price')
|
||
ax1.legend()
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
# Plot 2: Trend comparison
|
||
ax2 = axes[1]
|
||
ax2.plot(dates, result['original_trend'], label='Original Trend', alpha=0.8, linewidth=2, color='blue')
|
||
ax2.plot(dates, result['new_trend'], label='New Trend', alpha=0.8, linewidth=2, linestyle='--', color='red')
|
||
ax2.set_title(f'{indicator_name} Trend Direction (1=Up, -1=Down)')
|
||
ax2.legend()
|
||
ax2.grid(True, alpha=0.3)
|
||
ax2.set_ylim(-1.5, 1.5)
|
||
|
||
# Plot 3: Supertrend values comparison
|
||
ax3 = axes[2]
|
||
ax3.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2)
|
||
ax3.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--')
|
||
ax3.set_title(f'{indicator_name} Values Comparison')
|
||
ax3.legend()
|
||
ax3.grid(True, alpha=0.3)
|
||
|
||
# Plot 4: Difference analysis
|
||
ax4 = axes[3]
|
||
diff = np.array(result['new']) - np.array(result['original'])
|
||
ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1)
|
||
ax4.set_title(f'{indicator_name} Difference (New - Original)')
|
||
ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5)
|
||
ax4.grid(True, alpha=0.3)
|
||
|
||
# Plot 5: Signal comparison
|
||
ax5 = axes[4]
|
||
signal_dates = dates[1:] # Signals start from second data point
|
||
ax5.scatter(signal_dates, np.array(result['original_signals'][1:]),
|
||
label='Original Signals', alpha=0.7, color='blue', s=30)
|
||
ax5.scatter(signal_dates, np.array(result['new_signals'][1:]) + 0.1,
|
||
label='New Signals', alpha=0.7, color='red', s=30, marker='^')
|
||
ax5.set_title(f'{indicator_name} Trend Change Signals')
|
||
ax5.legend()
|
||
ax5.grid(True, alpha=0.3)
|
||
ax5.set_ylim(-0.2, 1.3)
|
||
|
||
# Add statistics text
|
||
valid_diff = diff[~np.isnan(diff)]
|
||
if len(valid_diff) > 0:
|
||
trend_diff = np.array(result['new_trend']) - np.array(result['original_trend'])
|
||
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
|
||
|
||
stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n'
|
||
stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n'
|
||
stats_text += f'Trend Match: {trend_matches:.1f}%'
|
||
ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes,
|
||
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
||
|
||
# Format x-axis
|
||
for ax in axes:
|
||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
||
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
|
||
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
||
|
||
plt.tight_layout()
|
||
|
||
# Save plot
|
||
plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png"
|
||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||
print(f"Plot saved to {plot_path}")
|
||
|
||
plt.show()
|
||
|
||
def plot_all_comparisons(self):
|
||
"""Plot comparisons for all tested indicators."""
|
||
print("\n=== Generating Detailed Comparison Plots ===")
|
||
|
||
for indicator_name in self.results.keys():
|
||
print(f"Plotting {indicator_name}...")
|
||
self.plot_comparison(indicator_name)
|
||
plt.close('all')
|
||
|
||
def generate_report(self):
|
||
"""Generate detailed report for Supertrend indicators."""
|
||
print("\n=== Generating Supertrend Report ===")
|
||
|
||
report_lines = []
|
||
report_lines.append("# Supertrend Indicators Comparison Report")
|
||
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
report_lines.append(f"Data file: {self.data_file}")
|
||
report_lines.append(f"Sample size: {len(self.data)} data points")
|
||
report_lines.append("")
|
||
|
||
# Summary table
|
||
report_lines.append("## Summary Table")
|
||
report_lines.append("| Indicator | Period | Multiplier | Max Diff | Mean Diff | Trend Match | Status |")
|
||
report_lines.append("|-----------|--------|------------|----------|-----------|-------------|--------|")
|
||
|
||
for indicator_name, result in self.results.items():
|
||
diff = np.array(result['new']) - np.array(result['original'])
|
||
valid_diff = diff[~np.isnan(diff)]
|
||
|
||
trend_diff = np.array(result['new_trend']) - np.array(result['original_trend'])
|
||
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
|
||
|
||
if len(valid_diff) > 0:
|
||
max_diff = np.max(np.abs(valid_diff))
|
||
mean_diff = np.mean(np.abs(valid_diff))
|
||
|
||
if max_diff < 1e-10 and trend_matches == 100:
|
||
status = "✅ PASSED"
|
||
elif max_diff < 1e-6 and trend_matches >= 99:
|
||
status = "⚠️ WARNING"
|
||
else:
|
||
status = "❌ FAILED"
|
||
|
||
report_lines.append(f"| {indicator_name} | {result['period']} | {result['multiplier']} | "
|
||
f"{max_diff:.2e} | {mean_diff:.2e} | {trend_matches:.1f}% | {status} |")
|
||
else:
|
||
report_lines.append(f"| {indicator_name} | {result['period']} | {result['multiplier']} | "
|
||
f"N/A | N/A | N/A | ❌ ERROR |")
|
||
|
||
report_lines.append("")
|
||
|
||
# Methodology explanation
|
||
report_lines.append("## Methodology")
|
||
report_lines.append("### Supertrend Calculation")
|
||
report_lines.append("1. **Basic Upper Band**: (High + Low) / 2 + (Multiplier × ATR)")
|
||
report_lines.append("2. **Basic Lower Band**: (High + Low) / 2 - (Multiplier × ATR)")
|
||
report_lines.append("3. **Final Upper Band**: min(Basic Upper Band, Previous Final Upper Band if Close[1] <= Previous Final Upper Band)")
|
||
report_lines.append("4. **Final Lower Band**: max(Basic Lower Band, Previous Final Lower Band if Close[1] >= Previous Final Lower Band)")
|
||
report_lines.append("5. **Supertrend**: Final Lower Band if trend is up, Final Upper Band if trend is down")
|
||
report_lines.append("6. **Trend**: Up if Close > Previous Supertrend, Down if Close <= Previous Supertrend")
|
||
report_lines.append("")
|
||
|
||
# Detailed analysis
|
||
report_lines.append("## Detailed Analysis")
|
||
|
||
for indicator_name, result in self.results.items():
|
||
report_lines.append(f"### {indicator_name}")
|
||
|
||
diff = np.array(result['new']) - np.array(result['original'])
|
||
valid_diff = diff[~np.isnan(diff)]
|
||
|
||
trend_diff = np.array(result['new_trend']) - np.array(result['original_trend'])
|
||
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
|
||
|
||
signal_diff = np.array(result['new_signals']) - np.array(result['original_signals'])
|
||
signal_matches = np.sum(signal_diff == 0) / len(signal_diff) * 100
|
||
|
||
if len(valid_diff) > 0:
|
||
report_lines.append(f"- **Period**: {result['period']}")
|
||
report_lines.append(f"- **Multiplier**: {result['multiplier']}")
|
||
report_lines.append(f"- **Valid data points**: {len(valid_diff)}")
|
||
report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}")
|
||
report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}")
|
||
report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}")
|
||
report_lines.append(f"- **Trend direction match**: {trend_matches:.2f}%")
|
||
report_lines.append(f"- **Signal timing match**: {signal_matches:.2f}%")
|
||
|
||
# Supertrend-specific metrics
|
||
valid_original = np.array(result['original'])[~np.isnan(result['original'])]
|
||
if len(valid_original) > 0:
|
||
mean_st = np.mean(valid_original)
|
||
relative_error = np.mean(np.abs(valid_diff)) / mean_st * 100
|
||
report_lines.append(f"- **Mean Supertrend value**: {mean_st:.6f}")
|
||
report_lines.append(f"- **Relative error**: {relative_error:.2e}%")
|
||
|
||
# Count trend changes
|
||
original_changes = np.sum(np.array(result['original_signals']))
|
||
new_changes = np.sum(np.array(result['new_signals']))
|
||
report_lines.append(f"- **Original trend changes**: {original_changes}")
|
||
report_lines.append(f"- **New trend changes**: {new_changes}")
|
||
|
||
# Percentile analysis
|
||
percentiles = [1, 5, 25, 50, 75, 95, 99]
|
||
perc_values = np.percentile(np.abs(valid_diff), percentiles)
|
||
perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)])
|
||
report_lines.append(f"- **Percentiles**: {perc_str}")
|
||
|
||
report_lines.append("")
|
||
|
||
# Save report
|
||
report_path = self.results_dir / "supertrend_indicators_report.md"
|
||
with open(report_path, 'w', encoding='utf-8') as f:
|
||
f.write('\n'.join(report_lines))
|
||
|
||
print(f"Report saved to {report_path}")
|
||
|
||
def run_tests(self):
|
||
"""Run all Supertrend tests."""
|
||
print("Starting Supertrend Comparison Tests...")
|
||
|
||
# Load data
|
||
self.load_data()
|
||
|
||
# Run tests
|
||
self.test_supertrend()
|
||
|
||
# Generate outputs
|
||
self.plot_all_comparisons()
|
||
self.generate_report()
|
||
|
||
print("\n✅ Supertrend tests completed!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
tester = SupertrendComparisonTest(sample_size=3000)
|
||
tester.run_tests() |