Remove deprecated training scripts and Systemd service files

- Deleted `install_cron.sh`, `setup_schedule.sh`, and `train_daily.sh` as part of the transition to a new scheduling mechanism.
- Removed associated Systemd service and timer files for daily model training.
- Updated `live_regime_strategy.py` and `main.py` to reflect changes in model training and scheduling logic.
- Adjusted `regime_strategy.py` to align with new target calculation methods and updated optimal parameters.
- Enhanced `regime_detection.py` to incorporate path-dependent labeling for target calculations.
This commit is contained in:
2026-01-18 14:35:46 +08:00
parent b5550f4ff4
commit 582a43cd4a
10 changed files with 285 additions and 638 deletions

View File

@@ -1,29 +0,0 @@
#!/bin/bash
# Install cron job for daily model training
# Runs daily at 00:30
PROJECT_DIR="/home/tamaya/Documents/Work/TCP/lowkey_backtest_live"
SCRIPT_PATH="$PROJECT_DIR/train_daily.sh"
LOG_PATH="$PROJECT_DIR/logs/training.log"
# Check if script exists
if [ ! -f "$SCRIPT_PATH" ]; then
echo "Error: $SCRIPT_PATH not found!"
exit 1
fi
# Make executable
chmod +x "$SCRIPT_PATH"
# Prepare cron entry
# 30 0 * * * = 00:30 daily
CRON_CMD="30 0 * * * $SCRIPT_PATH >> $LOG_PATH 2>&1"
# Check if job already exists
(crontab -l 2>/dev/null | grep -F "$SCRIPT_PATH") && echo "Cron job already exists." && exit 0
# Add to crontab
(crontab -l 2>/dev/null; echo "$CRON_CMD") | crontab -
echo "Cron job installed successfully:"
echo "$CRON_CMD"

View File

