Cycles/test/plot_signal_comparison.py

534 lines
22 KiB
Python
Raw Normal View History

"""
Visual Signal Comparison Plot
This script creates comprehensive plots comparing:
1. Price data with signals overlaid
2. Meta-trend values over time
3. Individual Supertrend indicators
4. Signal timing comparison
Shows both original (buggy and fixed) and incremental strategies.
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.patches import Rectangle
import seaborn as sns
import logging
from typing import Dict, List, Tuple
import os
import sys
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cycles.strategies.default_strategy import DefaultStrategy
from cycles.IncStrategies.metatrend_strategy import IncMetaTrendStrategy
from cycles.IncStrategies.indicators.supertrend import SupertrendCollection
from cycles.utils.storage import Storage
from cycles.strategies.base import StrategySignal
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
class FixedDefaultStrategy(DefaultStrategy):
"""DefaultStrategy with the exit condition bug fixed."""
def get_exit_signal(self, backtester, df_index: int) -> StrategySignal:
"""Generate exit signal with CORRECTED logic."""
if not self.initialized:
return StrategySignal("HOLD", 0.0)
if df_index < 1:
return StrategySignal("HOLD", 0.0)
# Check bounds
if not hasattr(self, 'meta_trend') or df_index >= len(self.meta_trend):
return StrategySignal("HOLD", 0.0)
# Check for meta-trend exit signal (CORRECTED LOGIC)
prev_trend = self.meta_trend[df_index - 1]
curr_trend = self.meta_trend[df_index]
# FIXED: Check if prev_trend != -1 (not prev_trend != 1)
if prev_trend != -1 and curr_trend == -1:
return StrategySignal("EXIT", confidence=1.0,
metadata={"type": "META_TREND_EXIT_SIGNAL"})
return StrategySignal("HOLD", confidence=0.0)
class SignalPlotter:
"""Class to create comprehensive signal comparison plots."""
def __init__(self):
"""Initialize the plotter."""
self.storage = Storage(logging=logger)
self.test_data = None
self.original_signals = []
self.fixed_original_signals = []
self.incremental_signals = []
self.original_meta_trend = None
self.fixed_original_meta_trend = None
self.incremental_meta_trend = []
self.individual_trends = []
def load_and_prepare_data(self, limit: int = 1000) -> pd.DataFrame:
"""Load test data and prepare all strategy results."""
logger.info(f"Loading and preparing data (limit: {limit} points)")
try:
# Load recent data
filename = "btcusd_1-min_data.csv"
start_date = pd.to_datetime("2024-12-31")
end_date = pd.to_datetime("2025-01-01")
df = self.storage.load_data(filename, start_date, end_date)
if len(df) > limit:
df = df.tail(limit)
logger.info(f"Limited data to last {limit} points")
# Reset index to get timestamp as column
df_with_timestamp = df.reset_index()
self.test_data = df_with_timestamp
logger.info(f"Loaded {len(df_with_timestamp)} data points")
logger.info(f"Date range: {df_with_timestamp['timestamp'].min()} to {df_with_timestamp['timestamp'].max()}")
return df_with_timestamp
except Exception as e:
logger.error(f"Failed to load test data: {e}")
raise
def run_original_strategy(self, use_fixed: bool = False) -> Tuple[List[Dict], np.ndarray]:
"""Run original strategy and extract signals and meta-trend."""
strategy_name = "FIXED Original" if use_fixed else "Original (Buggy)"
logger.info(f"Running {strategy_name} DefaultStrategy...")
# Create indexed DataFrame for original strategy
indexed_data = self.test_data.set_index('timestamp')
# Limit to 200 points like original strategy does
if len(indexed_data) > 200:
original_data_used = indexed_data.tail(200)
data_start_index = len(self.test_data) - 200
else:
original_data_used = indexed_data
data_start_index = 0
# Create mock backtester
class MockBacktester:
def __init__(self, df):
self.original_df = df
self.min1_df = df
self.strategies = {}
backtester = MockBacktester(original_data_used)
# Initialize strategy (fixed or original)
if use_fixed:
strategy = FixedDefaultStrategy(weight=1.0, params={
"stop_loss_pct": 0.03,
"timeframe": "1min"
})
else:
strategy = DefaultStrategy(weight=1.0, params={
"stop_loss_pct": 0.03,
"timeframe": "1min"
})
strategy.initialize(backtester)
# Extract signals and meta-trend
signals = []
meta_trend = strategy.meta_trend
for i in range(len(original_data_used)):
# Get entry signal
entry_signal = strategy.get_entry_signal(backtester, i)
if entry_signal.signal_type == "ENTRY":
signals.append({
'index': i,
'global_index': data_start_index + i,
'timestamp': original_data_used.index[i],
'close': original_data_used.iloc[i]['close'],
'signal_type': 'ENTRY',
'confidence': entry_signal.confidence,
'source': 'fixed_original' if use_fixed else 'original'
})
# Get exit signal
exit_signal = strategy.get_exit_signal(backtester, i)
if exit_signal.signal_type == "EXIT":
signals.append({
'index': i,
'global_index': data_start_index + i,
'timestamp': original_data_used.index[i],
'close': original_data_used.iloc[i]['close'],
'signal_type': 'EXIT',
'confidence': exit_signal.confidence,
'source': 'fixed_original' if use_fixed else 'original'
})
logger.info(f"{strategy_name} generated {len(signals)} signals")
return signals, meta_trend, data_start_index
def run_incremental_strategy(self, data_start_index: int = 0) -> Tuple[List[Dict], List[int], List[List[int]]]:
"""Run incremental strategy and extract signals, meta-trend, and individual trends."""
logger.info("Running Incremental IncMetaTrendStrategy...")
# Create strategy instance
strategy = IncMetaTrendStrategy("metatrend", weight=1.0, params={
"timeframe": "1min",
"enable_logging": False
})
# Determine data range to match original strategy
if len(self.test_data) > 200:
test_data_subset = self.test_data.tail(200)
else:
test_data_subset = self.test_data
# Process data incrementally and collect signals
signals = []
meta_trends = []
individual_trends_list = []
for idx, (_, row) in enumerate(test_data_subset.iterrows()):
ohlc = {
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close']
}
# Update strategy with new data point
strategy.calculate_on_data(ohlc, row['timestamp'])
# Get current meta-trend and individual trends
current_meta_trend = strategy.get_current_meta_trend()
meta_trends.append(current_meta_trend)
# Get individual Supertrend states
individual_states = strategy.get_individual_supertrend_states()
if individual_states and len(individual_states) >= 3:
individual_trends = [state.get('current_trend', 0) for state in individual_states]
else:
individual_trends = [0, 0, 0] # Default if not available
individual_trends_list.append(individual_trends)
# Check for entry signal
entry_signal = strategy.get_entry_signal()
if entry_signal.signal_type == "ENTRY":
signals.append({
'index': idx,
'global_index': data_start_index + idx,
'timestamp': row['timestamp'],
'close': row['close'],
'signal_type': 'ENTRY',
'confidence': entry_signal.confidence,
'source': 'incremental'
})
# Check for exit signal
exit_signal = strategy.get_exit_signal()
if exit_signal.signal_type == "EXIT":
signals.append({
'index': idx,
'global_index': data_start_index + idx,
'timestamp': row['timestamp'],
'close': row['close'],
'signal_type': 'EXIT',
'confidence': exit_signal.confidence,
'source': 'incremental'
})
logger.info(f"Incremental strategy generated {len(signals)} signals")
return signals, meta_trends, individual_trends_list
def create_comprehensive_plot(self, save_path: str = "results/signal_comparison_plot.png"):
"""Create comprehensive comparison plot."""
logger.info("Creating comprehensive comparison plot...")
# Load and prepare data
self.load_and_prepare_data(limit=2000)
# Run all strategies
self.original_signals, self.original_meta_trend, data_start_index = self.run_original_strategy(use_fixed=False)
self.fixed_original_signals, self.fixed_original_meta_trend, _ = self.run_original_strategy(use_fixed=True)
self.incremental_signals, self.incremental_meta_trend, self.individual_trends = self.run_incremental_strategy(data_start_index)
# Prepare data for plotting
if len(self.test_data) > 200:
plot_data = self.test_data.tail(200).copy()
else:
plot_data = self.test_data.copy()
plot_data['timestamp'] = pd.to_datetime(plot_data['timestamp'])
# Create figure with subplots
fig, axes = plt.subplots(4, 1, figsize=(16, 20))
fig.suptitle('MetaTrend Strategy Signal Comparison', fontsize=16, fontweight='bold')
# Plot 1: Price with signals
self._plot_price_with_signals(axes[0], plot_data)
# Plot 2: Meta-trend comparison
self._plot_meta_trends(axes[1], plot_data)
# Plot 3: Individual Supertrend indicators
self._plot_individual_supertrends(axes[2], plot_data)
# Plot 4: Signal timing comparison
self._plot_signal_timing(axes[3], plot_data)
# Adjust layout and save
plt.tight_layout()
os.makedirs("results", exist_ok=True)
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"Plot saved to {save_path}")
plt.show()
def _plot_price_with_signals(self, ax, plot_data):
"""Plot price data with signals overlaid."""
ax.set_title('Price Chart with Trading Signals', fontsize=14, fontweight='bold')
# Plot price
ax.plot(plot_data['timestamp'], plot_data['close'],
color='black', linewidth=1, label='BTC Price', alpha=0.8)
# Plot signals
signal_colors = {
'original': {'ENTRY': 'red', 'EXIT': 'darkred'},
'fixed_original': {'ENTRY': 'blue', 'EXIT': 'darkblue'},
'incremental': {'ENTRY': 'green', 'EXIT': 'darkgreen'}
}
signal_markers = {'ENTRY': '^', 'EXIT': 'v'}
signal_sizes = {'ENTRY': 100, 'EXIT': 80}
# Plot original signals
for signal in self.original_signals:
if signal['index'] < len(plot_data):
timestamp = plot_data.iloc[signal['index']]['timestamp']
price = signal['close']
ax.scatter(timestamp, price,
c=signal_colors['original'][signal['signal_type']],
marker=signal_markers[signal['signal_type']],
s=signal_sizes[signal['signal_type']],
alpha=0.7,
label=f"Original {signal['signal_type']}" if signal == self.original_signals[0] else "")
# Plot fixed original signals
for signal in self.fixed_original_signals:
if signal['index'] < len(plot_data):
timestamp = plot_data.iloc[signal['index']]['timestamp']
price = signal['close']
ax.scatter(timestamp, price,
c=signal_colors['fixed_original'][signal['signal_type']],
marker=signal_markers[signal['signal_type']],
s=signal_sizes[signal['signal_type']],
alpha=0.7, edgecolors='white', linewidth=1,
label=f"Fixed {signal['signal_type']}" if signal == self.fixed_original_signals[0] else "")
# Plot incremental signals
for signal in self.incremental_signals:
if signal['index'] < len(plot_data):
timestamp = plot_data.iloc[signal['index']]['timestamp']
price = signal['close']
ax.scatter(timestamp, price,
c=signal_colors['incremental'][signal['signal_type']],
marker=signal_markers[signal['signal_type']],
s=signal_sizes[signal['signal_type']],
alpha=0.8, edgecolors='black', linewidth=0.5,
label=f"Incremental {signal['signal_type']}" if signal == self.incremental_signals[0] else "")
ax.set_ylabel('Price (USD)')
ax.legend(loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3)
# Format x-axis
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
ax.xaxis.set_major_locator(mdates.HourLocator(interval=2))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
def _plot_meta_trends(self, ax, plot_data):
"""Plot meta-trend comparison."""
ax.set_title('Meta-Trend Comparison', fontsize=14, fontweight='bold')
timestamps = plot_data['timestamp']
# Plot original meta-trend
if self.original_meta_trend is not None:
ax.plot(timestamps, self.original_meta_trend,
color='red', linewidth=2, alpha=0.7,
label='Original (Buggy)', marker='o', markersize=3)
# Plot fixed original meta-trend
if self.fixed_original_meta_trend is not None:
ax.plot(timestamps, self.fixed_original_meta_trend,
color='blue', linewidth=2, alpha=0.7,
label='Fixed Original', marker='s', markersize=3)
# Plot incremental meta-trend
if self.incremental_meta_trend:
ax.plot(timestamps, self.incremental_meta_trend,
color='green', linewidth=2, alpha=0.8,
label='Incremental', marker='D', markersize=3)
# Add horizontal lines for trend levels
ax.axhline(y=1, color='lightgreen', linestyle='--', alpha=0.5, label='Uptrend')
ax.axhline(y=0, color='gray', linestyle='-', alpha=0.5, label='Neutral')
ax.axhline(y=-1, color='lightcoral', linestyle='--', alpha=0.5, label='Downtrend')
ax.set_ylabel('Meta-Trend Value')
ax.set_ylim(-1.5, 1.5)
ax.legend(loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3)
# Format x-axis
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
ax.xaxis.set_major_locator(mdates.HourLocator(interval=2))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
def _plot_individual_supertrends(self, ax, plot_data):
"""Plot individual Supertrend indicators."""
ax.set_title('Individual Supertrend Indicators (Incremental)', fontsize=14, fontweight='bold')
if not self.individual_trends:
ax.text(0.5, 0.5, 'No individual trend data available',
transform=ax.transAxes, ha='center', va='center')
return
timestamps = plot_data['timestamp']
individual_trends_array = np.array(self.individual_trends)
# Plot each Supertrend
supertrend_configs = [(12, 3.0), (10, 1.0), (11, 2.0)]
colors = ['purple', 'orange', 'brown']
for i, (period, multiplier) in enumerate(supertrend_configs):
if i < individual_trends_array.shape[1]:
ax.plot(timestamps, individual_trends_array[:, i],
color=colors[i], linewidth=1.5, alpha=0.8,
label=f'ST{i+1} (P={period}, M={multiplier})',
marker='o', markersize=2)
# Add horizontal lines for trend levels
ax.axhline(y=1, color='lightgreen', linestyle='--', alpha=0.5)
ax.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
ax.axhline(y=-1, color='lightcoral', linestyle='--', alpha=0.5)
ax.set_ylabel('Supertrend Value')
ax.set_ylim(-1.5, 1.5)
ax.legend(loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3)
# Format x-axis
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
ax.xaxis.set_major_locator(mdates.HourLocator(interval=2))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
def _plot_signal_timing(self, ax, plot_data):
"""Plot signal timing comparison."""
ax.set_title('Signal Timing Comparison', fontsize=14, fontweight='bold')
timestamps = plot_data['timestamp']
# Create signal arrays
original_entry = np.zeros(len(timestamps))
original_exit = np.zeros(len(timestamps))
fixed_entry = np.zeros(len(timestamps))
fixed_exit = np.zeros(len(timestamps))
inc_entry = np.zeros(len(timestamps))
inc_exit = np.zeros(len(timestamps))
# Fill signal arrays
for signal in self.original_signals:
if signal['index'] < len(timestamps):
if signal['signal_type'] == 'ENTRY':
original_entry[signal['index']] = 1
else:
original_exit[signal['index']] = -1
for signal in self.fixed_original_signals:
if signal['index'] < len(timestamps):
if signal['signal_type'] == 'ENTRY':
fixed_entry[signal['index']] = 1
else:
fixed_exit[signal['index']] = -1
for signal in self.incremental_signals:
if signal['index'] < len(timestamps):
if signal['signal_type'] == 'ENTRY':
inc_entry[signal['index']] = 1
else:
inc_exit[signal['index']] = -1
# Plot signals as vertical lines
y_positions = [3, 2, 1]
labels = ['Original (Buggy)', 'Fixed Original', 'Incremental']
colors = ['red', 'blue', 'green']
for i, (entry_signals, exit_signals, label, color) in enumerate(zip(
[original_entry, fixed_entry, inc_entry],
[original_exit, fixed_exit, inc_exit],
labels, colors
)):
y_pos = y_positions[i]
# Plot entry signals
entry_indices = np.where(entry_signals == 1)[0]
for idx in entry_indices:
ax.axvline(x=timestamps.iloc[idx], ymin=(y_pos-0.4)/4, ymax=(y_pos+0.4)/4,
color=color, linewidth=3, alpha=0.8)
ax.scatter(timestamps.iloc[idx], y_pos, marker='^', s=50, color=color, alpha=0.8)
# Plot exit signals
exit_indices = np.where(exit_signals == -1)[0]
for idx in exit_indices:
ax.axvline(x=timestamps.iloc[idx], ymin=(y_pos-0.4)/4, ymax=(y_pos+0.4)/4,
color=color, linewidth=3, alpha=0.8)
ax.scatter(timestamps.iloc[idx], y_pos, marker='v', s=50, color=color, alpha=0.8)
ax.set_yticks(y_positions)
ax.set_yticklabels(labels)
ax.set_ylabel('Strategy')
ax.set_ylim(0.5, 3.5)
ax.grid(True, alpha=0.3)
# Format x-axis
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
ax.xaxis.set_major_locator(mdates.HourLocator(interval=2))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
# Add legend
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], marker='^', color='gray', linestyle='None', markersize=8, label='Entry Signal'),
Line2D([0], [0], marker='v', color='gray', linestyle='None', markersize=8, label='Exit Signal')
]
ax.legend(handles=legend_elements, loc='upper right', fontsize=10)
def main():
"""Create and display the comprehensive signal comparison plot."""
plotter = SignalPlotter()
plotter.create_comprehensive_plot()
if __name__ == "__main__":
main()