- Deleted `install_cron.sh`, `setup_schedule.sh`, and `train_daily.sh` as part of the transition to a new scheduling mechanism. - Removed associated Systemd service and timer files for daily model training. - Updated `live_regime_strategy.py` and `main.py` to reflect changes in model training and scheduling logic. - Adjusted `regime_strategy.py` to align with new target calculation methods and updated optimal parameters. - Enhanced `regime_detection.py` to incorporate path-dependent labeling for target calculations.
396 lines
14 KiB
Python
396 lines
14 KiB
Python
"""
|
|
Live Regime Reversion Strategy.
|
|
|
|
Adapts the backtest regime strategy for live trading.
|
|
Uses a pre-trained ML model or trains on historical data.
|
|
"""
|
|
import logging
|
|
import pickle
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
from .config import TradingConfig, PathConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LiveRegimeStrategy:
|
|
"""
|
|
Live trading implementation of the ML-based regime detection
|
|
and mean reversion strategy.
|
|
|
|
Logic:
|
|
1. Calculates BTC/ETH spread Z-Score
|
|
2. Uses Random Forest to predict reversion probability
|
|
3. Applies funding rate filter
|
|
4. Generates long/short signals on ETH perpetual
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
trading_config: TradingConfig,
|
|
path_config: PathConfig
|
|
):
|
|
self.config = trading_config
|
|
self.paths = path_config
|
|
self.model: Optional[RandomForestClassifier] = None
|
|
self.feature_cols: Optional[list] = None
|
|
self.horizon: int = 54 # Default horizon
|
|
self._last_model_load_time: float = 0.0
|
|
self._last_train_time: float = 0.0
|
|
self._load_or_train_model()
|
|
|
|
def reload_model_if_changed(self) -> None:
|
|
"""Check if model file has changed and reload if necessary."""
|
|
if not self.paths.model_path.exists():
|
|
return
|
|
|
|
try:
|
|
mtime = self.paths.model_path.stat().st_mtime
|
|
if mtime > self._last_model_load_time:
|
|
logger.info(f"Model file changed, reloading... (last: {self._last_model_load_time}, new: {mtime})")
|
|
self._load_or_train_model()
|
|
except Exception as e:
|
|
logger.warning(f"Error checking model file: {e}")
|
|
|
|
def _load_or_train_model(self) -> None:
|
|
"""Load pre-trained model or train a new one."""
|
|
if self.paths.model_path.exists():
|
|
try:
|
|
self._last_model_load_time = self.paths.model_path.stat().st_mtime
|
|
with open(self.paths.model_path, 'rb') as f:
|
|
saved = pickle.load(f)
|
|
self.model = saved['model']
|
|
self.feature_cols = saved['feature_cols']
|
|
|
|
# Load horizon from metrics if available
|
|
if 'metrics' in saved and 'horizon' in saved['metrics']:
|
|
self.horizon = saved['metrics']['horizon']
|
|
logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})")
|
|
else:
|
|
logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})")
|
|
|
|
# Load timestamp if available
|
|
if 'timestamp' in saved:
|
|
self._last_train_time = saved['timestamp']
|
|
else:
|
|
self._last_train_time = self._last_model_load_time
|
|
|
|
return
|
|
except Exception as e:
|
|
logger.warning(f"Could not load model: {e}")
|
|
|
|
logger.info("No pre-trained model found. Will train on first data batch.")
|
|
|
|
def save_model(self) -> None:
|
|
"""Save trained model to file."""
|
|
if self.model is None:
|
|
return
|
|
|
|
try:
|
|
with open(self.paths.model_path, 'wb') as f:
|
|
pickle.dump({
|
|
'model': self.model,
|
|
'feature_cols': self.feature_cols,
|
|
'metrics': {'horizon': self.horizon}, # Save horizon
|
|
'timestamp': time.time()
|
|
}, f)
|
|
logger.info(f"Saved model to {self.paths.model_path}")
|
|
except Exception as e:
|
|
logger.error(f"Could not save model: {e}")
|
|
|
|
def check_retrain(self, features: pd.DataFrame) -> None:
|
|
"""Check if model needs retraining (older than 24h)."""
|
|
if time.time() - self._last_train_time > 24 * 3600:
|
|
logger.info("Model is older than 24h. Retraining...")
|
|
self.train_model(features)
|
|
self._last_train_time = time.time()
|
|
|
|
def train_model(self, features: pd.DataFrame) -> None:
|
|
"""
|
|
Train the Random Forest model on historical data.
|
|
|
|
Args:
|
|
features: DataFrame with calculated features
|
|
"""
|
|
logger.info(f"Training model on {len(features)} samples...")
|
|
|
|
z_thresh = self.config.z_entry_threshold
|
|
horizon = self.horizon
|
|
profit_target = 0.005 # 0.5% profit threshold
|
|
stop_loss_pct = self.config.stop_loss_pct
|
|
|
|
# Calculate targets path-dependently
|
|
spread = features['spread'].values
|
|
z_score = features['z_score'].values
|
|
n = len(spread)
|
|
|
|
targets = np.zeros(n, dtype=int)
|
|
|
|
candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0]
|
|
|
|
for i in candidates:
|
|
if i + horizon >= n:
|
|
continue
|
|
|
|
entry_price = spread[i]
|
|
future_prices = spread[i+1 : i+1+horizon]
|
|
|
|
if z_score[i] > z_thresh: # Short
|
|
target_price = entry_price * (1 - profit_target)
|
|
stop_price = entry_price * (1 + stop_loss_pct)
|
|
|
|
hit_tp = future_prices <= target_price
|
|
hit_sl = future_prices >= stop_price
|
|
|
|
if not np.any(hit_tp):
|
|
targets[i] = 0
|
|
elif not np.any(hit_sl):
|
|
targets[i] = 1
|
|
else:
|
|
first_tp_idx = np.argmax(hit_tp)
|
|
first_sl_idx = np.argmax(hit_sl)
|
|
if first_tp_idx < first_sl_idx:
|
|
targets[i] = 1
|
|
else:
|
|
targets[i] = 0
|
|
|
|
else: # Long
|
|
target_price = entry_price * (1 + profit_target)
|
|
stop_price = entry_price * (1 - stop_loss_pct)
|
|
|
|
hit_tp = future_prices >= target_price
|
|
hit_sl = future_prices <= stop_price
|
|
|
|
if not np.any(hit_tp):
|
|
targets[i] = 0
|
|
elif not np.any(hit_sl):
|
|
targets[i] = 1
|
|
else:
|
|
first_tp_idx = np.argmax(hit_tp)
|
|
first_sl_idx = np.argmax(hit_sl)
|
|
if first_tp_idx < first_sl_idx:
|
|
targets[i] = 1
|
|
else:
|
|
targets[i] = 0
|
|
|
|
# Exclude non-feature columns
|
|
exclude = ['spread', 'btc_close', 'eth_close', 'eth_volume']
|
|
self.feature_cols = [c for c in features.columns if c not in exclude]
|
|
|
|
# Clean features
|
|
X = features[self.feature_cols].fillna(0)
|
|
X = X.replace([np.inf, -np.inf], 0)
|
|
|
|
# Use rows where we had enough data to look ahead
|
|
valid_mask = np.zeros(n, dtype=bool)
|
|
valid_mask[:n-horizon] = True
|
|
|
|
X_clean = X[valid_mask]
|
|
y_clean = targets[valid_mask]
|
|
|
|
if len(X_clean) < 100:
|
|
logger.warning("Not enough data to train model")
|
|
return
|
|
|
|
# Train model
|
|
self.model = RandomForestClassifier(
|
|
n_estimators=300,
|
|
max_depth=5,
|
|
min_samples_leaf=30,
|
|
class_weight={0: 1, 1: 3},
|
|
random_state=42
|
|
)
|
|
self.model.fit(X_clean, y_clean)
|
|
|
|
logger.info(f"Model trained on {len(X_clean)} samples")
|
|
self.save_model()
|
|
|
|
def generate_signal(
|
|
self,
|
|
features: pd.DataFrame,
|
|
current_funding: dict,
|
|
position_side: Optional[str] = None
|
|
) -> dict:
|
|
"""
|
|
Generate trading signal from latest features.
|
|
|
|
Args:
|
|
features: DataFrame with calculated features
|
|
current_funding: Dictionary with funding rate data
|
|
position_side: Current position side ('long', 'short', or None)
|
|
|
|
Returns:
|
|
Signal dictionary with action, side, confidence, etc.
|
|
"""
|
|
# Check if retraining is needed
|
|
self.check_retrain(features)
|
|
|
|
if self.model is None:
|
|
# Train model if not available
|
|
if len(features) >= 200:
|
|
self.train_model(features)
|
|
else:
|
|
return {'action': 'hold', 'reason': 'model_not_trained'}
|
|
|
|
if self.model is None:
|
|
return {'action': 'hold', 'reason': 'insufficient_data_for_training'}
|
|
|
|
# Get latest row
|
|
latest = features.iloc[-1]
|
|
z_score = latest['z_score']
|
|
eth_price = latest['eth_close']
|
|
btc_price = latest['btc_close']
|
|
|
|
# Prepare features for prediction
|
|
X = features[self.feature_cols].iloc[[-1]].fillna(0)
|
|
X = X.replace([np.inf, -np.inf], 0)
|
|
|
|
# Get prediction probability
|
|
prob = self.model.predict_proba(X)[0, 1]
|
|
|
|
# Apply thresholds
|
|
z_thresh = self.config.z_entry_threshold
|
|
prob_thresh = self.config.model_prob_threshold
|
|
|
|
# Determine signal direction
|
|
signal = {
|
|
'action': 'hold',
|
|
'side': None,
|
|
'probability': prob,
|
|
'z_score': z_score,
|
|
'eth_price': eth_price,
|
|
'btc_price': btc_price,
|
|
'reason': '',
|
|
}
|
|
|
|
# Check for entry conditions
|
|
if prob > prob_thresh:
|
|
if z_score > z_thresh:
|
|
# Spread high (ETH expensive relative to BTC) -> Short ETH
|
|
signal['action'] = 'entry'
|
|
signal['side'] = 'short'
|
|
signal['reason'] = f'z_score={z_score:.2f}>threshold, prob={prob:.2f}'
|
|
elif z_score < -z_thresh:
|
|
# Spread low (ETH cheap relative to BTC) -> Long ETH
|
|
signal['action'] = 'entry'
|
|
signal['side'] = 'long'
|
|
signal['reason'] = f'z_score={z_score:.2f}<-threshold, prob={prob:.2f}'
|
|
else:
|
|
signal['reason'] = f'z_score={z_score:.2f} within threshold'
|
|
else:
|
|
signal['reason'] = f'prob={prob:.2f}<threshold'
|
|
|
|
# Apply funding rate filter
|
|
if signal['action'] == 'entry':
|
|
btc_funding = current_funding.get('btc_funding', 0)
|
|
funding_thresh = self.config.funding_threshold
|
|
|
|
if signal['side'] == 'long' and btc_funding > funding_thresh:
|
|
# High positive funding = overheated, don't go long
|
|
signal['action'] = 'hold'
|
|
signal['reason'] = f'funding_filter_blocked_long (funding={btc_funding:.4f})'
|
|
elif signal['side'] == 'short' and btc_funding < -funding_thresh:
|
|
# High negative funding = oversold, don't go short
|
|
signal['action'] = 'hold'
|
|
signal['reason'] = f'funding_filter_blocked_short (funding={btc_funding:.4f})'
|
|
|
|
# Check for exit conditions (Overshoot Logic)
|
|
if signal['action'] == 'hold' and position_side:
|
|
# Overshoot Logic
|
|
# If Long, exit if Z > 0.5 (Reverted past 0 to +0.5)
|
|
if position_side == 'long' and z_score > 0.5:
|
|
signal['action'] = 'check_exit'
|
|
signal['reason'] = f'overshoot_exit_long (z={z_score:.2f} > 0.5)'
|
|
# If Short, exit if Z < -0.5 (Reverted past 0 to -0.5)
|
|
elif position_side == 'short' and z_score < -0.5:
|
|
signal['action'] = 'check_exit'
|
|
signal['reason'] = f'overshoot_exit_short (z={z_score:.2f} < -0.5)'
|
|
|
|
logger.info(
|
|
f"Signal: {signal['action']} {signal['side'] or ''} "
|
|
f"(prob={prob:.2f}, z={z_score:.2f}, reason={signal['reason']})"
|
|
)
|
|
|
|
return signal
|
|
|
|
def calculate_position_size(
|
|
self,
|
|
signal: dict,
|
|
available_usdt: float
|
|
) -> float:
|
|
"""
|
|
Calculate position size based on signal confidence.
|
|
|
|
Args:
|
|
signal: Signal dictionary with probability
|
|
available_usdt: Available USDT balance
|
|
|
|
Returns:
|
|
Position size in USDT
|
|
"""
|
|
prob = signal.get('probability', 0.5)
|
|
|
|
# Base size: if max_position_usdt <= 0, use all available funds
|
|
if self.config.max_position_usdt <= 0:
|
|
base_size = available_usdt
|
|
else:
|
|
base_size = min(available_usdt, self.config.max_position_usdt)
|
|
|
|
# Scale by probability (1.0x at 0.5 prob, up to 1.6x at 0.8 prob)
|
|
scale = 1.0 + (prob - 0.5) * 2.0
|
|
scale = max(1.0, min(scale, 2.0)) # Clamp between 1x and 2x
|
|
|
|
size = base_size * scale
|
|
|
|
# Ensure minimum position size
|
|
if size < self.config.min_position_usdt:
|
|
return 0.0
|
|
|
|
return min(size, available_usdt * 0.95) # Leave 5% buffer
|
|
|
|
def calculate_sl_tp(
|
|
self,
|
|
entry_price: Optional[float],
|
|
side: str
|
|
) -> tuple[Optional[float], Optional[float]]:
|
|
"""
|
|
Calculate stop-loss and take-profit prices.
|
|
|
|
Args:
|
|
entry_price: Entry price
|
|
side: "long" or "short"
|
|
|
|
Returns:
|
|
Tuple of (stop_loss_price, take_profit_price), or (None, None) if
|
|
entry_price is invalid
|
|
|
|
Raises:
|
|
ValueError: If side is not "long" or "short"
|
|
"""
|
|
if entry_price is None or entry_price <= 0:
|
|
logger.error(
|
|
f"Invalid entry_price for SL/TP calculation: {entry_price}"
|
|
)
|
|
return None, None
|
|
|
|
if side not in ("long", "short"):
|
|
raise ValueError(f"Invalid side: {side}. Must be 'long' or 'short'")
|
|
|
|
sl_pct = self.config.stop_loss_pct
|
|
tp_pct = self.config.take_profit_pct
|
|
|
|
if side == "long":
|
|
stop_loss = entry_price * (1 - sl_pct)
|
|
take_profit = entry_price * (1 + tp_pct)
|
|
else: # short
|
|
stop_loss = entry_price * (1 + sl_pct)
|
|
take_profit = entry_price * (1 - tp_pct)
|
|
|
|
return stop_loss, take_profit
|