cleanup of the old Incremental trader after refactopring
This commit is contained in:
@@ -1,395 +0,0 @@
|
||||
"""
|
||||
ATR Indicators Comparison Test
|
||||
|
||||
Focused testing for ATR and Simple ATR implementations.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Import original indicators
|
||||
from cycles.IncStrategies.indicators import (
|
||||
ATRState as OriginalATR,
|
||||
SimpleATRState as OriginalSimpleATR
|
||||
)
|
||||
|
||||
# Import new indicators
|
||||
from IncrementalTrader.strategies.indicators import (
|
||||
ATRState as NewATR,
|
||||
SimpleATRState as NewSimpleATR
|
||||
)
|
||||
|
||||
|
||||
class ATRComparisonTest:
|
||||
"""Test framework for comparing ATR implementations."""
|
||||
|
||||
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
|
||||
self.data_file = data_file
|
||||
self.sample_size = sample_size
|
||||
self.data = None
|
||||
self.results = {}
|
||||
|
||||
# Create results directory
|
||||
self.results_dir = Path("test/results/atr_indicators")
|
||||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def load_data(self):
|
||||
"""Load and prepare the data for testing."""
|
||||
print(f"Loading data from {self.data_file}...")
|
||||
|
||||
df = pd.read_csv(self.data_file)
|
||||
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||
|
||||
if self.sample_size and len(df) > self.sample_size:
|
||||
df = df.tail(self.sample_size).reset_index(drop=True)
|
||||
|
||||
self.data = df
|
||||
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
|
||||
|
||||
def test_atr(self, periods=[7, 14, 21, 28]):
|
||||
"""Test ATR implementations."""
|
||||
print("\n=== Testing ATR (Wilder's Smoothing) ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing ATR({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_atr = OriginalATR(period)
|
||||
new_atr = NewATR(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
true_ranges = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
high, low, close = row['High'], row['Low'], row['Close']
|
||||
|
||||
# Create OHLC dictionary for both indicators
|
||||
ohlc_data = {
|
||||
'open': row['Open'],
|
||||
'high': high,
|
||||
'low': low,
|
||||
'close': close
|
||||
}
|
||||
|
||||
original_atr.update(ohlc_data)
|
||||
new_atr.update(ohlc_data)
|
||||
|
||||
original_values.append(original_atr.get_current_value() if original_atr.is_warmed_up() else np.nan)
|
||||
new_values.append(new_atr.get_current_value() if new_atr.is_warmed_up() else np.nan)
|
||||
|
||||
# Calculate true range for reference
|
||||
if len(self.data) > 1:
|
||||
prev_close = self.data.iloc[max(0, len(true_ranges)-1)]['Close'] if true_ranges else close
|
||||
tr = max(high - low, abs(high - prev_close), abs(low - prev_close))
|
||||
true_ranges.append(tr)
|
||||
else:
|
||||
true_ranges.append(high - low)
|
||||
|
||||
# Store results
|
||||
self.results[f'ATR_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'true_ranges': true_ranges,
|
||||
'highs': self.data['High'].tolist(),
|
||||
'lows': self.data['Low'].tolist(),
|
||||
'closes': self.data['Close'].tolist(),
|
||||
'dates': self.data['datetime'].tolist(),
|
||||
'period': period
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
std_diff = np.std(valid_diff)
|
||||
|
||||
print(f" Max difference: {max_diff:.12f}")
|
||||
print(f" Mean difference: {mean_diff:.12f}")
|
||||
print(f" Std difference: {std_diff:.12f}")
|
||||
|
||||
# Status check
|
||||
if max_diff < 1e-10:
|
||||
print(f" ✅ PASSED: Mathematically equivalent")
|
||||
elif max_diff < 1e-6:
|
||||
print(f" ⚠️ WARNING: Small differences (floating point precision)")
|
||||
else:
|
||||
print(f" ❌ FAILED: Significant differences detected")
|
||||
else:
|
||||
print(f" ❌ ERROR: No valid data points")
|
||||
|
||||
def test_simple_atr(self, periods=[7, 14, 21, 28]):
|
||||
"""Test Simple ATR implementations."""
|
||||
print("\n=== Testing Simple ATR (Simple Moving Average) ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing SimpleATR({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_atr = OriginalSimpleATR(period)
|
||||
new_atr = NewSimpleATR(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
true_ranges = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
high, low, close = row['High'], row['Low'], row['Close']
|
||||
|
||||
# Create OHLC dictionary for both indicators
|
||||
ohlc_data = {
|
||||
'open': row['Open'],
|
||||
'high': high,
|
||||
'low': low,
|
||||
'close': close
|
||||
}
|
||||
|
||||
original_atr.update(ohlc_data)
|
||||
new_atr.update(ohlc_data)
|
||||
|
||||
original_values.append(original_atr.get_current_value() if original_atr.is_warmed_up() else np.nan)
|
||||
new_values.append(new_atr.get_current_value() if new_atr.is_warmed_up() else np.nan)
|
||||
|
||||
# Calculate true range for reference
|
||||
if len(self.data) > 1:
|
||||
prev_close = self.data.iloc[max(0, len(true_ranges)-1)]['Close'] if true_ranges else close
|
||||
tr = max(high - low, abs(high - prev_close), abs(low - prev_close))
|
||||
true_ranges.append(tr)
|
||||
else:
|
||||
true_ranges.append(high - low)
|
||||
|
||||
# Store results
|
||||
self.results[f'SimpleATR_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'true_ranges': true_ranges,
|
||||
'highs': self.data['High'].tolist(),
|
||||
'lows': self.data['Low'].tolist(),
|
||||
'closes': self.data['Close'].tolist(),
|
||||
'dates': self.data['datetime'].tolist(),
|
||||
'period': period
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
std_diff = np.std(valid_diff)
|
||||
|
||||
print(f" Max difference: {max_diff:.12f}")
|
||||
print(f" Mean difference: {mean_diff:.12f}")
|
||||
print(f" Std difference: {std_diff:.12f}")
|
||||
|
||||
# Status check
|
||||
if max_diff < 1e-10:
|
||||
print(f" ✅ PASSED: Mathematically equivalent")
|
||||
elif max_diff < 1e-6:
|
||||
print(f" ⚠️ WARNING: Small differences (floating point precision)")
|
||||
else:
|
||||
print(f" ❌ FAILED: Significant differences detected")
|
||||
else:
|
||||
print(f" ❌ ERROR: No valid data points")
|
||||
|
||||
def plot_comparison(self, indicator_name: str):
|
||||
"""Plot detailed comparison for a specific indicator."""
|
||||
if indicator_name not in self.results:
|
||||
print(f"No results found for {indicator_name}")
|
||||
return
|
||||
|
||||
result = self.results[indicator_name]
|
||||
dates = pd.to_datetime(result['dates'])
|
||||
|
||||
# Create figure with subplots
|
||||
fig, axes = plt.subplots(4, 1, figsize=(15, 16))
|
||||
fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16)
|
||||
|
||||
# Plot 1: OHLC data
|
||||
ax1 = axes[0]
|
||||
ax1.plot(dates, result['highs'], label='High', alpha=0.6, color='green')
|
||||
ax1.plot(dates, result['lows'], label='Low', alpha=0.6, color='red')
|
||||
ax1.plot(dates, result['closes'], label='Close', alpha=0.8, color='blue')
|
||||
ax1.set_title('OHLC Data')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 2: True Range
|
||||
ax2 = axes[1]
|
||||
ax2.plot(dates, result['true_ranges'], label='True Range', alpha=0.7, color='orange')
|
||||
ax2.set_title('True Range Values')
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 3: ATR comparison
|
||||
ax3 = axes[2]
|
||||
ax3.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2)
|
||||
ax3.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--')
|
||||
ax3.set_title(f'{indicator_name} Values Comparison')
|
||||
ax3.legend()
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 4: Difference analysis
|
||||
ax4 = axes[3]
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1)
|
||||
ax4.set_title(f'{indicator_name} Difference (New - Original)')
|
||||
ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5)
|
||||
ax4.grid(True, alpha=0.3)
|
||||
|
||||
# Add statistics text
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
if len(valid_diff) > 0:
|
||||
stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n'
|
||||
stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n'
|
||||
stats_text += f'Std: {np.std(valid_diff):.2e}'
|
||||
ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
||||
|
||||
# Format x-axis
|
||||
for ax in axes:
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
||||
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
|
||||
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save plot
|
||||
plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png"
|
||||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Plot saved to {plot_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def plot_all_comparisons(self):
|
||||
"""Plot comparisons for all tested indicators."""
|
||||
print("\n=== Generating Detailed Comparison Plots ===")
|
||||
|
||||
for indicator_name in self.results.keys():
|
||||
print(f"Plotting {indicator_name}...")
|
||||
self.plot_comparison(indicator_name)
|
||||
plt.close('all')
|
||||
|
||||
def generate_report(self):
|
||||
"""Generate detailed report for ATR indicators."""
|
||||
print("\n=== Generating ATR Report ===")
|
||||
|
||||
report_lines = []
|
||||
report_lines.append("# ATR Indicators Comparison Report")
|
||||
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
report_lines.append(f"Data file: {self.data_file}")
|
||||
report_lines.append(f"Sample size: {len(self.data)} data points")
|
||||
report_lines.append("")
|
||||
|
||||
# Summary table
|
||||
report_lines.append("## Summary Table")
|
||||
report_lines.append("| Indicator | Period | Max Diff | Mean Diff | Status |")
|
||||
report_lines.append("|-----------|--------|----------|-----------|--------|")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
|
||||
if max_diff < 1e-10:
|
||||
status = "✅ PASSED"
|
||||
elif max_diff < 1e-6:
|
||||
status = "⚠️ WARNING"
|
||||
else:
|
||||
status = "❌ FAILED"
|
||||
|
||||
report_lines.append(f"| {indicator_name} | {result['period']} | {max_diff:.2e} | {mean_diff:.2e} | {status} |")
|
||||
else:
|
||||
report_lines.append(f"| {indicator_name} | {result['period']} | N/A | N/A | ❌ ERROR |")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Methodology explanation
|
||||
report_lines.append("## Methodology")
|
||||
report_lines.append("### ATR (Average True Range)")
|
||||
report_lines.append("- Uses Wilder's smoothing method: ATR = (Previous ATR * (n-1) + Current TR) / n")
|
||||
report_lines.append("- True Range = max(High-Low, |High-PrevClose|, |Low-PrevClose|)")
|
||||
report_lines.append("")
|
||||
report_lines.append("### Simple ATR")
|
||||
report_lines.append("- Uses simple moving average of True Range values")
|
||||
report_lines.append("- More responsive to recent changes than Wilder's method")
|
||||
report_lines.append("")
|
||||
|
||||
# Detailed analysis
|
||||
report_lines.append("## Detailed Analysis")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
report_lines.append(f"### {indicator_name}")
|
||||
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
report_lines.append(f"- **Period**: {result['period']}")
|
||||
report_lines.append(f"- **Valid data points**: {len(valid_diff)}")
|
||||
report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}")
|
||||
report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}")
|
||||
report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}")
|
||||
|
||||
# ATR-specific metrics
|
||||
valid_original = np.array(result['original'])[~np.isnan(result['original'])]
|
||||
if len(valid_original) > 0:
|
||||
mean_atr = np.mean(valid_original)
|
||||
relative_error = np.mean(np.abs(valid_diff)) / mean_atr * 100
|
||||
report_lines.append(f"- **Mean ATR value**: {mean_atr:.6f}")
|
||||
report_lines.append(f"- **Relative error**: {relative_error:.2e}%")
|
||||
|
||||
# Percentile analysis
|
||||
percentiles = [1, 5, 25, 50, 75, 95, 99]
|
||||
perc_values = np.percentile(np.abs(valid_diff), percentiles)
|
||||
perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)])
|
||||
report_lines.append(f"- **Percentiles**: {perc_str}")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Save report
|
||||
report_path = self.results_dir / "atr_indicators_report.md"
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(report_lines))
|
||||
|
||||
print(f"Report saved to {report_path}")
|
||||
|
||||
def run_tests(self):
|
||||
"""Run all ATR tests."""
|
||||
print("Starting ATR Comparison Tests...")
|
||||
|
||||
# Load data
|
||||
self.load_data()
|
||||
|
||||
# Run tests
|
||||
self.test_atr()
|
||||
self.test_simple_atr()
|
||||
|
||||
# Generate outputs
|
||||
self.plot_all_comparisons()
|
||||
self.generate_report()
|
||||
|
||||
print("\n✅ ATR tests completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tester = ATRComparisonTest(sample_size=3000)
|
||||
tester.run_tests()
|
||||
@@ -1,487 +0,0 @@
|
||||
"""
|
||||
Bollinger Bands Indicators Comparison Test
|
||||
|
||||
Focused testing for Bollinger Bands implementations.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Import original indicators
|
||||
from cycles.IncStrategies.indicators import (
|
||||
BollingerBandsState as OriginalBB,
|
||||
BollingerBandsOHLCState as OriginalBBOHLC
|
||||
)
|
||||
|
||||
# Import new indicators
|
||||
from IncrementalTrader.strategies.indicators import (
|
||||
BollingerBandsState as NewBB,
|
||||
BollingerBandsOHLCState as NewBBOHLC
|
||||
)
|
||||
|
||||
|
||||
class BollingerBandsComparisonTest:
|
||||
"""Test framework for comparing Bollinger Bands implementations."""
|
||||
|
||||
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
|
||||
self.data_file = data_file
|
||||
self.sample_size = sample_size
|
||||
self.data = None
|
||||
self.results = {}
|
||||
|
||||
# Create results directory
|
||||
self.results_dir = Path("test/results/bollinger_bands")
|
||||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def load_data(self):
|
||||
"""Load and prepare the data for testing."""
|
||||
print(f"Loading data from {self.data_file}...")
|
||||
|
||||
df = pd.read_csv(self.data_file)
|
||||
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||
|
||||
if self.sample_size and len(df) > self.sample_size:
|
||||
df = df.tail(self.sample_size).reset_index(drop=True)
|
||||
|
||||
self.data = df
|
||||
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
|
||||
|
||||
def test_bollinger_bands(self, periods=[10, 20, 30], std_devs=[1.5, 2.0, 2.5]):
|
||||
"""Test Bollinger Bands implementations (Close price based)."""
|
||||
print("\n=== Testing Bollinger Bands (Close Price) ===")
|
||||
|
||||
for period in periods:
|
||||
for std_dev in std_devs:
|
||||
print(f"Testing BollingerBands({period}, {std_dev})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_bb = OriginalBB(period, std_dev)
|
||||
new_bb = NewBB(period, std_dev)
|
||||
|
||||
original_upper = []
|
||||
original_middle = []
|
||||
original_lower = []
|
||||
new_upper = []
|
||||
new_middle = []
|
||||
new_lower = []
|
||||
prices = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
prices.append(price)
|
||||
|
||||
original_bb.update(price)
|
||||
new_bb.update(price)
|
||||
|
||||
if original_bb.is_warmed_up():
|
||||
original_upper.append(original_bb.get_current_value()['upper_band'])
|
||||
original_middle.append(original_bb.get_current_value()['middle_band'])
|
||||
original_lower.append(original_bb.get_current_value()['lower_band'])
|
||||
else:
|
||||
original_upper.append(np.nan)
|
||||
original_middle.append(np.nan)
|
||||
original_lower.append(np.nan)
|
||||
|
||||
if new_bb.is_warmed_up():
|
||||
new_upper.append(new_bb.get_current_value()['upper_band'])
|
||||
new_middle.append(new_bb.get_current_value()['middle_band'])
|
||||
new_lower.append(new_bb.get_current_value()['lower_band'])
|
||||
else:
|
||||
new_upper.append(np.nan)
|
||||
new_middle.append(np.nan)
|
||||
new_lower.append(np.nan)
|
||||
|
||||
# Store results
|
||||
key = f'BB_{period}_{std_dev}'
|
||||
self.results[key] = {
|
||||
'original_upper': original_upper,
|
||||
'original_middle': original_middle,
|
||||
'original_lower': original_lower,
|
||||
'new_upper': new_upper,
|
||||
'new_middle': new_middle,
|
||||
'new_lower': new_lower,
|
||||
'prices': prices,
|
||||
'dates': self.data['datetime'].tolist(),
|
||||
'period': period,
|
||||
'std_dev': std_dev,
|
||||
'type': 'Close'
|
||||
}
|
||||
|
||||
# Calculate differences for each band
|
||||
for band in ['upper', 'middle', 'lower']:
|
||||
orig = np.array(locals()[f'original_{band}'])
|
||||
new = np.array(locals()[f'new_{band}'])
|
||||
diff = new - orig
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
|
||||
print(f" {band.capitalize()} band - Max diff: {max_diff:.12f}, Mean diff: {mean_diff:.12f}")
|
||||
|
||||
# Status check for this band
|
||||
if max_diff < 1e-10:
|
||||
status = "✅ PASSED"
|
||||
elif max_diff < 1e-6:
|
||||
status = "⚠️ WARNING"
|
||||
else:
|
||||
status = "❌ FAILED"
|
||||
print(f" Status: {status}")
|
||||
else:
|
||||
print(f" {band.capitalize()} band - ❌ ERROR: No valid data points")
|
||||
|
||||
def test_bollinger_bands_ohlc(self, periods=[10, 20, 30], std_devs=[1.5, 2.0, 2.5]):
|
||||
"""Test Bollinger Bands OHLC implementations (Typical price based)."""
|
||||
print("\n=== Testing Bollinger Bands OHLC (Typical Price) ===")
|
||||
|
||||
for period in periods:
|
||||
for std_dev in std_devs:
|
||||
print(f"Testing BollingerBandsOHLC({period}, {std_dev})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_bb = OriginalBBOHLC(period, std_dev)
|
||||
new_bb = NewBBOHLC(period, std_dev)
|
||||
|
||||
original_upper = []
|
||||
original_middle = []
|
||||
original_lower = []
|
||||
new_upper = []
|
||||
new_middle = []
|
||||
new_lower = []
|
||||
typical_prices = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
high, low, close = row['High'], row['Low'], row['Close']
|
||||
typical_price = (high + low + close) / 3
|
||||
typical_prices.append(typical_price)
|
||||
|
||||
# Create OHLC dictionary for both indicators
|
||||
ohlc_data = {
|
||||
'open': row['Open'],
|
||||
'high': high,
|
||||
'low': low,
|
||||
'close': close
|
||||
}
|
||||
|
||||
original_bb.update(ohlc_data)
|
||||
new_bb.update(ohlc_data)
|
||||
|
||||
if original_bb.is_warmed_up():
|
||||
original_upper.append(original_bb.get_current_value()['upper_band'])
|
||||
original_middle.append(original_bb.get_current_value()['middle_band'])
|
||||
original_lower.append(original_bb.get_current_value()['lower_band'])
|
||||
else:
|
||||
original_upper.append(np.nan)
|
||||
original_middle.append(np.nan)
|
||||
original_lower.append(np.nan)
|
||||
|
||||
if new_bb.is_warmed_up():
|
||||
new_upper.append(new_bb.get_current_value()['upper_band'])
|
||||
new_middle.append(new_bb.get_current_value()['middle_band'])
|
||||
new_lower.append(new_bb.get_current_value()['lower_band'])
|
||||
else:
|
||||
new_upper.append(np.nan)
|
||||
new_middle.append(np.nan)
|
||||
new_lower.append(np.nan)
|
||||
|
||||
# Store results
|
||||
key = f'BBOHLC_{period}_{std_dev}'
|
||||
self.results[key] = {
|
||||
'original_upper': original_upper,
|
||||
'original_middle': original_middle,
|
||||
'original_lower': original_lower,
|
||||
'new_upper': new_upper,
|
||||
'new_middle': new_middle,
|
||||
'new_lower': new_lower,
|
||||
'prices': self.data['Close'].tolist(),
|
||||
'typical_prices': typical_prices,
|
||||
'highs': self.data['High'].tolist(),
|
||||
'lows': self.data['Low'].tolist(),
|
||||
'dates': self.data['datetime'].tolist(),
|
||||
'period': period,
|
||||
'std_dev': std_dev,
|
||||
'type': 'OHLC'
|
||||
}
|
||||
|
||||
# Calculate differences for each band
|
||||
for band in ['upper', 'middle', 'lower']:
|
||||
orig = np.array(locals()[f'original_{band}'])
|
||||
new = np.array(locals()[f'new_{band}'])
|
||||
diff = new - orig
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
|
||||
print(f" {band.capitalize()} band - Max diff: {max_diff:.12f}, Mean diff: {mean_diff:.12f}")
|
||||
|
||||
# Status check for this band
|
||||
if max_diff < 1e-10:
|
||||
status = "✅ PASSED"
|
||||
elif max_diff < 1e-6:
|
||||
status = "⚠️ WARNING"
|
||||
else:
|
||||
status = "❌ FAILED"
|
||||
print(f" Status: {status}")
|
||||
else:
|
||||
print(f" {band.capitalize()} band - ❌ ERROR: No valid data points")
|
||||
|
||||
def plot_comparison(self, indicator_name: str):
|
||||
"""Plot detailed comparison for a specific indicator."""
|
||||
if indicator_name not in self.results:
|
||||
print(f"No results found for {indicator_name}")
|
||||
return
|
||||
|
||||
result = self.results[indicator_name]
|
||||
dates = pd.to_datetime(result['dates'])
|
||||
|
||||
# Create figure with subplots
|
||||
fig, axes = plt.subplots(4, 1, figsize=(15, 16))
|
||||
fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16)
|
||||
|
||||
# Plot 1: Price and Bollinger Bands
|
||||
ax1 = axes[0]
|
||||
if result['type'] == 'OHLC':
|
||||
ax1.plot(dates, result['typical_prices'], label='Typical Price', alpha=0.7, color='black', linewidth=1)
|
||||
else:
|
||||
ax1.plot(dates, result['prices'], label='Close Price', alpha=0.7, color='black', linewidth=1)
|
||||
|
||||
ax1.plot(dates, result['original_upper'], label='Original Upper', alpha=0.8, color='red')
|
||||
ax1.plot(dates, result['original_middle'], label='Original Middle', alpha=0.8, color='blue')
|
||||
ax1.plot(dates, result['original_lower'], label='Original Lower', alpha=0.8, color='green')
|
||||
ax1.fill_between(dates, result['original_upper'], result['original_lower'], alpha=0.1, color='gray')
|
||||
ax1.set_title(f'{indicator_name} - Original Implementation')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 2: New implementation
|
||||
ax2 = axes[1]
|
||||
if result['type'] == 'OHLC':
|
||||
ax2.plot(dates, result['typical_prices'], label='Typical Price', alpha=0.7, color='black', linewidth=1)
|
||||
else:
|
||||
ax2.plot(dates, result['prices'], label='Close Price', alpha=0.7, color='black', linewidth=1)
|
||||
|
||||
ax2.plot(dates, result['new_upper'], label='New Upper', alpha=0.8, color='red', linestyle='--')
|
||||
ax2.plot(dates, result['new_middle'], label='New Middle', alpha=0.8, color='blue', linestyle='--')
|
||||
ax2.plot(dates, result['new_lower'], label='New Lower', alpha=0.8, color='green', linestyle='--')
|
||||
ax2.fill_between(dates, result['new_upper'], result['new_lower'], alpha=0.1, color='gray')
|
||||
ax2.set_title(f'{indicator_name} - New Implementation')
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 3: Overlay comparison
|
||||
ax3 = axes[2]
|
||||
ax3.plot(dates, result['original_upper'], label='Original Upper', alpha=0.8, color='red')
|
||||
ax3.plot(dates, result['original_middle'], label='Original Middle', alpha=0.8, color='blue')
|
||||
ax3.plot(dates, result['original_lower'], label='Original Lower', alpha=0.8, color='green')
|
||||
ax3.plot(dates, result['new_upper'], label='New Upper', alpha=0.8, color='red', linestyle='--')
|
||||
ax3.plot(dates, result['new_middle'], label='New Middle', alpha=0.8, color='blue', linestyle='--')
|
||||
ax3.plot(dates, result['new_lower'], label='New Lower', alpha=0.8, color='green', linestyle='--')
|
||||
ax3.set_title(f'{indicator_name} - Overlay Comparison')
|
||||
ax3.legend()
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 4: Differences for all bands
|
||||
ax4 = axes[3]
|
||||
for band, color in [('upper', 'red'), ('middle', 'blue'), ('lower', 'green')]:
|
||||
orig = np.array(result[f'original_{band}'])
|
||||
new = np.array(result[f'new_{band}'])
|
||||
diff = new - orig
|
||||
ax4.plot(dates, diff, label=f'{band.capitalize()} diff', alpha=0.7, color=color)
|
||||
|
||||
ax4.set_title(f'{indicator_name} Differences (New - Original)')
|
||||
ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5)
|
||||
ax4.legend()
|
||||
ax4.grid(True, alpha=0.3)
|
||||
|
||||
# Add statistics text
|
||||
stats_lines = []
|
||||
for band in ['upper', 'middle', 'lower']:
|
||||
orig = np.array(result[f'original_{band}'])
|
||||
new = np.array(result[f'new_{band}'])
|
||||
diff = new - orig
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
if len(valid_diff) > 0:
|
||||
stats_lines.append(f'{band.capitalize()}: Max={np.max(np.abs(valid_diff)):.2e}')
|
||||
|
||||
stats_text = '\n'.join(stats_lines)
|
||||
ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
||||
|
||||
# Format x-axis
|
||||
for ax in axes:
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
||||
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
|
||||
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save plot
|
||||
plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png"
|
||||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Plot saved to {plot_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def plot_all_comparisons(self):
|
||||
"""Plot comparisons for all tested indicators."""
|
||||
print("\n=== Generating Detailed Comparison Plots ===")
|
||||
|
||||
for indicator_name in self.results.keys():
|
||||
print(f"Plotting {indicator_name}...")
|
||||
self.plot_comparison(indicator_name)
|
||||
plt.close('all')
|
||||
|
||||
def generate_report(self):
|
||||
"""Generate detailed report for Bollinger Bands indicators."""
|
||||
print("\n=== Generating Bollinger Bands Report ===")
|
||||
|
||||
report_lines = []
|
||||
report_lines.append("# Bollinger Bands Indicators Comparison Report")
|
||||
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
report_lines.append(f"Data file: {self.data_file}")
|
||||
report_lines.append(f"Sample size: {len(self.data)} data points")
|
||||
report_lines.append("")
|
||||
|
||||
# Summary table
|
||||
report_lines.append("## Summary Table")
|
||||
report_lines.append("| Indicator | Period | Std Dev | Upper Max Diff | Middle Max Diff | Lower Max Diff | Status |")
|
||||
report_lines.append("|-----------|--------|---------|----------------|-----------------|----------------|--------|")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
max_diffs = []
|
||||
for band in ['upper', 'middle', 'lower']:
|
||||
orig = np.array(result[f'original_{band}'])
|
||||
new = np.array(result[f'new_{band}'])
|
||||
diff = new - orig
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
max_diffs.append(max_diff)
|
||||
else:
|
||||
max_diffs.append(float('inf'))
|
||||
|
||||
overall_max = max(max_diffs) if max_diffs else float('inf')
|
||||
|
||||
if overall_max < 1e-10:
|
||||
status = "✅ PASSED"
|
||||
elif overall_max < 1e-6:
|
||||
status = "⚠️ WARNING"
|
||||
else:
|
||||
status = "❌ FAILED"
|
||||
|
||||
max_diff_strs = [f"{d:.2e}" if d != float('inf') else "N/A" for d in max_diffs]
|
||||
report_lines.append(f"| {indicator_name} | {result['period']} | {result['std_dev']} | "
|
||||
f"{max_diff_strs[0]} | {max_diff_strs[1]} | {max_diff_strs[2]} | {status} |")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Methodology explanation
|
||||
report_lines.append("## Methodology")
|
||||
report_lines.append("### Bollinger Bands (Close Price)")
|
||||
report_lines.append("- **Middle Band**: Simple Moving Average of Close prices")
|
||||
report_lines.append("- **Upper Band**: Middle Band + (Standard Deviation × Multiplier)")
|
||||
report_lines.append("- **Lower Band**: Middle Band - (Standard Deviation × Multiplier)")
|
||||
report_lines.append("- Uses Close price for all calculations")
|
||||
report_lines.append("")
|
||||
report_lines.append("### Bollinger Bands OHLC (Typical Price)")
|
||||
report_lines.append("- **Typical Price**: (High + Low + Close) / 3")
|
||||
report_lines.append("- **Middle Band**: Simple Moving Average of Typical prices")
|
||||
report_lines.append("- **Upper Band**: Middle Band + (Standard Deviation × Multiplier)")
|
||||
report_lines.append("- **Lower Band**: Middle Band - (Standard Deviation × Multiplier)")
|
||||
report_lines.append("- Uses Typical price for all calculations")
|
||||
report_lines.append("")
|
||||
|
||||
# Detailed analysis
|
||||
report_lines.append("## Detailed Analysis")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
report_lines.append(f"### {indicator_name}")
|
||||
|
||||
report_lines.append(f"- **Type**: {result['type']}")
|
||||
report_lines.append(f"- **Period**: {result['period']}")
|
||||
report_lines.append(f"- **Standard Deviation Multiplier**: {result['std_dev']}")
|
||||
|
||||
for band in ['upper', 'middle', 'lower']:
|
||||
orig = np.array(result[f'original_{band}'])
|
||||
new = np.array(result[f'new_{band}'])
|
||||
diff = new - orig
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
report_lines.append(f"- **{band.capitalize()} Band Analysis**:")
|
||||
report_lines.append(f" - Valid data points: {len(valid_diff)}")
|
||||
report_lines.append(f" - Max absolute difference: {np.max(np.abs(valid_diff)):.12f}")
|
||||
report_lines.append(f" - Mean absolute difference: {np.mean(np.abs(valid_diff)):.12f}")
|
||||
report_lines.append(f" - Standard deviation: {np.std(valid_diff):.12f}")
|
||||
|
||||
# Band-specific metrics
|
||||
valid_original = orig[~np.isnan(orig)]
|
||||
if len(valid_original) > 0:
|
||||
mean_value = np.mean(valid_original)
|
||||
relative_error = np.mean(np.abs(valid_diff)) / mean_value * 100
|
||||
report_lines.append(f" - Mean {band} value: {mean_value:.6f}")
|
||||
report_lines.append(f" - Relative error: {relative_error:.2e}%")
|
||||
|
||||
# Band width analysis
|
||||
orig_width = np.array(result['original_upper']) - np.array(result['original_lower'])
|
||||
new_width = np.array(result['new_upper']) - np.array(result['new_lower'])
|
||||
width_diff = new_width - orig_width
|
||||
valid_width_diff = width_diff[~np.isnan(width_diff)]
|
||||
|
||||
if len(valid_width_diff) > 0:
|
||||
report_lines.append(f"- **Band Width Analysis**:")
|
||||
report_lines.append(f" - Max width difference: {np.max(np.abs(valid_width_diff)):.12f}")
|
||||
report_lines.append(f" - Mean width difference: {np.mean(np.abs(valid_width_diff)):.12f}")
|
||||
|
||||
# Squeeze detection (when bands are narrow)
|
||||
valid_orig_width = orig_width[~np.isnan(orig_width)]
|
||||
if len(valid_orig_width) > 0:
|
||||
width_percentile_20 = np.percentile(valid_orig_width, 20)
|
||||
squeeze_periods = np.sum(valid_orig_width < width_percentile_20)
|
||||
report_lines.append(f" - Squeeze periods (width < 20th percentile): {squeeze_periods}")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Save report
|
||||
report_path = self.results_dir / "bollinger_bands_report.md"
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(report_lines))
|
||||
|
||||
print(f"Report saved to {report_path}")
|
||||
|
||||
def run_tests(self):
|
||||
"""Run all Bollinger Bands tests."""
|
||||
print("Starting Bollinger Bands Comparison Tests...")
|
||||
|
||||
# Load data
|
||||
self.load_data()
|
||||
|
||||
# Run tests
|
||||
self.test_bollinger_bands()
|
||||
self.test_bollinger_bands_ohlc()
|
||||
|
||||
# Generate outputs
|
||||
self.plot_all_comparisons()
|
||||
self.generate_report()
|
||||
|
||||
print("\n✅ Bollinger Bands tests completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tester = BollingerBandsComparisonTest(sample_size=3000)
|
||||
tester.run_tests()
|
||||
@@ -1,610 +0,0 @@
|
||||
"""
|
||||
Comprehensive Indicator Comparison Test Suite
|
||||
|
||||
This module provides testing framework to compare original indicators from cycles module
|
||||
with new implementations in IncrementalTrader module to ensure mathematical equivalence.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Import original indicators
|
||||
from cycles.IncStrategies.indicators import (
|
||||
MovingAverageState as OriginalMA,
|
||||
ExponentialMovingAverageState as OriginalEMA,
|
||||
ATRState as OriginalATR,
|
||||
SimpleATRState as OriginalSimpleATR,
|
||||
SupertrendState as OriginalSupertrend,
|
||||
RSIState as OriginalRSI,
|
||||
SimpleRSIState as OriginalSimpleRSI,
|
||||
BollingerBandsState as OriginalBB,
|
||||
BollingerBandsOHLCState as OriginalBBOHLC
|
||||
)
|
||||
|
||||
# Import new indicators
|
||||
from IncrementalTrader.strategies.indicators import (
|
||||
MovingAverageState as NewMA,
|
||||
ExponentialMovingAverageState as NewEMA,
|
||||
ATRState as NewATR,
|
||||
SimpleATRState as NewSimpleATR,
|
||||
SupertrendState as NewSupertrend,
|
||||
RSIState as NewRSI,
|
||||
SimpleRSIState as NewSimpleRSI,
|
||||
BollingerBandsState as NewBB,
|
||||
BollingerBandsOHLCState as NewBBOHLC
|
||||
)
|
||||
|
||||
|
||||
class IndicatorComparisonTester:
|
||||
"""Test framework for comparing original and new indicator implementations."""
|
||||
|
||||
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 10000):
|
||||
"""
|
||||
Initialize the tester with data.
|
||||
|
||||
Args:
|
||||
data_file: Path to the CSV data file
|
||||
sample_size: Number of data points to use for testing (None for all data)
|
||||
"""
|
||||
self.data_file = data_file
|
||||
self.sample_size = sample_size
|
||||
self.data = None
|
||||
self.results = {}
|
||||
|
||||
# Create results directory
|
||||
self.results_dir = Path("test/results")
|
||||
self.results_dir.mkdir(exist_ok=True)
|
||||
|
||||
def load_data(self):
|
||||
"""Load and prepare the data for testing."""
|
||||
print(f"Loading data from {self.data_file}...")
|
||||
|
||||
# Load data
|
||||
df = pd.read_csv(self.data_file)
|
||||
|
||||
# Convert timestamp to datetime
|
||||
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||
|
||||
# Take sample if specified
|
||||
if self.sample_size and len(df) > self.sample_size:
|
||||
# Take the most recent data
|
||||
df = df.tail(self.sample_size).reset_index(drop=True)
|
||||
|
||||
self.data = df
|
||||
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
|
||||
|
||||
def compare_moving_averages(self, periods=[20, 50]):
|
||||
"""Compare Moving Average implementations."""
|
||||
print("\n=== Testing Moving Averages ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing MA({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_ma = OriginalMA(period)
|
||||
new_ma = NewMA(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
|
||||
original_ma.update(price)
|
||||
new_ma.update(price)
|
||||
|
||||
original_values.append(original_ma.get_current_value() if original_ma.is_warmed_up() else np.nan)
|
||||
new_values.append(new_ma.get_current_value() if new_ma.is_warmed_up() else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'MA_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_exponential_moving_averages(self, periods=[20, 50]):
|
||||
"""Compare Exponential Moving Average implementations."""
|
||||
print("\n=== Testing Exponential Moving Averages ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing EMA({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_ema = OriginalEMA(period)
|
||||
new_ema = NewEMA(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
|
||||
original_ema.update(price)
|
||||
new_ema.update(price)
|
||||
|
||||
original_values.append(original_ema.value if original_ema.is_ready else np.nan)
|
||||
new_values.append(new_ema.value if new_ema.is_ready else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'EMA_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_atr(self, periods=[14]):
|
||||
"""Compare ATR implementations."""
|
||||
print("\n=== Testing ATR ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing ATR({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_atr = OriginalATR(period)
|
||||
new_atr = NewATR(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
high, low, close = row['High'], row['Low'], row['Close']
|
||||
|
||||
original_atr.update(high, low, close)
|
||||
new_atr.update(high, low, close)
|
||||
|
||||
original_values.append(original_atr.value if original_atr.is_ready else np.nan)
|
||||
new_values.append(new_atr.value if new_atr.is_ready else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'ATR_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_simple_atr(self, periods=[14]):
|
||||
"""Compare Simple ATR implementations."""
|
||||
print("\n=== Testing Simple ATR ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing SimpleATR({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_atr = OriginalSimpleATR(period)
|
||||
new_atr = NewSimpleATR(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
high, low, close = row['High'], row['Low'], row['Close']
|
||||
|
||||
original_atr.update(high, low, close)
|
||||
new_atr.update(high, low, close)
|
||||
|
||||
original_values.append(original_atr.value if original_atr.is_ready else np.nan)
|
||||
new_values.append(new_atr.value if new_atr.is_ready else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'SimpleATR_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_supertrend(self, periods=[10], multipliers=[3.0]):
|
||||
"""Compare Supertrend implementations."""
|
||||
print("\n=== Testing Supertrend ===")
|
||||
|
||||
for period in periods:
|
||||
for multiplier in multipliers:
|
||||
print(f"Testing Supertrend({period}, {multiplier})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_st = OriginalSupertrend(period, multiplier)
|
||||
new_st = NewSupertrend(period, multiplier)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
original_trends = []
|
||||
new_trends = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
high, low, close = row['High'], row['Low'], row['Close']
|
||||
|
||||
original_st.update(high, low, close)
|
||||
new_st.update(high, low, close)
|
||||
|
||||
original_values.append(original_st.value if original_st.is_ready else np.nan)
|
||||
new_values.append(new_st.value if new_st.is_ready else np.nan)
|
||||
original_trends.append(original_st.trend if original_st.is_ready else 0)
|
||||
new_trends.append(new_st.trend if new_st.is_ready else 0)
|
||||
|
||||
# Store results
|
||||
key = f'Supertrend_{period}_{multiplier}'
|
||||
self.results[key] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'original_trend': original_trends,
|
||||
'new_trend': new_trends,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
trend_diff = np.array(new_trends) - np.array(original_trends)
|
||||
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Trend match: {trend_matches:.2f}%")
|
||||
|
||||
def compare_rsi(self, periods=[14]):
|
||||
"""Compare RSI implementations."""
|
||||
print("\n=== Testing RSI ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing RSI({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_rsi = OriginalRSI(period)
|
||||
new_rsi = NewRSI(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
|
||||
original_rsi.update(price)
|
||||
new_rsi.update(price)
|
||||
|
||||
original_values.append(original_rsi.value if original_rsi.is_ready else np.nan)
|
||||
new_values.append(new_rsi.value if new_rsi.is_ready else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'RSI_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_simple_rsi(self, periods=[14]):
|
||||
"""Compare Simple RSI implementations."""
|
||||
print("\n=== Testing Simple RSI ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing SimpleRSI({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_rsi = OriginalSimpleRSI(period)
|
||||
new_rsi = NewSimpleRSI(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
|
||||
original_rsi.update(price)
|
||||
new_rsi.update(price)
|
||||
|
||||
original_values.append(original_rsi.value if original_rsi.is_ready else np.nan)
|
||||
new_values.append(new_rsi.value if new_rsi.is_ready else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'SimpleRSI_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_bollinger_bands(self, periods=[20], std_devs=[2.0]):
|
||||
"""Compare Bollinger Bands implementations."""
|
||||
print("\n=== Testing Bollinger Bands ===")
|
||||
|
||||
for period in periods:
|
||||
for std_dev in std_devs:
|
||||
print(f"Testing BollingerBands({period}, {std_dev})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_bb = OriginalBB(period, std_dev)
|
||||
new_bb = NewBB(period, std_dev)
|
||||
|
||||
original_upper = []
|
||||
original_middle = []
|
||||
original_lower = []
|
||||
new_upper = []
|
||||
new_middle = []
|
||||
new_lower = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
|
||||
original_bb.update(price)
|
||||
new_bb.update(price)
|
||||
|
||||
if original_bb.is_ready:
|
||||
original_upper.append(original_bb.upper)
|
||||
original_middle.append(original_bb.middle)
|
||||
original_lower.append(original_bb.lower)
|
||||
else:
|
||||
original_upper.append(np.nan)
|
||||
original_middle.append(np.nan)
|
||||
original_lower.append(np.nan)
|
||||
|
||||
if new_bb.is_ready:
|
||||
new_upper.append(new_bb.upper)
|
||||
new_middle.append(new_bb.middle)
|
||||
new_lower.append(new_bb.lower)
|
||||
else:
|
||||
new_upper.append(np.nan)
|
||||
new_middle.append(np.nan)
|
||||
new_lower.append(np.nan)
|
||||
|
||||
# Store results
|
||||
key = f'BB_{period}_{std_dev}'
|
||||
self.results[key] = {
|
||||
'original_upper': original_upper,
|
||||
'original_middle': original_middle,
|
||||
'original_lower': original_lower,
|
||||
'new_upper': new_upper,
|
||||
'new_middle': new_middle,
|
||||
'new_lower': new_lower,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
for band in ['upper', 'middle', 'lower']:
|
||||
orig = np.array(locals()[f'original_{band}'])
|
||||
new = np.array(locals()[f'new_{band}'])
|
||||
diff = new - orig
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" {band.capitalize()} band - Max diff: {np.max(np.abs(valid_diff)):.10f}, "
|
||||
f"Mean diff: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
|
||||
def plot_comparison(self, indicator_name: str, save_plot: bool = True):
|
||||
"""Plot comparison between original and new indicator implementations."""
|
||||
if indicator_name not in self.results:
|
||||
print(f"No results found for {indicator_name}")
|
||||
return
|
||||
|
||||
result = self.results[indicator_name]
|
||||
dates = pd.to_datetime(result['dates'])
|
||||
|
||||
# Create figure
|
||||
fig, axes = plt.subplots(2, 1, figsize=(15, 10))
|
||||
fig.suptitle(f'{indicator_name} - Original vs New Implementation Comparison', fontsize=16)
|
||||
|
||||
# Plot 1: Overlay comparison
|
||||
ax1 = axes[0]
|
||||
|
||||
if 'original' in result and 'new' in result:
|
||||
# Standard indicator comparison
|
||||
ax1.plot(dates, result['original'], label='Original', alpha=0.7, linewidth=1)
|
||||
ax1.plot(dates, result['new'], label='New', alpha=0.7, linewidth=1, linestyle='--')
|
||||
ax1.set_title(f'{indicator_name} Values Comparison')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 2: Difference
|
||||
ax2 = axes[1]
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
ax2.plot(dates, diff, color='red', alpha=0.7)
|
||||
ax2.set_title(f'{indicator_name} Difference (New - Original)')
|
||||
ax2.axhline(y=0, color='black', linestyle='-', alpha=0.5)
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
elif 'original_upper' in result:
|
||||
# Bollinger Bands comparison
|
||||
ax1.plot(dates, result['original_upper'], label='Original Upper', alpha=0.7)
|
||||
ax1.plot(dates, result['original_middle'], label='Original Middle', alpha=0.7)
|
||||
ax1.plot(dates, result['original_lower'], label='Original Lower', alpha=0.7)
|
||||
ax1.plot(dates, result['new_upper'], label='New Upper', alpha=0.7, linestyle='--')
|
||||
ax1.plot(dates, result['new_middle'], label='New Middle', alpha=0.7, linestyle='--')
|
||||
ax1.plot(dates, result['new_lower'], label='New Lower', alpha=0.7, linestyle='--')
|
||||
ax1.set_title(f'{indicator_name} Bollinger Bands Comparison')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 2: Differences for all bands
|
||||
ax2 = axes[1]
|
||||
for band in ['upper', 'middle', 'lower']:
|
||||
orig = np.array(result[f'original_{band}'])
|
||||
new = np.array(result[f'new_{band}'])
|
||||
diff = new - orig
|
||||
ax2.plot(dates, diff, label=f'{band.capitalize()} diff', alpha=0.7)
|
||||
ax2.set_title(f'{indicator_name} Differences (New - Original)')
|
||||
ax2.axhline(y=0, color='black', linestyle='-', alpha=0.5)
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# Format x-axis
|
||||
for ax in axes:
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
||||
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
|
||||
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_plot:
|
||||
plot_path = self.results_dir / f"{indicator_name}_comparison.png"
|
||||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Plot saved to {plot_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def plot_all_comparisons(self):
|
||||
"""Plot comparisons for all tested indicators."""
|
||||
print("\n=== Generating Comparison Plots ===")
|
||||
|
||||
for indicator_name in self.results.keys():
|
||||
print(f"Plotting {indicator_name}...")
|
||||
self.plot_comparison(indicator_name, save_plot=True)
|
||||
plt.close('all') # Close plots to save memory
|
||||
|
||||
def generate_summary_report(self):
|
||||
"""Generate a summary report of all comparisons."""
|
||||
print("\n=== Summary Report ===")
|
||||
|
||||
report_lines = []
|
||||
report_lines.append("# Indicator Comparison Summary Report")
|
||||
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
report_lines.append(f"Data file: {self.data_file}")
|
||||
report_lines.append(f"Sample size: {len(self.data)} data points")
|
||||
report_lines.append("")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
report_lines.append(f"## {indicator_name}")
|
||||
|
||||
if 'original' in result and 'new' in result:
|
||||
# Standard indicator
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
report_lines.append(f"- Max absolute difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
report_lines.append(f"- Mean absolute difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
report_lines.append(f"- Standard deviation: {np.std(valid_diff):.10f}")
|
||||
report_lines.append(f"- Valid data points: {len(valid_diff)}")
|
||||
|
||||
# Check if differences are negligible
|
||||
if np.max(np.abs(valid_diff)) < 1e-10:
|
||||
report_lines.append("- ✅ **PASSED**: Implementations are mathematically equivalent")
|
||||
elif np.max(np.abs(valid_diff)) < 1e-6:
|
||||
report_lines.append("- ⚠️ **WARNING**: Small differences detected (likely floating point precision)")
|
||||
else:
|
||||
report_lines.append("- ❌ **FAILED**: Significant differences detected")
|
||||
else:
|
||||
report_lines.append("- ❌ **ERROR**: No valid data points for comparison")
|
||||
|
||||
elif 'original_upper' in result:
|
||||
# Bollinger Bands
|
||||
all_passed = True
|
||||
for band in ['upper', 'middle', 'lower']:
|
||||
orig = np.array(result[f'original_{band}'])
|
||||
new = np.array(result[f'new_{band}'])
|
||||
diff = new - orig
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
report_lines.append(f"- {band.capitalize()} band max diff: {max_diff:.10f}")
|
||||
if max_diff >= 1e-6:
|
||||
all_passed = False
|
||||
|
||||
if all_passed:
|
||||
report_lines.append("- ✅ **PASSED**: All bands are mathematically equivalent")
|
||||
else:
|
||||
report_lines.append("- ❌ **FAILED**: Significant differences in one or more bands")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Save report
|
||||
report_path = self.results_dir / "comparison_summary.md"
|
||||
with open(report_path, 'w') as f:
|
||||
f.write('\n'.join(report_lines))
|
||||
|
||||
print(f"Summary report saved to {report_path}")
|
||||
|
||||
# Print summary to console
|
||||
print('\n'.join(report_lines))
|
||||
|
||||
def run_all_tests(self):
|
||||
"""Run all indicator comparison tests."""
|
||||
print("Starting comprehensive indicator comparison tests...")
|
||||
|
||||
# Load data
|
||||
self.load_data()
|
||||
|
||||
# Run all comparisons
|
||||
self.compare_moving_averages()
|
||||
self.compare_exponential_moving_averages()
|
||||
self.compare_atr()
|
||||
self.compare_simple_atr()
|
||||
self.compare_supertrend()
|
||||
self.compare_rsi()
|
||||
self.compare_simple_rsi()
|
||||
self.compare_bollinger_bands()
|
||||
|
||||
# Generate plots and reports
|
||||
self.plot_all_comparisons()
|
||||
self.generate_summary_report()
|
||||
|
||||
print("\n✅ All tests completed! Check the test/results/ directory for detailed outputs.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the comprehensive test suite
|
||||
tester = IndicatorComparisonTester(sample_size=5000) # Use 5000 data points for faster testing
|
||||
tester.run_all_tests()
|
||||
@@ -1,549 +0,0 @@
|
||||
"""
|
||||
Comprehensive Indicator Comparison Test Suite (Fixed Interface)
|
||||
|
||||
This module provides testing framework to compare original indicators from cycles module
|
||||
with new implementations in IncrementalTrader module to ensure mathematical equivalence.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Import original indicators
|
||||
from cycles.IncStrategies.indicators import (
|
||||
MovingAverageState as OriginalMA,
|
||||
ExponentialMovingAverageState as OriginalEMA,
|
||||
ATRState as OriginalATR,
|
||||
SimpleATRState as OriginalSimpleATR,
|
||||
SupertrendState as OriginalSupertrend,
|
||||
RSIState as OriginalRSI,
|
||||
SimpleRSIState as OriginalSimpleRSI,
|
||||
BollingerBandsState as OriginalBB,
|
||||
BollingerBandsOHLCState as OriginalBBOHLC
|
||||
)
|
||||
|
||||
# Import new indicators
|
||||
from IncrementalTrader.strategies.indicators import (
|
||||
MovingAverageState as NewMA,
|
||||
ExponentialMovingAverageState as NewEMA,
|
||||
ATRState as NewATR,
|
||||
SimpleATRState as NewSimpleATR,
|
||||
SupertrendState as NewSupertrend,
|
||||
RSIState as NewRSI,
|
||||
SimpleRSIState as NewSimpleRSI,
|
||||
BollingerBandsState as NewBB,
|
||||
BollingerBandsOHLCState as NewBBOHLC
|
||||
)
|
||||
|
||||
|
||||
class IndicatorComparisonTester:
|
||||
"""Test framework for comparing original and new indicator implementations."""
|
||||
|
||||
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
|
||||
"""
|
||||
Initialize the tester with data.
|
||||
|
||||
Args:
|
||||
data_file: Path to the CSV data file
|
||||
sample_size: Number of data points to use for testing (None for all data)
|
||||
"""
|
||||
self.data_file = data_file
|
||||
self.sample_size = sample_size
|
||||
self.data = None
|
||||
self.results = {}
|
||||
|
||||
# Create results directory
|
||||
self.results_dir = Path("test/results")
|
||||
self.results_dir.mkdir(exist_ok=True)
|
||||
|
||||
def load_data(self):
|
||||
"""Load and prepare the data for testing."""
|
||||
print(f"Loading data from {self.data_file}...")
|
||||
|
||||
# Load data
|
||||
df = pd.read_csv(self.data_file)
|
||||
|
||||
# Convert timestamp to datetime
|
||||
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||
|
||||
# Take sample if specified
|
||||
if self.sample_size and len(df) > self.sample_size:
|
||||
# Take the most recent data
|
||||
df = df.tail(self.sample_size).reset_index(drop=True)
|
||||
|
||||
self.data = df
|
||||
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
|
||||
|
||||
def compare_moving_averages(self, periods=[20, 50]):
|
||||
"""Compare Moving Average implementations."""
|
||||
print("\n=== Testing Moving Averages ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing MA({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_ma = OriginalMA(period)
|
||||
new_ma = NewMA(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
|
||||
original_ma.update(price)
|
||||
new_ma.update(price)
|
||||
|
||||
original_values.append(original_ma.get_current_value() if original_ma.is_warmed_up() else np.nan)
|
||||
new_values.append(new_ma.get_current_value() if new_ma.is_warmed_up() else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'MA_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_exponential_moving_averages(self, periods=[20, 50]):
|
||||
"""Compare Exponential Moving Average implementations."""
|
||||
print("\n=== Testing Exponential Moving Averages ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing EMA({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_ema = OriginalEMA(period)
|
||||
new_ema = NewEMA(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
|
||||
original_ema.update(price)
|
||||
new_ema.update(price)
|
||||
|
||||
original_values.append(original_ema.get_current_value() if original_ema.is_warmed_up() else np.nan)
|
||||
new_values.append(new_ema.get_current_value() if new_ema.is_warmed_up() else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'EMA_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_atr(self, periods=[14]):
|
||||
"""Compare ATR implementations."""
|
||||
print("\n=== Testing ATR ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing ATR({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_atr = OriginalATR(period)
|
||||
new_atr = NewATR(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
high, low, close = row['High'], row['Low'], row['Close']
|
||||
ohlc = {'open': close, 'high': high, 'low': low, 'close': close}
|
||||
|
||||
original_atr.update(ohlc)
|
||||
new_atr.update(ohlc)
|
||||
|
||||
original_values.append(original_atr.get_current_value() if original_atr.is_warmed_up() else np.nan)
|
||||
new_values.append(new_atr.get_current_value() if new_atr.is_warmed_up() else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'ATR_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_simple_atr(self, periods=[14]):
|
||||
"""Compare Simple ATR implementations."""
|
||||
print("\n=== Testing Simple ATR ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing SimpleATR({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_atr = OriginalSimpleATR(period)
|
||||
new_atr = NewSimpleATR(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
high, low, close = row['High'], row['Low'], row['Close']
|
||||
ohlc = {'open': close, 'high': high, 'low': low, 'close': close}
|
||||
|
||||
original_atr.update(ohlc)
|
||||
new_atr.update(ohlc)
|
||||
|
||||
original_values.append(original_atr.get_current_value() if original_atr.is_warmed_up() else np.nan)
|
||||
new_values.append(new_atr.get_current_value() if new_atr.is_warmed_up() else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'SimpleATR_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_supertrend(self, periods=[10], multipliers=[3.0]):
|
||||
"""Compare Supertrend implementations."""
|
||||
print("\n=== Testing Supertrend ===")
|
||||
|
||||
for period in periods:
|
||||
for multiplier in multipliers:
|
||||
print(f"Testing Supertrend({period}, {multiplier})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_st = OriginalSupertrend(period, multiplier)
|
||||
new_st = NewSupertrend(period, multiplier)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
original_trends = []
|
||||
new_trends = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
high, low, close = row['High'], row['Low'], row['Close']
|
||||
ohlc = {'open': close, 'high': high, 'low': low, 'close': close}
|
||||
|
||||
original_st.update(ohlc)
|
||||
new_st.update(ohlc)
|
||||
|
||||
# Get current values
|
||||
orig_result = original_st.get_current_value() if original_st.is_warmed_up() else None
|
||||
new_result = new_st.get_current_value() if new_st.is_warmed_up() else None
|
||||
|
||||
if orig_result:
|
||||
original_values.append(orig_result['supertrend'])
|
||||
original_trends.append(orig_result['trend'])
|
||||
else:
|
||||
original_values.append(np.nan)
|
||||
original_trends.append(0)
|
||||
|
||||
if new_result:
|
||||
new_values.append(new_result['supertrend'])
|
||||
new_trends.append(new_result['trend'])
|
||||
else:
|
||||
new_values.append(np.nan)
|
||||
new_trends.append(0)
|
||||
|
||||
# Store results
|
||||
key = f'Supertrend_{period}_{multiplier}'
|
||||
self.results[key] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'original_trend': original_trends,
|
||||
'new_trend': new_trends,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
trend_diff = np.array(new_trends) - np.array(original_trends)
|
||||
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Trend match: {trend_matches:.2f}%")
|
||||
|
||||
def compare_rsi(self, periods=[14]):
|
||||
"""Compare RSI implementations."""
|
||||
print("\n=== Testing RSI ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing RSI({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_rsi = OriginalRSI(period)
|
||||
new_rsi = NewRSI(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
|
||||
original_rsi.update(price)
|
||||
new_rsi.update(price)
|
||||
|
||||
original_values.append(original_rsi.get_current_value() if original_rsi.is_warmed_up() else np.nan)
|
||||
new_values.append(new_rsi.get_current_value() if new_rsi.is_warmed_up() else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'RSI_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_simple_rsi(self, periods=[14]):
|
||||
"""Compare Simple RSI implementations."""
|
||||
print("\n=== Testing Simple RSI ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing SimpleRSI({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_rsi = OriginalSimpleRSI(period)
|
||||
new_rsi = NewSimpleRSI(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
|
||||
original_rsi.update(price)
|
||||
new_rsi.update(price)
|
||||
|
||||
original_values.append(original_rsi.get_current_value() if original_rsi.is_warmed_up() else np.nan)
|
||||
new_values.append(new_rsi.get_current_value() if new_rsi.is_warmed_up() else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'SimpleRSI_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" Max difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
print(f" Mean difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
print(f" Std difference: {np.std(valid_diff):.10f}")
|
||||
|
||||
def compare_bollinger_bands(self, periods=[20], std_devs=[2.0]):
|
||||
"""Compare Bollinger Bands implementations."""
|
||||
print("\n=== Testing Bollinger Bands ===")
|
||||
|
||||
for period in periods:
|
||||
for std_dev in std_devs:
|
||||
print(f"Testing BollingerBands({period}, {std_dev})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_bb = OriginalBB(period, std_dev)
|
||||
new_bb = NewBB(period, std_dev)
|
||||
|
||||
original_upper = []
|
||||
original_middle = []
|
||||
original_lower = []
|
||||
new_upper = []
|
||||
new_middle = []
|
||||
new_lower = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
|
||||
original_bb.update(price)
|
||||
new_bb.update(price)
|
||||
|
||||
# Get current values
|
||||
orig_result = original_bb.get_current_value() if original_bb.is_warmed_up() else None
|
||||
new_result = new_bb.get_current_value() if new_bb.is_warmed_up() else None
|
||||
|
||||
if orig_result:
|
||||
original_upper.append(orig_result['upper_band'])
|
||||
original_middle.append(orig_result['middle_band'])
|
||||
original_lower.append(orig_result['lower_band'])
|
||||
else:
|
||||
original_upper.append(np.nan)
|
||||
original_middle.append(np.nan)
|
||||
original_lower.append(np.nan)
|
||||
|
||||
if new_result:
|
||||
new_upper.append(new_result['upper_band'])
|
||||
new_middle.append(new_result['middle_band'])
|
||||
new_lower.append(new_result['lower_band'])
|
||||
else:
|
||||
new_upper.append(np.nan)
|
||||
new_middle.append(np.nan)
|
||||
new_lower.append(np.nan)
|
||||
|
||||
# Store results
|
||||
key = f'BB_{period}_{std_dev}'
|
||||
self.results[key] = {
|
||||
'original_upper': original_upper,
|
||||
'original_middle': original_middle,
|
||||
'original_lower': original_lower,
|
||||
'new_upper': new_upper,
|
||||
'new_middle': new_middle,
|
||||
'new_lower': new_lower,
|
||||
'dates': self.data['datetime'].tolist()
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
for band in ['upper', 'middle', 'lower']:
|
||||
orig = np.array(locals()[f'original_{band}'])
|
||||
new = np.array(locals()[f'new_{band}'])
|
||||
diff = new - orig
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
print(f" {band.capitalize()} band - Max diff: {np.max(np.abs(valid_diff)):.10f}, "
|
||||
f"Mean diff: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
|
||||
def generate_summary_report(self):
|
||||
"""Generate a summary report of all comparisons."""
|
||||
print("\n=== Summary Report ===")
|
||||
|
||||
report_lines = []
|
||||
report_lines.append("# Indicator Comparison Summary Report")
|
||||
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
report_lines.append(f"Data file: {self.data_file}")
|
||||
report_lines.append(f"Sample size: {len(self.data)} data points")
|
||||
report_lines.append("")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
report_lines.append(f"## {indicator_name}")
|
||||
|
||||
if 'original' in result and 'new' in result:
|
||||
# Standard indicator
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
report_lines.append(f"- Max absolute difference: {np.max(np.abs(valid_diff)):.10f}")
|
||||
report_lines.append(f"- Mean absolute difference: {np.mean(np.abs(valid_diff)):.10f}")
|
||||
report_lines.append(f"- Standard deviation: {np.std(valid_diff):.10f}")
|
||||
report_lines.append(f"- Valid data points: {len(valid_diff)}")
|
||||
|
||||
# Check if differences are negligible
|
||||
if np.max(np.abs(valid_diff)) < 1e-10:
|
||||
report_lines.append("- ✅ **PASSED**: Implementations are mathematically equivalent")
|
||||
elif np.max(np.abs(valid_diff)) < 1e-6:
|
||||
report_lines.append("- ⚠️ **WARNING**: Small differences detected (likely floating point precision)")
|
||||
else:
|
||||
report_lines.append("- ❌ **FAILED**: Significant differences detected")
|
||||
else:
|
||||
report_lines.append("- ❌ **ERROR**: No valid data points for comparison")
|
||||
|
||||
elif 'original_upper' in result:
|
||||
# Bollinger Bands
|
||||
all_passed = True
|
||||
for band in ['upper', 'middle', 'lower']:
|
||||
orig = np.array(result[f'original_{band}'])
|
||||
new = np.array(result[f'new_{band}'])
|
||||
diff = new - orig
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
report_lines.append(f"- {band.capitalize()} band max diff: {max_diff:.10f}")
|
||||
if max_diff >= 1e-6:
|
||||
all_passed = False
|
||||
|
||||
if all_passed:
|
||||
report_lines.append("- ✅ **PASSED**: All bands are mathematically equivalent")
|
||||
else:
|
||||
report_lines.append("- ❌ **FAILED**: Significant differences in one or more bands")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Save report
|
||||
report_path = self.results_dir / "comparison_summary.md"
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(report_lines))
|
||||
|
||||
print(f"Summary report saved to {report_path}")
|
||||
|
||||
# Print summary to console
|
||||
print('\n'.join(report_lines))
|
||||
|
||||
def run_all_tests(self):
|
||||
"""Run all indicator comparison tests."""
|
||||
print("Starting comprehensive indicator comparison tests...")
|
||||
|
||||
# Load data
|
||||
self.load_data()
|
||||
|
||||
# Run all comparisons
|
||||
self.compare_moving_averages()
|
||||
self.compare_exponential_moving_averages()
|
||||
self.compare_atr()
|
||||
self.compare_simple_atr()
|
||||
self.compare_supertrend()
|
||||
self.compare_rsi()
|
||||
self.compare_simple_rsi()
|
||||
self.compare_bollinger_bands()
|
||||
|
||||
# Generate reports
|
||||
self.generate_summary_report()
|
||||
|
||||
print("\n✅ All tests completed! Check the test/results/ directory for detailed outputs.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the comprehensive test suite
|
||||
tester = IndicatorComparisonTester(sample_size=3000) # Use 3000 data points for faster testing
|
||||
tester.run_all_tests()
|
||||
@@ -1,335 +0,0 @@
|
||||
"""
|
||||
Moving Average Indicators Comparison Test
|
||||
|
||||
Focused testing for Moving Average and Exponential Moving Average implementations.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Import original indicators
|
||||
from cycles.IncStrategies.indicators import (
|
||||
MovingAverageState as OriginalMA,
|
||||
ExponentialMovingAverageState as OriginalEMA
|
||||
)
|
||||
|
||||
# Import new indicators
|
||||
from IncrementalTrader.strategies.indicators import (
|
||||
MovingAverageState as NewMA,
|
||||
ExponentialMovingAverageState as NewEMA
|
||||
)
|
||||
|
||||
|
||||
class MovingAverageComparisonTest:
|
||||
"""Test framework for comparing moving average implementations."""
|
||||
|
||||
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
|
||||
self.data_file = data_file
|
||||
self.sample_size = sample_size
|
||||
self.data = None
|
||||
self.results = {}
|
||||
|
||||
# Create results directory
|
||||
self.results_dir = Path("test/results/moving_averages")
|
||||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def load_data(self):
|
||||
"""Load and prepare the data for testing."""
|
||||
print(f"Loading data from {self.data_file}...")
|
||||
|
||||
df = pd.read_csv(self.data_file)
|
||||
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||
|
||||
if self.sample_size and len(df) > self.sample_size:
|
||||
df = df.tail(self.sample_size).reset_index(drop=True)
|
||||
|
||||
self.data = df
|
||||
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
|
||||
|
||||
def test_simple_moving_average(self, periods=[5, 10, 20, 50, 100]):
|
||||
"""Test Simple Moving Average implementations."""
|
||||
print("\n=== Testing Simple Moving Average ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing SMA({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_ma = OriginalMA(period)
|
||||
new_ma = NewMA(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
prices = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
prices.append(price)
|
||||
|
||||
original_ma.update(price)
|
||||
new_ma.update(price)
|
||||
|
||||
original_values.append(original_ma.get_current_value() if original_ma.is_warmed_up() else np.nan)
|
||||
new_values.append(new_ma.get_current_value() if new_ma.is_warmed_up() else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'SMA_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'prices': prices,
|
||||
'dates': self.data['datetime'].tolist(),
|
||||
'period': period
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
std_diff = np.std(valid_diff)
|
||||
|
||||
print(f" Max difference: {max_diff:.12f}")
|
||||
print(f" Mean difference: {mean_diff:.12f}")
|
||||
print(f" Std difference: {std_diff:.12f}")
|
||||
|
||||
# Status check
|
||||
if max_diff < 1e-10:
|
||||
print(f" ✅ PASSED: Mathematically equivalent")
|
||||
elif max_diff < 1e-6:
|
||||
print(f" ⚠️ WARNING: Small differences (floating point precision)")
|
||||
else:
|
||||
print(f" ❌ FAILED: Significant differences detected")
|
||||
else:
|
||||
print(f" ❌ ERROR: No valid data points")
|
||||
|
||||
def test_exponential_moving_average(self, periods=[5, 10, 20, 50, 100]):
|
||||
"""Test Exponential Moving Average implementations."""
|
||||
print("\n=== Testing Exponential Moving Average ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing EMA({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_ema = OriginalEMA(period)
|
||||
new_ema = NewEMA(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
prices = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
prices.append(price)
|
||||
|
||||
original_ema.update(price)
|
||||
new_ema.update(price)
|
||||
|
||||
original_values.append(original_ema.get_current_value() if original_ema.is_warmed_up() else np.nan)
|
||||
new_values.append(new_ema.get_current_value() if new_ema.is_warmed_up() else np.nan)
|
||||
|
||||
# Store results
|
||||
self.results[f'EMA_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'prices': prices,
|
||||
'dates': self.data['datetime'].tolist(),
|
||||
'period': period
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
std_diff = np.std(valid_diff)
|
||||
|
||||
print(f" Max difference: {max_diff:.12f}")
|
||||
print(f" Mean difference: {mean_diff:.12f}")
|
||||
print(f" Std difference: {std_diff:.12f}")
|
||||
|
||||
# Status check
|
||||
if max_diff < 1e-10:
|
||||
print(f" ✅ PASSED: Mathematically equivalent")
|
||||
elif max_diff < 1e-6:
|
||||
print(f" ⚠️ WARNING: Small differences (floating point precision)")
|
||||
else:
|
||||
print(f" ❌ FAILED: Significant differences detected")
|
||||
else:
|
||||
print(f" ❌ ERROR: No valid data points")
|
||||
|
||||
def plot_comparison(self, indicator_name: str):
|
||||
"""Plot detailed comparison for a specific indicator."""
|
||||
if indicator_name not in self.results:
|
||||
print(f"No results found for {indicator_name}")
|
||||
return
|
||||
|
||||
result = self.results[indicator_name]
|
||||
dates = pd.to_datetime(result['dates'])
|
||||
|
||||
# Create figure with subplots
|
||||
fig, axes = plt.subplots(3, 1, figsize=(15, 12))
|
||||
fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16)
|
||||
|
||||
# Plot 1: Price and indicators
|
||||
ax1 = axes[0]
|
||||
ax1.plot(dates, result['prices'], label='Price', alpha=0.6, color='gray')
|
||||
ax1.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2)
|
||||
ax1.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--')
|
||||
ax1.set_title(f'{indicator_name} vs Price')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 2: Overlay comparison (zoomed)
|
||||
ax2 = axes[1]
|
||||
ax2.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2)
|
||||
ax2.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--')
|
||||
ax2.set_title(f'{indicator_name} Values Comparison (Detailed)')
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 3: Difference analysis
|
||||
ax3 = axes[2]
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
ax3.plot(dates, diff, color='red', alpha=0.7, linewidth=1)
|
||||
ax3.set_title(f'{indicator_name} Difference (New - Original)')
|
||||
ax3.axhline(y=0, color='black', linestyle='-', alpha=0.5)
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# Add statistics text
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
if len(valid_diff) > 0:
|
||||
stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n'
|
||||
stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n'
|
||||
stats_text += f'Std: {np.std(valid_diff):.2e}'
|
||||
ax3.text(0.02, 0.98, stats_text, transform=ax3.transAxes,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
||||
|
||||
# Format x-axis
|
||||
for ax in axes:
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
||||
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
|
||||
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save plot
|
||||
plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png"
|
||||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Plot saved to {plot_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def plot_all_comparisons(self):
|
||||
"""Plot comparisons for all tested indicators."""
|
||||
print("\n=== Generating Detailed Comparison Plots ===")
|
||||
|
||||
for indicator_name in self.results.keys():
|
||||
print(f"Plotting {indicator_name}...")
|
||||
self.plot_comparison(indicator_name)
|
||||
plt.close('all')
|
||||
|
||||
def generate_report(self):
|
||||
"""Generate detailed report for moving averages."""
|
||||
print("\n=== Generating Moving Average Report ===")
|
||||
|
||||
report_lines = []
|
||||
report_lines.append("# Moving Average Indicators Comparison Report")
|
||||
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
report_lines.append(f"Data file: {self.data_file}")
|
||||
report_lines.append(f"Sample size: {len(self.data)} data points")
|
||||
report_lines.append("")
|
||||
|
||||
# Summary table
|
||||
report_lines.append("## Summary Table")
|
||||
report_lines.append("| Indicator | Period | Max Diff | Mean Diff | Status |")
|
||||
report_lines.append("|-----------|--------|----------|-----------|--------|")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
|
||||
if max_diff < 1e-10:
|
||||
status = "✅ PASSED"
|
||||
elif max_diff < 1e-6:
|
||||
status = "⚠️ WARNING"
|
||||
else:
|
||||
status = "❌ FAILED"
|
||||
|
||||
report_lines.append(f"| {indicator_name} | {result['period']} | {max_diff:.2e} | {mean_diff:.2e} | {status} |")
|
||||
else:
|
||||
report_lines.append(f"| {indicator_name} | {result['period']} | N/A | N/A | ❌ ERROR |")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Detailed analysis
|
||||
report_lines.append("## Detailed Analysis")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
report_lines.append(f"### {indicator_name}")
|
||||
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
report_lines.append(f"- **Period**: {result['period']}")
|
||||
report_lines.append(f"- **Valid data points**: {len(valid_diff)}")
|
||||
report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}")
|
||||
report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}")
|
||||
report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}")
|
||||
report_lines.append(f"- **Min difference**: {np.min(valid_diff):.12f}")
|
||||
report_lines.append(f"- **Max difference**: {np.max(valid_diff):.12f}")
|
||||
|
||||
# Percentile analysis
|
||||
percentiles = [1, 5, 25, 50, 75, 95, 99]
|
||||
perc_values = np.percentile(np.abs(valid_diff), percentiles)
|
||||
perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)])
|
||||
report_lines.append(f"- **Percentiles**: {perc_str}")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Save report
|
||||
report_path = self.results_dir / "moving_averages_report.md"
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(report_lines))
|
||||
|
||||
print(f"Report saved to {report_path}")
|
||||
|
||||
def run_tests(self):
|
||||
"""Run all moving average tests."""
|
||||
print("Starting Moving Average Comparison Tests...")
|
||||
|
||||
# Load data
|
||||
self.load_data()
|
||||
|
||||
# Run tests
|
||||
self.test_simple_moving_average()
|
||||
self.test_exponential_moving_average()
|
||||
|
||||
# Generate outputs
|
||||
self.plot_all_comparisons()
|
||||
self.generate_report()
|
||||
|
||||
print("\n✅ Moving Average tests completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tester = MovingAverageComparisonTest(sample_size=3000)
|
||||
tester.run_tests()
|
||||
@@ -1,401 +0,0 @@
|
||||
"""
|
||||
RSI Indicators Comparison Test
|
||||
|
||||
Focused testing for RSI and Simple RSI implementations.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Import original indicators
|
||||
from cycles.IncStrategies.indicators import (
|
||||
RSIState as OriginalRSI,
|
||||
SimpleRSIState as OriginalSimpleRSI
|
||||
)
|
||||
|
||||
# Import new indicators
|
||||
from IncrementalTrader.strategies.indicators import (
|
||||
RSIState as NewRSI,
|
||||
SimpleRSIState as NewSimpleRSI
|
||||
)
|
||||
|
||||
|
||||
class RSIComparisonTest:
|
||||
"""Test framework for comparing RSI implementations."""
|
||||
|
||||
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
|
||||
self.data_file = data_file
|
||||
self.sample_size = sample_size
|
||||
self.data = None
|
||||
self.results = {}
|
||||
|
||||
# Create results directory
|
||||
self.results_dir = Path("test/results/rsi_indicators")
|
||||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def load_data(self):
|
||||
"""Load and prepare the data for testing."""
|
||||
print(f"Loading data from {self.data_file}...")
|
||||
|
||||
df = pd.read_csv(self.data_file)
|
||||
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||
|
||||
if self.sample_size and len(df) > self.sample_size:
|
||||
df = df.tail(self.sample_size).reset_index(drop=True)
|
||||
|
||||
self.data = df
|
||||
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
|
||||
|
||||
def test_rsi(self, periods=[7, 14, 21, 28]):
|
||||
"""Test RSI implementations (Wilder's smoothing)."""
|
||||
print("\n=== Testing RSI (Wilder's Smoothing) ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing RSI({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_rsi = OriginalRSI(period)
|
||||
new_rsi = NewRSI(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
prices = []
|
||||
price_changes = []
|
||||
|
||||
# Process data
|
||||
prev_price = None
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
prices.append(price)
|
||||
|
||||
if prev_price is not None:
|
||||
price_changes.append(price - prev_price)
|
||||
else:
|
||||
price_changes.append(0)
|
||||
|
||||
original_rsi.update(price)
|
||||
new_rsi.update(price)
|
||||
|
||||
original_values.append(original_rsi.get_current_value() if original_rsi.is_warmed_up() else np.nan)
|
||||
new_values.append(new_rsi.get_current_value() if new_rsi.is_warmed_up() else np.nan)
|
||||
|
||||
prev_price = price
|
||||
|
||||
# Store results
|
||||
self.results[f'RSI_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'prices': prices,
|
||||
'price_changes': price_changes,
|
||||
'dates': self.data['datetime'].tolist(),
|
||||
'period': period
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
std_diff = np.std(valid_diff)
|
||||
|
||||
print(f" Max difference: {max_diff:.12f}")
|
||||
print(f" Mean difference: {mean_diff:.12f}")
|
||||
print(f" Std difference: {std_diff:.12f}")
|
||||
|
||||
# Status check
|
||||
if max_diff < 1e-10:
|
||||
print(f" ✅ PASSED: Mathematically equivalent")
|
||||
elif max_diff < 1e-6:
|
||||
print(f" ⚠️ WARNING: Small differences (floating point precision)")
|
||||
else:
|
||||
print(f" ❌ FAILED: Significant differences detected")
|
||||
else:
|
||||
print(f" ❌ ERROR: No valid data points")
|
||||
|
||||
def test_simple_rsi(self, periods=[7, 14, 21, 28]):
|
||||
"""Test Simple RSI implementations (Simple moving average)."""
|
||||
print("\n=== Testing Simple RSI (Simple Moving Average) ===")
|
||||
|
||||
for period in periods:
|
||||
print(f"Testing SimpleRSI({period})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_rsi = OriginalSimpleRSI(period)
|
||||
new_rsi = NewSimpleRSI(period)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
prices = []
|
||||
price_changes = []
|
||||
|
||||
# Process data
|
||||
prev_price = None
|
||||
for _, row in self.data.iterrows():
|
||||
price = row['Close']
|
||||
prices.append(price)
|
||||
|
||||
if prev_price is not None:
|
||||
price_changes.append(price - prev_price)
|
||||
else:
|
||||
price_changes.append(0)
|
||||
|
||||
original_rsi.update(price)
|
||||
new_rsi.update(price)
|
||||
|
||||
original_values.append(original_rsi.get_current_value() if original_rsi.is_warmed_up() else np.nan)
|
||||
new_values.append(new_rsi.get_current_value() if new_rsi.is_warmed_up() else np.nan)
|
||||
|
||||
prev_price = price
|
||||
|
||||
# Store results
|
||||
self.results[f'SimpleRSI_{period}'] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'prices': prices,
|
||||
'price_changes': price_changes,
|
||||
'dates': self.data['datetime'].tolist(),
|
||||
'period': period
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
std_diff = np.std(valid_diff)
|
||||
|
||||
print(f" Max difference: {max_diff:.12f}")
|
||||
print(f" Mean difference: {mean_diff:.12f}")
|
||||
print(f" Std difference: {std_diff:.12f}")
|
||||
|
||||
# Status check
|
||||
if max_diff < 1e-10:
|
||||
print(f" ✅ PASSED: Mathematically equivalent")
|
||||
elif max_diff < 1e-6:
|
||||
print(f" ⚠️ WARNING: Small differences (floating point precision)")
|
||||
else:
|
||||
print(f" ❌ FAILED: Significant differences detected")
|
||||
else:
|
||||
print(f" ❌ ERROR: No valid data points")
|
||||
|
||||
def plot_comparison(self, indicator_name: str):
|
||||
"""Plot detailed comparison for a specific indicator."""
|
||||
if indicator_name not in self.results:
|
||||
print(f"No results found for {indicator_name}")
|
||||
return
|
||||
|
||||
result = self.results[indicator_name]
|
||||
dates = pd.to_datetime(result['dates'])
|
||||
|
||||
# Create figure with subplots
|
||||
fig, axes = plt.subplots(4, 1, figsize=(15, 16))
|
||||
fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16)
|
||||
|
||||
# Plot 1: Price data
|
||||
ax1 = axes[0]
|
||||
ax1.plot(dates, result['prices'], label='Close Price', alpha=0.8, color='black', linewidth=1)
|
||||
ax1.set_title('Price Data')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 2: RSI comparison with levels
|
||||
ax2 = axes[1]
|
||||
ax2.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2, color='blue')
|
||||
ax2.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--', color='red')
|
||||
ax2.axhline(y=70, color='red', linestyle=':', alpha=0.7, label='Overbought (70)')
|
||||
ax2.axhline(y=30, color='green', linestyle=':', alpha=0.7, label='Oversold (30)')
|
||||
ax2.axhline(y=50, color='gray', linestyle='-', alpha=0.5, label='Midline (50)')
|
||||
ax2.set_title(f'{indicator_name} Values Comparison')
|
||||
ax2.set_ylim(0, 100)
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 3: Price changes
|
||||
ax3 = axes[2]
|
||||
positive_changes = [max(0, change) for change in result['price_changes']]
|
||||
negative_changes = [abs(min(0, change)) for change in result['price_changes']]
|
||||
ax3.plot(dates, positive_changes, label='Positive Changes', alpha=0.7, color='green')
|
||||
ax3.plot(dates, negative_changes, label='Negative Changes', alpha=0.7, color='red')
|
||||
ax3.set_title('Price Changes (Gains and Losses)')
|
||||
ax3.legend()
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 4: Difference analysis
|
||||
ax4 = axes[3]
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1)
|
||||
ax4.set_title(f'{indicator_name} Difference (New - Original)')
|
||||
ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5)
|
||||
ax4.grid(True, alpha=0.3)
|
||||
|
||||
# Add statistics text
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
if len(valid_diff) > 0:
|
||||
stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n'
|
||||
stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n'
|
||||
stats_text += f'Std: {np.std(valid_diff):.2e}'
|
||||
ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
||||
|
||||
# Format x-axis
|
||||
for ax in axes:
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
||||
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
|
||||
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save plot
|
||||
plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png"
|
||||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Plot saved to {plot_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def plot_all_comparisons(self):
|
||||
"""Plot comparisons for all tested indicators."""
|
||||
print("\n=== Generating Detailed Comparison Plots ===")
|
||||
|
||||
for indicator_name in self.results.keys():
|
||||
print(f"Plotting {indicator_name}...")
|
||||
self.plot_comparison(indicator_name)
|
||||
plt.close('all')
|
||||
|
||||
def generate_report(self):
|
||||
"""Generate detailed report for RSI indicators."""
|
||||
print("\n=== Generating RSI Report ===")
|
||||
|
||||
report_lines = []
|
||||
report_lines.append("# RSI Indicators Comparison Report")
|
||||
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
report_lines.append(f"Data file: {self.data_file}")
|
||||
report_lines.append(f"Sample size: {len(self.data)} data points")
|
||||
report_lines.append("")
|
||||
|
||||
# Summary table
|
||||
report_lines.append("## Summary Table")
|
||||
report_lines.append("| Indicator | Period | Max Diff | Mean Diff | Status |")
|
||||
report_lines.append("|-----------|--------|----------|-----------|--------|")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
|
||||
if max_diff < 1e-10:
|
||||
status = "✅ PASSED"
|
||||
elif max_diff < 1e-6:
|
||||
status = "⚠️ WARNING"
|
||||
else:
|
||||
status = "❌ FAILED"
|
||||
|
||||
report_lines.append(f"| {indicator_name} | {result['period']} | {max_diff:.2e} | {mean_diff:.2e} | {status} |")
|
||||
else:
|
||||
report_lines.append(f"| {indicator_name} | {result['period']} | N/A | N/A | ❌ ERROR |")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Methodology explanation
|
||||
report_lines.append("## Methodology")
|
||||
report_lines.append("### RSI (Relative Strength Index)")
|
||||
report_lines.append("- Uses Wilder's smoothing for average gains and losses")
|
||||
report_lines.append("- Average Gain = (Previous Average Gain × (n-1) + Current Gain) / n")
|
||||
report_lines.append("- Average Loss = (Previous Average Loss × (n-1) + Current Loss) / n")
|
||||
report_lines.append("- RS = Average Gain / Average Loss")
|
||||
report_lines.append("- RSI = 100 - (100 / (1 + RS))")
|
||||
report_lines.append("")
|
||||
report_lines.append("### Simple RSI")
|
||||
report_lines.append("- Uses simple moving average for average gains and losses")
|
||||
report_lines.append("- More responsive to recent price changes than Wilder's method")
|
||||
report_lines.append("")
|
||||
|
||||
# Detailed analysis
|
||||
report_lines.append("## Detailed Analysis")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
report_lines.append(f"### {indicator_name}")
|
||||
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
report_lines.append(f"- **Period**: {result['period']}")
|
||||
report_lines.append(f"- **Valid data points**: {len(valid_diff)}")
|
||||
report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}")
|
||||
report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}")
|
||||
report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}")
|
||||
|
||||
# RSI-specific metrics
|
||||
valid_original = np.array(result['original'])[~np.isnan(result['original'])]
|
||||
if len(valid_original) > 0:
|
||||
mean_rsi = np.mean(valid_original)
|
||||
overbought_count = np.sum(valid_original > 70)
|
||||
oversold_count = np.sum(valid_original < 30)
|
||||
|
||||
report_lines.append(f"- **Mean RSI value**: {mean_rsi:.2f}")
|
||||
report_lines.append(f"- **Overbought periods (>70)**: {overbought_count} ({overbought_count/len(valid_original)*100:.1f}%)")
|
||||
report_lines.append(f"- **Oversold periods (<30)**: {oversold_count} ({oversold_count/len(valid_original)*100:.1f}%)")
|
||||
|
||||
# Price change analysis
|
||||
positive_changes = [max(0, change) for change in result['price_changes']]
|
||||
negative_changes = [abs(min(0, change)) for change in result['price_changes']]
|
||||
avg_gain = np.mean([change for change in positive_changes if change > 0]) if any(change > 0 for change in positive_changes) else 0
|
||||
avg_loss = np.mean([change for change in negative_changes if change > 0]) if any(change > 0 for change in negative_changes) else 0
|
||||
|
||||
report_lines.append(f"- **Average gain**: {avg_gain:.6f}")
|
||||
report_lines.append(f"- **Average loss**: {avg_loss:.6f}")
|
||||
if avg_loss > 0:
|
||||
report_lines.append(f"- **Gain/Loss ratio**: {avg_gain/avg_loss:.3f}")
|
||||
|
||||
# Percentile analysis
|
||||
percentiles = [1, 5, 25, 50, 75, 95, 99]
|
||||
perc_values = np.percentile(np.abs(valid_diff), percentiles)
|
||||
perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)])
|
||||
report_lines.append(f"- **Percentiles**: {perc_str}")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Save report
|
||||
report_path = self.results_dir / "rsi_indicators_report.md"
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(report_lines))
|
||||
|
||||
print(f"Report saved to {report_path}")
|
||||
|
||||
def run_tests(self):
|
||||
"""Run all RSI tests."""
|
||||
print("Starting RSI Comparison Tests...")
|
||||
|
||||
# Load data
|
||||
self.load_data()
|
||||
|
||||
# Run tests
|
||||
self.test_rsi()
|
||||
self.test_simple_rsi()
|
||||
|
||||
# Generate outputs
|
||||
self.plot_all_comparisons()
|
||||
self.generate_report()
|
||||
|
||||
print("\n✅ RSI tests completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tester = RSIComparisonTest(sample_size=3000)
|
||||
tester.run_tests()
|
||||
@@ -1,374 +0,0 @@
|
||||
"""
|
||||
Supertrend Indicators Comparison Test
|
||||
|
||||
Focused testing for Supertrend implementations.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Import original indicators
|
||||
from cycles.IncStrategies.indicators import (
|
||||
SupertrendState as OriginalSupertrend
|
||||
)
|
||||
|
||||
# Import new indicators
|
||||
from IncrementalTrader.strategies.indicators import (
|
||||
SupertrendState as NewSupertrend
|
||||
)
|
||||
|
||||
|
||||
class SupertrendComparisonTest:
|
||||
"""Test framework for comparing Supertrend implementations."""
|
||||
|
||||
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv", sample_size: int = 5000):
|
||||
self.data_file = data_file
|
||||
self.sample_size = sample_size
|
||||
self.data = None
|
||||
self.results = {}
|
||||
|
||||
# Create results directory
|
||||
self.results_dir = Path("test/results/supertrend_indicators")
|
||||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def load_data(self):
|
||||
"""Load and prepare the data for testing."""
|
||||
print(f"Loading data from {self.data_file}...")
|
||||
|
||||
df = pd.read_csv(self.data_file)
|
||||
df['datetime'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||||
|
||||
if self.sample_size and len(df) > self.sample_size:
|
||||
df = df.tail(self.sample_size).reset_index(drop=True)
|
||||
|
||||
self.data = df
|
||||
print(f"Loaded {len(df)} data points from {df['datetime'].iloc[0]} to {df['datetime'].iloc[-1]}")
|
||||
|
||||
def test_supertrend(self, periods=[7, 10, 14, 21], multipliers=[2.0, 3.0, 4.0]):
|
||||
"""Test Supertrend implementations."""
|
||||
print("\n=== Testing Supertrend ===")
|
||||
|
||||
for period in periods:
|
||||
for multiplier in multipliers:
|
||||
print(f"Testing Supertrend({period}, {multiplier})...")
|
||||
|
||||
# Initialize indicators
|
||||
original_st = OriginalSupertrend(period, multiplier)
|
||||
new_st = NewSupertrend(period, multiplier)
|
||||
|
||||
original_values = []
|
||||
new_values = []
|
||||
original_trends = []
|
||||
new_trends = []
|
||||
original_signals = []
|
||||
new_signals = []
|
||||
|
||||
# Process data
|
||||
for _, row in self.data.iterrows():
|
||||
high, low, close = row['High'], row['Low'], row['Close']
|
||||
|
||||
# Create OHLC dictionary for both indicators
|
||||
ohlc_data = {
|
||||
'open': row['Open'],
|
||||
'high': high,
|
||||
'low': low,
|
||||
'close': close
|
||||
}
|
||||
|
||||
original_st.update(ohlc_data)
|
||||
new_st.update(ohlc_data)
|
||||
|
||||
original_values.append(original_st.get_current_value()['supertrend'] if original_st.is_warmed_up() else np.nan)
|
||||
new_values.append(new_st.get_current_value()['supertrend'] if new_st.is_warmed_up() else np.nan)
|
||||
original_trends.append(original_st.get_current_value()['trend'] if original_st.is_warmed_up() else 0)
|
||||
new_trends.append(new_st.get_current_value()['trend'] if new_st.is_warmed_up() else 0)
|
||||
|
||||
# Check for trend changes (signals)
|
||||
if len(original_trends) > 1:
|
||||
original_signals.append(1 if original_trends[-1] != original_trends[-2] else 0)
|
||||
new_signals.append(1 if new_trends[-1] != new_trends[-2] else 0)
|
||||
else:
|
||||
original_signals.append(0)
|
||||
new_signals.append(0)
|
||||
|
||||
# Store results
|
||||
key = f'Supertrend_{period}_{multiplier}'
|
||||
self.results[key] = {
|
||||
'original': original_values,
|
||||
'new': new_values,
|
||||
'original_trend': original_trends,
|
||||
'new_trend': new_trends,
|
||||
'original_signals': original_signals,
|
||||
'new_signals': new_signals,
|
||||
'highs': self.data['High'].tolist(),
|
||||
'lows': self.data['Low'].tolist(),
|
||||
'closes': self.data['Close'].tolist(),
|
||||
'dates': self.data['datetime'].tolist(),
|
||||
'period': period,
|
||||
'multiplier': multiplier
|
||||
}
|
||||
|
||||
# Calculate differences
|
||||
diff = np.array(new_values) - np.array(original_values)
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
# Trend comparison
|
||||
trend_diff = np.array(new_trends) - np.array(original_trends)
|
||||
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
|
||||
|
||||
# Signal comparison
|
||||
signal_diff = np.array(new_signals) - np.array(original_signals)
|
||||
signal_matches = np.sum(signal_diff == 0) / len(signal_diff) * 100
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
std_diff = np.std(valid_diff)
|
||||
|
||||
print(f" Max difference: {max_diff:.12f}")
|
||||
print(f" Mean difference: {mean_diff:.12f}")
|
||||
print(f" Std difference: {std_diff:.12f}")
|
||||
print(f" Trend match: {trend_matches:.2f}%")
|
||||
print(f" Signal match: {signal_matches:.2f}%")
|
||||
|
||||
# Status check
|
||||
if max_diff < 1e-10 and trend_matches == 100:
|
||||
print(f" ✅ PASSED: Mathematically equivalent")
|
||||
elif max_diff < 1e-6 and trend_matches >= 99:
|
||||
print(f" ⚠️ WARNING: Small differences (floating point precision)")
|
||||
else:
|
||||
print(f" ❌ FAILED: Significant differences detected")
|
||||
else:
|
||||
print(f" ❌ ERROR: No valid data points")
|
||||
|
||||
def plot_comparison(self, indicator_name: str):
|
||||
"""Plot detailed comparison for a specific indicator."""
|
||||
if indicator_name not in self.results:
|
||||
print(f"No results found for {indicator_name}")
|
||||
return
|
||||
|
||||
result = self.results[indicator_name]
|
||||
dates = pd.to_datetime(result['dates'])
|
||||
|
||||
# Create figure with subplots
|
||||
fig, axes = plt.subplots(5, 1, figsize=(15, 20))
|
||||
fig.suptitle(f'{indicator_name} - Detailed Comparison Analysis', fontsize=16)
|
||||
|
||||
# Plot 1: Price and Supertrend
|
||||
ax1 = axes[0]
|
||||
ax1.plot(dates, result['closes'], label='Close Price', alpha=0.7, color='black', linewidth=1)
|
||||
ax1.plot(dates, result['original'], label='Original Supertrend', alpha=0.8, linewidth=2, color='blue')
|
||||
ax1.plot(dates, result['new'], label='New Supertrend', alpha=0.8, linewidth=2, linestyle='--', color='red')
|
||||
ax1.set_title(f'{indicator_name} vs Price')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 2: Trend comparison
|
||||
ax2 = axes[1]
|
||||
ax2.plot(dates, result['original_trend'], label='Original Trend', alpha=0.8, linewidth=2, color='blue')
|
||||
ax2.plot(dates, result['new_trend'], label='New Trend', alpha=0.8, linewidth=2, linestyle='--', color='red')
|
||||
ax2.set_title(f'{indicator_name} Trend Direction (1=Up, -1=Down)')
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
ax2.set_ylim(-1.5, 1.5)
|
||||
|
||||
# Plot 3: Supertrend values comparison
|
||||
ax3 = axes[2]
|
||||
ax3.plot(dates, result['original'], label='Original', alpha=0.8, linewidth=2)
|
||||
ax3.plot(dates, result['new'], label='New', alpha=0.8, linewidth=2, linestyle='--')
|
||||
ax3.set_title(f'{indicator_name} Values Comparison')
|
||||
ax3.legend()
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 4: Difference analysis
|
||||
ax4 = axes[3]
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
ax4.plot(dates, diff, color='red', alpha=0.7, linewidth=1)
|
||||
ax4.set_title(f'{indicator_name} Difference (New - Original)')
|
||||
ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5)
|
||||
ax4.grid(True, alpha=0.3)
|
||||
|
||||
# Plot 5: Signal comparison
|
||||
ax5 = axes[4]
|
||||
signal_dates = dates[1:] # Signals start from second data point
|
||||
ax5.scatter(signal_dates, np.array(result['original_signals'][1:]),
|
||||
label='Original Signals', alpha=0.7, color='blue', s=30)
|
||||
ax5.scatter(signal_dates, np.array(result['new_signals'][1:]) + 0.1,
|
||||
label='New Signals', alpha=0.7, color='red', s=30, marker='^')
|
||||
ax5.set_title(f'{indicator_name} Trend Change Signals')
|
||||
ax5.legend()
|
||||
ax5.grid(True, alpha=0.3)
|
||||
ax5.set_ylim(-0.2, 1.3)
|
||||
|
||||
# Add statistics text
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
if len(valid_diff) > 0:
|
||||
trend_diff = np.array(result['new_trend']) - np.array(result['original_trend'])
|
||||
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
|
||||
|
||||
stats_text = f'Max: {np.max(np.abs(valid_diff)):.2e}\n'
|
||||
stats_text += f'Mean: {np.mean(np.abs(valid_diff)):.2e}\n'
|
||||
stats_text += f'Trend Match: {trend_matches:.1f}%'
|
||||
ax4.text(0.02, 0.98, stats_text, transform=ax4.transAxes,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
||||
|
||||
# Format x-axis
|
||||
for ax in axes:
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
||||
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(dates)//10)))
|
||||
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save plot
|
||||
plot_path = self.results_dir / f"{indicator_name}_detailed_comparison.png"
|
||||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Plot saved to {plot_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def plot_all_comparisons(self):
|
||||
"""Plot comparisons for all tested indicators."""
|
||||
print("\n=== Generating Detailed Comparison Plots ===")
|
||||
|
||||
for indicator_name in self.results.keys():
|
||||
print(f"Plotting {indicator_name}...")
|
||||
self.plot_comparison(indicator_name)
|
||||
plt.close('all')
|
||||
|
||||
def generate_report(self):
|
||||
"""Generate detailed report for Supertrend indicators."""
|
||||
print("\n=== Generating Supertrend Report ===")
|
||||
|
||||
report_lines = []
|
||||
report_lines.append("# Supertrend Indicators Comparison Report")
|
||||
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
report_lines.append(f"Data file: {self.data_file}")
|
||||
report_lines.append(f"Sample size: {len(self.data)} data points")
|
||||
report_lines.append("")
|
||||
|
||||
# Summary table
|
||||
report_lines.append("## Summary Table")
|
||||
report_lines.append("| Indicator | Period | Multiplier | Max Diff | Mean Diff | Trend Match | Status |")
|
||||
report_lines.append("|-----------|--------|------------|----------|-----------|-------------|--------|")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
trend_diff = np.array(result['new_trend']) - np.array(result['original_trend'])
|
||||
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
max_diff = np.max(np.abs(valid_diff))
|
||||
mean_diff = np.mean(np.abs(valid_diff))
|
||||
|
||||
if max_diff < 1e-10 and trend_matches == 100:
|
||||
status = "✅ PASSED"
|
||||
elif max_diff < 1e-6 and trend_matches >= 99:
|
||||
status = "⚠️ WARNING"
|
||||
else:
|
||||
status = "❌ FAILED"
|
||||
|
||||
report_lines.append(f"| {indicator_name} | {result['period']} | {result['multiplier']} | "
|
||||
f"{max_diff:.2e} | {mean_diff:.2e} | {trend_matches:.1f}% | {status} |")
|
||||
else:
|
||||
report_lines.append(f"| {indicator_name} | {result['period']} | {result['multiplier']} | "
|
||||
f"N/A | N/A | N/A | ❌ ERROR |")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Methodology explanation
|
||||
report_lines.append("## Methodology")
|
||||
report_lines.append("### Supertrend Calculation")
|
||||
report_lines.append("1. **Basic Upper Band**: (High + Low) / 2 + (Multiplier × ATR)")
|
||||
report_lines.append("2. **Basic Lower Band**: (High + Low) / 2 - (Multiplier × ATR)")
|
||||
report_lines.append("3. **Final Upper Band**: min(Basic Upper Band, Previous Final Upper Band if Close[1] <= Previous Final Upper Band)")
|
||||
report_lines.append("4. **Final Lower Band**: max(Basic Lower Band, Previous Final Lower Band if Close[1] >= Previous Final Lower Band)")
|
||||
report_lines.append("5. **Supertrend**: Final Lower Band if trend is up, Final Upper Band if trend is down")
|
||||
report_lines.append("6. **Trend**: Up if Close > Previous Supertrend, Down if Close <= Previous Supertrend")
|
||||
report_lines.append("")
|
||||
|
||||
# Detailed analysis
|
||||
report_lines.append("## Detailed Analysis")
|
||||
|
||||
for indicator_name, result in self.results.items():
|
||||
report_lines.append(f"### {indicator_name}")
|
||||
|
||||
diff = np.array(result['new']) - np.array(result['original'])
|
||||
valid_diff = diff[~np.isnan(diff)]
|
||||
|
||||
trend_diff = np.array(result['new_trend']) - np.array(result['original_trend'])
|
||||
trend_matches = np.sum(trend_diff == 0) / len(trend_diff) * 100
|
||||
|
||||
signal_diff = np.array(result['new_signals']) - np.array(result['original_signals'])
|
||||
signal_matches = np.sum(signal_diff == 0) / len(signal_diff) * 100
|
||||
|
||||
if len(valid_diff) > 0:
|
||||
report_lines.append(f"- **Period**: {result['period']}")
|
||||
report_lines.append(f"- **Multiplier**: {result['multiplier']}")
|
||||
report_lines.append(f"- **Valid data points**: {len(valid_diff)}")
|
||||
report_lines.append(f"- **Max absolute difference**: {np.max(np.abs(valid_diff)):.12f}")
|
||||
report_lines.append(f"- **Mean absolute difference**: {np.mean(np.abs(valid_diff)):.12f}")
|
||||
report_lines.append(f"- **Standard deviation**: {np.std(valid_diff):.12f}")
|
||||
report_lines.append(f"- **Trend direction match**: {trend_matches:.2f}%")
|
||||
report_lines.append(f"- **Signal timing match**: {signal_matches:.2f}%")
|
||||
|
||||
# Supertrend-specific metrics
|
||||
valid_original = np.array(result['original'])[~np.isnan(result['original'])]
|
||||
if len(valid_original) > 0:
|
||||
mean_st = np.mean(valid_original)
|
||||
relative_error = np.mean(np.abs(valid_diff)) / mean_st * 100
|
||||
report_lines.append(f"- **Mean Supertrend value**: {mean_st:.6f}")
|
||||
report_lines.append(f"- **Relative error**: {relative_error:.2e}%")
|
||||
|
||||
# Count trend changes
|
||||
original_changes = np.sum(np.array(result['original_signals']))
|
||||
new_changes = np.sum(np.array(result['new_signals']))
|
||||
report_lines.append(f"- **Original trend changes**: {original_changes}")
|
||||
report_lines.append(f"- **New trend changes**: {new_changes}")
|
||||
|
||||
# Percentile analysis
|
||||
percentiles = [1, 5, 25, 50, 75, 95, 99]
|
||||
perc_values = np.percentile(np.abs(valid_diff), percentiles)
|
||||
perc_str = ", ".join([f"P{p}: {v:.2e}" for p, v in zip(percentiles, perc_values)])
|
||||
report_lines.append(f"- **Percentiles**: {perc_str}")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Save report
|
||||
report_path = self.results_dir / "supertrend_indicators_report.md"
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(report_lines))
|
||||
|
||||
print(f"Report saved to {report_path}")
|
||||
|
||||
def run_tests(self):
|
||||
"""Run all Supertrend tests."""
|
||||
print("Starting Supertrend Comparison Tests...")
|
||||
|
||||
# Load data
|
||||
self.load_data()
|
||||
|
||||
# Run tests
|
||||
self.test_supertrend()
|
||||
|
||||
# Generate outputs
|
||||
self.plot_all_comparisons()
|
||||
self.generate_report()
|
||||
|
||||
print("\n✅ Supertrend tests completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tester = SupertrendComparisonTest(sample_size=3000)
|
||||
tester.run_tests()
|
||||
@@ -1,531 +0,0 @@
|
||||
"""
|
||||
Strategy Comparison Test Framework
|
||||
|
||||
Comprehensive testing for comparing original incremental strategies from cycles/IncStrategies
|
||||
with new implementations in IncrementalTrader/strategies.
|
||||
|
||||
This test framework validates:
|
||||
1. MetaTrend Strategy: IncMetaTrendStrategy vs MetaTrendStrategy
|
||||
2. Random Strategy: IncRandomStrategy vs RandomStrategy
|
||||
3. BBRS Strategy: BBRSIncrementalState vs BBRSStrategy
|
||||
|
||||
Each test validates signal generation, mathematical equivalence, and behavioral consistency.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Any
|
||||
import os
|
||||
|
||||
# Add project paths
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
sys.path.append(str(Path(__file__).parent.parent / "cycles"))
|
||||
sys.path.append(str(Path(__file__).parent.parent / "IncrementalTrader"))
|
||||
|
||||
# Import original strategies
|
||||
from cycles.IncStrategies.metatrend_strategy import IncMetaTrendStrategy
|
||||
from cycles.IncStrategies.random_strategy import IncRandomStrategy
|
||||
from cycles.IncStrategies.bbrs_incremental import BBRSIncrementalState
|
||||
|
||||
# Import new strategies
|
||||
from IncrementalTrader.strategies.metatrend import MetaTrendStrategy
|
||||
from IncrementalTrader.strategies.random import RandomStrategy
|
||||
from IncrementalTrader.strategies.bbrs import BBRSStrategy
|
||||
|
||||
class StrategyComparisonTester:
|
||||
def __init__(self, data_file: str = "data/btcusd_1-min_data.csv"):
|
||||
"""Initialize the strategy comparison tester."""
|
||||
self.data_file = data_file
|
||||
self.data = None
|
||||
self.results_dir = Path("test/results/strategies")
|
||||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def load_data(self, limit: int = 1000) -> bool:
|
||||
"""Load and prepare test data."""
|
||||
try:
|
||||
print(f"Loading data from {self.data_file}...")
|
||||
self.data = pd.read_csv(self.data_file)
|
||||
|
||||
# Limit data for testing
|
||||
if limit:
|
||||
self.data = self.data.head(limit)
|
||||
|
||||
print(f"Loaded {len(self.data)} data points")
|
||||
print(f"Data columns: {list(self.data.columns)}")
|
||||
print(f"Data sample:\n{self.data.head()}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading data: {e}")
|
||||
return False
|
||||
|
||||
def compare_metatrend_strategies(self) -> Dict[str, Any]:
|
||||
"""Compare IncMetaTrendStrategy vs MetaTrendStrategy."""
|
||||
print("\n" + "="*80)
|
||||
print("COMPARING METATREND STRATEGIES")
|
||||
print("="*80)
|
||||
|
||||
try:
|
||||
# Initialize strategies with same parameters
|
||||
original_strategy = IncMetaTrendStrategy()
|
||||
new_strategy = MetaTrendStrategy()
|
||||
|
||||
# Track signals
|
||||
original_entry_signals = []
|
||||
new_entry_signals = []
|
||||
original_exit_signals = []
|
||||
new_exit_signals = []
|
||||
combined_original_signals = []
|
||||
combined_new_signals = []
|
||||
timestamps = []
|
||||
|
||||
# Process data
|
||||
for i, row in self.data.iterrows():
|
||||
timestamp = pd.Timestamp(row['Timestamp'], unit='s')
|
||||
ohlcv_data = {
|
||||
'open': row['Open'],
|
||||
'high': row['High'],
|
||||
'low': row['Low'],
|
||||
'close': row['Close'],
|
||||
'volume': row['Volume']
|
||||
}
|
||||
|
||||
# Update original strategy (uses update_minute_data)
|
||||
original_strategy.update_minute_data(timestamp, ohlcv_data)
|
||||
|
||||
# Update new strategy (uses process_data_point)
|
||||
new_strategy.process_data_point(timestamp, ohlcv_data)
|
||||
|
||||
# Get signals
|
||||
orig_entry = original_strategy.get_entry_signal()
|
||||
new_entry = new_strategy.get_entry_signal()
|
||||
orig_exit = original_strategy.get_exit_signal()
|
||||
new_exit = new_strategy.get_exit_signal()
|
||||
|
||||
# Store signals (both use signal_type)
|
||||
original_entry_signals.append(orig_entry.signal_type if orig_entry else "HOLD")
|
||||
new_entry_signals.append(new_entry.signal_type if new_entry else "HOLD")
|
||||
original_exit_signals.append(orig_exit.signal_type if orig_exit else "HOLD")
|
||||
new_exit_signals.append(new_exit.signal_type if new_exit else "HOLD")
|
||||
|
||||
# Combined signal logic (simplified)
|
||||
orig_signal = "BUY" if orig_entry and orig_entry.signal_type == "ENTRY" else ("SELL" if orig_exit and orig_exit.signal_type == "EXIT" else "HOLD")
|
||||
new_signal = "BUY" if new_entry and new_entry.signal_type == "ENTRY" else ("SELL" if new_exit and new_exit.signal_type == "EXIT" else "HOLD")
|
||||
|
||||
combined_original_signals.append(orig_signal)
|
||||
combined_new_signals.append(new_signal)
|
||||
timestamps.append(timestamp)
|
||||
|
||||
# Calculate consistency metrics
|
||||
entry_matches = sum(1 for o, n in zip(original_entry_signals, new_entry_signals) if o == n)
|
||||
exit_matches = sum(1 for o, n in zip(original_exit_signals, new_exit_signals) if o == n)
|
||||
combined_matches = sum(1 for o, n in zip(combined_original_signals, combined_new_signals) if o == n)
|
||||
|
||||
total_points = len(self.data)
|
||||
entry_consistency = (entry_matches / total_points) * 100
|
||||
exit_consistency = (exit_matches / total_points) * 100
|
||||
combined_consistency = (combined_matches / total_points) * 100
|
||||
|
||||
results = {
|
||||
'strategy_name': 'MetaTrend',
|
||||
'total_points': total_points,
|
||||
'entry_consistency': entry_consistency,
|
||||
'exit_consistency': exit_consistency,
|
||||
'combined_consistency': combined_consistency,
|
||||
'original_entry_signals': original_entry_signals,
|
||||
'new_entry_signals': new_entry_signals,
|
||||
'original_exit_signals': original_exit_signals,
|
||||
'new_exit_signals': new_exit_signals,
|
||||
'combined_original_signals': combined_original_signals,
|
||||
'combined_new_signals': combined_new_signals,
|
||||
'timestamps': timestamps
|
||||
}
|
||||
|
||||
print(f"Entry Signal Consistency: {entry_consistency:.2f}%")
|
||||
print(f"Exit Signal Consistency: {exit_consistency:.2f}%")
|
||||
print(f"Combined Signal Consistency: {combined_consistency:.2f}%")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error comparing MetaTrend strategies: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {}
|
||||
|
||||
def compare_random_strategies(self) -> Dict[str, Any]:
|
||||
"""Compare IncRandomStrategy vs RandomStrategy."""
|
||||
print("\n" + "="*80)
|
||||
print("COMPARING RANDOM STRATEGIES")
|
||||
print("="*80)
|
||||
|
||||
try:
|
||||
# Initialize strategies with same seed for reproducibility
|
||||
# Original: IncRandomStrategy(weight, params)
|
||||
# New: RandomStrategy(name, weight, params)
|
||||
original_strategy = IncRandomStrategy(weight=1.0, params={"random_seed": 42})
|
||||
new_strategy = RandomStrategy(name="random", weight=1.0, params={"random_seed": 42})
|
||||
|
||||
# Track signals
|
||||
original_signals = []
|
||||
new_signals = []
|
||||
timestamps = []
|
||||
|
||||
# Process data
|
||||
for i, row in self.data.iterrows():
|
||||
timestamp = pd.Timestamp(row['Timestamp'], unit='s')
|
||||
ohlcv_data = {
|
||||
'open': row['Open'],
|
||||
'high': row['High'],
|
||||
'low': row['Low'],
|
||||
'close': row['Close'],
|
||||
'volume': row['Volume']
|
||||
}
|
||||
|
||||
# Update strategies
|
||||
original_strategy.update_minute_data(timestamp, ohlcv_data)
|
||||
new_strategy.process_data_point(timestamp, ohlcv_data)
|
||||
|
||||
# Get signals
|
||||
orig_signal = original_strategy.get_entry_signal() # Random strategy uses get_entry_signal
|
||||
new_signal = new_strategy.get_entry_signal()
|
||||
|
||||
# Store signals
|
||||
original_signals.append(orig_signal.signal_type if orig_signal else "HOLD")
|
||||
new_signals.append(new_signal.signal_type if new_signal else "HOLD")
|
||||
timestamps.append(timestamp)
|
||||
|
||||
# Calculate consistency metrics
|
||||
matches = sum(1 for o, n in zip(original_signals, new_signals) if o == n)
|
||||
total_points = len(self.data)
|
||||
consistency = (matches / total_points) * 100
|
||||
|
||||
results = {
|
||||
'strategy_name': 'Random',
|
||||
'total_points': total_points,
|
||||
'consistency': consistency,
|
||||
'original_signals': original_signals,
|
||||
'new_signals': new_signals,
|
||||
'timestamps': timestamps
|
||||
}
|
||||
|
||||
print(f"Signal Consistency: {consistency:.2f}%")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error comparing Random strategies: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {}
|
||||
|
||||
def compare_bbrs_strategies(self) -> Dict[str, Any]:
|
||||
"""Compare BBRSIncrementalState vs BBRSStrategy."""
|
||||
print("\n" + "="*80)
|
||||
print("COMPARING BBRS STRATEGIES")
|
||||
print("="*80)
|
||||
|
||||
try:
|
||||
# Initialize strategies with same configuration
|
||||
# Original: BBRSIncrementalState(config)
|
||||
# New: BBRSStrategy(name, weight, params)
|
||||
original_config = {
|
||||
"timeframe_minutes": 60,
|
||||
"bb_period": 20,
|
||||
"rsi_period": 14,
|
||||
"bb_width": 0.05,
|
||||
"trending": {
|
||||
"bb_std_dev_multiplier": 2.5,
|
||||
"rsi_threshold": [30, 70]
|
||||
},
|
||||
"sideways": {
|
||||
"bb_std_dev_multiplier": 1.8,
|
||||
"rsi_threshold": [40, 60]
|
||||
},
|
||||
"SqueezeStrategy": True
|
||||
}
|
||||
|
||||
new_params = {
|
||||
"timeframe": "1h",
|
||||
"bb_period": 20,
|
||||
"rsi_period": 14,
|
||||
"bb_width_threshold": 0.05,
|
||||
"trending_bb_multiplier": 2.5,
|
||||
"sideways_bb_multiplier": 1.8,
|
||||
"trending_rsi_thresholds": [30, 70],
|
||||
"sideways_rsi_thresholds": [40, 60],
|
||||
"squeeze_strategy": True,
|
||||
"enable_logging": False
|
||||
}
|
||||
|
||||
original_strategy = BBRSIncrementalState(original_config)
|
||||
new_strategy = BBRSStrategy(name="bbrs", weight=1.0, params=new_params)
|
||||
|
||||
# Track signals
|
||||
original_signals = []
|
||||
new_signals = []
|
||||
timestamps = []
|
||||
|
||||
# Process data
|
||||
for i, row in self.data.iterrows():
|
||||
timestamp = pd.Timestamp(row['Timestamp'], unit='s')
|
||||
ohlcv_data = {
|
||||
'open': row['Open'],
|
||||
'high': row['High'],
|
||||
'low': row['Low'],
|
||||
'close': row['Close'],
|
||||
'volume': row['Volume']
|
||||
}
|
||||
|
||||
# Update strategies
|
||||
orig_result = original_strategy.update_minute_data(timestamp, ohlcv_data)
|
||||
new_strategy.process_data_point(timestamp, ohlcv_data)
|
||||
|
||||
# Get signals from original (returns dict with buy_signal/sell_signal)
|
||||
if orig_result and orig_result.get('buy_signal', False):
|
||||
orig_signal = "BUY"
|
||||
elif orig_result and orig_result.get('sell_signal', False):
|
||||
orig_signal = "SELL"
|
||||
else:
|
||||
orig_signal = "HOLD"
|
||||
|
||||
# Get signals from new strategy
|
||||
new_entry = new_strategy.get_entry_signal()
|
||||
new_exit = new_strategy.get_exit_signal()
|
||||
|
||||
if new_entry and new_entry.signal_type == "ENTRY":
|
||||
new_signal = "BUY"
|
||||
elif new_exit and new_exit.signal_type == "EXIT":
|
||||
new_signal = "SELL"
|
||||
else:
|
||||
new_signal = "HOLD"
|
||||
|
||||
# Store signals
|
||||
original_signals.append(orig_signal)
|
||||
new_signals.append(new_signal)
|
||||
timestamps.append(timestamp)
|
||||
|
||||
# Calculate consistency metrics
|
||||
matches = sum(1 for o, n in zip(original_signals, new_signals) if o == n)
|
||||
total_points = len(self.data)
|
||||
consistency = (matches / total_points) * 100
|
||||
|
||||
results = {
|
||||
'strategy_name': 'BBRS',
|
||||
'total_points': total_points,
|
||||
'consistency': consistency,
|
||||
'original_signals': original_signals,
|
||||
'new_signals': new_signals,
|
||||
'timestamps': timestamps
|
||||
}
|
||||
|
||||
print(f"Signal Consistency: {consistency:.2f}%")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error comparing BBRS strategies: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {}
|
||||
|
||||
def generate_report(self, results: List[Dict[str, Any]]) -> None:
|
||||
"""Generate comprehensive comparison report."""
|
||||
print("\n" + "="*80)
|
||||
print("GENERATING STRATEGY COMPARISON REPORT")
|
||||
print("="*80)
|
||||
|
||||
# Create summary report
|
||||
report_file = self.results_dir / "strategy_comparison_report.txt"
|
||||
|
||||
with open(report_file, 'w', encoding='utf-8') as f:
|
||||
f.write("Strategy Comparison Report\n")
|
||||
f.write("=" * 50 + "\n\n")
|
||||
f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"Data points tested: {results[0]['total_points'] if results else 'N/A'}\n\n")
|
||||
|
||||
for result in results:
|
||||
if not result:
|
||||
continue
|
||||
|
||||
f.write(f"Strategy: {result['strategy_name']}\n")
|
||||
f.write("-" * 30 + "\n")
|
||||
|
||||
if result['strategy_name'] == 'MetaTrend':
|
||||
f.write(f"Entry Signal Consistency: {result['entry_consistency']:.2f}%\n")
|
||||
f.write(f"Exit Signal Consistency: {result['exit_consistency']:.2f}%\n")
|
||||
f.write(f"Combined Signal Consistency: {result['combined_consistency']:.2f}%\n")
|
||||
|
||||
# Status determination
|
||||
if result['combined_consistency'] >= 95:
|
||||
status = "✅ EXCELLENT"
|
||||
elif result['combined_consistency'] >= 90:
|
||||
status = "✅ GOOD"
|
||||
elif result['combined_consistency'] >= 80:
|
||||
status = "⚠️ ACCEPTABLE"
|
||||
else:
|
||||
status = "❌ NEEDS REVIEW"
|
||||
|
||||
else:
|
||||
f.write(f"Signal Consistency: {result['consistency']:.2f}%\n")
|
||||
|
||||
# Status determination
|
||||
if result['consistency'] >= 95:
|
||||
status = "✅ EXCELLENT"
|
||||
elif result['consistency'] >= 90:
|
||||
status = "✅ GOOD"
|
||||
elif result['consistency'] >= 80:
|
||||
status = "⚠️ ACCEPTABLE"
|
||||
else:
|
||||
status = "❌ NEEDS REVIEW"
|
||||
|
||||
f.write(f"Status: {status}\n\n")
|
||||
|
||||
print(f"Report saved to: {report_file}")
|
||||
|
||||
# Generate plots for each strategy
|
||||
for result in results:
|
||||
if not result:
|
||||
continue
|
||||
self.plot_strategy_comparison(result)
|
||||
|
||||
def plot_strategy_comparison(self, result: Dict[str, Any]) -> None:
|
||||
"""Generate comparison plots for a strategy."""
|
||||
strategy_name = result['strategy_name']
|
||||
|
||||
fig, axes = plt.subplots(2, 1, figsize=(15, 10))
|
||||
fig.suptitle(f'{strategy_name} Strategy Comparison', fontsize=16, fontweight='bold')
|
||||
|
||||
timestamps = result['timestamps']
|
||||
|
||||
if strategy_name == 'MetaTrend':
|
||||
# Plot entry signals
|
||||
axes[0].plot(timestamps, [1 if s == "ENTRY" else 0 for s in result['original_entry_signals']],
|
||||
label='Original Entry', alpha=0.7, linewidth=2)
|
||||
axes[0].plot(timestamps, [1 if s == "ENTRY" else 0 for s in result['new_entry_signals']],
|
||||
label='New Entry', alpha=0.7, linewidth=2, linestyle='--')
|
||||
axes[0].set_title(f'Entry Signals - Consistency: {result["entry_consistency"]:.2f}%')
|
||||
axes[0].set_ylabel('Entry Signal')
|
||||
axes[0].legend()
|
||||
axes[0].grid(True, alpha=0.3)
|
||||
|
||||
# Plot combined signals
|
||||
signal_map = {"BUY": 1, "SELL": -1, "HOLD": 0}
|
||||
orig_combined = [signal_map[s] for s in result['combined_original_signals']]
|
||||
new_combined = [signal_map[s] for s in result['combined_new_signals']]
|
||||
|
||||
axes[1].plot(timestamps, orig_combined, label='Original Combined', alpha=0.7, linewidth=2)
|
||||
axes[1].plot(timestamps, new_combined, label='New Combined', alpha=0.7, linewidth=2, linestyle='--')
|
||||
axes[1].set_title(f'Combined Signals - Consistency: {result["combined_consistency"]:.2f}%')
|
||||
axes[1].set_ylabel('Signal (-1=SELL, 0=HOLD, 1=BUY)')
|
||||
|
||||
else:
|
||||
# For Random and BBRS strategies
|
||||
signal_map = {"BUY": 1, "SELL": -1, "HOLD": 0}
|
||||
orig_signals = [signal_map.get(s, 0) for s in result['original_signals']]
|
||||
new_signals = [signal_map.get(s, 0) for s in result['new_signals']]
|
||||
|
||||
axes[0].plot(timestamps, orig_signals, label='Original', alpha=0.7, linewidth=2)
|
||||
axes[0].plot(timestamps, new_signals, label='New', alpha=0.7, linewidth=2, linestyle='--')
|
||||
axes[0].set_title(f'Signals - Consistency: {result["consistency"]:.2f}%')
|
||||
axes[0].set_ylabel('Signal (-1=SELL, 0=HOLD, 1=BUY)')
|
||||
|
||||
# Plot difference
|
||||
diff = [o - n for o, n in zip(orig_signals, new_signals)]
|
||||
axes[1].plot(timestamps, diff, label='Difference (Original - New)', color='red', alpha=0.7)
|
||||
axes[1].set_title('Signal Differences')
|
||||
axes[1].set_ylabel('Difference')
|
||||
axes[1].axhline(y=0, color='black', linestyle='-', alpha=0.3)
|
||||
|
||||
# Format x-axis
|
||||
for ax in axes:
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
|
||||
ax.xaxis.set_major_locator(mdates.HourLocator(interval=2))
|
||||
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
||||
|
||||
plt.xlabel('Time')
|
||||
plt.tight_layout()
|
||||
|
||||
# Save plot
|
||||
plot_file = self.results_dir / f"{strategy_name.lower()}_strategy_comparison.png"
|
||||
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"Plot saved to: {plot_file}")
|
||||
|
||||
def main():
|
||||
"""Main test execution."""
|
||||
print("Strategy Comparison Test Framework")
|
||||
print("=" * 50)
|
||||
|
||||
# Initialize tester
|
||||
tester = StrategyComparisonTester()
|
||||
|
||||
# Load data
|
||||
if not tester.load_data(limit=1000): # Use 1000 points for testing
|
||||
print("Failed to load data. Exiting.")
|
||||
return
|
||||
|
||||
# Run comparisons
|
||||
results = []
|
||||
|
||||
# Compare MetaTrend strategies
|
||||
metatrend_result = tester.compare_metatrend_strategies()
|
||||
if metatrend_result:
|
||||
results.append(metatrend_result)
|
||||
|
||||
# Compare Random strategies
|
||||
random_result = tester.compare_random_strategies()
|
||||
if random_result:
|
||||
results.append(random_result)
|
||||
|
||||
# Compare BBRS strategies
|
||||
bbrs_result = tester.compare_bbrs_strategies()
|
||||
if bbrs_result:
|
||||
results.append(bbrs_result)
|
||||
|
||||
# Generate comprehensive report
|
||||
if results:
|
||||
tester.generate_report(results)
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("STRATEGY COMPARISON SUMMARY")
|
||||
print("="*80)
|
||||
|
||||
for result in results:
|
||||
if not result:
|
||||
continue
|
||||
|
||||
strategy_name = result['strategy_name']
|
||||
|
||||
if strategy_name == 'MetaTrend':
|
||||
consistency = result['combined_consistency']
|
||||
print(f"{strategy_name}: {consistency:.2f}% consistency")
|
||||
else:
|
||||
consistency = result['consistency']
|
||||
print(f"{strategy_name}: {consistency:.2f}% consistency")
|
||||
|
||||
if consistency >= 95:
|
||||
status = "✅ EXCELLENT"
|
||||
elif consistency >= 90:
|
||||
status = "✅ GOOD"
|
||||
elif consistency >= 80:
|
||||
status = "⚠️ ACCEPTABLE"
|
||||
else:
|
||||
status = "❌ NEEDS REVIEW"
|
||||
|
||||
print(f" Status: {status}")
|
||||
|
||||
print(f"\nDetailed results saved in: {tester.results_dir}")
|
||||
else:
|
||||
print("No successful comparisons completed.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,618 +0,0 @@
|
||||
"""
|
||||
Enhanced Strategy Comparison Test Framework for 2025 Data
|
||||
|
||||
Comprehensive testing for comparing original incremental strategies from cycles/IncStrategies
|
||||
with new implementations in IncrementalTrader/strategies using real 2025 data.
|
||||
|
||||
Features:
|
||||
- Interactive plots using Plotly
|
||||
- CSV export of all signals
|
||||
- Detailed signal analysis
|
||||
- Performance comparison
|
||||
- Real 2025 data (Jan-Apr)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
import plotly.subplots as sp
|
||||
from plotly.offline import plot
|
||||
from datetime import datetime
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Any
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# Add project paths
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
sys.path.insert(0, str(project_root / "cycles"))
|
||||
sys.path.insert(0, str(project_root / "IncrementalTrader"))
|
||||
|
||||
# Import original strategies
|
||||
from cycles.IncStrategies.metatrend_strategy import IncMetaTrendStrategy
|
||||
from cycles.IncStrategies.random_strategy import IncRandomStrategy
|
||||
from cycles.IncStrategies.bbrs_incremental import BBRSIncrementalState
|
||||
|
||||
# Import new strategies
|
||||
from IncrementalTrader.strategies.metatrend import MetaTrendStrategy
|
||||
from IncrementalTrader.strategies.random import RandomStrategy
|
||||
from IncrementalTrader.strategies.bbrs import BBRSStrategy
|
||||
|
||||
class Enhanced2025StrategyComparison:
|
||||
"""Enhanced strategy comparison framework with interactive plots and CSV export."""
|
||||
|
||||
def __init__(self, data_file: str = "data/temp_2025_data.csv"):
|
||||
"""Initialize the comparison framework."""
|
||||
self.data_file = data_file
|
||||
self.data = None
|
||||
self.results = {}
|
||||
|
||||
# Create results directory
|
||||
self.results_dir = Path("test/results/strategies_2025")
|
||||
self.results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("Enhanced 2025 Strategy Comparison Framework")
|
||||
print("=" * 60)
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""Load and prepare 2025 data."""
|
||||
print(f"Loading data from {self.data_file}...")
|
||||
|
||||
self.data = pd.read_csv(self.data_file)
|
||||
|
||||
# Convert timestamp to datetime
|
||||
self.data['DateTime'] = pd.to_datetime(self.data['Timestamp'], unit='s')
|
||||
|
||||
print(f"Data loaded: {len(self.data):,} rows")
|
||||
print(f"Date range: {self.data['DateTime'].iloc[0]} to {self.data['DateTime'].iloc[-1]}")
|
||||
print(f"Columns: {list(self.data.columns)}")
|
||||
|
||||
def compare_metatrend_strategies(self) -> Dict[str, Any]:
|
||||
"""Compare IncMetaTrendStrategy vs MetaTrendStrategy with detailed analysis."""
|
||||
print("\n" + "="*80)
|
||||
print("COMPARING METATREND STRATEGIES - 2025 DATA")
|
||||
print("="*80)
|
||||
|
||||
try:
|
||||
# Initialize strategies
|
||||
original_strategy = IncMetaTrendStrategy(weight=1.0, params={})
|
||||
new_strategy = MetaTrendStrategy(name="metatrend", weight=1.0, params={})
|
||||
|
||||
# Track all signals and data
|
||||
signals_data = []
|
||||
price_data = []
|
||||
|
||||
print("Processing data points...")
|
||||
|
||||
# Process data
|
||||
for i, row in self.data.iterrows():
|
||||
if i % 10000 == 0:
|
||||
print(f"Processed {i:,} / {len(self.data):,} data points...")
|
||||
|
||||
timestamp = row['DateTime']
|
||||
ohlcv_data = {
|
||||
'open': row['Open'],
|
||||
'high': row['High'],
|
||||
'low': row['Low'],
|
||||
'close': row['Close'],
|
||||
'volume': row['Volume']
|
||||
}
|
||||
|
||||
# Update strategies
|
||||
original_strategy.update_minute_data(timestamp, ohlcv_data)
|
||||
new_strategy.process_data_point(timestamp, ohlcv_data)
|
||||
|
||||
# Get signals
|
||||
orig_entry = original_strategy.get_entry_signal()
|
||||
new_entry = new_strategy.get_entry_signal()
|
||||
orig_exit = original_strategy.get_exit_signal()
|
||||
new_exit = new_strategy.get_exit_signal()
|
||||
|
||||
# Determine combined signals
|
||||
orig_signal = "BUY" if orig_entry and orig_entry.signal_type == "ENTRY" else (
|
||||
"SELL" if orig_exit and orig_exit.signal_type == "EXIT" else "HOLD")
|
||||
new_signal = "BUY" if new_entry and new_entry.signal_type == "ENTRY" else (
|
||||
"SELL" if new_exit and new_exit.signal_type == "EXIT" else "HOLD")
|
||||
|
||||
# Store data
|
||||
signals_data.append({
|
||||
'timestamp': timestamp,
|
||||
'price': row['Close'],
|
||||
'original_entry': orig_entry.signal_type if orig_entry else "HOLD",
|
||||
'new_entry': new_entry.signal_type if new_entry else "HOLD",
|
||||
'original_exit': orig_exit.signal_type if orig_exit else "HOLD",
|
||||
'new_exit': new_exit.signal_type if new_exit else "HOLD",
|
||||
'original_combined': orig_signal,
|
||||
'new_combined': new_signal,
|
||||
'signals_match': orig_signal == new_signal
|
||||
})
|
||||
|
||||
price_data.append({
|
||||
'timestamp': timestamp,
|
||||
'open': row['Open'],
|
||||
'high': row['High'],
|
||||
'low': row['Low'],
|
||||
'close': row['Close'],
|
||||
'volume': row['Volume']
|
||||
})
|
||||
|
||||
# Convert to DataFrame
|
||||
signals_df = pd.DataFrame(signals_data)
|
||||
price_df = pd.DataFrame(price_data)
|
||||
|
||||
# Calculate statistics
|
||||
total_signals = len(signals_df)
|
||||
matching_signals = signals_df['signals_match'].sum()
|
||||
consistency = (matching_signals / total_signals) * 100
|
||||
|
||||
# Signal distribution
|
||||
orig_signal_counts = signals_df['original_combined'].value_counts()
|
||||
new_signal_counts = signals_df['new_combined'].value_counts()
|
||||
|
||||
# Save signals to CSV
|
||||
csv_file = self.results_dir / "metatrend_signals_2025.csv"
|
||||
signals_df.to_csv(csv_file, index=False, encoding='utf-8')
|
||||
|
||||
# Create interactive plot
|
||||
self.create_interactive_plot(signals_df, price_df, "MetaTrend", "metatrend_2025")
|
||||
|
||||
results = {
|
||||
'strategy': 'MetaTrend',
|
||||
'total_signals': total_signals,
|
||||
'matching_signals': matching_signals,
|
||||
'consistency_percentage': consistency,
|
||||
'original_signal_distribution': orig_signal_counts.to_dict(),
|
||||
'new_signal_distribution': new_signal_counts.to_dict(),
|
||||
'signals_dataframe': signals_df,
|
||||
'csv_file': str(csv_file)
|
||||
}
|
||||
|
||||
print(f"✅ MetaTrend Strategy Comparison Complete")
|
||||
print(f" Signal Consistency: {consistency:.2f}%")
|
||||
print(f" Total Signals: {total_signals:,}")
|
||||
print(f" Matching Signals: {matching_signals:,}")
|
||||
print(f" CSV Saved: {csv_file}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error in MetaTrend comparison: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {'error': str(e)}
|
||||
|
||||
def compare_random_strategies(self) -> Dict[str, Any]:
|
||||
"""Compare IncRandomStrategy vs RandomStrategy with detailed analysis."""
|
||||
print("\n" + "="*80)
|
||||
print("COMPARING RANDOM STRATEGIES - 2025 DATA")
|
||||
print("="*80)
|
||||
|
||||
try:
|
||||
# Initialize strategies with same seed for reproducibility
|
||||
original_strategy = IncRandomStrategy(weight=1.0, params={"random_seed": 42})
|
||||
new_strategy = RandomStrategy(name="random", weight=1.0, params={"random_seed": 42})
|
||||
|
||||
# Track all signals and data
|
||||
signals_data = []
|
||||
|
||||
print("Processing data points...")
|
||||
|
||||
# Process data (use subset for Random strategy to speed up)
|
||||
subset_data = self.data.iloc[::10] # Every 10th point for Random strategy
|
||||
|
||||
for i, row in subset_data.iterrows():
|
||||
if i % 1000 == 0:
|
||||
print(f"Processed {i:,} data points...")
|
||||
|
||||
timestamp = row['DateTime']
|
||||
ohlcv_data = {
|
||||
'open': row['Open'],
|
||||
'high': row['High'],
|
||||
'low': row['Low'],
|
||||
'close': row['Close'],
|
||||
'volume': row['Volume']
|
||||
}
|
||||
|
||||
# Update strategies
|
||||
original_strategy.update_minute_data(timestamp, ohlcv_data)
|
||||
new_strategy.process_data_point(timestamp, ohlcv_data)
|
||||
|
||||
# Get signals
|
||||
orig_entry = original_strategy.get_entry_signal()
|
||||
new_entry = new_strategy.get_entry_signal()
|
||||
orig_exit = original_strategy.get_exit_signal()
|
||||
new_exit = new_strategy.get_exit_signal()
|
||||
|
||||
# Determine combined signals
|
||||
orig_signal = "BUY" if orig_entry and orig_entry.signal_type == "ENTRY" else (
|
||||
"SELL" if orig_exit and orig_exit.signal_type == "EXIT" else "HOLD")
|
||||
new_signal = "BUY" if new_entry and new_entry.signal_type == "ENTRY" else (
|
||||
"SELL" if new_exit and new_exit.signal_type == "EXIT" else "HOLD")
|
||||
|
||||
# Store data
|
||||
signals_data.append({
|
||||
'timestamp': timestamp,
|
||||
'price': row['Close'],
|
||||
'original_entry': orig_entry.signal_type if orig_entry else "HOLD",
|
||||
'new_entry': new_entry.signal_type if new_entry else "HOLD",
|
||||
'original_exit': orig_exit.signal_type if orig_exit else "HOLD",
|
||||
'new_exit': new_exit.signal_type if new_exit else "HOLD",
|
||||
'original_combined': orig_signal,
|
||||
'new_combined': new_signal,
|
||||
'signals_match': orig_signal == new_signal
|
||||
})
|
||||
|
||||
# Convert to DataFrame
|
||||
signals_df = pd.DataFrame(signals_data)
|
||||
|
||||
# Calculate statistics
|
||||
total_signals = len(signals_df)
|
||||
matching_signals = signals_df['signals_match'].sum()
|
||||
consistency = (matching_signals / total_signals) * 100
|
||||
|
||||
# Save signals to CSV
|
||||
csv_file = self.results_dir / "random_signals_2025.csv"
|
||||
signals_df.to_csv(csv_file, index=False, encoding='utf-8')
|
||||
|
||||
results = {
|
||||
'strategy': 'Random',
|
||||
'total_signals': total_signals,
|
||||
'matching_signals': matching_signals,
|
||||
'consistency_percentage': consistency,
|
||||
'signals_dataframe': signals_df,
|
||||
'csv_file': str(csv_file)
|
||||
}
|
||||
|
||||
print(f"✅ Random Strategy Comparison Complete")
|
||||
print(f" Signal Consistency: {consistency:.2f}%")
|
||||
print(f" Total Signals: {total_signals:,}")
|
||||
print(f" CSV Saved: {csv_file}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error in Random comparison: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {'error': str(e)}
|
||||
|
||||
def compare_bbrs_strategies(self) -> Dict[str, Any]:
|
||||
"""Compare BBRSIncrementalState vs BBRSStrategy with detailed analysis."""
|
||||
print("\n" + "="*80)
|
||||
print("COMPARING BBRS STRATEGIES - 2025 DATA")
|
||||
print("="*80)
|
||||
|
||||
try:
|
||||
# Initialize strategies
|
||||
bbrs_config = {
|
||||
"bb_period": 20,
|
||||
"bb_std": 2.0,
|
||||
"rsi_period": 14,
|
||||
"volume_ma_period": 20
|
||||
}
|
||||
|
||||
original_strategy = BBRSIncrementalState(config=bbrs_config)
|
||||
new_strategy = BBRSStrategy(name="bbrs", weight=1.0, params=bbrs_config)
|
||||
|
||||
# Track all signals and data
|
||||
signals_data = []
|
||||
|
||||
print("Processing data points...")
|
||||
|
||||
# Process data
|
||||
for i, row in self.data.iterrows():
|
||||
if i % 10000 == 0:
|
||||
print(f"Processed {i:,} / {len(self.data):,} data points...")
|
||||
|
||||
timestamp = row['DateTime']
|
||||
ohlcv_data = {
|
||||
'open': row['Open'],
|
||||
'high': row['High'],
|
||||
'low': row['Low'],
|
||||
'close': row['Close'],
|
||||
'volume': row['Volume']
|
||||
}
|
||||
|
||||
# Update strategies
|
||||
orig_result = original_strategy.update_minute_data(timestamp, ohlcv_data)
|
||||
new_strategy.process_data_point(timestamp, ohlcv_data)
|
||||
|
||||
# Get signals - original returns signals in result, new uses methods
|
||||
if orig_result is not None:
|
||||
orig_buy = orig_result.get('buy_signal', False)
|
||||
orig_sell = orig_result.get('sell_signal', False)
|
||||
else:
|
||||
orig_buy = False
|
||||
orig_sell = False
|
||||
|
||||
new_entry = new_strategy.get_entry_signal()
|
||||
new_exit = new_strategy.get_exit_signal()
|
||||
new_buy = new_entry and new_entry.signal_type == "ENTRY"
|
||||
new_sell = new_exit and new_exit.signal_type == "EXIT"
|
||||
|
||||
# Determine combined signals
|
||||
orig_signal = "BUY" if orig_buy else ("SELL" if orig_sell else "HOLD")
|
||||
new_signal = "BUY" if new_buy else ("SELL" if new_sell else "HOLD")
|
||||
|
||||
# Store data
|
||||
signals_data.append({
|
||||
'timestamp': timestamp,
|
||||
'price': row['Close'],
|
||||
'original_entry': "ENTRY" if orig_buy else "HOLD",
|
||||
'new_entry': new_entry.signal_type if new_entry else "HOLD",
|
||||
'original_exit': "EXIT" if orig_sell else "HOLD",
|
||||
'new_exit': new_exit.signal_type if new_exit else "HOLD",
|
||||
'original_combined': orig_signal,
|
||||
'new_combined': new_signal,
|
||||
'signals_match': orig_signal == new_signal
|
||||
})
|
||||
|
||||
# Convert to DataFrame
|
||||
signals_df = pd.DataFrame(signals_data)
|
||||
|
||||
# Calculate statistics
|
||||
total_signals = len(signals_df)
|
||||
matching_signals = signals_df['signals_match'].sum()
|
||||
consistency = (matching_signals / total_signals) * 100
|
||||
|
||||
# Save signals to CSV
|
||||
csv_file = self.results_dir / "bbrs_signals_2025.csv"
|
||||
signals_df.to_csv(csv_file, index=False, encoding='utf-8')
|
||||
|
||||
# Create interactive plot
|
||||
self.create_interactive_plot(signals_df, self.data, "BBRS", "bbrs_2025")
|
||||
|
||||
results = {
|
||||
'strategy': 'BBRS',
|
||||
'total_signals': total_signals,
|
||||
'matching_signals': matching_signals,
|
||||
'consistency_percentage': consistency,
|
||||
'signals_dataframe': signals_df,
|
||||
'csv_file': str(csv_file)
|
||||
}
|
||||
|
||||
print(f"✅ BBRS Strategy Comparison Complete")
|
||||
print(f" Signal Consistency: {consistency:.2f}%")
|
||||
print(f" Total Signals: {total_signals:,}")
|
||||
print(f" CSV Saved: {csv_file}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error in BBRS comparison: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {'error': str(e)}
|
||||
|
||||
def create_interactive_plot(self, signals_df: pd.DataFrame, price_df: pd.DataFrame,
|
||||
strategy_name: str, filename: str) -> None:
|
||||
"""Create interactive Plotly chart with signals and price data."""
|
||||
print(f"Creating interactive plot for {strategy_name}...")
|
||||
|
||||
# Create subplots
|
||||
fig = sp.make_subplots(
|
||||
rows=3, cols=1,
|
||||
shared_xaxes=True,
|
||||
vertical_spacing=0.05,
|
||||
subplot_titles=(
|
||||
f'{strategy_name} Strategy - Price & Signals',
|
||||
'Signal Comparison',
|
||||
'Signal Consistency'
|
||||
),
|
||||
row_heights=[0.6, 0.2, 0.2]
|
||||
)
|
||||
|
||||
# Price chart with signals
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=price_df['timestamp'],
|
||||
y=price_df['close'],
|
||||
mode='lines',
|
||||
name='Price',
|
||||
line=dict(color='blue', width=1)
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Add buy signals
|
||||
buy_signals_orig = signals_df[signals_df['original_combined'] == 'BUY']
|
||||
buy_signals_new = signals_df[signals_df['new_combined'] == 'BUY']
|
||||
|
||||
if len(buy_signals_orig) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=buy_signals_orig['timestamp'],
|
||||
y=buy_signals_orig['price'],
|
||||
mode='markers',
|
||||
name='Original BUY',
|
||||
marker=dict(color='green', size=8, symbol='triangle-up')
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
if len(buy_signals_new) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=buy_signals_new['timestamp'],
|
||||
y=buy_signals_new['price'],
|
||||
mode='markers',
|
||||
name='New BUY',
|
||||
marker=dict(color='lightgreen', size=6, symbol='triangle-up')
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Add sell signals
|
||||
sell_signals_orig = signals_df[signals_df['original_combined'] == 'SELL']
|
||||
sell_signals_new = signals_df[signals_df['new_combined'] == 'SELL']
|
||||
|
||||
if len(sell_signals_orig) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=sell_signals_orig['timestamp'],
|
||||
y=sell_signals_orig['price'],
|
||||
mode='markers',
|
||||
name='Original SELL',
|
||||
marker=dict(color='red', size=8, symbol='triangle-down')
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
if len(sell_signals_new) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=sell_signals_new['timestamp'],
|
||||
y=sell_signals_new['price'],
|
||||
mode='markers',
|
||||
name='New SELL',
|
||||
marker=dict(color='pink', size=6, symbol='triangle-down')
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Signal comparison chart
|
||||
signal_mapping = {'HOLD': 0, 'BUY': 1, 'SELL': -1}
|
||||
signals_df['original_numeric'] = signals_df['original_combined'].map(signal_mapping)
|
||||
signals_df['new_numeric'] = signals_df['new_combined'].map(signal_mapping)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=signals_df['timestamp'],
|
||||
y=signals_df['original_numeric'],
|
||||
mode='lines',
|
||||
name='Original Signals',
|
||||
line=dict(color='blue', width=2)
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=signals_df['timestamp'],
|
||||
y=signals_df['new_numeric'],
|
||||
mode='lines',
|
||||
name='New Signals',
|
||||
line=dict(color='red', width=1, dash='dash')
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
# Signal consistency chart
|
||||
signals_df['consistency_numeric'] = signals_df['signals_match'].astype(int)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=signals_df['timestamp'],
|
||||
y=signals_df['consistency_numeric'],
|
||||
mode='lines',
|
||||
name='Signal Match',
|
||||
line=dict(color='green', width=1),
|
||||
fill='tonexty'
|
||||
),
|
||||
row=3, col=1
|
||||
)
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
title=f'{strategy_name} Strategy Comparison - 2025 Data',
|
||||
height=800,
|
||||
showlegend=True,
|
||||
hovermode='x unified'
|
||||
)
|
||||
|
||||
# Update y-axes
|
||||
fig.update_yaxes(title_text="Price ($)", row=1, col=1)
|
||||
fig.update_yaxes(title_text="Signal", row=2, col=1, tickvals=[-1, 0, 1], ticktext=['SELL', 'HOLD', 'BUY'])
|
||||
fig.update_yaxes(title_text="Match", row=3, col=1, tickvals=[0, 1], ticktext=['No', 'Yes'])
|
||||
|
||||
# Save interactive plot
|
||||
html_file = self.results_dir / f"{filename}_interactive.html"
|
||||
plot(fig, filename=str(html_file), auto_open=False)
|
||||
|
||||
print(f" Interactive plot saved: {html_file}")
|
||||
|
||||
def generate_comprehensive_report(self) -> None:
|
||||
"""Generate comprehensive comparison report."""
|
||||
print("\n" + "="*80)
|
||||
print("GENERATING COMPREHENSIVE REPORT")
|
||||
print("="*80)
|
||||
|
||||
report_file = self.results_dir / "comprehensive_strategy_comparison_2025.md"
|
||||
|
||||
with open(report_file, 'w', encoding='utf-8') as f:
|
||||
f.write("# Comprehensive Strategy Comparison Report - 2025 Data\n\n")
|
||||
f.write(f"**Generated**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
f.write(f"**Data Period**: January 1, 2025 - April 30, 2025\n")
|
||||
f.write(f"**Total Data Points**: {len(self.data):,} minute-level OHLCV records\n\n")
|
||||
|
||||
f.write("## Executive Summary\n\n")
|
||||
f.write("This report compares the signal generation consistency between original incremental strategies ")
|
||||
f.write("from `cycles/IncStrategies` and new implementations in `IncrementalTrader/strategies` ")
|
||||
f.write("using real market data from 2025.\n\n")
|
||||
|
||||
f.write("## Strategy Comparison Results\n\n")
|
||||
|
||||
for strategy_name, results in self.results.items():
|
||||
if 'error' not in results:
|
||||
f.write(f"### {results['strategy']} Strategy\n\n")
|
||||
f.write(f"- **Signal Consistency**: {results['consistency_percentage']:.2f}%\n")
|
||||
f.write(f"- **Total Signals Compared**: {results['total_signals']:,}\n")
|
||||
f.write(f"- **Matching Signals**: {results['matching_signals']:,}\n")
|
||||
f.write(f"- **CSV Export**: `{results['csv_file']}`\n\n")
|
||||
|
||||
if 'original_signal_distribution' in results:
|
||||
f.write("**Original Strategy Signal Distribution:**\n")
|
||||
for signal, count in results['original_signal_distribution'].items():
|
||||
f.write(f"- {signal}: {count:,}\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write("**New Strategy Signal Distribution:**\n")
|
||||
for signal, count in results['new_signal_distribution'].items():
|
||||
f.write(f"- {signal}: {count:,}\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write("## Files Generated\n\n")
|
||||
f.write("### CSV Signal Exports\n")
|
||||
for csv_file in self.results_dir.glob("*_signals_2025.csv"):
|
||||
f.write(f"- `{csv_file.name}`: Complete signal history with timestamps\n")
|
||||
|
||||
f.write("\n### Interactive Plots\n")
|
||||
for html_file in self.results_dir.glob("*_interactive.html"):
|
||||
f.write(f"- `{html_file.name}`: Interactive Plotly visualization\n")
|
||||
|
||||
f.write("\n## Conclusion\n\n")
|
||||
f.write("The strategy comparison validates the migration accuracy by comparing signal generation ")
|
||||
f.write("between original and refactored implementations. High consistency percentages indicate ")
|
||||
f.write("successful preservation of strategy behavior during the refactoring process.\n")
|
||||
|
||||
print(f"✅ Comprehensive report saved: {report_file}")
|
||||
|
||||
def run_all_comparisons(self) -> None:
|
||||
"""Run all strategy comparisons."""
|
||||
print("Starting comprehensive strategy comparison with 2025 data...")
|
||||
|
||||
# Load data
|
||||
self.load_data()
|
||||
|
||||
# Run comparisons
|
||||
self.results['metatrend'] = self.compare_metatrend_strategies()
|
||||
self.results['random'] = self.compare_random_strategies()
|
||||
self.results['bbrs'] = self.compare_bbrs_strategies()
|
||||
|
||||
# Generate report
|
||||
self.generate_comprehensive_report()
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("ALL STRATEGY COMPARISONS COMPLETED")
|
||||
print("="*80)
|
||||
print(f"Results directory: {self.results_dir}")
|
||||
print("Files generated:")
|
||||
for file in sorted(self.results_dir.glob("*")):
|
||||
print(f" - {file.name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the enhanced comparison
|
||||
comparison = Enhanced2025StrategyComparison()
|
||||
comparison.run_all_comparisons()
|
||||
Reference in New Issue
Block a user