""" Live Regime Reversion Strategy. Adapts the backtest regime strategy for live trading. Uses a pre-trained ML model or trains on historical data. """ import logging import pickle import time from pathlib import Path from typing import Optional import numpy as np import pandas as pd from sklearn.ensemble import RandomForestClassifier from .config import TradingConfig, PathConfig logger = logging.getLogger(__name__) class LiveRegimeStrategy: """ Live trading implementation of the ML-based regime detection and mean reversion strategy. Logic: 1. Calculates BTC/ETH spread Z-Score 2. Uses Random Forest to predict reversion probability 3. Applies funding rate filter 4. Generates long/short signals on ETH perpetual """ def __init__( self, trading_config: TradingConfig, path_config: PathConfig ): self.config = trading_config self.paths = path_config self.model: Optional[RandomForestClassifier] = None self.feature_cols: Optional[list] = None self.horizon: int = 54 # Default horizon self._last_model_load_time: float = 0.0 self._last_train_time: float = 0.0 self._load_or_train_model() def reload_model_if_changed(self) -> None: """Check if model file has changed and reload if necessary.""" if not self.paths.model_path.exists(): return try: mtime = self.paths.model_path.stat().st_mtime if mtime > self._last_model_load_time: logger.info(f"Model file changed, reloading... (last: {self._last_model_load_time}, new: {mtime})") self._load_or_train_model() except Exception as e: logger.warning(f"Error checking model file: {e}") def _load_or_train_model(self) -> None: """Load pre-trained model or train a new one.""" if self.paths.model_path.exists(): try: self._last_model_load_time = self.paths.model_path.stat().st_mtime with open(self.paths.model_path, 'rb') as f: saved = pickle.load(f) self.model = saved['model'] self.feature_cols = saved['feature_cols'] # Load horizon from metrics if available if 'metrics' in saved and 'horizon' in saved['metrics']: self.horizon = saved['metrics']['horizon'] logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})") else: logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})") # Load timestamp if available if 'timestamp' in saved: self._last_train_time = saved['timestamp'] else: self._last_train_time = self._last_model_load_time return except Exception as e: logger.warning(f"Could not load model: {e}") logger.info("No pre-trained model found. Will train on first data batch.") def save_model(self) -> None: """Save trained model to file.""" if self.model is None: return try: with open(self.paths.model_path, 'wb') as f: pickle.dump({ 'model': self.model, 'feature_cols': self.feature_cols, 'metrics': {'horizon': self.horizon}, # Save horizon 'timestamp': time.time() }, f) logger.info(f"Saved model to {self.paths.model_path}") except Exception as e: logger.error(f"Could not save model: {e}") def check_retrain(self, features: pd.DataFrame) -> None: """Check if model needs retraining (older than 24h).""" if time.time() - self._last_train_time > 24 * 3600: logger.info("Model is older than 24h. Retraining...") self.train_model(features) self._last_train_time = time.time() def train_model(self, features: pd.DataFrame) -> None: """ Train the Random Forest model on historical data. Args: features: DataFrame with calculated features """ logger.info(f"Training model on {len(features)} samples...") z_thresh = self.config.z_entry_threshold horizon = self.horizon profit_target = 0.005 # 0.5% profit threshold stop_loss_pct = self.config.stop_loss_pct # Calculate targets path-dependently spread = features['spread'].values z_score = features['z_score'].values n = len(spread) targets = np.zeros(n, dtype=int) candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0] for i in candidates: if i + horizon >= n: continue entry_price = spread[i] future_prices = spread[i+1 : i+1+horizon] if z_score[i] > z_thresh: # Short target_price = entry_price * (1 - profit_target) stop_price = entry_price * (1 + stop_loss_pct) hit_tp = future_prices <= target_price hit_sl = future_prices >= stop_price if not np.any(hit_tp): targets[i] = 0 elif not np.any(hit_sl): targets[i] = 1 else: first_tp_idx = np.argmax(hit_tp) first_sl_idx = np.argmax(hit_sl) if first_tp_idx < first_sl_idx: targets[i] = 1 else: targets[i] = 0 else: # Long target_price = entry_price * (1 + profit_target) stop_price = entry_price * (1 - stop_loss_pct) hit_tp = future_prices >= target_price hit_sl = future_prices <= stop_price if not np.any(hit_tp): targets[i] = 0 elif not np.any(hit_sl): targets[i] = 1 else: first_tp_idx = np.argmax(hit_tp) first_sl_idx = np.argmax(hit_sl) if first_tp_idx < first_sl_idx: targets[i] = 1 else: targets[i] = 0 # Exclude non-feature columns exclude = ['spread', 'btc_close', 'eth_close', 'eth_volume'] self.feature_cols = [c for c in features.columns if c not in exclude] # Clean features X = features[self.feature_cols].fillna(0) X = X.replace([np.inf, -np.inf], 0) # Use rows where we had enough data to look ahead valid_mask = np.zeros(n, dtype=bool) valid_mask[:n-horizon] = True X_clean = X[valid_mask] y_clean = targets[valid_mask] if len(X_clean) < 100: logger.warning("Not enough data to train model") return # Train model self.model = RandomForestClassifier( n_estimators=300, max_depth=5, min_samples_leaf=30, class_weight={0: 1, 1: 3}, random_state=42 ) self.model.fit(X_clean, y_clean) logger.info(f"Model trained on {len(X_clean)} samples") self.save_model() def generate_signal( self, features: pd.DataFrame, current_funding: dict, position_side: Optional[str] = None ) -> dict: """ Generate trading signal from latest features. Args: features: DataFrame with calculated features current_funding: Dictionary with funding rate data position_side: Current position side ('long', 'short', or None) Returns: Signal dictionary with action, side, confidence, etc. """ # Check if retraining is needed self.check_retrain(features) if self.model is None: # Train model if not available if len(features) >= 200: self.train_model(features) else: return {'action': 'hold', 'reason': 'model_not_trained'} if self.model is None: return {'action': 'hold', 'reason': 'insufficient_data_for_training'} # Get latest row latest = features.iloc[-1] z_score = latest['z_score'] eth_price = latest['eth_close'] btc_price = latest['btc_close'] # Prepare features for prediction X = features[self.feature_cols].iloc[[-1]].fillna(0) X = X.replace([np.inf, -np.inf], 0) # Get prediction probability prob = self.model.predict_proba(X)[0, 1] # Apply thresholds z_thresh = self.config.z_entry_threshold prob_thresh = self.config.model_prob_threshold # Determine signal direction signal = { 'action': 'hold', 'side': None, 'probability': prob, 'z_score': z_score, 'eth_price': eth_price, 'btc_price': btc_price, 'reason': '', } # Check for entry conditions if prob > prob_thresh: if z_score > z_thresh: # Spread high (ETH expensive relative to BTC) -> Short ETH signal['action'] = 'entry' signal['side'] = 'short' signal['reason'] = f'z_score={z_score:.2f}>threshold, prob={prob:.2f}' elif z_score < -z_thresh: # Spread low (ETH cheap relative to BTC) -> Long ETH signal['action'] = 'entry' signal['side'] = 'long' signal['reason'] = f'z_score={z_score:.2f}<-threshold, prob={prob:.2f}' else: signal['reason'] = f'z_score={z_score:.2f} within threshold' else: signal['reason'] = f'prob={prob:.2f} funding_thresh: # High positive funding = overheated, don't go long signal['action'] = 'hold' signal['reason'] = f'funding_filter_blocked_long (funding={btc_funding:.4f})' elif signal['side'] == 'short' and btc_funding < -funding_thresh: # High negative funding = oversold, don't go short signal['action'] = 'hold' signal['reason'] = f'funding_filter_blocked_short (funding={btc_funding:.4f})' # Check for exit conditions (Overshoot Logic) if signal['action'] == 'hold' and position_side: # Overshoot Logic # If Long, exit if Z > 0.5 (Reverted past 0 to +0.5) if position_side == 'long' and z_score > 0.5: signal['action'] = 'check_exit' signal['reason'] = f'overshoot_exit_long (z={z_score:.2f} > 0.5)' # If Short, exit if Z < -0.5 (Reverted past 0 to -0.5) elif position_side == 'short' and z_score < -0.5: signal['action'] = 'check_exit' signal['reason'] = f'overshoot_exit_short (z={z_score:.2f} < -0.5)' logger.info( f"Signal: {signal['action']} {signal['side'] or ''} " f"(prob={prob:.2f}, z={z_score:.2f}, reason={signal['reason']})" ) return signal def calculate_position_size( self, signal: dict, available_usdt: float ) -> float: """ Calculate position size based on signal confidence. Args: signal: Signal dictionary with probability available_usdt: Available USDT balance Returns: Position size in USDT """ prob = signal.get('probability', 0.5) # Base size: if max_position_usdt <= 0, use all available funds if self.config.max_position_usdt <= 0: base_size = available_usdt else: base_size = min(available_usdt, self.config.max_position_usdt) # Scale by probability (1.0x at 0.5 prob, up to 1.6x at 0.8 prob) scale = 1.0 + (prob - 0.5) * 2.0 scale = max(1.0, min(scale, 2.0)) # Clamp between 1x and 2x size = base_size * scale # Ensure minimum position size if size < self.config.min_position_usdt: return 0.0 return min(size, available_usdt * 0.95) # Leave 5% buffer def calculate_sl_tp( self, entry_price: Optional[float], side: str ) -> tuple[Optional[float], Optional[float]]: """ Calculate stop-loss and take-profit prices. Args: entry_price: Entry price side: "long" or "short" Returns: Tuple of (stop_loss_price, take_profit_price), or (None, None) if entry_price is invalid Raises: ValueError: If side is not "long" or "short" """ if entry_price is None or entry_price <= 0: logger.error( f"Invalid entry_price for SL/TP calculation: {entry_price}" ) return None, None if side not in ("long", "short"): raise ValueError(f"Invalid side: {side}. Must be 'long' or 'short'") sl_pct = self.config.stop_loss_pct tp_pct = self.config.take_profit_pct if side == "long": stop_loss = entry_price * (1 - sl_pct) take_profit = entry_price * (1 + tp_pct) else: # short stop_loss = entry_price * (1 + sl_pct) take_profit = entry_price * (1 - tp_pct) return stop_loss, take_profit