From 10bb3710542e0428f4b3d2b5e15de74e1c2c9e2d Mon Sep 17 00:00:00 2001 From: Simon Moisy Date: Tue, 13 Jan 2026 21:55:34 +0800 Subject: [PATCH] Implement Regime Reversion Strategy and remove regime_detection.py - Introduced `RegimeReversionStrategy` for ML-based regime detection and mean reversion trading. - Added feature engineering and model training logic within the new strategy. - Removed the deprecated `regime_detection.py` file to streamline the codebase. - Updated the strategy factory to include the new regime strategy configuration. --- research/regime_detection.py | 384 ---------------------------------- strategies/factory.py | 14 ++ strategies/regime_strategy.py | 280 +++++++++++++++++++++++++ 3 files changed, 294 insertions(+), 384 deletions(-) delete mode 100644 research/regime_detection.py create mode 100644 strategies/regime_strategy.py diff --git a/research/regime_detection.py b/research/regime_detection.py deleted file mode 100644 index 6468c60..0000000 --- a/research/regime_detection.py +++ /dev/null @@ -1,384 +0,0 @@ -import sys -import os -from pathlib import Path - -# Add project root to path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import pandas as pd -import numpy as np -import ta -from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import train_test_split -from sklearn.metrics import classification_report, confusion_matrix -import plotly.graph_objects as go -from plotly.subplots import make_subplots - -from engine.data_manager import DataManager -from engine.market import MarketType - -def prepare_data(symbol_a="BTC-USDT", symbol_b="ETH-USDT", timeframe="1h", limit=None, start_date=None, end_date=None): - """ - Load and align data for two assets to create a pair. - """ - dm = DataManager() - - print(f"Loading data for {symbol_a} and {symbol_b}...") - - # Helper to load or download - def get_df(symbol): - try: - # Try load first - df = dm.load_data("okx", symbol, timeframe, MarketType.SPOT) - except Exception: - df = dm.download_data("okx", symbol, timeframe, market_type=MarketType.SPOT) - - # If we have start/end dates, ensure we have enough data or re-download - if start_date: - mask_start = pd.Timestamp(start_date, tz='UTC') - if df.index.min() > mask_start: - print(f"Local data starts {df.index.min()}, need {mask_start}. Downloading...") - df = dm.download_data("okx", symbol, timeframe, start_date=start_date, end_date=end_date, market_type=MarketType.SPOT) - return df - - df_a = get_df(symbol_a) - df_b = get_df(symbol_b) - - # Filter by date if provided (to match CQ data range) - if start_date: - df_a = df_a[df_a.index >= pd.Timestamp(start_date, tz='UTC')] - df_b = df_b[df_b.index >= pd.Timestamp(start_date, tz='UTC')] - - if end_date: - df_a = df_a[df_a.index <= pd.Timestamp(end_date, tz='UTC')] - df_b = df_b[df_b.index <= pd.Timestamp(end_date, tz='UTC')] - - # Align DataFrames - print("Aligning data...") - common_index = df_a.index.intersection(df_b.index) - df_a = df_a.loc[common_index].copy() - df_b = df_b.loc[common_index].copy() - - if limit: - df_a = df_a.tail(limit) - df_b = df_b.tail(limit) - - return df_a, df_b - -def load_cryptoquant_data(file_path: str) -> pd.DataFrame | None: - """ - Load CryptoQuant data and prepare it for merging. - """ - if not os.path.exists(file_path): - print(f"Warning: CQ data file {file_path} not found.") - return None - - print(f"Loading CryptoQuant data from {file_path}...") - df = pd.read_csv(file_path, index_col='timestamp', parse_dates=True) - - # CQ data is usually daily (UTC 00:00). - # Ensure index is timezone aware to match market data - if df.index.tz is None: - df.index = df.index.tz_localize('UTC') - - return df - -def calculate_features(df_a, df_b, cq_df=None, window=24): - """ - Calculate spread, z-score, and advanced regime features including CQ data. - """ - # 1. Price Ratio (Spread) - spread = df_b['close'] / df_a['close'] - - # 2. Rolling Statistics for Z-Score - rolling_mean = spread.rolling(window=window).mean() - rolling_std = spread.rolling(window=window).std() - z_score = (spread - rolling_mean) / rolling_std - - # 3. Spread Momentum / Technicals - spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi() - spread_roc = spread.pct_change(periods=5) * 100 - - # 4. Volume Dynamics - vol_ratio = df_b['volume'] / df_a['volume'] - vol_ratio_ma = vol_ratio.rolling(window=12).mean() - - # 5. Volatility Regime - ret_a = df_a['close'].pct_change() - ret_b = df_b['close'].pct_change() - vol_a = ret_a.rolling(window=window).std() - vol_b = ret_b.rolling(window=window).std() - vol_spread_ratio = vol_b / vol_a - - # Create feature DataFrame - features = pd.DataFrame(index=spread.index) - features['spread'] = spread - features['z_score'] = z_score - features['spread_rsi'] = spread_rsi - features['spread_roc'] = spread_roc - features['vol_ratio'] = vol_ratio - features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma - features['vol_diff_ratio'] = vol_spread_ratio - - # 6. Merge CryptoQuant Data - if cq_df is not None: - print("Merging CryptoQuant features...") - # Forward fill daily data to hourly timestamps - # reindex features to match cq_df range or join - - # Resample CQ to hourly (ffill) - # But easier: join features with cq_df using asof or reindex - cq_aligned = cq_df.reindex(features.index, method='ffill') - - # Add derived CQ features - # Funding Diff: If ETH funding > BTC funding => ETH overheated - if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns: - cq_aligned['funding_diff'] = cq_aligned['eth_funding'] - cq_aligned['btc_funding'] - - # Inflow Ratio: If ETH inflow >> BTC inflow => ETH dump incoming? - if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns: - # Add small epsilon to avoid div by zero - cq_aligned['inflow_ratio'] = cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1) - - features = features.join(cq_aligned) - - # --- Refined Target Definition (Anytime Profit) --- - horizon = 6 - threshold = 0.005 # 0.5% profit target - z_threshold = 1.0 - - # For Short Spread (Z > 1): Did it drop below target? - # We look for the MINIMUM spread in the next 'horizon' periods - future_min = features['spread'].rolling(window=horizon).min().shift(-horizon) - target_short = features['spread'] * (1 - threshold) - success_short = (features['z_score'] > z_threshold) & (future_min < target_short) - - # For Long Spread (Z < -1): Did it rise above target? - # We look for the MAXIMUM spread in the next 'horizon' periods - future_max = features['spread'].rolling(window=horizon).max().shift(-horizon) - target_long = features['spread'] * (1 + threshold) - success_long = (features['z_score'] < -z_threshold) & (future_max > target_long) - - conditions = [success_short, success_long] - - features['target'] = np.select(conditions, [1, 1], default=0) - - return features.dropna() - -def train_regime_model(features): - """ - Train a Random Forest to predict mean reversion success. - """ - # Define excluded columns (targets, raw prices, intermediates) - exclude_cols = ['spread', 'horizon_ret', 'target', 'rolling_mean', 'rolling_std'] - - # Auto-select all other numeric columns as features - feature_cols = [c for c in features.columns if c not in exclude_cols] - - # Handle NaN/Inf if any slipped through - X = features[feature_cols].replace([np.inf, -np.inf], np.nan).fillna(0) - y = features['target'] - - # Split Data - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=False) - - print(f"\nTraining on {len(X_train)} samples, Testing on {len(X_test)} samples...") - print(f"Features used: {feature_cols}") - print(f"Class Balance (Target=1): {y.mean():.2%}") - - # Model - model = RandomForestClassifier( - n_estimators=200, - max_depth=6, - min_samples_leaf=20, - class_weight='balanced_subsample', - random_state=42 - ) - model.fit(X_train, y_train) - - # Evaluation - y_pred = model.predict(X_test) - y_prob = model.predict_proba(X_test)[:, 1] - - print("\n--- Model Evaluation ---") - print(classification_report(y_test, y_pred)) - - # Feature Importance - importances = pd.Series(model.feature_importances_, index=feature_cols).sort_values(ascending=False) - print("\n--- Feature Importance ---") - print(importances) - - return model, X_test, y_test, y_pred, y_prob - -def plot_interactive_results(features, y_test, y_pred, y_prob): - """ - Create an interactive HTML plot using Plotly. - """ - print("\nGenerating interactive plot...") - - test_idx = y_test.index - test_data = features.loc[test_idx].copy() - test_data['prob'] = y_prob - test_data['prediction'] = y_pred - test_data['actual'] = y_test - - # Create Subplots - fig = make_subplots( - rows=3, cols=1, - shared_xaxes=True, - vertical_spacing=0.05, - row_heights=[0.5, 0.25, 0.25], - subplot_titles=('Spread & Signals', 'Exchange Inflows', 'Z-Score & Probability') - ) - - # Top: Spread - fig.add_trace( - go.Scatter(x=test_data.index, y=test_data['spread'], mode='lines', name='Spread', line=dict(color='gray')), - row=1, col=1 - ) - - # Signals - # Separate Long and Short signals for clarity - # Logic: If Z-Score was High (>1), we were betting on a SHORT Spread (Reversion Down) - # If Z-Score was Low (< -1), we were betting on a LONG Spread (Reversion Up) - - # Correct Short Signals (Green Triangle Down) - tp_short = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 1) & (test_data['z_score'] > 0)] - fig.add_trace( - go.Scatter(x=tp_short.index, y=tp_short['spread'], mode='markers', name='Win: Short Spread', - marker=dict(symbol='triangle-down', size=12, color='green')), - row=1, col=1 - ) - - # Correct Long Signals (Green Triangle Up) - tp_long = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 1) & (test_data['z_score'] < 0)] - fig.add_trace( - go.Scatter(x=tp_long.index, y=tp_long['spread'], mode='markers', name='Win: Long Spread', - marker=dict(symbol='triangle-up', size=12, color='green')), - row=1, col=1 - ) - - # False Short Signals (Red Triangle Down) - fp_short = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 0) & (test_data['z_score'] > 0)] - fig.add_trace( - go.Scatter(x=fp_short.index, y=fp_short['spread'], mode='markers', name='Loss: Short Spread', - marker=dict(symbol='triangle-down', size=10, color='red')), - row=1, col=1 - ) - - # False Long Signals (Red Triangle Up) - fp_long = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 0) & (test_data['z_score'] < 0)] - fig.add_trace( - go.Scatter(x=fp_long.index, y=fp_long['spread'], mode='markers', name='Loss: Long Spread', - marker=dict(symbol='triangle-up', size=10, color='red')), - row=1, col=1 - ) - - # Middle: Inflows (BTC vs ETH) - if 'btc_inflow' in test_data.columns: - fig.add_trace( - go.Bar(x=test_data.index, y=test_data['btc_inflow'], name='BTC Inflow', marker_color='orange', opacity=0.6), - row=2, col=1 - ) - if 'eth_inflow' in test_data.columns: - fig.add_trace( - go.Bar(x=test_data.index, y=test_data['eth_inflow'], name='ETH Inflow', marker_color='purple', opacity=0.6), - row=2, col=1 - ) - - # Bottom: Z-Score - fig.add_trace( - go.Scatter(x=test_data.index, y=test_data['z_score'], mode='lines', name='Z-Score', line=dict(color='blue'), opacity=0.5), - row=3, col=1 - ) - fig.add_hline(y=2, line_dash="dash", line_color="red", row=3, col=1) - fig.add_hline(y=-2, line_dash="dash", line_color="green", row=3, col=1) - - # Probability (Secondary Y for Row 3) - fig.add_trace( - go.Scatter(x=test_data.index, y=test_data['prob'], mode='lines', name='Prob', line=dict(color='cyan', width=1.5), yaxis='y4'), - row=3, col=1 - ) - - fig.update_layout( - title='Regime Detection Analysis (with CryptoQuant)', - autosize=True, - height=None, - hovermode='x unified', - yaxis4=dict(title='Probability', overlaying='y3', side='right', range=[0, 1], showgrid=False), - template="plotly_dark", - margin=dict(l=10, r=10, t=40, b=10), - barmode='group' - ) - - # Update all x-axes to ensure spikes are visible everywhere - fig.update_xaxes( - showspikes=True, - spikemode='across', - spikesnap='cursor', - showline=False, - showgrid=True, - spikedash='dot', - spikecolor='white', # Make it bright to see - spikethickness=1, - ) - - fig.update_layout( - title='Regime Detection Analysis (with CryptoQuant)', - autosize=True, - height=None, - hovermode='x unified', # Keep unified hover for data reading - yaxis4=dict(title='Probability', overlaying='y3', side='right', range=[0, 1], showgrid=False), - template="plotly_dark", - margin=dict(l=10, r=10, t=40, b=10), - barmode='group' - ) - - output_path = "research/regime_results.html" - fig.write_html( - output_path, - config={'responsive': True, 'scrollZoom': True}, - include_plotlyjs='cdn', - full_html=True, - default_height='100vh', - default_width='100%' - ) - print(f"Interactive plot saved to {output_path}") - -def main(): - # 1. Load CQ Data first to determine valid date range - cq_path = "data/cq_training_data.csv" - cq_df = load_cryptoquant_data(cq_path) - - start_date = None - end_date = None - - if cq_df is not None and not cq_df.empty: - start_date = cq_df.index.min().strftime('%Y-%m-%d') - end_date = cq_df.index.max().strftime('%Y-%m-%d') - print(f"CryptoQuant Data Range: {start_date} to {end_date}") - - # 2. Get Market Data (Aligned to CQ range) - df_btc, df_eth = prepare_data( - "BTC-USDT", "ETH-USDT", - timeframe="1h", - start_date=start_date, - end_date=end_date - ) - - # 3. Calculate Features - print("Calculating advanced regime features...") - data = calculate_features(df_btc, df_eth, cq_df=cq_df, window=24) - - if data.empty: - print("Error: No overlapping data found between Price and CryptoQuant data.") - return - - # 4. Train & Evaluate - model, X_test, y_test, y_pred, y_prob = train_regime_model(data) - - # 5. Plot - plot_interactive_results(data, y_test, y_pred, y_prob) - -if __name__ == "__main__": - main() diff --git a/strategies/factory.py b/strategies/factory.py index 9272499..3c3005e 100644 --- a/strategies/factory.py +++ b/strategies/factory.py @@ -36,6 +36,7 @@ def _build_registry() -> dict[str, StrategyConfig]: # Import here to avoid circular imports from strategies.examples import MaCrossStrategy, RsiStrategy from strategies.supertrend import MetaSupertrendStrategy + from strategies.regime_strategy import RegimeReversionStrategy return { "rsi": StrategyConfig( @@ -76,6 +77,19 @@ def _build_registry() -> dict[str, StrategyConfig]: 'period3': 12, 'multiplier3': 1.0 } ), + "regime": StrategyConfig( + strategy_class=RegimeReversionStrategy, + default_params={ + 'horizon': 96, + 'z_window': 24, + 'stop_loss': 0.06, + 'take_profit': 0.05 + }, + grid_params={ + 'horizon': [72, 96, 120], + 'stop_loss': [0.04, 0.06, 0.08] + } + ) } diff --git a/strategies/regime_strategy.py b/strategies/regime_strategy.py new file mode 100644 index 0000000..680f4b4 --- /dev/null +++ b/strategies/regime_strategy.py @@ -0,0 +1,280 @@ +import pandas as pd +import numpy as np +import ta +import vectorbt as vbt +from sklearn.ensemble import RandomForestClassifier + +from strategies.base import BaseStrategy +from engine.market import MarketType +from engine.data_manager import DataManager +from engine.logging_config import get_logger + +logger = get_logger(__name__) + +class RegimeReversionStrategy(BaseStrategy): + """ + ML-Based Regime Detection & Mean Reversion Strategy. + + Logic: + 1. Tracks the BTC/ETH Spread and its Z-Score (24h window). + 2. Uses a Random Forest model to predict if an extreme Z-Score will revert profitably. + 3. Features: Spread Technicals (RSI, ROC) + On-Chain Flows (Inflow, Funding). + 4. Entry: When Model Probability > 0.5. + 5. Exit: Z-Score reversion to 0 or SL/TP. + + Walk-Forward Training: + - Trains on first `train_ratio` of data (default 70%) + - Generates signals only for remaining test period (30%) + - Eliminates look-ahead bias for realistic backtest results + """ + + def __init__(self, + model_path: str = "data/regime_model.pkl", + horizon: int = 96, # 4 Days based on research + z_window: int = 24, + stop_loss: float = 0.06, # 6% to survive 2% avg MAE + take_profit: float = 0.05, # Swing target + train_ratio: float = 0.7 # Walk-forward: train on first 70% + ): + super().__init__() + self.model_path = model_path + self.horizon = horizon + self.z_window = z_window + self.stop_loss = stop_loss + self.take_profit = take_profit + self.train_ratio = train_ratio + + # Default Strategy Config + self.default_market_type = MarketType.PERPETUAL + self.default_leverage = 1 + + self.dm = DataManager() + self.model = None + self.feature_cols = None + self.train_end_idx = None # Will store the training cutoff point + + def run(self, close, **kwargs): + """ + Execute the strategy logic. + We assume this strategy is run on ETH-USDT (the active asset). + We will fetch BTC-USDT internally to calculate the spread. + """ + # 1. Identify Context + # We need BTC data aligned with the incoming ETH 'close' series + start_date = close.index.min() + end_date = close.index.max() + + logger.info("Fetching BTC context data...") + try: + # Load BTC data (Context) - Must match the timeframe of the backtest + # Research was done on 1h candles, so strategy should be run on 1h + df_btc = self.dm.load_data("okx", "BTC-USDT", "1h", MarketType.SPOT) + + # Align BTC to ETH (close) + df_btc = df_btc.reindex(close.index, method='ffill') + btc_close = df_btc['close'] + + except Exception as e: + logger.error(f"Failed to load BTC context: {e}") + empty = self.create_empty_signals(close) + return empty, empty, empty, empty + + # 2. Construct DataFrames for Feature Engineering + # We need volume/high/low for features, but 'run' signature primarily gives 'close'. + # kwargs might have high/low/volume if passed by Backtester.run_strategy + eth_vol = kwargs.get('volume') + + if eth_vol is None: + logger.warning("Volume data missing. Feature calculation might fail.") + # Fallback or error handling + eth_vol = pd.Series(0, index=close.index) + + # Construct dummy dfs for prepare_features + # We only really need Close and Volume for the current feature set + df_a = pd.DataFrame({'close': btc_close, 'volume': df_btc['volume']}) + df_b = pd.DataFrame({'close': close, 'volume': eth_vol}) + + # 3. Load On-Chain Data (CryptoQuant) + # We use the saved CSV for training/inference + # In a live setting, this would query the API for recent data + cq_df = None + try: + cq_path = "data/cq_training_data.csv" + cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True) + if cq_df.index.tz is None: + cq_df.index = cq_df.index.tz_localize('UTC') + except Exception: + logger.warning("CryptoQuant data not found. Running without on-chain features.") + + # 4. Calculate Features + features = self.prepare_features(df_a, df_b, cq_df) + + # 5. Walk-Forward Split + # Train on first `train_ratio` of data, test on remainder + n_samples = len(features) + train_size = int(n_samples * self.train_ratio) + + train_features = features.iloc[:train_size] + test_features = features.iloc[train_size:] + + train_end_date = train_features.index[-1] + test_start_date = test_features.index[0] + + logger.info( + f"Walk-Forward Split: Train={len(train_features)} bars " + f"(until {train_end_date.strftime('%Y-%m-%d')}), " + f"Test={len(test_features)} bars " + f"(from {test_start_date.strftime('%Y-%m-%d')})" + ) + + # 6. Train Model on Training Period ONLY + if self.model is None: + logger.info("Training Regime Model on training period only...") + self.model, self.feature_cols = self.train_model(train_features) + + # 7. Predict on TEST Period ONLY + # Use valid columns only + X_test = test_features[self.feature_cols].fillna(0) + X_test = X_test.replace([np.inf, -np.inf], 0) + + # Predict Probabilities for test period + probs = self.model.predict_proba(X_test)[:, 1] + + # 8. Generate Entry Signals (TEST period only) + # If Z > 1 (Spread High, ETH Expensive) -> Short ETH + # If Z < -1 (Spread Low, ETH Cheap) -> Long ETH + + short_signal_test = (probs > 0.5) & (test_features['z_score'].values > 1.0) + long_signal_test = (probs > 0.5) & (test_features['z_score'].values < -1.0) + + # Create full-length signal series (False for training period) + long_entries = pd.Series(False, index=close.index) + short_entries = pd.Series(False, index=close.index) + + # Map test signals to their correct indices + test_idx = test_features.index + for i, idx in enumerate(test_idx): + if idx in close.index: + long_entries.loc[idx] = bool(long_signal_test[i]) + short_entries.loc[idx] = bool(short_signal_test[i]) + + # 9. Generate Exits + # Exit when Z-Score crosses back through 0 (mean reversion complete) + z_reindexed = features['z_score'].reindex(close.index, fill_value=0) + + # Exit Long when Z > 0, Exit Short when Z < 0 + long_exits = z_reindexed > 0 + short_exits = z_reindexed < 0 + + # Log signal counts for verification + n_long = long_entries.sum() + n_short = short_entries.sum() + logger.info(f"Generated {n_long} long signals, {n_short} short signals (test period only)") + + return long_entries, long_exits, short_entries, short_exits + + def prepare_features(self, df_btc, df_eth, cq_df=None): + """Replicate research feature engineering""" + # Align + common = df_btc.index.intersection(df_eth.index) + df_a = df_btc.loc[common].copy() + df_b = df_eth.loc[common].copy() + + # Spread + spread = df_b['close'] / df_a['close'] + + # Z-Score + rolling_mean = spread.rolling(window=self.z_window).mean() + rolling_std = spread.rolling(window=self.z_window).std() + z_score = (spread - rolling_mean) / rolling_std + + # Technicals + spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi() + spread_roc = spread.pct_change(periods=5) * 100 + spread_change_1h = spread.pct_change(periods=1) + + # Volume + vol_ratio = df_b['volume'] / df_a['volume'] + vol_ratio_ma = vol_ratio.rolling(window=12).mean() + + # Volatility + ret_a = df_a['close'].pct_change() + ret_b = df_b['close'].pct_change() + vol_a = ret_a.rolling(window=self.z_window).std() + vol_b = ret_b.rolling(window=self.z_window).std() + vol_spread_ratio = vol_b / vol_a + + features = pd.DataFrame(index=spread.index) + features['spread'] = spread + features['z_score'] = z_score + features['spread_rsi'] = spread_rsi + features['spread_roc'] = spread_roc + features['spread_change_1h'] = spread_change_1h + features['vol_ratio'] = vol_ratio + features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma + features['vol_diff_ratio'] = vol_spread_ratio + + # CQ Merge + if cq_df is not None: + cq_aligned = cq_df.reindex(features.index, method='ffill') + if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns: + cq_aligned['funding_diff'] = cq_aligned['eth_funding'] - cq_aligned['btc_funding'] + if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns: + cq_aligned['inflow_ratio'] = cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1) + features = features.join(cq_aligned) + + return features.dropna() + + def train_model(self, train_features): + """ + Train Random Forest on training data only. + + This method receives ONLY the training subset of features, + ensuring no look-ahead bias. The model learns from historical + patterns and is then applied to unseen test data. + + Args: + train_features: DataFrame containing features for training period only + """ + threshold = 0.005 + horizon = self.horizon + + # Define targets using ONLY training data + # For Short Spread (Z > 1): Did spread drop below target within horizon? + future_min = train_features['spread'].rolling(window=horizon).min().shift(-horizon) + target_short = train_features['spread'] * (1 - threshold) + success_short = (train_features['z_score'] > 1.0) & (future_min < target_short) + + # For Long Spread (Z < -1): Did spread rise above target within horizon? + future_max = train_features['spread'].rolling(window=horizon).max().shift(-horizon) + target_long = train_features['spread'] * (1 + threshold) + success_long = (train_features['z_score'] < -1.0) & (future_max > target_long) + + targets = np.select([success_short, success_long], [1, 1], default=0) + + # Build model + model = RandomForestClassifier( + n_estimators=300, max_depth=5, min_samples_leaf=30, + class_weight={0: 1, 1: 3}, random_state=42 + ) + + # Exclude non-feature columns + exclude = ['spread'] + cols = [c for c in train_features.columns if c not in exclude] + + # Clean features + X_train = train_features[cols].fillna(0) + X_train = X_train.replace([np.inf, -np.inf], 0) + + # Remove rows with NaN targets (from rolling window at end of training period) + valid_mask = ~np.isnan(targets) & ~np.isinf(targets) + # Also check for rows where future data doesn't exist (shift created NaNs) + valid_mask = valid_mask & (future_min.notna().values) & (future_max.notna().values) + + X_train_clean = X_train[valid_mask] + targets_clean = targets[valid_mask] + + logger.info(f"Training on {len(X_train_clean)} valid samples (removed {len(X_train) - len(X_train_clean)} with incomplete future data)") + + model.fit(X_train_clean, targets_clean) + return model, cols