import os import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import numpy as np class BacktestCharts: @staticmethod def plot(df, meta_trend): """ Plot close price line chart with a bar at the bottom: green when trend is 1, red when trend is 0. The bar stays at the bottom even when zooming/panning. - df: DataFrame with columns ['close', ...] and a datetime index or 'timestamp' column. - meta_trend: array-like, same length as df, values 1 (green) or 0 (red). """ fig, (ax_price, ax_bar) = plt.subplots( nrows=2, ncols=1, figsize=(16, 8), sharex=True, gridspec_kw={'height_ratios': [12, 1]} ) sns.lineplot(x=df.index, y=df['close'], label='Close Price', color='blue', ax=ax_price) ax_price.set_title('Close Price with Trend Bar (Green=1, Red=0)') ax_price.set_ylabel('Price') ax_price.grid(True, alpha=0.3) ax_price.legend() # Clean meta_trend: ensure only 0/1, handle NaNs by forward-fill then fill remaining with 0 meta_trend_arr = np.asarray(meta_trend) if not np.issubdtype(meta_trend_arr.dtype, np.number): meta_trend_arr = pd.Series(meta_trend_arr).astype(float).to_numpy() if np.isnan(meta_trend_arr).any(): meta_trend_arr = pd.Series(meta_trend_arr).fillna(method='ffill').fillna(0).astype(int).to_numpy() else: meta_trend_arr = meta_trend_arr.astype(int) meta_trend_arr = np.where(meta_trend_arr != 1, 0, 1) # force only 0 or 1 if hasattr(df.index, 'to_numpy'): x_vals = df.index.to_numpy() else: x_vals = np.array(df.index) # Find contiguous regions regions = [] start = 0 for i in range(1, len(meta_trend_arr)): if meta_trend_arr[i] != meta_trend_arr[i-1]: regions.append((start, i-1, meta_trend_arr[i-1])) start = i regions.append((start, len(meta_trend_arr)-1, meta_trend_arr[-1])) # Draw red vertical lines at the start of each new region (except the first) for region_idx in range(1, len(regions)): region_start = regions[region_idx][0] ax_price.axvline(x=x_vals[region_start], color='black', linestyle='--', alpha=0.7, linewidth=1) for start, end, trend in regions: color = '#089981' if trend == 1 else '#F23645' # Offset by 1 on x: span from x_vals[start] to x_vals[end+1] if possible x_start = x_vals[start] x_end = x_vals[end+1] if end+1 < len(x_vals) else x_vals[end] ax_bar.axvspan(x_start, x_end, color=color, alpha=1, ymin=0, ymax=1) ax_bar.set_ylim(0, 1) ax_bar.set_yticks([]) ax_bar.set_ylabel('Trend') ax_bar.set_xlabel('Time') ax_bar.grid(False) ax_bar.set_title('Meta Trend') plt.tight_layout(h_pad=0.1) plt.show()