Cycles/test/visual_test_aggregation.py

455 lines
17 KiB
Python
Raw Normal View History

2025-05-28 18:26:51 +08:00
#!/usr/bin/env python3
"""
Visual test for timeframe aggregation utilities.
This script loads BTC minute data and aggregates it to different timeframes,
then plots candlestick charts to visually verify the aggregation correctness.
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.patches import Rectangle
import sys
import os
from datetime import datetime, timedelta
# Add the project root to Python path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from IncrementalTrader.utils import aggregate_minute_data_to_timeframe, parse_timeframe_to_minutes
def load_btc_data(file_path: str, date_filter: str = None, max_rows: int = None) -> pd.DataFrame:
"""
Load BTC minute data from CSV file.
Args:
file_path: Path to the CSV file
date_filter: Date to filter (e.g., "2024-01-01")
max_rows: Maximum number of rows to load
Returns:
DataFrame with OHLCV data
"""
print(f"📊 Loading BTC data from {file_path}")
try:
# Load the CSV file
df = pd.read_csv(file_path)
print(f" 📈 Loaded {len(df)} rows")
print(f" 📋 Columns: {list(df.columns)}")
# Check the first few rows to understand the format
print(f" 🔍 First few rows:")
print(df.head())
# Handle Unix timestamp format
if 'Timestamp' in df.columns:
print(f" 🕐 Converting Unix timestamps...")
df['timestamp'] = pd.to_datetime(df['Timestamp'], unit='s')
print(f" ✅ Converted timestamps from {df['timestamp'].min()} to {df['timestamp'].max()}")
else:
# Try to identify timestamp column
timestamp_cols = ['timestamp', 'time', 'datetime', 'date']
timestamp_col = None
for col in timestamp_cols:
if col in df.columns:
timestamp_col = col
break
if timestamp_col is None:
# Try to find a column that looks like a timestamp
for col in df.columns:
if 'time' in col.lower() or 'date' in col.lower():
timestamp_col = col
break
if timestamp_col is None:
print(" ❌ Could not find timestamp column")
return None
print(f" 🕐 Using timestamp column: {timestamp_col}")
df['timestamp'] = pd.to_datetime(df[timestamp_col])
# Standardize column names
column_mapping = {}
for col in df.columns:
col_lower = col.lower()
if 'open' in col_lower:
column_mapping[col] = 'open'
elif 'high' in col_lower:
column_mapping[col] = 'high'
elif 'low' in col_lower:
column_mapping[col] = 'low'
elif 'close' in col_lower:
column_mapping[col] = 'close'
elif 'volume' in col_lower:
column_mapping[col] = 'volume'
df = df.rename(columns=column_mapping)
# Ensure we have required columns
required_cols = ['open', 'high', 'low', 'close', 'volume']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
print(f" ❌ Missing required columns: {missing_cols}")
return None
# Remove rows with zero or invalid prices
initial_len = len(df)
df = df[(df['open'] > 0) & (df['high'] > 0) & (df['low'] > 0) & (df['close'] > 0)]
if len(df) < initial_len:
print(f" 🧹 Removed {initial_len - len(df)} rows with invalid prices")
# Filter by date if specified
if date_filter:
target_date = pd.to_datetime(date_filter).date()
df = df[df['timestamp'].dt.date == target_date]
print(f" 📅 Filtered to {date_filter}: {len(df)} rows")
if len(df) == 0:
print(f" ⚠️ No data found for {date_filter}")
# Find available dates
available_dates = df['timestamp'].dt.date.unique()
print(f" 📅 Available dates (sample): {sorted(available_dates)[:10]}")
return None
# If no date filter, let's find a good date with lots of data
if date_filter is None:
print(f" 📅 Finding a good date with active trading...")
# Group by date and count rows
date_counts = df.groupby(df['timestamp'].dt.date).size()
# Find dates with close to 1440 minutes (full day)
good_dates = date_counts[date_counts >= 1000].index
if len(good_dates) > 0:
# Pick a recent date with good data
selected_date = good_dates[-1] # Most recent good date
df = df[df['timestamp'].dt.date == selected_date]
print(f" ✅ Auto-selected date {selected_date} with {len(df)} data points")
else:
print(f" ⚠️ No dates with sufficient data found")
# Limit rows if specified
if max_rows and len(df) > max_rows:
df = df.head(max_rows)
print(f" ✂️ Limited to {max_rows} rows")
# Sort by timestamp
df = df.sort_values('timestamp')
print(f" ✅ Final dataset: {len(df)} rows from {df['timestamp'].min()} to {df['timestamp'].max()}")
return df
except Exception as e:
print(f" ❌ Error loading data: {e}")
import traceback
traceback.print_exc()
return None
def convert_df_to_minute_data(df: pd.DataFrame) -> list:
"""Convert DataFrame to list of dictionaries for aggregation."""
minute_data = []
for _, row in df.iterrows():
minute_data.append({
'timestamp': row['timestamp'],
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume'])
})
return minute_data
def plot_candlesticks(ax, data, timeframe, color='blue', alpha=0.7, width_factor=0.8):
"""
Plot candlestick chart on given axes.
Args:
ax: Matplotlib axes
data: List of OHLCV dictionaries
timeframe: Timeframe string for labeling
color: Color for the candlesticks
alpha: Transparency
width_factor: Width factor for candlesticks
"""
if not data:
return
# Calculate bar width based on timeframe
timeframe_minutes = parse_timeframe_to_minutes(timeframe)
bar_width = pd.Timedelta(minutes=timeframe_minutes * width_factor)
for bar in data:
timestamp = bar['timestamp']
open_price = bar['open']
high_price = bar['high']
low_price = bar['low']
close_price = bar['close']
# For "end" timestamp mode, the bar represents data from (timestamp - timeframe) to timestamp
bar_start = timestamp - pd.Timedelta(minutes=timeframe_minutes)
bar_end = timestamp
# Determine color based on open/close
if close_price >= open_price:
# Green/bullish candle
candle_color = 'green' if color == 'blue' else color
body_color = candle_color
else:
# Red/bearish candle
candle_color = 'red' if color == 'blue' else color
body_color = candle_color
# Draw the wick (high-low line) at the center of the time period
bar_center = bar_start + (bar_end - bar_start) / 2
ax.plot([bar_center, bar_center], [low_price, high_price],
color=candle_color, linewidth=1, alpha=alpha)
# Draw the body (open-close rectangle) spanning the time period
body_height = abs(close_price - open_price)
body_bottom = min(open_price, close_price)
if body_height > 0:
rect = Rectangle((bar_start, body_bottom),
bar_end - bar_start, body_height,
facecolor=body_color, edgecolor=candle_color,
alpha=alpha, linewidth=0.5)
ax.add_patch(rect)
else:
# Doji (open == close) - draw a horizontal line
ax.plot([bar_start, bar_end], [open_price, close_price],
color=candle_color, linewidth=2, alpha=alpha)
def create_comparison_plot(minute_data, timeframes, title="Timeframe Aggregation Comparison"):
"""
Create a comparison plot showing different timeframes.
Args:
minute_data: List of minute OHLCV data
timeframes: List of timeframes to compare
title: Plot title
"""
print(f"\n📊 Creating comparison plot for timeframes: {timeframes}")
# Aggregate data for each timeframe
aggregated_data = {}
for tf in timeframes:
print(f" 🔄 Aggregating to {tf}...")
aggregated_data[tf] = aggregate_minute_data_to_timeframe(minute_data, tf, "end")
print(f"{len(aggregated_data[tf])} bars")
# Create subplots
fig, axes = plt.subplots(len(timeframes), 1, figsize=(15, 4 * len(timeframes)))
if len(timeframes) == 1:
axes = [axes]
fig.suptitle(title, fontsize=16, fontweight='bold')
# Colors for different timeframes
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown']
for i, tf in enumerate(timeframes):
ax = axes[i]
data = aggregated_data[tf]
if data:
# Plot candlesticks
plot_candlesticks(ax, data, tf, color=colors[i % len(colors)])
# Set title and labels
ax.set_title(f"{tf} Timeframe ({len(data)} bars)", fontweight='bold')
ax.set_ylabel('Price (USD)')
# Format x-axis based on data range
if len(data) > 0:
time_range = data[-1]['timestamp'] - data[0]['timestamp']
if time_range.total_seconds() <= 24 * 3600: # Less than 24 hours
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
ax.xaxis.set_major_locator(mdates.HourLocator(interval=2))
else:
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
ax.xaxis.set_major_locator(mdates.DayLocator())
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
# Add grid
ax.grid(True, alpha=0.3)
# Add statistics
if data:
first_bar = data[0]
last_bar = data[-1]
price_change = last_bar['close'] - first_bar['open']
price_change_pct = (price_change / first_bar['open']) * 100
stats_text = f"Open: ${first_bar['open']:.2f} | Close: ${last_bar['close']:.2f} | Change: {price_change_pct:+.2f}%"
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
else:
ax.text(0.5, 0.5, f"No data for {tf}", transform=ax.transAxes,
ha='center', va='center', fontsize=14)
plt.tight_layout()
return fig
def create_overlay_plot(minute_data, timeframes, title="Timeframe Overlay Comparison"):
"""
Create an overlay plot showing multiple timeframes on the same chart.
Args:
minute_data: List of minute OHLCV data
timeframes: List of timeframes to overlay
title: Plot title
"""
print(f"\n📊 Creating overlay plot for timeframes: {timeframes}")
# Aggregate data for each timeframe
aggregated_data = {}
for tf in timeframes:
print(f" 🔄 Aggregating to {tf}...")
aggregated_data[tf] = aggregate_minute_data_to_timeframe(minute_data, tf, "end")
print(f"{len(aggregated_data[tf])} bars")
# Create single plot
fig, ax = plt.subplots(1, 1, figsize=(15, 8))
fig.suptitle(title, fontsize=16, fontweight='bold')
# Colors and alphas for different timeframes (lighter for larger timeframes)
colors = ['lightcoral', 'lightgreen', 'orange', 'lightblue'] # Reordered for better visibility
alphas = [0.9, 0.7, 0.5, 0.3] # Higher alpha for smaller timeframes
# Plot timeframes from largest to smallest (background to foreground)
sorted_timeframes = sorted(timeframes, key=parse_timeframe_to_minutes, reverse=True)
for i, tf in enumerate(sorted_timeframes):
data = aggregated_data[tf]
if data:
color_idx = timeframes.index(tf)
plot_candlesticks(ax, data, tf,
color=colors[color_idx % len(colors)],
alpha=alphas[color_idx % len(alphas)])
# Set labels and formatting
ax.set_ylabel('Price (USD)')
ax.set_xlabel('Time')
# Format x-axis based on data range
if minute_data:
time_range = minute_data[-1]['timestamp'] - minute_data[0]['timestamp']
if time_range.total_seconds() <= 24 * 3600: # Less than 24 hours
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
ax.xaxis.set_major_locator(mdates.HourLocator(interval=2))
else:
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
ax.xaxis.set_major_locator(mdates.DayLocator())
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
# Add grid
ax.grid(True, alpha=0.3)
# Add legend
legend_elements = []
for i, tf in enumerate(timeframes):
data = aggregated_data[tf]
if data:
legend_elements.append(plt.Rectangle((0,0),1,1,
facecolor=colors[i % len(colors)],
alpha=alphas[i % len(alphas)],
label=f"{tf} ({len(data)} bars)"))
ax.legend(handles=legend_elements, loc='upper left')
# Add explanation text
explanation = ("Smaller timeframes should be contained within larger timeframes.\n"
"Each bar spans its full time period (not just a point in time).")
ax.text(0.02, 0.02, explanation, transform=ax.transAxes,
verticalalignment='bottom', fontsize=10,
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
plt.tight_layout()
return fig
def main():
"""Main function to run the visual test."""
print("🚀 Visual Test for Timeframe Aggregation")
print("=" * 50)
# Configuration
data_file = "./data/btcusd_1-min_data.csv"
test_date = None # Let the script auto-select a good date
max_rows = 1440 # 24 hours of minute data
timeframes = ["5min", "15min", "30min", "1h"]
# Check if data file exists
if not os.path.exists(data_file):
print(f"❌ Data file not found: {data_file}")
print("Please ensure the BTC data file exists in the ./data/ directory")
return False
# Load data
df = load_btc_data(data_file, date_filter=test_date, max_rows=max_rows)
if df is None or len(df) == 0:
print("❌ Failed to load data or no data available")
return False
# Convert to minute data format
minute_data = convert_df_to_minute_data(df)
print(f"\n📈 Converted to {len(minute_data)} minute data points")
# Show data range
if minute_data:
start_time = minute_data[0]['timestamp']
end_time = minute_data[-1]['timestamp']
print(f"📅 Data range: {start_time} to {end_time}")
# Show sample data
print(f"📊 Sample data point:")
sample = minute_data[0]
print(f" Timestamp: {sample['timestamp']}")
print(f" OHLCV: O={sample['open']:.2f}, H={sample['high']:.2f}, L={sample['low']:.2f}, C={sample['close']:.2f}, V={sample['volume']:.0f}")
# Create comparison plots
try:
# Individual timeframe plots
fig1 = create_comparison_plot(minute_data, timeframes,
f"BTC Timeframe Comparison - {start_time.date()}")
# Overlay plot
fig2 = create_overlay_plot(minute_data, timeframes,
f"BTC Timeframe Overlay - {start_time.date()}")
# Show plots
plt.show()
print("\n✅ Visual test completed successfully!")
print("📊 Check the plots to verify:")
print(" 1. Higher timeframes contain lower timeframes")
print(" 2. OHLCV values are correctly aggregated")
print(" 3. Timestamps represent bar end times")
print(" 4. No future data leakage")
return True
except Exception as e:
print(f"❌ Error creating plots: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)