@@ -6,6 +6,7 @@ Uses a pre-trained ML model or trains on historical data.
"""
import logging
import pickle
import time
from pathlib import Path
from typing import Optional
@@ -39,8 +40,9 @@ class LiveRegimeStrategy:
self.paths = path_config
self.model: Optional[RandomForestClassifier] = None
self.feature_cols: Optional[list] = None
self.horizon: int = 102 # Default horizon
self.horizon: int = 54 # Default horizon
self._last_model_load_time: float = 0.0
self._last_train_time: float = 0.0
self._load_or_train_model()
def reload_model_if_changed(self) -> None:
@@ -72,6 +74,13 @@ class LiveRegimeStrategy:
logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})")
else:
logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})")
# Load timestamp if available
if 'timestamp' in saved:
self._last_train_time = saved['timestamp']
else:
self._last_train_time = self._last_model_load_time
return
except Exception as e:
logger.warning(f"Could not load model: {e}")
@@ -88,12 +97,20 @@ class LiveRegimeStrategy:
pickle.dump({
'model': self.model,
'feature_cols': self.feature_cols,
'metrics': {'horizon': self.horizon} # Save horizon
'metrics': {'horizon': self.horizon}, # Save horizon
'timestamp': time.time()
}, f)
logger.info(f"Saved model to {self.paths.model_path}")
except Exception as e:
logger.error(f"Could not save model: {e}")
def check_retrain(self, features: pd.DataFrame) -> None:
"""Check if model needs retraining (older than 24h)."""
if time.time() - self._last_train_time > 24 * 3600:
logger.info("Model is older than 24h. Retraining...")
self.train_model(features)
self._last_train_time = time.time()
def train_model(self, features: pd.DataFrame) -> None:
"""
Train the Random Forest model on historical data.
@@ -106,18 +123,61 @@ class LiveRegimeStrategy:
z_thresh = self.config.z_entry_threshold
horizon = self.horizon
profit_target = 0.005 # 0.5% profit threshold
stop_loss_pct = self.config.stop_loss_pct
# Define targets
future_min = features['spread'].rolling(window=horizon).min().shift(-horizon)
future_max = features['spread'].rolling(window=horizon).max().shift(-horizon)
# Calculate targets path-dependently
spread = features['spread'].values
z_score = features['z_score'].values
n = len(spread)
target_short = features['spread'] * (1 - profit_target)
target_long = features['spread'] * (1 + profit_target)
targets = np.zeros(n, dtype=int)
success_short = (features['z_score'] > z_thresh) & (future_min < target_short)
success_long = (features['z_score'] < -z_thresh) & (future_max > target_long)
candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0]
targets = np.select([success_short, success_long], [1, 1], default=0)
for i in candidates:
if i + horizon >= n:
continue
entry_price = spread[i]
future_prices = spread[i+1 : i+1+horizon]
if z_score[i] > z_thresh: # Short
target_price = entry_price * (1 - profit_target)
stop_price = entry_price * (1 + stop_loss_pct)
hit_tp = future_prices <= target_price
hit_sl = future_prices >= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
else: # Long
target_price = entry_price * (1 + profit_target)
stop_price = entry_price * (1 - stop_loss_pct)
hit_tp = future_prices >= target_price
hit_sl = future_prices <= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
# Exclude non-feature columns
exclude = ['spread', 'btc_close', 'eth_close', 'eth_volume']
@@ -127,8 +187,10 @@ class LiveRegimeStrategy:
X = features[self.feature_cols].fillna(0)
X = X.replace([np.inf, -np.inf], 0)
# Remove rows with invalid targets
valid_mask = ~np.isnan(targets) & future_min.notna().values & future_max.notna().values
# Use rows where we had enough data to look ahead
valid_mask = np.zeros(n, dtype=bool)
valid_mask[:n-horizon] = True
X_clean = X[valid_mask]
y_clean = targets[valid_mask]
@@ -152,7 +214,8 @@ class LiveRegimeStrategy:
def generate_signal(
self,
features: pd.DataFrame,
current_funding: dict
current_funding: dict,
position_side: Optional[str] = None
) -> dict:
"""
Generate trading signal from latest features.
@@ -160,10 +223,14 @@ class LiveRegimeStrategy:
Args:
features: DataFrame with calculated features
current_funding: Dictionary with funding rate data
position_side: Current position side ('long', 'short', or None)
Returns:
Signal dictionary with action, side, confidence, etc.
"""
# Check if retraining is needed
self.check_retrain(features)
if self.model is None:
# Train model if not available
if len(features) >= 200:
@@ -233,12 +300,17 @@ class LiveRegimeStrategy:
signal['action'] = 'hold'
signal['reason'] = f'funding_filter_blocked_short (funding={btc_funding:.4f})'
# Check for exit conditions (mean reversion complete)
if signal['action'] == 'hold':
# Z-score crossed back through 0
if abs(z_score) < 0.3:
# Check for exit conditions (Overshoot Logic)
if signal['action'] == 'hold' and position_side:
# Overshoot Logic
# If Long, exit if Z > 0.5 (Reverted past 0 to +0.5)
if position_side == 'long' and z_score > 0.5:
signal['action'] = 'check_exit'
signal['reason'] = f'z_score_reverted_to_mean ({z_score:.2f})'
signal['reason'] = f'overshoot_exit_long (z={z_score:.2f} > 0.5)'
# If Short, exit if Z < -0.5 (Reverted past 0 to -0.5)
elif position_side == 'short' and z_score < -0.5:
signal['action'] = 'check_exit'
signal['reason'] = f'overshoot_exit_short (z={z_score:.2f} < -0.5)'
logger.info(
f"Signal: {signal['action']} {signal['side'] or ''} "

View File

@@ -206,11 +206,16 @@ class LiveTradingBot:
# 3. Sync with exchange positions
self.position_manager.sync_with_exchange()
# Get current position side for signal generation
symbol = self.trading_config.eth_symbol
position = self.position_manager.get_position_for_symbol(symbol)
position_side = position.side if position else None
# 4. Get current funding rates
funding = self.data_feed.get_current_funding_rates()
# 5. Generate trading signal
sig = self.strategy.generate_signal(features, funding)
sig = self.strategy.generate_signal(features, funding, position_side=position_side)
# 6. Update shared state with strategy info
self._update_strategy_state(sig, funding)

View File

@@ -32,6 +32,7 @@ logger = get_logger(__name__)
# Configuration
TRAIN_RATIO = 0.7 # 70% train, 30% test
PROFIT_THRESHOLD = 0.005 # 0.5% profit target
STOP_LOSS_PCT = 0.06 # 6% stop loss
Z_WINDOW = 24
FEE_RATE = 0.001 # 0.1% round-trip fee
DEFAULT_DAYS = 90 # Default lookback period in days
@@ -139,26 +140,74 @@ def calculate_features(df_btc, df_eth, cq_df=None):
def calculate_targets(features, horizon):
"""Calculate target labels for a given horizon."""
spread = features['spread']
z_score = features['z_score']
"""
Calculate target labels for a given horizon.
# For Short (Z > 1): Did spread drop below target?
future_min = spread.rolling(window=horizon).min().shift(-horizon)
target_short = spread * (1 - PROFIT_THRESHOLD)
success_short = (z_score > 1.0) & (future_min < target_short)
Uses path-dependent labeling: Success is hitting Profit Target BEFORE Stop Loss.
"""
spread = features['spread'].values
z_score = features['z_score'].values
n = len(spread)
# For Long (Z < -1): Did spread rise above target?
future_max = spread.rolling(window=horizon).max().shift(-horizon)
target_long = spread * (1 + PROFIT_THRESHOLD)
success_long = (z_score < -1.0) & (future_max > target_long)
targets = np.select([success_short, success_long], [1, 1], default=0)
targets = np.zeros(n, dtype=int)
# Create valid mask (rows with complete future data)
valid_mask = future_min.notna() & future_max.notna()
valid_mask = np.zeros(n, dtype=bool)
valid_mask[:n-horizon] = True
return targets, valid_mask, future_min, future_max
# Only iterate relevant rows for efficiency
candidates = np.where((z_score > 1.0) | (z_score < -1.0))[0]
for i in candidates:
if i + horizon >= n:
continue
entry_price = spread[i]
future_prices = spread[i+1 : i+1+horizon]
if z_score[i] > 1.0: # Short
target_price = entry_price * (1 - PROFIT_THRESHOLD)
stop_price = entry_price * (1 + STOP_LOSS_PCT)
# Identify first hit indices
hit_tp = future_prices <= target_price
hit_sl = future_prices >= stop_price
if not np.any(hit_tp):
targets[i] = 0 # Target never hit
elif not np.any(hit_sl):
targets[i] = 1 # Target hit, SL never hit
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
# Success if TP hit before SL
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
else: # Long
target_price = entry_price * (1 + PROFIT_THRESHOLD)
stop_price = entry_price * (1 - STOP_LOSS_PCT)
hit_tp = future_prices >= target_price
hit_sl = future_prices <= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
return targets, pd.Series(valid_mask, index=features.index), None, None
def calculate_mae(features, predictions, test_idx, horizon):
@@ -197,7 +246,7 @@ def calculate_mae(features, predictions, test_idx, horizon):
def calculate_net_profit(features, predictions, test_idx, horizon):
"""
Calculate estimated net profit including fees.
Enforces 'one trade at a time' to avoid inflating returns with overlapping signals.
Enforces 'one trade at a time' and simulates SL/TP exits.
"""
test_features = features.loc[test_idx]
spread = test_features['spread']
@@ -209,6 +258,9 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
# Track when we are free to trade again
next_trade_idx = 0
# Pre-calculate indices for speed
all_indices = features.index
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
# Skip if we are still in a trade
if i < next_trade_idx:
@@ -221,29 +273,76 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
z = z_score.loc[idx]
# Get future spread values
future_idx = features.index.get_loc(idx)
future_end = min(future_idx + horizon, len(features))
future_spreads = features['spread'].iloc[future_idx:future_end]
current_loc = features.index.get_loc(idx)
future_end_loc = min(current_loc + horizon, len(features))
future_spreads = features['spread'].iloc[current_loc+1 : future_end_loc]
if len(future_spreads) < 2:
if len(future_spreads) < 1:
continue
# Calculate PnL based on direction
if z > 1.0: # Short trade - profit if spread drops
exit_spread = future_spreads.iloc[-1] # Exit at horizon
pnl = (entry_spread - exit_spread) / entry_spread
else: # Long trade - profit if spread rises
exit_spread = future_spreads.iloc[-1]
pnl = (exit_spread - entry_spread) / entry_spread
pnl = 0.0
trade_duration = len(future_spreads)
if z > 1.0: # Short trade
tp_price = entry_spread * (1 - PROFIT_THRESHOLD)
sl_price = entry_spread * (1 + STOP_LOSS_PCT)
hit_tp = future_spreads <= tp_price
hit_sl = future_spreads >= sl_price
# Check what happened first
first_tp = np.argmax(hit_tp.values) if hit_tp.any() else 99999
first_sl = np.argmax(hit_sl.values) if hit_sl.any() else 99999
if first_sl < first_tp and first_sl < 99999:
# Stopped out
exit_price = future_spreads.iloc[first_sl] # Approx SL price
# Use exact SL price for realistic simulation? Or close
# Let's use the close price of the bar where it crossed
pnl = (entry_spread - exit_price) / entry_spread
trade_duration = first_sl + 1
elif first_tp < first_sl and first_tp < 99999:
# Take profit
exit_price = future_spreads.iloc[first_tp]
pnl = (entry_spread - exit_price) / entry_spread
trade_duration = first_tp + 1
else:
# Held to horizon
exit_price = future_spreads.iloc[-1]
pnl = (entry_spread - exit_price) / entry_spread
else: # Long trade
tp_price = entry_spread * (1 + PROFIT_THRESHOLD)
sl_price = entry_spread * (1 - STOP_LOSS_PCT)
hit_tp = future_spreads >= tp_price
hit_sl = future_spreads <= sl_price
first_tp = np.argmax(hit_tp.values) if hit_tp.any() else 99999
first_sl = np.argmax(hit_sl.values) if hit_sl.any() else 99999
if first_sl < first_tp and first_sl < 99999:
# Stopped out
exit_price = future_spreads.iloc[first_sl]
pnl = (exit_price - entry_spread) / entry_spread
trade_duration = first_sl + 1
elif first_tp < first_sl and first_tp < 99999:
# Take profit
exit_price = future_spreads.iloc[first_tp]
pnl = (exit_price - entry_spread) / entry_spread
trade_duration = first_tp + 1
else:
# Held to horizon
exit_price = future_spreads.iloc[-1]
pnl = (exit_price - entry_spread) / entry_spread
# Subtract fees
net_pnl = pnl - FEE_RATE
total_pnl += net_pnl
n_trades += 1
# Set next available trade index (simple non-overlapping logic)
# We assume we hold for 'horizon' bars
next_trade_idx = i + horizon
# Set next available trade index
next_trade_idx = i + trade_duration
return total_pnl, n_trades
@@ -321,7 +420,7 @@ def test_horizons(features, horizons):
print("\n" + "=" * 80)
print("WALK-FORWARD HORIZON OPTIMIZATION")
print(f"Train Ratio: {TRAIN_RATIO*100:.0f}% | Profit Target: {PROFIT_THRESHOLD*100:.1f}% | Fee Rate: {FEE_RATE*100:.2f}%")
print(f"Train Ratio: {TRAIN_RATIO*100:.0f}% | Profit Target: {PROFIT_THRESHOLD*100:.1f}% | Stop Loss: {STOP_LOSS_PCT*100:.1f}% | Fee Rate: {FEE_RATE*100:.2f}%")
print("=" * 80)
for h in horizons:

View File

@@ -1,16 +0,0 @@
#!/bin/bash
# Setup script for Systemd Timer (Daily Training)
SERVICE_FILE="tasks/lowkey-training.service"
TIMER_FILE="tasks/lowkey-training.timer"
SYSTEMD_DIR="/etc/systemd/system"
echo "To install the daily training schedule, please run the following commands:"
echo ""
echo "sudo cp $SERVICE_FILE $SYSTEMD_DIR/"
echo "sudo cp $TIMER_FILE $SYSTEMD_DIR/"
echo "sudo systemctl daemon-reload"
echo "sudo systemctl enable --now lowkey-training.timer"
echo ""
echo "To check the status:"
echo "systemctl list-timers --all | grep lowkey"

View File

@@ -30,7 +30,7 @@ class RegimeReversionStrategy(BaseStrategy):
# Optimal parameters from walk-forward research (2025-10 to 2025-12)
# Research: research/horizon_optimization_results.csv
OPTIMAL_HORIZON = 102 # 4.25 days - best Net PnL (+232%)
OPTIMAL_HORIZON = 54 # Updated from 102h based on corrected labeling
OPTIMAL_Z_WINDOW = 24 # 24h rolling window for spread Z-score
OPTIMAL_TRAIN_RATIO = 0.7 # 70% train / 30% test split
OPTIMAL_PROFIT_TARGET = 0.005 # 0.5% profit threshold for target definition
@@ -321,21 +321,64 @@ class RegimeReversionStrategy(BaseStrategy):
train_features: DataFrame containing features for training period only
"""
threshold = self.profit_target
stop_loss_pct = self.stop_loss
horizon = self.horizon
z_thresh = self.z_entry_threshold
# Define targets using ONLY training data
# For Short Spread (Z > threshold): Did spread drop below target within horizon?
future_min = train_features['spread'].rolling(window=horizon).min().shift(-horizon)
target_short = train_features['spread'] * (1 - threshold)
success_short = (train_features['z_score'] > z_thresh) & (future_min < target_short)
# Calculate targets path-dependently (checking SL before TP)
spread = train_features['spread'].values
z_score = train_features['z_score'].values
n = len(spread)
# For Long Spread (Z < -threshold): Did spread rise above target within horizon?
future_max = train_features['spread'].rolling(window=horizon).max().shift(-horizon)
target_long = train_features['spread'] * (1 + threshold)
success_long = (train_features['z_score'] < -z_thresh) & (future_max > target_long)
targets = np.zeros(n, dtype=int)
targets = np.select([success_short, success_long], [1, 1], default=0)
# Only iterate relevant rows for efficiency
candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0]
for i in candidates:
if i + horizon >= n:
continue
entry_price = spread[i]
future_prices = spread[i+1 : i+1+horizon]
if z_score[i] > z_thresh: # Short
target_price = entry_price * (1 - threshold)
stop_price = entry_price * (1 + stop_loss_pct)
hit_tp = future_prices <= target_price
hit_sl = future_prices >= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
else: # Long
target_price = entry_price * (1 + threshold)
stop_price = entry_price * (1 - stop_loss_pct)
hit_tp = future_prices >= target_price
hit_sl = future_prices <= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
# Build model
model = RandomForestClassifier(
@@ -351,10 +394,9 @@ class RegimeReversionStrategy(BaseStrategy):
X_train = train_features[cols].fillna(0)
X_train = X_train.replace([np.inf, -np.inf], 0)
# Remove rows with NaN targets (from rolling window at end of training period)
valid_mask = ~np.isnan(targets) & ~np.isinf(targets)
# Also check for rows where future data doesn't exist (shift created NaNs)
valid_mask = valid_mask & (future_min.notna().values) & (future_max.notna().values)
# Use rows where we had enough data to look ahead
valid_mask = np.zeros(n, dtype=bool)
valid_mask[:n-horizon] = True
X_train_clean = X_train[valid_mask]
targets_clean = targets[valid_mask]

View File

@@ -1,15 +0,0 @@
[Unit]
Description=Lowkey Backtest Daily Model Training
After=network.target
[Service]
Type=oneshot
WorkingDirectory=/home/tamaya/Documents/Work/TCP/lowkey_backtest_live
ExecStart=/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/train_daily.sh
User=tamaya
Group=tamaya
StandardOutput=append:/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/logs/training.log
StandardError=append:/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/logs/training.log
[Install]
WantedBy=multi-user.target

View File

@@ -1,10 +0,0 @@
[Unit]
Description=Run Lowkey Backtest Training Daily
[Timer]
OnCalendar=*-*-* 00:30:00
Persistent=true
Unit=lowkey-training.service
[Install]
WantedBy=timers.target

View File

@@ -1,50 +0,0 @@
#!/bin/bash
# Daily ML Model Training Script
#
# Downloads fresh data and retrains the regime detection model.
# Can be run manually or scheduled via cron.
#
# Usage:
# ./train_daily.sh # Full workflow
# ./train_daily.sh --skip-research # Skip research validation
set -e # Exit on error
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
LOG_DIR="logs"
mkdir -p "$LOG_DIR"
TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S")
echo "[$TIMESTAMP] Starting daily training..."
# 1. Download fresh data
echo "Downloading BTC-USDT 1h data..."
uv run python main.py download -p BTC-USDT -t 1h
echo "Downloading ETH-USDT 1h data..."
uv run python main.py download -p ETH-USDT -t 1h
# 2. Research optimization (find best horizon)
echo "Running research optimization..."
uv run python research/regime_detection.py --output-horizon data/optimal_horizon.txt
# 3. Read best horizon
if [[ -f "data/optimal_horizon.txt" ]]; then
BEST_HORIZON=$(cat data/optimal_horizon.txt)
echo "Found optimal horizon: ${BEST_HORIZON} bars"
else
BEST_HORIZON=102
echo "Warning: Could not find optimal horizon file. Using default: ${BEST_HORIZON}"
fi
# 4. Train model
echo "Training ML model with horizon ${BEST_HORIZON}..."
uv run python train_model.py --horizon "$BEST_HORIZON"
# 5. Cleanup
rm -f data/optimal_horizon.txt
TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S")
echo "[$TIMESTAMP] Daily training complete."

View File

@@ -1,451 +0,0 @@
"""
ML Model Training Script.
Trains the regime detection Random Forest model on historical data.
Can be run manually or scheduled via cron for daily retraining.
Usage:
uv run python train_model.py [options]
Options:
--days DAYS Number of days of historical data to use (default: 90)
--pair PAIR Trading pair for context (default: BTC-USDT)
--timeframe TF Timeframe (default: 1h)
--output PATH Output model path (default: data/regime_model.pkl)
--train-ratio R Train/test split ratio (default: 0.7)
--dry-run Run without saving model
"""
import argparse
import pickle
import sys
from datetime import datetime, timedelta
from pathlib import Path
import numpy as np
import pandas as pd
import ta
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, f1_score
from engine.data_manager import DataManager
from engine.market import MarketType
from engine.logging_config import get_logger
logger = get_logger(__name__)
# Default configuration (from research optimization)
DEFAULT_HORIZON = 102 # 4.25 days - optimal from research
DEFAULT_Z_WINDOW = 24 # 24h rolling window
DEFAULT_PROFIT_TARGET = 0.005 # 0.5% profit threshold
DEFAULT_Z_THRESHOLD = 1.0 # Z-score entry threshold
DEFAULT_TRAIN_RATIO = 0.7 # 70% train / 30% test
FEE_RATE = 0.001 # 0.1% round-trip fee
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Train the regime detection ML model"
)
parser.add_argument(
"--days",
type=int,
default=90,
help="Number of days of historical data to use (default: 90)"
)
parser.add_argument(
"--pair",
type=str,
default="BTC-USDT",
help="Base pair for context data (default: BTC-USDT)"
)
parser.add_argument(
"--spread-pair",
type=str,
default="ETH-USDT",
help="Spread pair to trade (default: ETH-USDT)"
)
parser.add_argument(
"--timeframe",
type=str,
default="1h",
help="Timeframe (default: 1h)"
)
parser.add_argument(
"--market",
type=str,
choices=["spot", "perpetual"],
default="perpetual",
help="Market type (default: perpetual)"
)
parser.add_argument(
"--output",
type=str,
default="data/regime_model.pkl",
help="Output model path (default: data/regime_model.pkl)"
)
parser.add_argument(
"--train-ratio",
type=float,
default=DEFAULT_TRAIN_RATIO,
help="Train/test split ratio (default: 0.7)"
)
parser.add_argument(
"--horizon",
type=int,
default=DEFAULT_HORIZON,
help="Prediction horizon in bars (default: 102)"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Run without saving model"
)
parser.add_argument(
"--download",
action="store_true",
help="Download latest data before training"
)
return parser.parse_args()
def download_data(dm: DataManager, pair: str, timeframe: str, market_type: MarketType):
"""Download latest data for a pair."""
logger.info(f"Downloading latest data for {pair}...")
try:
dm.download_data("okx", pair, timeframe, market_type)
logger.info(f"Downloaded {pair} data")
except Exception as e:
logger.error(f"Failed to download {pair}: {e}")
raise
def load_data(
dm: DataManager,
base_pair: str,
spread_pair: str,
timeframe: str,
market_type: MarketType,
days: int
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Load and align historical data for both pairs."""
df_base = dm.load_data("okx", base_pair, timeframe, market_type)
df_spread = dm.load_data("okx", spread_pair, timeframe, market_type)
# Filter to last N days
end_date = pd.Timestamp.now(tz="UTC")
start_date = end_date - timedelta(days=days)
df_base = df_base[(df_base.index >= start_date) & (df_base.index <= end_date)]
df_spread = df_spread[(df_spread.index >= start_date) & (df_spread.index <= end_date)]
# Align indices
common = df_base.index.intersection(df_spread.index)
df_base = df_base.loc[common]
df_spread = df_spread.loc[common]
logger.info(
f"Loaded {len(common)} bars from {common.min()} to {common.max()}"
)
return df_base, df_spread
def load_cryptoquant_data() -> pd.DataFrame | None:
"""Load CryptoQuant on-chain data if available."""
try:
cq_path = Path("data/cq_training_data.csv")
if not cq_path.exists():
logger.info("CryptoQuant data not found, skipping on-chain features")
return None
cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True)
if cq_df.index.tz is None:
cq_df.index = cq_df.index.tz_localize('UTC')
logger.info(f"Loaded CryptoQuant data: {len(cq_df)} rows")
return cq_df
except Exception as e:
logger.warning(f"Could not load CryptoQuant data: {e}")
return None
def calculate_features(
df_base: pd.DataFrame,
df_spread: pd.DataFrame,
cq_df: pd.DataFrame | None = None,
z_window: int = DEFAULT_Z_WINDOW
) -> pd.DataFrame:
"""Calculate all features for the model."""
spread = df_spread['close'] / df_base['close']
# Z-Score
rolling_mean = spread.rolling(window=z_window).mean()
rolling_std = spread.rolling(window=z_window).std()
z_score = (spread - rolling_mean) / rolling_std
# Technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
spread_change_1h = spread.pct_change(periods=1)
# Volume
vol_ratio = df_spread['volume'] / df_base['volume']
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
# Volatility
ret_base = df_base['close'].pct_change()
ret_spread = df_spread['close'].pct_change()
vol_base = ret_base.rolling(window=z_window).std()
vol_spread = ret_spread.rolling(window=z_window).std()
vol_spread_ratio = vol_spread / vol_base
features = pd.DataFrame(index=spread.index)
features['spread'] = spread
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['spread_change_1h'] = spread_change_1h
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
features['vol_diff_ratio'] = vol_spread_ratio
# Add CQ features if available
if cq_df is not None:
cq_aligned = cq_df.reindex(features.index, method='ffill')
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
cq_aligned['funding_diff'] = (
cq_aligned['eth_funding'] - cq_aligned['btc_funding']
)
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
cq_aligned['inflow_ratio'] = (
cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
)
features = features.join(cq_aligned)
return features.dropna()
def calculate_targets(
features: pd.DataFrame,
horizon: int,
profit_target: float = DEFAULT_PROFIT_TARGET,
z_threshold: float = DEFAULT_Z_THRESHOLD
) -> tuple[np.ndarray, pd.Series]:
"""Calculate target labels for training."""
spread = features['spread']
z_score = features['z_score']
# For Short (Z > threshold): Did spread drop below target?
future_min = spread.rolling(window=horizon).min().shift(-horizon)
target_short = spread * (1 - profit_target)
success_short = (z_score > z_threshold) & (future_min < target_short)
# For Long (Z < -threshold): Did spread rise above target?
future_max = spread.rolling(window=horizon).max().shift(-horizon)
target_long = spread * (1 + profit_target)
success_long = (z_score < -z_threshold) & (future_max > target_long)
targets = np.select([success_short, success_long], [1, 1], default=0)
# Create valid mask (rows with complete future data)
valid_mask = future_min.notna() & future_max.notna()
return targets, valid_mask
def train_model(
features: pd.DataFrame,
train_ratio: float = DEFAULT_TRAIN_RATIO,
horizon: int = DEFAULT_HORIZON
) -> tuple[RandomForestClassifier, list[str], dict]:
"""
Train Random Forest model with walk-forward split.
Args:
features: DataFrame with calculated features
train_ratio: Fraction of data to use for training
horizon: Prediction horizon in bars
Returns:
Tuple of (trained model, feature columns, metrics dict)
"""
# Calculate targets
targets, valid_mask = calculate_targets(features, horizon)
# Walk-forward split
n_samples = len(features)
train_size = int(n_samples * train_ratio)
train_features = features.iloc[:train_size]
test_features = features.iloc[train_size:]
train_targets = targets[:train_size]
test_targets = targets[train_size:]
train_valid = valid_mask.iloc[:train_size]
test_valid = valid_mask.iloc[train_size:]
# Prepare training data
exclude = ['spread']
feature_cols = [c for c in features.columns if c not in exclude]
X_train = train_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0)
X_train_valid = X_train[train_valid]
y_train_valid = train_targets[train_valid]
if len(X_train_valid) < 100:
raise ValueError(
f"Not enough training data: {len(X_train_valid)} samples (need >= 100)"
)
logger.info(f"Training on {len(X_train_valid)} samples...")
# Train model
model = RandomForestClassifier(
n_estimators=300,
max_depth=5,
min_samples_leaf=30,
class_weight={0: 1, 1: 3},
random_state=42
)
model.fit(X_train_valid, y_train_valid)
# Evaluate on test set
X_test = test_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0)
predictions = model.predict(X_test)
# Only evaluate on valid test rows
test_valid_mask = test_valid.values
y_test_valid = test_targets[test_valid_mask]
pred_valid = predictions[test_valid_mask]
# Calculate metrics
f1 = f1_score(y_test_valid, pred_valid, zero_division=0)
metrics = {
'train_samples': len(X_train_valid),
'test_samples': len(X_test),
'f1_score': f1,
'train_end': train_features.index[-1].isoformat(),
'test_start': test_features.index[0].isoformat(),
'horizon': horizon,
'feature_cols': feature_cols,
}
logger.info(f"Model trained. F1 Score: {f1:.3f}")
logger.info(
f"Train period: {train_features.index[0]} to {train_features.index[-1]}"
)
logger.info(
f"Test period: {test_features.index[0]} to {test_features.index[-1]}"
)
return model, feature_cols, metrics
def save_model(
model: RandomForestClassifier,
feature_cols: list[str],
metrics: dict,
output_path: str,
versioned: bool = True
):
"""
Save trained model to file.
Args:
model: Trained model
feature_cols: List of feature column names
metrics: Training metrics
output_path: Output file path
versioned: If True, also save a timestamped version
"""
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
data = {
'model': model,
'feature_cols': feature_cols,
'metrics': metrics,
'trained_at': datetime.now().isoformat(),
}
# Save main model file
with open(output, 'wb') as f:
pickle.dump(data, f)
logger.info(f"Saved model to {output}")
# Save versioned copy
if versioned:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
versioned_path = output.parent / f"regime_model_{timestamp}.pkl"
with open(versioned_path, 'wb') as f:
pickle.dump(data, f)
logger.info(f"Saved versioned model to {versioned_path}")
def main():
"""Main training function."""
args = parse_args()
market_type = MarketType.PERPETUAL if args.market == "perpetual" else MarketType.SPOT
dm = DataManager()
# Download latest data if requested
if args.download:
download_data(dm, args.pair, args.timeframe, market_type)
download_data(dm, args.spread_pair, args.timeframe, market_type)
# Load data
try:
df_base, df_spread = load_data(
dm, args.pair, args.spread_pair, args.timeframe, market_type, args.days
)
except Exception as e:
logger.error(f"Failed to load data: {e}")
logger.info("Try running with --download flag to fetch latest data")
sys.exit(1)
# Load on-chain data
cq_df = load_cryptoquant_data()
# Calculate features
features = calculate_features(df_base, df_spread, cq_df)
logger.info(
f"Calculated {len(features)} feature rows with {len(features.columns)} columns"
)
if len(features) < 200:
logger.error(f"Not enough data: {len(features)} rows (need >= 200)")
sys.exit(1)
# Train model
try:
model, feature_cols, metrics = train_model(
features, args.train_ratio, args.horizon
)
except ValueError as e:
logger.error(f"Training failed: {e}")
sys.exit(1)
# Print metrics summary
print("\n" + "=" * 60)
print("TRAINING COMPLETE")
print("=" * 60)
print(f"Train samples: {metrics['train_samples']}")
print(f"Test samples: {metrics['test_samples']}")
print(f"F1 Score: {metrics['f1_score']:.3f}")
print(f"Horizon: {metrics['horizon']} bars")
print(f"Features: {len(feature_cols)}")
print("=" * 60)
# Save model
if not args.dry_run:
save_model(model, feature_cols, metrics, args.output)
else:
logger.info("Dry run - model not saved")
return 0
if __name__ == "__main__":
sys.exit(main())