Files
lowkey_backtest/live_trading/live_regime_strategy.py
Simon Moisy 582a43cd4a Remove deprecated training scripts and Systemd service files
- Deleted `install_cron.sh`, `setup_schedule.sh`, and `train_daily.sh` as part of the transition to a new scheduling mechanism.
- Removed associated Systemd service and timer files for daily model training.
- Updated `live_regime_strategy.py` and `main.py` to reflect changes in model training and scheduling logic.
- Adjusted `regime_strategy.py` to align with new target calculation methods and updated optimal parameters.
- Enhanced `regime_detection.py` to incorporate path-dependent labeling for target calculations.
2026-01-18 14:35:46 +08:00

396 lines
14 KiB
Python

"""
Live Regime Reversion Strategy.
Adapts the backtest regime strategy for live trading.
Uses a pre-trained ML model or trains on historical data.
"""
import logging
import pickle
import time
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from .config import TradingConfig, PathConfig
logger = logging.getLogger(__name__)
class LiveRegimeStrategy:
"""
Live trading implementation of the ML-based regime detection
and mean reversion strategy.
Logic:
1. Calculates BTC/ETH spread Z-Score
2. Uses Random Forest to predict reversion probability
3. Applies funding rate filter
4. Generates long/short signals on ETH perpetual
"""
def __init__(
self,
trading_config: TradingConfig,
path_config: PathConfig
):
self.config = trading_config
self.paths = path_config
self.model: Optional[RandomForestClassifier] = None
self.feature_cols: Optional[list] = None
self.horizon: int = 54 # Default horizon
self._last_model_load_time: float = 0.0
self._last_train_time: float = 0.0
self._load_or_train_model()
def reload_model_if_changed(self) -> None:
"""Check if model file has changed and reload if necessary."""
if not self.paths.model_path.exists():
return
try:
mtime = self.paths.model_path.stat().st_mtime
if mtime > self._last_model_load_time:
logger.info(f"Model file changed, reloading... (last: {self._last_model_load_time}, new: {mtime})")
self._load_or_train_model()
except Exception as e:
logger.warning(f"Error checking model file: {e}")
def _load_or_train_model(self) -> None:
"""Load pre-trained model or train a new one."""
if self.paths.model_path.exists():
try:
self._last_model_load_time = self.paths.model_path.stat().st_mtime
with open(self.paths.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
# Load horizon from metrics if available
if 'metrics' in saved and 'horizon' in saved['metrics']:
self.horizon = saved['metrics']['horizon']
logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})")
else:
logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})")
# Load timestamp if available
if 'timestamp' in saved:
self._last_train_time = saved['timestamp']
else:
self._last_train_time = self._last_model_load_time
return
except Exception as e:
logger.warning(f"Could not load model: {e}")
logger.info("No pre-trained model found. Will train on first data batch.")
def save_model(self) -> None:
"""Save trained model to file."""
if self.model is None:
return
try:
with open(self.paths.model_path, 'wb') as f:
pickle.dump({
'model': self.model,
'feature_cols': self.feature_cols,
'metrics': {'horizon': self.horizon}, # Save horizon
'timestamp': time.time()
}, f)
logger.info(f"Saved model to {self.paths.model_path}")
except Exception as e:
logger.error(f"Could not save model: {e}")
def check_retrain(self, features: pd.DataFrame) -> None:
"""Check if model needs retraining (older than 24h)."""
if time.time() - self._last_train_time > 24 * 3600:
logger.info("Model is older than 24h. Retraining...")
self.train_model(features)
self._last_train_time = time.time()
def train_model(self, features: pd.DataFrame) -> None:
"""
Train the Random Forest model on historical data.
Args:
features: DataFrame with calculated features
"""
logger.info(f"Training model on {len(features)} samples...")
z_thresh = self.config.z_entry_threshold
horizon = self.horizon
profit_target = 0.005 # 0.5% profit threshold
stop_loss_pct = self.config.stop_loss_pct
# Calculate targets path-dependently
spread = features['spread'].values
z_score = features['z_score'].values
n = len(spread)
targets = np.zeros(n, dtype=int)
candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0]
for i in candidates:
if i + horizon >= n:
continue
entry_price = spread[i]
future_prices = spread[i+1 : i+1+horizon]
if z_score[i] > z_thresh: # Short
target_price = entry_price * (1 - profit_target)
stop_price = entry_price * (1 + stop_loss_pct)
hit_tp = future_prices <= target_price
hit_sl = future_prices >= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
else: # Long
target_price = entry_price * (1 + profit_target)
stop_price = entry_price * (1 - stop_loss_pct)
hit_tp = future_prices >= target_price
hit_sl = future_prices <= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
# Exclude non-feature columns
exclude = ['spread', 'btc_close', 'eth_close', 'eth_volume']
self.feature_cols = [c for c in features.columns if c not in exclude]
# Clean features
X = features[self.feature_cols].fillna(0)
X = X.replace([np.inf, -np.inf], 0)
# Use rows where we had enough data to look ahead
valid_mask = np.zeros(n, dtype=bool)
valid_mask[:n-horizon] = True
X_clean = X[valid_mask]
y_clean = targets[valid_mask]
if len(X_clean) < 100:
logger.warning("Not enough data to train model")
return
# Train model
self.model = RandomForestClassifier(
n_estimators=300,
max_depth=5,
min_samples_leaf=30,
class_weight={0: 1, 1: 3},
random_state=42
)
self.model.fit(X_clean, y_clean)
logger.info(f"Model trained on {len(X_clean)} samples")
self.save_model()
def generate_signal(
self,
features: pd.DataFrame,
current_funding: dict,
position_side: Optional[str] = None
) -> dict:
"""
Generate trading signal from latest features.
Args:
features: DataFrame with calculated features
current_funding: Dictionary with funding rate data
position_side: Current position side ('long', 'short', or None)
Returns:
Signal dictionary with action, side, confidence, etc.
"""
# Check if retraining is needed
self.check_retrain(features)
if self.model is None:
# Train model if not available
if len(features) >= 200:
self.train_model(features)
else:
return {'action': 'hold', 'reason': 'model_not_trained'}
if self.model is None:
return {'action': 'hold', 'reason': 'insufficient_data_for_training'}
# Get latest row
latest = features.iloc[-1]
z_score = latest['z_score']
eth_price = latest['eth_close']
btc_price = latest['btc_close']
# Prepare features for prediction
X = features[self.feature_cols].iloc[[-1]].fillna(0)
X = X.replace([np.inf, -np.inf], 0)
# Get prediction probability
prob = self.model.predict_proba(X)[0, 1]
# Apply thresholds
z_thresh = self.config.z_entry_threshold
prob_thresh = self.config.model_prob_threshold
# Determine signal direction
signal = {
'action': 'hold',
'side': None,
'probability': prob,
'z_score': z_score,
'eth_price': eth_price,
'btc_price': btc_price,
'reason': '',
}
# Check for entry conditions
if prob > prob_thresh:
if z_score > z_thresh:
# Spread high (ETH expensive relative to BTC) -> Short ETH
signal['action'] = 'entry'
signal['side'] = 'short'
signal['reason'] = f'z_score={z_score:.2f}>threshold, prob={prob:.2f}'
elif z_score < -z_thresh:
# Spread low (ETH cheap relative to BTC) -> Long ETH
signal['action'] = 'entry'
signal['side'] = 'long'
signal['reason'] = f'z_score={z_score:.2f}<-threshold, prob={prob:.2f}'
else:
signal['reason'] = f'z_score={z_score:.2f} within threshold'
else:
signal['reason'] = f'prob={prob:.2f}<threshold'
# Apply funding rate filter
if signal['action'] == 'entry':
btc_funding = current_funding.get('btc_funding', 0)
funding_thresh = self.config.funding_threshold
if signal['side'] == 'long' and btc_funding > funding_thresh:
# High positive funding = overheated, don't go long
signal['action'] = 'hold'
signal['reason'] = f'funding_filter_blocked_long (funding={btc_funding:.4f})'
elif signal['side'] == 'short' and btc_funding < -funding_thresh:
# High negative funding = oversold, don't go short
signal['action'] = 'hold'
signal['reason'] = f'funding_filter_blocked_short (funding={btc_funding:.4f})'
# Check for exit conditions (Overshoot Logic)
if signal['action'] == 'hold' and position_side:
# Overshoot Logic
# If Long, exit if Z > 0.5 (Reverted past 0 to +0.5)
if position_side == 'long' and z_score > 0.5:
signal['action'] = 'check_exit'
signal['reason'] = f'overshoot_exit_long (z={z_score:.2f} > 0.5)'
# If Short, exit if Z < -0.5 (Reverted past 0 to -0.5)
elif position_side == 'short' and z_score < -0.5:
signal['action'] = 'check_exit'
signal['reason'] = f'overshoot_exit_short (z={z_score:.2f} < -0.5)'
logger.info(
f"Signal: {signal['action']} {signal['side'] or ''} "
f"(prob={prob:.2f}, z={z_score:.2f}, reason={signal['reason']})"
)
return signal
def calculate_position_size(
self,
signal: dict,
available_usdt: float
) -> float:
"""
Calculate position size based on signal confidence.
Args:
signal: Signal dictionary with probability
available_usdt: Available USDT balance
Returns:
Position size in USDT
"""
prob = signal.get('probability', 0.5)
# Base size: if max_position_usdt <= 0, use all available funds
if self.config.max_position_usdt <= 0:
base_size = available_usdt
else:
base_size = min(available_usdt, self.config.max_position_usdt)
# Scale by probability (1.0x at 0.5 prob, up to 1.6x at 0.8 prob)
scale = 1.0 + (prob - 0.5) * 2.0
scale = max(1.0, min(scale, 2.0)) # Clamp between 1x and 2x
size = base_size * scale
# Ensure minimum position size
if size < self.config.min_position_usdt:
return 0.0
return min(size, available_usdt * 0.95) # Leave 5% buffer
def calculate_sl_tp(
self,
entry_price: Optional[float],
side: str
) -> tuple[Optional[float], Optional[float]]:
"""
Calculate stop-loss and take-profit prices.
Args:
entry_price: Entry price
side: "long" or "short"
Returns:
Tuple of (stop_loss_price, take_profit_price), or (None, None) if
entry_price is invalid
Raises:
ValueError: If side is not "long" or "short"
"""
if entry_price is None or entry_price <= 0:
logger.error(
f"Invalid entry_price for SL/TP calculation: {entry_price}"
)
return None, None
if side not in ("long", "short"):
raise ValueError(f"Invalid side: {side}. Must be 'long' or 'short'")
sl_pct = self.config.stop_loss_pct
tp_pct = self.config.take_profit_pct
if side == "long":
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else: # short
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
return stop_loss, take_profit