- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps. - Added `install_cron.sh` for setting up a cron job to run the daily training script. - Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks. - Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling. - Updated `pyproject.toml` to include the `rich` dependency for UI functionality. - Enhanced `.gitignore` to exclude model and log files. - Added database support for trade persistence and metrics calculation. - Updated README with installation and usage instructions for the new features.
452 lines
14 KiB
Python
452 lines
14 KiB
Python
"""
|
|
ML Model Training Script.
|
|
|
|
Trains the regime detection Random Forest model on historical data.
|
|
Can be run manually or scheduled via cron for daily retraining.
|
|
|
|
Usage:
|
|
uv run python train_model.py [options]
|
|
|
|
Options:
|
|
--days DAYS Number of days of historical data to use (default: 90)
|
|
--pair PAIR Trading pair for context (default: BTC-USDT)
|
|
--timeframe TF Timeframe (default: 1h)
|
|
--output PATH Output model path (default: data/regime_model.pkl)
|
|
--train-ratio R Train/test split ratio (default: 0.7)
|
|
--dry-run Run without saving model
|
|
"""
|
|
import argparse
|
|
import pickle
|
|
import sys
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import ta
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from sklearn.metrics import classification_report, f1_score
|
|
|
|
from engine.data_manager import DataManager
|
|
from engine.market import MarketType
|
|
from engine.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Default configuration (from research optimization)
|
|
DEFAULT_HORIZON = 102 # 4.25 days - optimal from research
|
|
DEFAULT_Z_WINDOW = 24 # 24h rolling window
|
|
DEFAULT_PROFIT_TARGET = 0.005 # 0.5% profit threshold
|
|
DEFAULT_Z_THRESHOLD = 1.0 # Z-score entry threshold
|
|
DEFAULT_TRAIN_RATIO = 0.7 # 70% train / 30% test
|
|
FEE_RATE = 0.001 # 0.1% round-trip fee
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Train the regime detection ML model"
|
|
)
|
|
parser.add_argument(
|
|
"--days",
|
|
type=int,
|
|
default=90,
|
|
help="Number of days of historical data to use (default: 90)"
|
|
)
|
|
parser.add_argument(
|
|
"--pair",
|
|
type=str,
|
|
default="BTC-USDT",
|
|
help="Base pair for context data (default: BTC-USDT)"
|
|
)
|
|
parser.add_argument(
|
|
"--spread-pair",
|
|
type=str,
|
|
default="ETH-USDT",
|
|
help="Spread pair to trade (default: ETH-USDT)"
|
|
)
|
|
parser.add_argument(
|
|
"--timeframe",
|
|
type=str,
|
|
default="1h",
|
|
help="Timeframe (default: 1h)"
|
|
)
|
|
parser.add_argument(
|
|
"--market",
|
|
type=str,
|
|
choices=["spot", "perpetual"],
|
|
default="perpetual",
|
|
help="Market type (default: perpetual)"
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default="data/regime_model.pkl",
|
|
help="Output model path (default: data/regime_model.pkl)"
|
|
)
|
|
parser.add_argument(
|
|
"--train-ratio",
|
|
type=float,
|
|
default=DEFAULT_TRAIN_RATIO,
|
|
help="Train/test split ratio (default: 0.7)"
|
|
)
|
|
parser.add_argument(
|
|
"--horizon",
|
|
type=int,
|
|
default=DEFAULT_HORIZON,
|
|
help="Prediction horizon in bars (default: 102)"
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Run without saving model"
|
|
)
|
|
parser.add_argument(
|
|
"--download",
|
|
action="store_true",
|
|
help="Download latest data before training"
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def download_data(dm: DataManager, pair: str, timeframe: str, market_type: MarketType):
|
|
"""Download latest data for a pair."""
|
|
logger.info(f"Downloading latest data for {pair}...")
|
|
try:
|
|
dm.download_data("okx", pair, timeframe, market_type)
|
|
logger.info(f"Downloaded {pair} data")
|
|
except Exception as e:
|
|
logger.error(f"Failed to download {pair}: {e}")
|
|
raise
|
|
|
|
|
|
def load_data(
|
|
dm: DataManager,
|
|
base_pair: str,
|
|
spread_pair: str,
|
|
timeframe: str,
|
|
market_type: MarketType,
|
|
days: int
|
|
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
"""Load and align historical data for both pairs."""
|
|
df_base = dm.load_data("okx", base_pair, timeframe, market_type)
|
|
df_spread = dm.load_data("okx", spread_pair, timeframe, market_type)
|
|
|
|
# Filter to last N days
|
|
end_date = pd.Timestamp.now(tz="UTC")
|
|
start_date = end_date - timedelta(days=days)
|
|
|
|
df_base = df_base[(df_base.index >= start_date) & (df_base.index <= end_date)]
|
|
df_spread = df_spread[(df_spread.index >= start_date) & (df_spread.index <= end_date)]
|
|
|
|
# Align indices
|
|
common = df_base.index.intersection(df_spread.index)
|
|
df_base = df_base.loc[common]
|
|
df_spread = df_spread.loc[common]
|
|
|
|
logger.info(
|
|
f"Loaded {len(common)} bars from {common.min()} to {common.max()}"
|
|
)
|
|
return df_base, df_spread
|
|
|
|
|
|
def load_cryptoquant_data() -> pd.DataFrame | None:
|
|
"""Load CryptoQuant on-chain data if available."""
|
|
try:
|
|
cq_path = Path("data/cq_training_data.csv")
|
|
if not cq_path.exists():
|
|
logger.info("CryptoQuant data not found, skipping on-chain features")
|
|
return None
|
|
|
|
cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True)
|
|
if cq_df.index.tz is None:
|
|
cq_df.index = cq_df.index.tz_localize('UTC')
|
|
logger.info(f"Loaded CryptoQuant data: {len(cq_df)} rows")
|
|
return cq_df
|
|
except Exception as e:
|
|
logger.warning(f"Could not load CryptoQuant data: {e}")
|
|
return None
|
|
|
|
|
|
def calculate_features(
|
|
df_base: pd.DataFrame,
|
|
df_spread: pd.DataFrame,
|
|
cq_df: pd.DataFrame | None = None,
|
|
z_window: int = DEFAULT_Z_WINDOW
|
|
) -> pd.DataFrame:
|
|
"""Calculate all features for the model."""
|
|
spread = df_spread['close'] / df_base['close']
|
|
|
|
# Z-Score
|
|
rolling_mean = spread.rolling(window=z_window).mean()
|
|
rolling_std = spread.rolling(window=z_window).std()
|
|
z_score = (spread - rolling_mean) / rolling_std
|
|
|
|
# Technicals
|
|
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
|
spread_roc = spread.pct_change(periods=5) * 100
|
|
spread_change_1h = spread.pct_change(periods=1)
|
|
|
|
# Volume
|
|
vol_ratio = df_spread['volume'] / df_base['volume']
|
|
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
|
|
|
# Volatility
|
|
ret_base = df_base['close'].pct_change()
|
|
ret_spread = df_spread['close'].pct_change()
|
|
vol_base = ret_base.rolling(window=z_window).std()
|
|
vol_spread = ret_spread.rolling(window=z_window).std()
|
|
vol_spread_ratio = vol_spread / vol_base
|
|
|
|
features = pd.DataFrame(index=spread.index)
|
|
features['spread'] = spread
|
|
features['z_score'] = z_score
|
|
features['spread_rsi'] = spread_rsi
|
|
features['spread_roc'] = spread_roc
|
|
features['spread_change_1h'] = spread_change_1h
|
|
features['vol_ratio'] = vol_ratio
|
|
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
|
|
features['vol_diff_ratio'] = vol_spread_ratio
|
|
|
|
# Add CQ features if available
|
|
if cq_df is not None:
|
|
cq_aligned = cq_df.reindex(features.index, method='ffill')
|
|
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
|
|
cq_aligned['funding_diff'] = (
|
|
cq_aligned['eth_funding'] - cq_aligned['btc_funding']
|
|
)
|
|
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
|
|
cq_aligned['inflow_ratio'] = (
|
|
cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
|
|
)
|
|
features = features.join(cq_aligned)
|
|
|
|
return features.dropna()
|
|
|
|
|
|
def calculate_targets(
|
|
features: pd.DataFrame,
|
|
horizon: int,
|
|
profit_target: float = DEFAULT_PROFIT_TARGET,
|
|
z_threshold: float = DEFAULT_Z_THRESHOLD
|
|
) -> tuple[np.ndarray, pd.Series]:
|
|
"""Calculate target labels for training."""
|
|
spread = features['spread']
|
|
z_score = features['z_score']
|
|
|
|
# For Short (Z > threshold): Did spread drop below target?
|
|
future_min = spread.rolling(window=horizon).min().shift(-horizon)
|
|
target_short = spread * (1 - profit_target)
|
|
success_short = (z_score > z_threshold) & (future_min < target_short)
|
|
|
|
# For Long (Z < -threshold): Did spread rise above target?
|
|
future_max = spread.rolling(window=horizon).max().shift(-horizon)
|
|
target_long = spread * (1 + profit_target)
|
|
success_long = (z_score < -z_threshold) & (future_max > target_long)
|
|
|
|
targets = np.select([success_short, success_long], [1, 1], default=0)
|
|
|
|
# Create valid mask (rows with complete future data)
|
|
valid_mask = future_min.notna() & future_max.notna()
|
|
|
|
return targets, valid_mask
|
|
|
|
|
|
def train_model(
|
|
features: pd.DataFrame,
|
|
train_ratio: float = DEFAULT_TRAIN_RATIO,
|
|
horizon: int = DEFAULT_HORIZON
|
|
) -> tuple[RandomForestClassifier, list[str], dict]:
|
|
"""
|
|
Train Random Forest model with walk-forward split.
|
|
|
|
Args:
|
|
features: DataFrame with calculated features
|
|
train_ratio: Fraction of data to use for training
|
|
horizon: Prediction horizon in bars
|
|
|
|
Returns:
|
|
Tuple of (trained model, feature columns, metrics dict)
|
|
"""
|
|
# Calculate targets
|
|
targets, valid_mask = calculate_targets(features, horizon)
|
|
|
|
# Walk-forward split
|
|
n_samples = len(features)
|
|
train_size = int(n_samples * train_ratio)
|
|
|
|
train_features = features.iloc[:train_size]
|
|
test_features = features.iloc[train_size:]
|
|
|
|
train_targets = targets[:train_size]
|
|
test_targets = targets[train_size:]
|
|
|
|
train_valid = valid_mask.iloc[:train_size]
|
|
test_valid = valid_mask.iloc[train_size:]
|
|
|
|
# Prepare training data
|
|
exclude = ['spread']
|
|
feature_cols = [c for c in features.columns if c not in exclude]
|
|
|
|
X_train = train_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0)
|
|
X_train_valid = X_train[train_valid]
|
|
y_train_valid = train_targets[train_valid]
|
|
|
|
if len(X_train_valid) < 100:
|
|
raise ValueError(
|
|
f"Not enough training data: {len(X_train_valid)} samples (need >= 100)"
|
|
)
|
|
|
|
logger.info(f"Training on {len(X_train_valid)} samples...")
|
|
|
|
# Train model
|
|
model = RandomForestClassifier(
|
|
n_estimators=300,
|
|
max_depth=5,
|
|
min_samples_leaf=30,
|
|
class_weight={0: 1, 1: 3},
|
|
random_state=42
|
|
)
|
|
model.fit(X_train_valid, y_train_valid)
|
|
|
|
# Evaluate on test set
|
|
X_test = test_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0)
|
|
predictions = model.predict(X_test)
|
|
|
|
# Only evaluate on valid test rows
|
|
test_valid_mask = test_valid.values
|
|
y_test_valid = test_targets[test_valid_mask]
|
|
pred_valid = predictions[test_valid_mask]
|
|
|
|
# Calculate metrics
|
|
f1 = f1_score(y_test_valid, pred_valid, zero_division=0)
|
|
|
|
metrics = {
|
|
'train_samples': len(X_train_valid),
|
|
'test_samples': len(X_test),
|
|
'f1_score': f1,
|
|
'train_end': train_features.index[-1].isoformat(),
|
|
'test_start': test_features.index[0].isoformat(),
|
|
'horizon': horizon,
|
|
'feature_cols': feature_cols,
|
|
}
|
|
|
|
logger.info(f"Model trained. F1 Score: {f1:.3f}")
|
|
logger.info(
|
|
f"Train period: {train_features.index[0]} to {train_features.index[-1]}"
|
|
)
|
|
logger.info(
|
|
f"Test period: {test_features.index[0]} to {test_features.index[-1]}"
|
|
)
|
|
|
|
return model, feature_cols, metrics
|
|
|
|
|
|
def save_model(
|
|
model: RandomForestClassifier,
|
|
feature_cols: list[str],
|
|
metrics: dict,
|
|
output_path: str,
|
|
versioned: bool = True
|
|
):
|
|
"""
|
|
Save trained model to file.
|
|
|
|
Args:
|
|
model: Trained model
|
|
feature_cols: List of feature column names
|
|
metrics: Training metrics
|
|
output_path: Output file path
|
|
versioned: If True, also save a timestamped version
|
|
"""
|
|
output = Path(output_path)
|
|
output.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
data = {
|
|
'model': model,
|
|
'feature_cols': feature_cols,
|
|
'metrics': metrics,
|
|
'trained_at': datetime.now().isoformat(),
|
|
}
|
|
|
|
# Save main model file
|
|
with open(output, 'wb') as f:
|
|
pickle.dump(data, f)
|
|
logger.info(f"Saved model to {output}")
|
|
|
|
# Save versioned copy
|
|
if versioned:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
versioned_path = output.parent / f"regime_model_{timestamp}.pkl"
|
|
with open(versioned_path, 'wb') as f:
|
|
pickle.dump(data, f)
|
|
logger.info(f"Saved versioned model to {versioned_path}")
|
|
|
|
|
|
def main():
|
|
"""Main training function."""
|
|
args = parse_args()
|
|
|
|
market_type = MarketType.PERPETUAL if args.market == "perpetual" else MarketType.SPOT
|
|
dm = DataManager()
|
|
|
|
# Download latest data if requested
|
|
if args.download:
|
|
download_data(dm, args.pair, args.timeframe, market_type)
|
|
download_data(dm, args.spread_pair, args.timeframe, market_type)
|
|
|
|
# Load data
|
|
try:
|
|
df_base, df_spread = load_data(
|
|
dm, args.pair, args.spread_pair, args.timeframe, market_type, args.days
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load data: {e}")
|
|
logger.info("Try running with --download flag to fetch latest data")
|
|
sys.exit(1)
|
|
|
|
# Load on-chain data
|
|
cq_df = load_cryptoquant_data()
|
|
|
|
# Calculate features
|
|
features = calculate_features(df_base, df_spread, cq_df)
|
|
logger.info(
|
|
f"Calculated {len(features)} feature rows with {len(features.columns)} columns"
|
|
)
|
|
|
|
if len(features) < 200:
|
|
logger.error(f"Not enough data: {len(features)} rows (need >= 200)")
|
|
sys.exit(1)
|
|
|
|
# Train model
|
|
try:
|
|
model, feature_cols, metrics = train_model(
|
|
features, args.train_ratio, args.horizon
|
|
)
|
|
except ValueError as e:
|
|
logger.error(f"Training failed: {e}")
|
|
sys.exit(1)
|
|
|
|
# Print metrics summary
|
|
print("\n" + "=" * 60)
|
|
print("TRAINING COMPLETE")
|
|
print("=" * 60)
|
|
print(f"Train samples: {metrics['train_samples']}")
|
|
print(f"Test samples: {metrics['test_samples']}")
|
|
print(f"F1 Score: {metrics['f1_score']:.3f}")
|
|
print(f"Horizon: {metrics['horizon']} bars")
|
|
print(f"Features: {len(feature_cols)}")
|
|
print("=" * 60)
|
|
|
|
# Save model
|
|
if not args.dry_run:
|
|
save_model(model, feature_cols, metrics, args.output)
|
|
else:
|
|
logger.info("Dry run - model not saved")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|