455 lines
17 KiB
Python
455 lines
17 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Visual test for timeframe aggregation utilities.
|
||
|
|
|
||
|
|
This script loads BTC minute data and aggregates it to different timeframes,
|
||
|
|
then plots candlestick charts to visually verify the aggregation correctness.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pandas as pd
|
||
|
|
import numpy as np
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
import matplotlib.dates as mdates
|
||
|
|
from matplotlib.patches import Rectangle
|
||
|
|
import sys
|
||
|
|
import os
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
|
||
|
|
# Add the project root to Python path
|
||
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
|
|
|
||
|
|
from IncrementalTrader.utils import aggregate_minute_data_to_timeframe, parse_timeframe_to_minutes
|
||
|
|
|
||
|
|
|
||
|
|
def load_btc_data(file_path: str, date_filter: str = None, max_rows: int = None) -> pd.DataFrame:
|
||
|
|
"""
|
||
|
|
Load BTC minute data from CSV file.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
file_path: Path to the CSV file
|
||
|
|
date_filter: Date to filter (e.g., "2024-01-01")
|
||
|
|
max_rows: Maximum number of rows to load
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
DataFrame with OHLCV data
|
||
|
|
"""
|
||
|
|
print(f"📊 Loading BTC data from {file_path}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Load the CSV file
|
||
|
|
df = pd.read_csv(file_path)
|
||
|
|
print(f" 📈 Loaded {len(df)} rows")
|
||
|
|
print(f" 📋 Columns: {list(df.columns)}")
|
||
|
|
|
||
|
|
# Check the first few rows to understand the format
|
||
|
|
print(f" 🔍 First few rows:")
|
||
|
|
print(df.head())
|
||
|
|
|
||
|
|
# Handle Unix timestamp format
|
||
|
|
if 'Timestamp' in df.columns:
|
||
|
|
print(f" 🕐 Converting Unix timestamps...")
|
||
|
|
df['timestamp'] = pd.to_datetime(df['Timestamp'], unit='s')
|
||
|
|
print(f" ✅ Converted timestamps from {df['timestamp'].min()} to {df['timestamp'].max()}")
|
||
|
|
else:
|
||
|
|
# Try to identify timestamp column
|
||
|
|
timestamp_cols = ['timestamp', 'time', 'datetime', 'date']
|
||
|
|
timestamp_col = None
|
||
|
|
|
||
|
|
for col in timestamp_cols:
|
||
|
|
if col in df.columns:
|
||
|
|
timestamp_col = col
|
||
|
|
break
|
||
|
|
|
||
|
|
if timestamp_col is None:
|
||
|
|
# Try to find a column that looks like a timestamp
|
||
|
|
for col in df.columns:
|
||
|
|
if 'time' in col.lower() or 'date' in col.lower():
|
||
|
|
timestamp_col = col
|
||
|
|
break
|
||
|
|
|
||
|
|
if timestamp_col is None:
|
||
|
|
print(" ❌ Could not find timestamp column")
|
||
|
|
return None
|
||
|
|
|
||
|
|
print(f" 🕐 Using timestamp column: {timestamp_col}")
|
||
|
|
df['timestamp'] = pd.to_datetime(df[timestamp_col])
|
||
|
|
|
||
|
|
# Standardize column names
|
||
|
|
column_mapping = {}
|
||
|
|
for col in df.columns:
|
||
|
|
col_lower = col.lower()
|
||
|
|
if 'open' in col_lower:
|
||
|
|
column_mapping[col] = 'open'
|
||
|
|
elif 'high' in col_lower:
|
||
|
|
column_mapping[col] = 'high'
|
||
|
|
elif 'low' in col_lower:
|
||
|
|
column_mapping[col] = 'low'
|
||
|
|
elif 'close' in col_lower:
|
||
|
|
column_mapping[col] = 'close'
|
||
|
|
elif 'volume' in col_lower:
|
||
|
|
column_mapping[col] = 'volume'
|
||
|
|
|
||
|
|
df = df.rename(columns=column_mapping)
|
||
|
|
|
||
|
|
# Ensure we have required columns
|
||
|
|
required_cols = ['open', 'high', 'low', 'close', 'volume']
|
||
|
|
missing_cols = [col for col in required_cols if col not in df.columns]
|
||
|
|
|
||
|
|
if missing_cols:
|
||
|
|
print(f" ❌ Missing required columns: {missing_cols}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Remove rows with zero or invalid prices
|
||
|
|
initial_len = len(df)
|
||
|
|
df = df[(df['open'] > 0) & (df['high'] > 0) & (df['low'] > 0) & (df['close'] > 0)]
|
||
|
|
if len(df) < initial_len:
|
||
|
|
print(f" 🧹 Removed {initial_len - len(df)} rows with invalid prices")
|
||
|
|
|
||
|
|
# Filter by date if specified
|
||
|
|
if date_filter:
|
||
|
|
target_date = pd.to_datetime(date_filter).date()
|
||
|
|
df = df[df['timestamp'].dt.date == target_date]
|
||
|
|
print(f" 📅 Filtered to {date_filter}: {len(df)} rows")
|
||
|
|
|
||
|
|
if len(df) == 0:
|
||
|
|
print(f" ⚠️ No data found for {date_filter}")
|
||
|
|
# Find available dates
|
||
|
|
available_dates = df['timestamp'].dt.date.unique()
|
||
|
|
print(f" 📅 Available dates (sample): {sorted(available_dates)[:10]}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
# If no date filter, let's find a good date with lots of data
|
||
|
|
if date_filter is None:
|
||
|
|
print(f" 📅 Finding a good date with active trading...")
|
||
|
|
# Group by date and count rows
|
||
|
|
date_counts = df.groupby(df['timestamp'].dt.date).size()
|
||
|
|
# Find dates with close to 1440 minutes (full day)
|
||
|
|
good_dates = date_counts[date_counts >= 1000].index
|
||
|
|
if len(good_dates) > 0:
|
||
|
|
# Pick a recent date with good data
|
||
|
|
selected_date = good_dates[-1] # Most recent good date
|
||
|
|
df = df[df['timestamp'].dt.date == selected_date]
|
||
|
|
print(f" ✅ Auto-selected date {selected_date} with {len(df)} data points")
|
||
|
|
else:
|
||
|
|
print(f" ⚠️ No dates with sufficient data found")
|
||
|
|
|
||
|
|
# Limit rows if specified
|
||
|
|
if max_rows and len(df) > max_rows:
|
||
|
|
df = df.head(max_rows)
|
||
|
|
print(f" ✂️ Limited to {max_rows} rows")
|
||
|
|
|
||
|
|
# Sort by timestamp
|
||
|
|
df = df.sort_values('timestamp')
|
||
|
|
|
||
|
|
print(f" ✅ Final dataset: {len(df)} rows from {df['timestamp'].min()} to {df['timestamp'].max()}")
|
||
|
|
|
||
|
|
return df
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f" ❌ Error loading data: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def convert_df_to_minute_data(df: pd.DataFrame) -> list:
|
||
|
|
"""Convert DataFrame to list of dictionaries for aggregation."""
|
||
|
|
minute_data = []
|
||
|
|
|
||
|
|
for _, row in df.iterrows():
|
||
|
|
minute_data.append({
|
||
|
|
'timestamp': row['timestamp'],
|
||
|
|
'open': float(row['open']),
|
||
|
|
'high': float(row['high']),
|
||
|
|
'low': float(row['low']),
|
||
|
|
'close': float(row['close']),
|
||
|
|
'volume': float(row['volume'])
|
||
|
|
})
|
||
|
|
|
||
|
|
return minute_data
|
||
|
|
|
||
|
|
|
||
|
|
def plot_candlesticks(ax, data, timeframe, color='blue', alpha=0.7, width_factor=0.8):
|
||
|
|
"""
|
||
|
|
Plot candlestick chart on given axes.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
ax: Matplotlib axes
|
||
|
|
data: List of OHLCV dictionaries
|
||
|
|
timeframe: Timeframe string for labeling
|
||
|
|
color: Color for the candlesticks
|
||
|
|
alpha: Transparency
|
||
|
|
width_factor: Width factor for candlesticks
|
||
|
|
"""
|
||
|
|
if not data:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Calculate bar width based on timeframe
|
||
|
|
timeframe_minutes = parse_timeframe_to_minutes(timeframe)
|
||
|
|
bar_width = pd.Timedelta(minutes=timeframe_minutes * width_factor)
|
||
|
|
|
||
|
|
for bar in data:
|
||
|
|
timestamp = bar['timestamp']
|
||
|
|
open_price = bar['open']
|
||
|
|
high_price = bar['high']
|
||
|
|
low_price = bar['low']
|
||
|
|
close_price = bar['close']
|
||
|
|
|
||
|
|
# For "end" timestamp mode, the bar represents data from (timestamp - timeframe) to timestamp
|
||
|
|
bar_start = timestamp - pd.Timedelta(minutes=timeframe_minutes)
|
||
|
|
bar_end = timestamp
|
||
|
|
|
||
|
|
# Determine color based on open/close
|
||
|
|
if close_price >= open_price:
|
||
|
|
# Green/bullish candle
|
||
|
|
candle_color = 'green' if color == 'blue' else color
|
||
|
|
body_color = candle_color
|
||
|
|
else:
|
||
|
|
# Red/bearish candle
|
||
|
|
candle_color = 'red' if color == 'blue' else color
|
||
|
|
body_color = candle_color
|
||
|
|
|
||
|
|
# Draw the wick (high-low line) at the center of the time period
|
||
|
|
bar_center = bar_start + (bar_end - bar_start) / 2
|
||
|
|
ax.plot([bar_center, bar_center], [low_price, high_price],
|
||
|
|
color=candle_color, linewidth=1, alpha=alpha)
|
||
|
|
|
||
|
|
# Draw the body (open-close rectangle) spanning the time period
|
||
|
|
body_height = abs(close_price - open_price)
|
||
|
|
body_bottom = min(open_price, close_price)
|
||
|
|
|
||
|
|
if body_height > 0:
|
||
|
|
rect = Rectangle((bar_start, body_bottom),
|
||
|
|
bar_end - bar_start, body_height,
|
||
|
|
facecolor=body_color, edgecolor=candle_color,
|
||
|
|
alpha=alpha, linewidth=0.5)
|
||
|
|
ax.add_patch(rect)
|
||
|
|
else:
|
||
|
|
# Doji (open == close) - draw a horizontal line
|
||
|
|
ax.plot([bar_start, bar_end], [open_price, close_price],
|
||
|
|
color=candle_color, linewidth=2, alpha=alpha)
|
||
|
|
|
||
|
|
|
||
|
|
def create_comparison_plot(minute_data, timeframes, title="Timeframe Aggregation Comparison"):
|
||
|
|
"""
|
||
|
|
Create a comparison plot showing different timeframes.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
minute_data: List of minute OHLCV data
|
||
|
|
timeframes: List of timeframes to compare
|
||
|
|
title: Plot title
|
||
|
|
"""
|
||
|
|
print(f"\n📊 Creating comparison plot for timeframes: {timeframes}")
|
||
|
|
|
||
|
|
# Aggregate data for each timeframe
|
||
|
|
aggregated_data = {}
|
||
|
|
for tf in timeframes:
|
||
|
|
print(f" 🔄 Aggregating to {tf}...")
|
||
|
|
aggregated_data[tf] = aggregate_minute_data_to_timeframe(minute_data, tf, "end")
|
||
|
|
print(f" ✅ {len(aggregated_data[tf])} bars")
|
||
|
|
|
||
|
|
# Create subplots
|
||
|
|
fig, axes = plt.subplots(len(timeframes), 1, figsize=(15, 4 * len(timeframes)))
|
||
|
|
if len(timeframes) == 1:
|
||
|
|
axes = [axes]
|
||
|
|
|
||
|
|
fig.suptitle(title, fontsize=16, fontweight='bold')
|
||
|
|
|
||
|
|
# Colors for different timeframes
|
||
|
|
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown']
|
||
|
|
|
||
|
|
for i, tf in enumerate(timeframes):
|
||
|
|
ax = axes[i]
|
||
|
|
data = aggregated_data[tf]
|
||
|
|
|
||
|
|
if data:
|
||
|
|
# Plot candlesticks
|
||
|
|
plot_candlesticks(ax, data, tf, color=colors[i % len(colors)])
|
||
|
|
|
||
|
|
# Set title and labels
|
||
|
|
ax.set_title(f"{tf} Timeframe ({len(data)} bars)", fontweight='bold')
|
||
|
|
ax.set_ylabel('Price (USD)')
|
||
|
|
|
||
|
|
# Format x-axis based on data range
|
||
|
|
if len(data) > 0:
|
||
|
|
time_range = data[-1]['timestamp'] - data[0]['timestamp']
|
||
|
|
if time_range.total_seconds() <= 24 * 3600: # Less than 24 hours
|
||
|
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
|
||
|
|
ax.xaxis.set_major_locator(mdates.HourLocator(interval=2))
|
||
|
|
else:
|
||
|
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
|
||
|
|
ax.xaxis.set_major_locator(mdates.DayLocator())
|
||
|
|
|
||
|
|
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
||
|
|
|
||
|
|
# Add grid
|
||
|
|
ax.grid(True, alpha=0.3)
|
||
|
|
|
||
|
|
# Add statistics
|
||
|
|
if data:
|
||
|
|
first_bar = data[0]
|
||
|
|
last_bar = data[-1]
|
||
|
|
price_change = last_bar['close'] - first_bar['open']
|
||
|
|
price_change_pct = (price_change / first_bar['open']) * 100
|
||
|
|
|
||
|
|
stats_text = f"Open: ${first_bar['open']:.2f} | Close: ${last_bar['close']:.2f} | Change: {price_change_pct:+.2f}%"
|
||
|
|
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
|
||
|
|
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
||
|
|
else:
|
||
|
|
ax.text(0.5, 0.5, f"No data for {tf}", transform=ax.transAxes,
|
||
|
|
ha='center', va='center', fontsize=14)
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
return fig
|
||
|
|
|
||
|
|
|
||
|
|
def create_overlay_plot(minute_data, timeframes, title="Timeframe Overlay Comparison"):
|
||
|
|
"""
|
||
|
|
Create an overlay plot showing multiple timeframes on the same chart.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
minute_data: List of minute OHLCV data
|
||
|
|
timeframes: List of timeframes to overlay
|
||
|
|
title: Plot title
|
||
|
|
"""
|
||
|
|
print(f"\n📊 Creating overlay plot for timeframes: {timeframes}")
|
||
|
|
|
||
|
|
# Aggregate data for each timeframe
|
||
|
|
aggregated_data = {}
|
||
|
|
for tf in timeframes:
|
||
|
|
print(f" 🔄 Aggregating to {tf}...")
|
||
|
|
aggregated_data[tf] = aggregate_minute_data_to_timeframe(minute_data, tf, "end")
|
||
|
|
print(f" ✅ {len(aggregated_data[tf])} bars")
|
||
|
|
|
||
|
|
# Create single plot
|
||
|
|
fig, ax = plt.subplots(1, 1, figsize=(15, 8))
|
||
|
|
fig.suptitle(title, fontsize=16, fontweight='bold')
|
||
|
|
|
||
|
|
# Colors and alphas for different timeframes (lighter for larger timeframes)
|
||
|
|
colors = ['lightcoral', 'lightgreen', 'orange', 'lightblue'] # Reordered for better visibility
|
||
|
|
alphas = [0.9, 0.7, 0.5, 0.3] # Higher alpha for smaller timeframes
|
||
|
|
|
||
|
|
# Plot timeframes from largest to smallest (background to foreground)
|
||
|
|
sorted_timeframes = sorted(timeframes, key=parse_timeframe_to_minutes, reverse=True)
|
||
|
|
|
||
|
|
for i, tf in enumerate(sorted_timeframes):
|
||
|
|
data = aggregated_data[tf]
|
||
|
|
if data:
|
||
|
|
color_idx = timeframes.index(tf)
|
||
|
|
plot_candlesticks(ax, data, tf,
|
||
|
|
color=colors[color_idx % len(colors)],
|
||
|
|
alpha=alphas[color_idx % len(alphas)])
|
||
|
|
|
||
|
|
# Set labels and formatting
|
||
|
|
ax.set_ylabel('Price (USD)')
|
||
|
|
ax.set_xlabel('Time')
|
||
|
|
|
||
|
|
# Format x-axis based on data range
|
||
|
|
if minute_data:
|
||
|
|
time_range = minute_data[-1]['timestamp'] - minute_data[0]['timestamp']
|
||
|
|
if time_range.total_seconds() <= 24 * 3600: # Less than 24 hours
|
||
|
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
|
||
|
|
ax.xaxis.set_major_locator(mdates.HourLocator(interval=2))
|
||
|
|
else:
|
||
|
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
|
||
|
|
ax.xaxis.set_major_locator(mdates.DayLocator())
|
||
|
|
|
||
|
|
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
|
||
|
|
|
||
|
|
# Add grid
|
||
|
|
ax.grid(True, alpha=0.3)
|
||
|
|
|
||
|
|
# Add legend
|
||
|
|
legend_elements = []
|
||
|
|
for i, tf in enumerate(timeframes):
|
||
|
|
data = aggregated_data[tf]
|
||
|
|
if data:
|
||
|
|
legend_elements.append(plt.Rectangle((0,0),1,1,
|
||
|
|
facecolor=colors[i % len(colors)],
|
||
|
|
alpha=alphas[i % len(alphas)],
|
||
|
|
label=f"{tf} ({len(data)} bars)"))
|
||
|
|
|
||
|
|
ax.legend(handles=legend_elements, loc='upper left')
|
||
|
|
|
||
|
|
# Add explanation text
|
||
|
|
explanation = ("Smaller timeframes should be contained within larger timeframes.\n"
|
||
|
|
"Each bar spans its full time period (not just a point in time).")
|
||
|
|
ax.text(0.02, 0.02, explanation, transform=ax.transAxes,
|
||
|
|
verticalalignment='bottom', fontsize=10,
|
||
|
|
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
return fig
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Main function to run the visual test."""
|
||
|
|
print("🚀 Visual Test for Timeframe Aggregation")
|
||
|
|
print("=" * 50)
|
||
|
|
|
||
|
|
# Configuration
|
||
|
|
data_file = "./data/btcusd_1-min_data.csv"
|
||
|
|
test_date = None # Let the script auto-select a good date
|
||
|
|
max_rows = 1440 # 24 hours of minute data
|
||
|
|
timeframes = ["5min", "15min", "30min", "1h"]
|
||
|
|
|
||
|
|
# Check if data file exists
|
||
|
|
if not os.path.exists(data_file):
|
||
|
|
print(f"❌ Data file not found: {data_file}")
|
||
|
|
print("Please ensure the BTC data file exists in the ./data/ directory")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Load data
|
||
|
|
df = load_btc_data(data_file, date_filter=test_date, max_rows=max_rows)
|
||
|
|
if df is None or len(df) == 0:
|
||
|
|
print("❌ Failed to load data or no data available")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Convert to minute data format
|
||
|
|
minute_data = convert_df_to_minute_data(df)
|
||
|
|
print(f"\n📈 Converted to {len(minute_data)} minute data points")
|
||
|
|
|
||
|
|
# Show data range
|
||
|
|
if minute_data:
|
||
|
|
start_time = minute_data[0]['timestamp']
|
||
|
|
end_time = minute_data[-1]['timestamp']
|
||
|
|
print(f"📅 Data range: {start_time} to {end_time}")
|
||
|
|
|
||
|
|
# Show sample data
|
||
|
|
print(f"📊 Sample data point:")
|
||
|
|
sample = minute_data[0]
|
||
|
|
print(f" Timestamp: {sample['timestamp']}")
|
||
|
|
print(f" OHLCV: O={sample['open']:.2f}, H={sample['high']:.2f}, L={sample['low']:.2f}, C={sample['close']:.2f}, V={sample['volume']:.0f}")
|
||
|
|
|
||
|
|
# Create comparison plots
|
||
|
|
try:
|
||
|
|
# Individual timeframe plots
|
||
|
|
fig1 = create_comparison_plot(minute_data, timeframes,
|
||
|
|
f"BTC Timeframe Comparison - {start_time.date()}")
|
||
|
|
|
||
|
|
# Overlay plot
|
||
|
|
fig2 = create_overlay_plot(minute_data, timeframes,
|
||
|
|
f"BTC Timeframe Overlay - {start_time.date()}")
|
||
|
|
|
||
|
|
# Show plots
|
||
|
|
plt.show()
|
||
|
|
|
||
|
|
print("\n✅ Visual test completed successfully!")
|
||
|
|
print("📊 Check the plots to verify:")
|
||
|
|
print(" 1. Higher timeframes contain lower timeframes")
|
||
|
|
print(" 2. OHLCV values are correctly aggregated")
|
||
|
|
print(" 3. Timestamps represent bar end times")
|
||
|
|
print(" 4. No future data leakage")
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error creating plots: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
success = main()
|
||
|
|
sys.exit(0 if success else 1)
|