""" ML Model Training Script. Trains the regime detection Random Forest model on historical data. Can be run manually or scheduled via cron for daily retraining. Usage: uv run python train_model.py [options] Options: --days DAYS Number of days of historical data to use (default: 90) --pair PAIR Trading pair for context (default: BTC-USDT) --timeframe TF Timeframe (default: 1h) --output PATH Output model path (default: data/regime_model.pkl) --train-ratio R Train/test split ratio (default: 0.7) --dry-run Run without saving model """ import argparse import pickle import sys from datetime import datetime, timedelta from pathlib import Path import numpy as np import pandas as pd import ta from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import classification_report, f1_score from engine.data_manager import DataManager from engine.market import MarketType from engine.logging_config import get_logger logger = get_logger(__name__) # Default configuration (from research optimization) DEFAULT_HORIZON = 102 # 4.25 days - optimal from research DEFAULT_Z_WINDOW = 24 # 24h rolling window DEFAULT_PROFIT_TARGET = 0.005 # 0.5% profit threshold DEFAULT_Z_THRESHOLD = 1.0 # Z-score entry threshold DEFAULT_TRAIN_RATIO = 0.7 # 70% train / 30% test FEE_RATE = 0.001 # 0.1% round-trip fee def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Train the regime detection ML model" ) parser.add_argument( "--days", type=int, default=90, help="Number of days of historical data to use (default: 90)" ) parser.add_argument( "--pair", type=str, default="BTC-USDT", help="Base pair for context data (default: BTC-USDT)" ) parser.add_argument( "--spread-pair", type=str, default="ETH-USDT", help="Spread pair to trade (default: ETH-USDT)" ) parser.add_argument( "--timeframe", type=str, default="1h", help="Timeframe (default: 1h)" ) parser.add_argument( "--market", type=str, choices=["spot", "perpetual"], default="perpetual", help="Market type (default: perpetual)" ) parser.add_argument( "--output", type=str, default="data/regime_model.pkl", help="Output model path (default: data/regime_model.pkl)" ) parser.add_argument( "--train-ratio", type=float, default=DEFAULT_TRAIN_RATIO, help="Train/test split ratio (default: 0.7)" ) parser.add_argument( "--horizon", type=int, default=DEFAULT_HORIZON, help="Prediction horizon in bars (default: 102)" ) parser.add_argument( "--dry-run", action="store_true", help="Run without saving model" ) parser.add_argument( "--download", action="store_true", help="Download latest data before training" ) return parser.parse_args() def download_data(dm: DataManager, pair: str, timeframe: str, market_type: MarketType): """Download latest data for a pair.""" logger.info(f"Downloading latest data for {pair}...") try: dm.download_data("okx", pair, timeframe, market_type) logger.info(f"Downloaded {pair} data") except Exception as e: logger.error(f"Failed to download {pair}: {e}") raise def load_data( dm: DataManager, base_pair: str, spread_pair: str, timeframe: str, market_type: MarketType, days: int ) -> tuple[pd.DataFrame, pd.DataFrame]: """Load and align historical data for both pairs.""" df_base = dm.load_data("okx", base_pair, timeframe, market_type) df_spread = dm.load_data("okx", spread_pair, timeframe, market_type) # Filter to last N days end_date = pd.Timestamp.now(tz="UTC") start_date = end_date - timedelta(days=days) df_base = df_base[(df_base.index >= start_date) & (df_base.index <= end_date)] df_spread = df_spread[(df_spread.index >= start_date) & (df_spread.index <= end_date)] # Align indices common = df_base.index.intersection(df_spread.index) df_base = df_base.loc[common] df_spread = df_spread.loc[common] logger.info( f"Loaded {len(common)} bars from {common.min()} to {common.max()}" ) return df_base, df_spread def load_cryptoquant_data() -> pd.DataFrame | None: """Load CryptoQuant on-chain data if available.""" try: cq_path = Path("data/cq_training_data.csv") if not cq_path.exists(): logger.info("CryptoQuant data not found, skipping on-chain features") return None cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True) if cq_df.index.tz is None: cq_df.index = cq_df.index.tz_localize('UTC') logger.info(f"Loaded CryptoQuant data: {len(cq_df)} rows") return cq_df except Exception as e: logger.warning(f"Could not load CryptoQuant data: {e}") return None def calculate_features( df_base: pd.DataFrame, df_spread: pd.DataFrame, cq_df: pd.DataFrame | None = None, z_window: int = DEFAULT_Z_WINDOW ) -> pd.DataFrame: """Calculate all features for the model.""" spread = df_spread['close'] / df_base['close'] # Z-Score rolling_mean = spread.rolling(window=z_window).mean() rolling_std = spread.rolling(window=z_window).std() z_score = (spread - rolling_mean) / rolling_std # Technicals spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi() spread_roc = spread.pct_change(periods=5) * 100 spread_change_1h = spread.pct_change(periods=1) # Volume vol_ratio = df_spread['volume'] / df_base['volume'] vol_ratio_ma = vol_ratio.rolling(window=12).mean() # Volatility ret_base = df_base['close'].pct_change() ret_spread = df_spread['close'].pct_change() vol_base = ret_base.rolling(window=z_window).std() vol_spread = ret_spread.rolling(window=z_window).std() vol_spread_ratio = vol_spread / vol_base features = pd.DataFrame(index=spread.index) features['spread'] = spread features['z_score'] = z_score features['spread_rsi'] = spread_rsi features['spread_roc'] = spread_roc features['spread_change_1h'] = spread_change_1h features['vol_ratio'] = vol_ratio features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma features['vol_diff_ratio'] = vol_spread_ratio # Add CQ features if available if cq_df is not None: cq_aligned = cq_df.reindex(features.index, method='ffill') if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns: cq_aligned['funding_diff'] = ( cq_aligned['eth_funding'] - cq_aligned['btc_funding'] ) if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns: cq_aligned['inflow_ratio'] = ( cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1) ) features = features.join(cq_aligned) return features.dropna() def calculate_targets( features: pd.DataFrame, horizon: int, profit_target: float = DEFAULT_PROFIT_TARGET, z_threshold: float = DEFAULT_Z_THRESHOLD ) -> tuple[np.ndarray, pd.Series]: """Calculate target labels for training.""" spread = features['spread'] z_score = features['z_score'] # For Short (Z > threshold): Did spread drop below target? future_min = spread.rolling(window=horizon).min().shift(-horizon) target_short = spread * (1 - profit_target) success_short = (z_score > z_threshold) & (future_min < target_short) # For Long (Z < -threshold): Did spread rise above target? future_max = spread.rolling(window=horizon).max().shift(-horizon) target_long = spread * (1 + profit_target) success_long = (z_score < -z_threshold) & (future_max > target_long) targets = np.select([success_short, success_long], [1, 1], default=0) # Create valid mask (rows with complete future data) valid_mask = future_min.notna() & future_max.notna() return targets, valid_mask def train_model( features: pd.DataFrame, train_ratio: float = DEFAULT_TRAIN_RATIO, horizon: int = DEFAULT_HORIZON ) -> tuple[RandomForestClassifier, list[str], dict]: """ Train Random Forest model with walk-forward split. Args: features: DataFrame with calculated features train_ratio: Fraction of data to use for training horizon: Prediction horizon in bars Returns: Tuple of (trained model, feature columns, metrics dict) """ # Calculate targets targets, valid_mask = calculate_targets(features, horizon) # Walk-forward split n_samples = len(features) train_size = int(n_samples * train_ratio) train_features = features.iloc[:train_size] test_features = features.iloc[train_size:] train_targets = targets[:train_size] test_targets = targets[train_size:] train_valid = valid_mask.iloc[:train_size] test_valid = valid_mask.iloc[train_size:] # Prepare training data exclude = ['spread'] feature_cols = [c for c in features.columns if c not in exclude] X_train = train_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0) X_train_valid = X_train[train_valid] y_train_valid = train_targets[train_valid] if len(X_train_valid) < 100: raise ValueError( f"Not enough training data: {len(X_train_valid)} samples (need >= 100)" ) logger.info(f"Training on {len(X_train_valid)} samples...") # Train model model = RandomForestClassifier( n_estimators=300, max_depth=5, min_samples_leaf=30, class_weight={0: 1, 1: 3}, random_state=42 ) model.fit(X_train_valid, y_train_valid) # Evaluate on test set X_test = test_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0) predictions = model.predict(X_test) # Only evaluate on valid test rows test_valid_mask = test_valid.values y_test_valid = test_targets[test_valid_mask] pred_valid = predictions[test_valid_mask] # Calculate metrics f1 = f1_score(y_test_valid, pred_valid, zero_division=0) metrics = { 'train_samples': len(X_train_valid), 'test_samples': len(X_test), 'f1_score': f1, 'train_end': train_features.index[-1].isoformat(), 'test_start': test_features.index[0].isoformat(), 'horizon': horizon, 'feature_cols': feature_cols, } logger.info(f"Model trained. F1 Score: {f1:.3f}") logger.info( f"Train period: {train_features.index[0]} to {train_features.index[-1]}" ) logger.info( f"Test period: {test_features.index[0]} to {test_features.index[-1]}" ) return model, feature_cols, metrics def save_model( model: RandomForestClassifier, feature_cols: list[str], metrics: dict, output_path: str, versioned: bool = True ): """ Save trained model to file. Args: model: Trained model feature_cols: List of feature column names metrics: Training metrics output_path: Output file path versioned: If True, also save a timestamped version """ output = Path(output_path) output.parent.mkdir(parents=True, exist_ok=True) data = { 'model': model, 'feature_cols': feature_cols, 'metrics': metrics, 'trained_at': datetime.now().isoformat(), } # Save main model file with open(output, 'wb') as f: pickle.dump(data, f) logger.info(f"Saved model to {output}") # Save versioned copy if versioned: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") versioned_path = output.parent / f"regime_model_{timestamp}.pkl" with open(versioned_path, 'wb') as f: pickle.dump(data, f) logger.info(f"Saved versioned model to {versioned_path}") def main(): """Main training function.""" args = parse_args() market_type = MarketType.PERPETUAL if args.market == "perpetual" else MarketType.SPOT dm = DataManager() # Download latest data if requested if args.download: download_data(dm, args.pair, args.timeframe, market_type) download_data(dm, args.spread_pair, args.timeframe, market_type) # Load data try: df_base, df_spread = load_data( dm, args.pair, args.spread_pair, args.timeframe, market_type, args.days ) except Exception as e: logger.error(f"Failed to load data: {e}") logger.info("Try running with --download flag to fetch latest data") sys.exit(1) # Load on-chain data cq_df = load_cryptoquant_data() # Calculate features features = calculate_features(df_base, df_spread, cq_df) logger.info( f"Calculated {len(features)} feature rows with {len(features.columns)} columns" ) if len(features) < 200: logger.error(f"Not enough data: {len(features)} rows (need >= 200)") sys.exit(1) # Train model try: model, feature_cols, metrics = train_model( features, args.train_ratio, args.horizon ) except ValueError as e: logger.error(f"Training failed: {e}") sys.exit(1) # Print metrics summary print("\n" + "=" * 60) print("TRAINING COMPLETE") print("=" * 60) print(f"Train samples: {metrics['train_samples']}") print(f"Test samples: {metrics['test_samples']}") print(f"F1 Score: {metrics['f1_score']:.3f}") print(f"Horizon: {metrics['horizon']} bars") print(f"Features: {len(feature_cols)}") print("=" * 60) # Save model if not args.dry_run: save_model(model, feature_cols, metrics, args.output) else: logger.info("Dry run - model not saved") return 0 if __name__ == "__main__": sys.exit(main())