Cycles/test/test_signal_comparison_fixed.py
Vasily.onl ba78539cbb Add incremental MetaTrend strategy implementation
- Introduced `IncMetaTrendStrategy` for real-time processing of the MetaTrend trading strategy, utilizing three Supertrend indicators.
- Added comprehensive documentation in `METATREND_IMPLEMENTATION.md` detailing architecture, key components, and usage examples.
- Updated `__init__.py` to include the new strategy in the strategy registry.
- Created tests to compare the incremental strategy's signals against the original implementation, ensuring mathematical equivalence.
- Developed visual comparison scripts to analyze performance and signal accuracy between original and incremental strategies.
2025-05-26 16:09:32 +08:00

394 lines
16 KiB
Python

"""
Signal Comparison Test (Fixed Original Strategy)
This test compares signals between:
1. Original DefaultStrategy (with exit condition bug FIXED)
2. Incremental IncMetaTrendStrategy
The original strategy has a bug in get_exit_signal where it checks:
if prev_trend != 1 and curr_trend == -1:
But it should check:
if prev_trend != -1 and curr_trend == -1:
This test fixes that bug to see if the strategies match when both are correct.
"""
import pandas as pd
import numpy as np
import logging
from typing import Dict, List, Tuple
import os
import sys
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cycles.strategies.default_strategy import DefaultStrategy
from cycles.IncStrategies.metatrend_strategy import IncMetaTrendStrategy
from cycles.utils.storage import Storage
from cycles.strategies.base import StrategySignal
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class FixedDefaultStrategy(DefaultStrategy):
"""DefaultStrategy with the exit condition bug fixed."""
def get_exit_signal(self, backtester, df_index: int) -> StrategySignal:
"""
Generate exit signal with CORRECTED logic.
Exit occurs when meta-trend changes from != -1 to == -1 (FIXED)
"""
if not self.initialized:
return StrategySignal("HOLD", 0.0)
if df_index < 1:
return StrategySignal("HOLD", 0.0)
# Check bounds
if not hasattr(self, 'meta_trend') or df_index >= len(self.meta_trend):
return StrategySignal("HOLD", 0.0)
# Check for meta-trend exit signal (CORRECTED LOGIC)
prev_trend = self.meta_trend[df_index - 1]
curr_trend = self.meta_trend[df_index]
# FIXED: Check if prev_trend != -1 (not prev_trend != 1)
if prev_trend != -1 and curr_trend == -1:
return StrategySignal("EXIT", confidence=1.0,
metadata={"type": "META_TREND_EXIT_SIGNAL"})
return StrategySignal("HOLD", confidence=0.0)
class SignalComparisonTestFixed:
"""Test to compare signals between fixed original and incremental strategies."""
def __init__(self):
"""Initialize the signal comparison test."""
self.storage = Storage(logging=logger)
self.test_data = None
self.original_signals = []
self.incremental_signals = []
def load_test_data(self, limit: int = 500) -> pd.DataFrame:
"""Load a small dataset for signal testing."""
logger.info(f"Loading test data (limit: {limit} points)")
try:
# Load recent data
filename = "btcusd_1-min_data.csv"
start_date = pd.to_datetime("2022-12-31")
end_date = pd.to_datetime("2023-01-01")
df = self.storage.load_data(filename, start_date, end_date)
if len(df) > limit:
df = df.tail(limit)
logger.info(f"Limited data to last {limit} points")
# Reset index to get timestamp as column
df_with_timestamp = df.reset_index()
self.test_data = df_with_timestamp
logger.info(f"Loaded {len(df_with_timestamp)} data points")
logger.info(f"Date range: {df_with_timestamp['timestamp'].min()} to {df_with_timestamp['timestamp'].max()}")
return df_with_timestamp
except Exception as e:
logger.error(f"Failed to load test data: {e}")
raise
def test_fixed_original_strategy_signals(self) -> List[Dict]:
"""Test FIXED original DefaultStrategy and extract all signals."""
logger.info("Testing FIXED Original DefaultStrategy signals...")
# Create indexed DataFrame for original strategy
indexed_data = self.test_data.set_index('timestamp')
# Limit to 200 points like original strategy does
if len(indexed_data) > 200:
original_data_used = indexed_data.tail(200)
data_start_index = len(self.test_data) - 200
else:
original_data_used = indexed_data
data_start_index = 0
# Create mock backtester
class MockBacktester:
def __init__(self, df):
self.original_df = df
self.min1_df = df
self.strategies = {}
backtester = MockBacktester(original_data_used)
# Initialize FIXED original strategy
strategy = FixedDefaultStrategy(weight=1.0, params={
"stop_loss_pct": 0.03,
"timeframe": "1min"
})
strategy.initialize(backtester)
# Extract signals by simulating the strategy step by step
signals = []
for i in range(len(original_data_used)):
# Get entry signal
entry_signal = strategy.get_entry_signal(backtester, i)
if entry_signal.signal_type == "ENTRY":
signals.append({
'index': i,
'global_index': data_start_index + i,
'timestamp': original_data_used.index[i],
'close': original_data_used.iloc[i]['close'],
'signal_type': 'ENTRY',
'confidence': entry_signal.confidence,
'metadata': entry_signal.metadata,
'source': 'fixed_original'
})
# Get exit signal
exit_signal = strategy.get_exit_signal(backtester, i)
if exit_signal.signal_type == "EXIT":
signals.append({
'index': i,
'global_index': data_start_index + i,
'timestamp': original_data_used.index[i],
'close': original_data_used.iloc[i]['close'],
'signal_type': 'EXIT',
'confidence': exit_signal.confidence,
'metadata': exit_signal.metadata,
'source': 'fixed_original'
})
self.original_signals = signals
logger.info(f"Fixed original strategy generated {len(signals)} signals")
return signals
def test_incremental_strategy_signals(self) -> List[Dict]:
"""Test incremental IncMetaTrendStrategy and extract all signals."""
logger.info("Testing Incremental IncMetaTrendStrategy signals...")
# Create strategy instance
strategy = IncMetaTrendStrategy("metatrend", weight=1.0, params={
"timeframe": "1min",
"enable_logging": False
})
# Determine data range to match original strategy
if len(self.test_data) > 200:
test_data_subset = self.test_data.tail(200)
data_start_index = len(self.test_data) - 200
else:
test_data_subset = self.test_data
data_start_index = 0
# Process data incrementally and collect signals
signals = []
for idx, (_, row) in enumerate(test_data_subset.iterrows()):
ohlc = {
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close']
}
# Update strategy with new data point
strategy.calculate_on_data(ohlc, row['timestamp'])
# Check for entry signal
entry_signal = strategy.get_entry_signal()
if entry_signal.signal_type == "ENTRY":
signals.append({
'index': idx,
'global_index': data_start_index + idx,
'timestamp': row['timestamp'],
'close': row['close'],
'signal_type': 'ENTRY',
'confidence': entry_signal.confidence,
'metadata': entry_signal.metadata,
'source': 'incremental'
})
# Check for exit signal
exit_signal = strategy.get_exit_signal()
if exit_signal.signal_type == "EXIT":
signals.append({
'index': idx,
'global_index': data_start_index + idx,
'timestamp': row['timestamp'],
'close': row['close'],
'signal_type': 'EXIT',
'confidence': exit_signal.confidence,
'metadata': exit_signal.metadata,
'source': 'incremental'
})
self.incremental_signals = signals
logger.info(f"Incremental strategy generated {len(signals)} signals")
return signals
def compare_signals(self) -> Dict:
"""Compare signals between fixed original and incremental strategies."""
logger.info("Comparing signals between strategies...")
if not self.original_signals or not self.incremental_signals:
raise ValueError("Must run both signal tests before comparison")
# Separate by signal type
orig_entry = [s for s in self.original_signals if s['signal_type'] == 'ENTRY']
orig_exit = [s for s in self.original_signals if s['signal_type'] == 'EXIT']
inc_entry = [s for s in self.incremental_signals if s['signal_type'] == 'ENTRY']
inc_exit = [s for s in self.incremental_signals if s['signal_type'] == 'EXIT']
# Compare counts
comparison = {
'original_total': len(self.original_signals),
'incremental_total': len(self.incremental_signals),
'original_entry_count': len(orig_entry),
'original_exit_count': len(orig_exit),
'incremental_entry_count': len(inc_entry),
'incremental_exit_count': len(inc_exit),
'entry_count_match': len(orig_entry) == len(inc_entry),
'exit_count_match': len(orig_exit) == len(inc_exit),
'total_count_match': len(self.original_signals) == len(self.incremental_signals)
}
# Compare signal timing (by index)
orig_entry_indices = set(s['index'] for s in orig_entry)
orig_exit_indices = set(s['index'] for s in orig_exit)
inc_entry_indices = set(s['index'] for s in inc_entry)
inc_exit_indices = set(s['index'] for s in inc_exit)
comparison.update({
'entry_indices_match': orig_entry_indices == inc_entry_indices,
'exit_indices_match': orig_exit_indices == inc_exit_indices,
'entry_index_diff': orig_entry_indices.symmetric_difference(inc_entry_indices),
'exit_index_diff': orig_exit_indices.symmetric_difference(inc_exit_indices)
})
return comparison
def print_signal_details(self):
"""Print detailed signal information for analysis."""
print("\n" + "="*80)
print("DETAILED SIGNAL COMPARISON (FIXED ORIGINAL)")
print("="*80)
# Original signals
print(f"\n📊 FIXED ORIGINAL STRATEGY SIGNALS ({len(self.original_signals)} total)")
print("-" * 60)
for signal in self.original_signals:
print(f"Index {signal['index']:3d} | {signal['timestamp']} | "
f"{signal['signal_type']:5s} | Price: {signal['close']:8.2f} | "
f"Conf: {signal['confidence']:.2f}")
# Incremental signals
print(f"\n📊 INCREMENTAL STRATEGY SIGNALS ({len(self.incremental_signals)} total)")
print("-" * 60)
for signal in self.incremental_signals:
print(f"Index {signal['index']:3d} | {signal['timestamp']} | "
f"{signal['signal_type']:5s} | Price: {signal['close']:8.2f} | "
f"Conf: {signal['confidence']:.2f}")
# Side-by-side comparison
print(f"\n🔄 SIDE-BY-SIDE COMPARISON")
print("-" * 80)
print(f"{'Index':<6} {'Fixed Original':<20} {'Incremental':<20} {'Match':<8}")
print("-" * 80)
# Get all unique indices
all_indices = set()
for signal in self.original_signals + self.incremental_signals:
all_indices.add(signal['index'])
for idx in sorted(all_indices):
orig_signal = next((s for s in self.original_signals if s['index'] == idx), None)
inc_signal = next((s for s in self.incremental_signals if s['index'] == idx), None)
orig_str = f"{orig_signal['signal_type']}" if orig_signal else "---"
inc_str = f"{inc_signal['signal_type']}" if inc_signal else "---"
match_str = "" if orig_str == inc_str else ""
print(f"{idx:<6} {orig_str:<20} {inc_str:<20} {match_str:<8}")
def run_signal_test(self, limit: int = 500) -> bool:
"""Run the complete signal comparison test."""
logger.info("="*80)
logger.info("STARTING FIXED SIGNAL COMPARISON TEST")
logger.info("="*80)
try:
# Load test data
self.load_test_data(limit)
# Test both strategies
self.test_fixed_original_strategy_signals()
self.test_incremental_strategy_signals()
# Compare results
comparison = self.compare_signals()
# Print results
print("\n" + "="*80)
print("FIXED SIGNAL COMPARISON RESULTS")
print("="*80)
print(f"\n📊 SIGNAL COUNTS:")
print(f"Fixed Original Strategy: {comparison['original_entry_count']} entries, {comparison['original_exit_count']} exits")
print(f"Incremental Strategy: {comparison['incremental_entry_count']} entries, {comparison['incremental_exit_count']} exits")
print(f"\n✅ MATCHES:")
print(f"Entry count match: {'✅ YES' if comparison['entry_count_match'] else '❌ NO'}")
print(f"Exit count match: {'✅ YES' if comparison['exit_count_match'] else '❌ NO'}")
print(f"Entry timing match: {'✅ YES' if comparison['entry_indices_match'] else '❌ NO'}")
print(f"Exit timing match: {'✅ YES' if comparison['exit_indices_match'] else '❌ NO'}")
if comparison['entry_index_diff']:
print(f"\n❌ Entry signal differences at indices: {sorted(comparison['entry_index_diff'])}")
if comparison['exit_index_diff']:
print(f"❌ Exit signal differences at indices: {sorted(comparison['exit_index_diff'])}")
# Print detailed signals
self.print_signal_details()
# Overall result
overall_match = (comparison['entry_count_match'] and
comparison['exit_count_match'] and
comparison['entry_indices_match'] and
comparison['exit_indices_match'])
print(f"\n🏆 OVERALL RESULT: {'✅ SIGNALS MATCH PERFECTLY' if overall_match else '❌ SIGNALS DIFFER'}")
return overall_match
except Exception as e:
logger.error(f"Signal test failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Run the fixed signal comparison test."""
test = SignalComparisonTestFixed()
# Run test with 500 data points
success = test.run_signal_test(limit=500)
return success
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)