Remove deprecated training scripts and Systemd service files
- Deleted `install_cron.sh`, `setup_schedule.sh`, and `train_daily.sh` as part of the transition to a new scheduling mechanism. - Removed associated Systemd service and timer files for daily model training. - Updated `live_regime_strategy.py` and `main.py` to reflect changes in model training and scheduling logic. - Adjusted `regime_strategy.py` to align with new target calculation methods and updated optimal parameters. - Enhanced `regime_detection.py` to incorporate path-dependent labeling for target calculations.
This commit is contained in:
@@ -1,29 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Install cron job for daily model training
|
||||
# Runs daily at 00:30
|
||||
|
||||
PROJECT_DIR="/home/tamaya/Documents/Work/TCP/lowkey_backtest_live"
|
||||
SCRIPT_PATH="$PROJECT_DIR/train_daily.sh"
|
||||
LOG_PATH="$PROJECT_DIR/logs/training.log"
|
||||
|
||||
# Check if script exists
|
||||
if [ ! -f "$SCRIPT_PATH" ]; then
|
||||
echo "Error: $SCRIPT_PATH not found!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Make executable
|
||||
chmod +x "$SCRIPT_PATH"
|
||||
|
||||
# Prepare cron entry
|
||||
# 30 0 * * * = 00:30 daily
|
||||
CRON_CMD="30 0 * * * $SCRIPT_PATH >> $LOG_PATH 2>&1"
|
||||
|
||||
# Check if job already exists
|
||||
(crontab -l 2>/dev/null | grep -F "$SCRIPT_PATH") && echo "Cron job already exists." && exit 0
|
||||
|
||||
# Add to crontab
|
||||
(crontab -l 2>/dev/null; echo "$CRON_CMD") | crontab -
|
||||
|
||||
echo "Cron job installed successfully:"
|
||||
echo "$CRON_CMD"
|
||||
@@ -6,6 +6,7 @@ Uses a pre-trained ML model or trains on historical data.
|
||||
"""
|
||||
import logging
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -39,8 +40,9 @@ class LiveRegimeStrategy:
|
||||
self.paths = path_config
|
||||
self.model: Optional[RandomForestClassifier] = None
|
||||
self.feature_cols: Optional[list] = None
|
||||
self.horizon: int = 102 # Default horizon
|
||||
self.horizon: int = 54 # Default horizon
|
||||
self._last_model_load_time: float = 0.0
|
||||
self._last_train_time: float = 0.0
|
||||
self._load_or_train_model()
|
||||
|
||||
def reload_model_if_changed(self) -> None:
|
||||
@@ -72,6 +74,13 @@ class LiveRegimeStrategy:
|
||||
logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})")
|
||||
else:
|
||||
logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})")
|
||||
|
||||
# Load timestamp if available
|
||||
if 'timestamp' in saved:
|
||||
self._last_train_time = saved['timestamp']
|
||||
else:
|
||||
self._last_train_time = self._last_model_load_time
|
||||
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load model: {e}")
|
||||
@@ -88,12 +97,20 @@ class LiveRegimeStrategy:
|
||||
pickle.dump({
|
||||
'model': self.model,
|
||||
'feature_cols': self.feature_cols,
|
||||
'metrics': {'horizon': self.horizon} # Save horizon
|
||||
'metrics': {'horizon': self.horizon}, # Save horizon
|
||||
'timestamp': time.time()
|
||||
}, f)
|
||||
logger.info(f"Saved model to {self.paths.model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Could not save model: {e}")
|
||||
|
||||
def check_retrain(self, features: pd.DataFrame) -> None:
|
||||
"""Check if model needs retraining (older than 24h)."""
|
||||
if time.time() - self._last_train_time > 24 * 3600:
|
||||
logger.info("Model is older than 24h. Retraining...")
|
||||
self.train_model(features)
|
||||
self._last_train_time = time.time()
|
||||
|
||||
def train_model(self, features: pd.DataFrame) -> None:
|
||||
"""
|
||||
Train the Random Forest model on historical data.
|
||||
@@ -106,18 +123,61 @@ class LiveRegimeStrategy:
|
||||
z_thresh = self.config.z_entry_threshold
|
||||
horizon = self.horizon
|
||||
profit_target = 0.005 # 0.5% profit threshold
|
||||
stop_loss_pct = self.config.stop_loss_pct
|
||||
|
||||
# Define targets
|
||||
future_min = features['spread'].rolling(window=horizon).min().shift(-horizon)
|
||||
future_max = features['spread'].rolling(window=horizon).max().shift(-horizon)
|
||||
# Calculate targets path-dependently
|
||||
spread = features['spread'].values
|
||||
z_score = features['z_score'].values
|
||||
n = len(spread)
|
||||
|
||||
target_short = features['spread'] * (1 - profit_target)
|
||||
target_long = features['spread'] * (1 + profit_target)
|
||||
targets = np.zeros(n, dtype=int)
|
||||
|
||||
success_short = (features['z_score'] > z_thresh) & (future_min < target_short)
|
||||
success_long = (features['z_score'] < -z_thresh) & (future_max > target_long)
|
||||
candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0]
|
||||
|
||||
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||
for i in candidates:
|
||||
if i + horizon >= n:
|
||||
continue
|
||||
|
||||
entry_price = spread[i]
|
||||
future_prices = spread[i+1 : i+1+horizon]
|
||||
|
||||
if z_score[i] > z_thresh: # Short
|
||||
target_price = entry_price * (1 - profit_target)
|
||||
stop_price = entry_price * (1 + stop_loss_pct)
|
||||
|
||||
hit_tp = future_prices <= target_price
|
||||
hit_sl = future_prices >= stop_price
|
||||
|
||||
if not np.any(hit_tp):
|
||||
targets[i] = 0
|
||||
elif not np.any(hit_sl):
|
||||
targets[i] = 1
|
||||
else:
|
||||
first_tp_idx = np.argmax(hit_tp)
|
||||
first_sl_idx = np.argmax(hit_sl)
|
||||
if first_tp_idx < first_sl_idx:
|
||||
targets[i] = 1
|
||||
else:
|
||||
targets[i] = 0
|
||||
|
||||
else: # Long
|
||||
target_price = entry_price * (1 + profit_target)
|
||||
stop_price = entry_price * (1 - stop_loss_pct)
|
||||
|
||||
hit_tp = future_prices >= target_price
|
||||
hit_sl = future_prices <= stop_price
|
||||
|
||||
if not np.any(hit_tp):
|
||||
targets[i] = 0
|
||||
elif not np.any(hit_sl):
|
||||
targets[i] = 1
|
||||
else:
|
||||
first_tp_idx = np.argmax(hit_tp)
|
||||
first_sl_idx = np.argmax(hit_sl)
|
||||
if first_tp_idx < first_sl_idx:
|
||||
targets[i] = 1
|
||||
else:
|
||||
targets[i] = 0
|
||||
|
||||
# Exclude non-feature columns
|
||||
exclude = ['spread', 'btc_close', 'eth_close', 'eth_volume']
|
||||
@@ -127,8 +187,10 @@ class LiveRegimeStrategy:
|
||||
X = features[self.feature_cols].fillna(0)
|
||||
X = X.replace([np.inf, -np.inf], 0)
|
||||
|
||||
# Remove rows with invalid targets
|
||||
valid_mask = ~np.isnan(targets) & future_min.notna().values & future_max.notna().values
|
||||
# Use rows where we had enough data to look ahead
|
||||
valid_mask = np.zeros(n, dtype=bool)
|
||||
valid_mask[:n-horizon] = True
|
||||
|
||||
X_clean = X[valid_mask]
|
||||
y_clean = targets[valid_mask]
|
||||
|
||||
@@ -152,7 +214,8 @@ class LiveRegimeStrategy:
|
||||
def generate_signal(
|
||||
self,
|
||||
features: pd.DataFrame,
|
||||
current_funding: dict
|
||||
current_funding: dict,
|
||||
position_side: Optional[str] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Generate trading signal from latest features.
|
||||
@@ -160,10 +223,14 @@ class LiveRegimeStrategy:
|
||||
Args:
|
||||
features: DataFrame with calculated features
|
||||
current_funding: Dictionary with funding rate data
|
||||
position_side: Current position side ('long', 'short', or None)
|
||||
|
||||
Returns:
|
||||
Signal dictionary with action, side, confidence, etc.
|
||||
"""
|
||||
# Check if retraining is needed
|
||||
self.check_retrain(features)
|
||||
|
||||
if self.model is None:
|
||||
# Train model if not available
|
||||
if len(features) >= 200:
|
||||
@@ -233,12 +300,17 @@ class LiveRegimeStrategy:
|
||||
signal['action'] = 'hold'
|
||||
signal['reason'] = f'funding_filter_blocked_short (funding={btc_funding:.4f})'
|
||||
|
||||
# Check for exit conditions (mean reversion complete)
|
||||
if signal['action'] == 'hold':
|
||||
# Z-score crossed back through 0
|
||||
if abs(z_score) < 0.3:
|
||||
# Check for exit conditions (Overshoot Logic)
|
||||
if signal['action'] == 'hold' and position_side:
|
||||
# Overshoot Logic
|
||||
# If Long, exit if Z > 0.5 (Reverted past 0 to +0.5)
|
||||
if position_side == 'long' and z_score > 0.5:
|
||||
signal['action'] = 'check_exit'
|
||||
signal['reason'] = f'z_score_reverted_to_mean ({z_score:.2f})'
|
||||
signal['reason'] = f'overshoot_exit_long (z={z_score:.2f} > 0.5)'
|
||||
# If Short, exit if Z < -0.5 (Reverted past 0 to -0.5)
|
||||
elif position_side == 'short' and z_score < -0.5:
|
||||
signal['action'] = 'check_exit'
|
||||
signal['reason'] = f'overshoot_exit_short (z={z_score:.2f} < -0.5)'
|
||||
|
||||
logger.info(
|
||||
f"Signal: {signal['action']} {signal['side'] or ''} "
|
||||
|
||||
@@ -206,11 +206,16 @@ class LiveTradingBot:
|
||||
# 3. Sync with exchange positions
|
||||
self.position_manager.sync_with_exchange()
|
||||
|
||||
# Get current position side for signal generation
|
||||
symbol = self.trading_config.eth_symbol
|
||||
position = self.position_manager.get_position_for_symbol(symbol)
|
||||
position_side = position.side if position else None
|
||||
|
||||
# 4. Get current funding rates
|
||||
funding = self.data_feed.get_current_funding_rates()
|
||||
|
||||
# 5. Generate trading signal
|
||||
sig = self.strategy.generate_signal(features, funding)
|
||||
sig = self.strategy.generate_signal(features, funding, position_side=position_side)
|
||||
|
||||
# 6. Update shared state with strategy info
|
||||
self._update_strategy_state(sig, funding)
|
||||
|
||||
@@ -32,6 +32,7 @@ logger = get_logger(__name__)
|
||||
# Configuration
|
||||
TRAIN_RATIO = 0.7 # 70% train, 30% test
|
||||
PROFIT_THRESHOLD = 0.005 # 0.5% profit target
|
||||
STOP_LOSS_PCT = 0.06 # 6% stop loss
|
||||
Z_WINDOW = 24
|
||||
FEE_RATE = 0.001 # 0.1% round-trip fee
|
||||
DEFAULT_DAYS = 90 # Default lookback period in days
|
||||
@@ -139,26 +140,74 @@ def calculate_features(df_btc, df_eth, cq_df=None):
|
||||
|
||||
|
||||
def calculate_targets(features, horizon):
|
||||
"""Calculate target labels for a given horizon."""
|
||||
spread = features['spread']
|
||||
z_score = features['z_score']
|
||||
"""
|
||||
Calculate target labels for a given horizon.
|
||||
|
||||
# For Short (Z > 1): Did spread drop below target?
|
||||
future_min = spread.rolling(window=horizon).min().shift(-horizon)
|
||||
target_short = spread * (1 - PROFIT_THRESHOLD)
|
||||
success_short = (z_score > 1.0) & (future_min < target_short)
|
||||
Uses path-dependent labeling: Success is hitting Profit Target BEFORE Stop Loss.
|
||||
"""
|
||||
spread = features['spread'].values
|
||||
z_score = features['z_score'].values
|
||||
n = len(spread)
|
||||
|
||||
# For Long (Z < -1): Did spread rise above target?
|
||||
future_max = spread.rolling(window=horizon).max().shift(-horizon)
|
||||
target_long = spread * (1 + PROFIT_THRESHOLD)
|
||||
success_long = (z_score < -1.0) & (future_max > target_long)
|
||||
|
||||
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||
targets = np.zeros(n, dtype=int)
|
||||
|
||||
# Create valid mask (rows with complete future data)
|
||||
valid_mask = future_min.notna() & future_max.notna()
|
||||
valid_mask = np.zeros(n, dtype=bool)
|
||||
valid_mask[:n-horizon] = True
|
||||
|
||||
return targets, valid_mask, future_min, future_max
|
||||
# Only iterate relevant rows for efficiency
|
||||
candidates = np.where((z_score > 1.0) | (z_score < -1.0))[0]
|
||||
|
||||
for i in candidates:
|
||||
if i + horizon >= n:
|
||||
continue
|
||||
|
||||
entry_price = spread[i]
|
||||
future_prices = spread[i+1 : i+1+horizon]
|
||||
|
||||
if z_score[i] > 1.0: # Short
|
||||
target_price = entry_price * (1 - PROFIT_THRESHOLD)
|
||||
stop_price = entry_price * (1 + STOP_LOSS_PCT)
|
||||
|
||||
# Identify first hit indices
|
||||
hit_tp = future_prices <= target_price
|
||||
hit_sl = future_prices >= stop_price
|
||||
|
||||
if not np.any(hit_tp):
|
||||
targets[i] = 0 # Target never hit
|
||||
elif not np.any(hit_sl):
|
||||
targets[i] = 1 # Target hit, SL never hit
|
||||
else:
|
||||
first_tp_idx = np.argmax(hit_tp)
|
||||
first_sl_idx = np.argmax(hit_sl)
|
||||
|
||||
# Success if TP hit before SL
|
||||
if first_tp_idx < first_sl_idx:
|
||||
targets[i] = 1
|
||||
else:
|
||||
targets[i] = 0
|
||||
|
||||
else: # Long
|
||||
target_price = entry_price * (1 + PROFIT_THRESHOLD)
|
||||
stop_price = entry_price * (1 - STOP_LOSS_PCT)
|
||||
|
||||
hit_tp = future_prices >= target_price
|
||||
hit_sl = future_prices <= stop_price
|
||||
|
||||
if not np.any(hit_tp):
|
||||
targets[i] = 0
|
||||
elif not np.any(hit_sl):
|
||||
targets[i] = 1
|
||||
else:
|
||||
first_tp_idx = np.argmax(hit_tp)
|
||||
first_sl_idx = np.argmax(hit_sl)
|
||||
|
||||
if first_tp_idx < first_sl_idx:
|
||||
targets[i] = 1
|
||||
else:
|
||||
targets[i] = 0
|
||||
|
||||
return targets, pd.Series(valid_mask, index=features.index), None, None
|
||||
|
||||
|
||||
def calculate_mae(features, predictions, test_idx, horizon):
|
||||
@@ -197,7 +246,7 @@ def calculate_mae(features, predictions, test_idx, horizon):
|
||||
def calculate_net_profit(features, predictions, test_idx, horizon):
|
||||
"""
|
||||
Calculate estimated net profit including fees.
|
||||
Enforces 'one trade at a time' to avoid inflating returns with overlapping signals.
|
||||
Enforces 'one trade at a time' and simulates SL/TP exits.
|
||||
"""
|
||||
test_features = features.loc[test_idx]
|
||||
spread = test_features['spread']
|
||||
@@ -209,6 +258,9 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
|
||||
# Track when we are free to trade again
|
||||
next_trade_idx = 0
|
||||
|
||||
# Pre-calculate indices for speed
|
||||
all_indices = features.index
|
||||
|
||||
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
|
||||
# Skip if we are still in a trade
|
||||
if i < next_trade_idx:
|
||||
@@ -221,29 +273,76 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
|
||||
z = z_score.loc[idx]
|
||||
|
||||
# Get future spread values
|
||||
future_idx = features.index.get_loc(idx)
|
||||
future_end = min(future_idx + horizon, len(features))
|
||||
future_spreads = features['spread'].iloc[future_idx:future_end]
|
||||
current_loc = features.index.get_loc(idx)
|
||||
future_end_loc = min(current_loc + horizon, len(features))
|
||||
future_spreads = features['spread'].iloc[current_loc+1 : future_end_loc]
|
||||
|
||||
if len(future_spreads) < 2:
|
||||
if len(future_spreads) < 1:
|
||||
continue
|
||||
|
||||
# Calculate PnL based on direction
|
||||
if z > 1.0: # Short trade - profit if spread drops
|
||||
exit_spread = future_spreads.iloc[-1] # Exit at horizon
|
||||
pnl = (entry_spread - exit_spread) / entry_spread
|
||||
else: # Long trade - profit if spread rises
|
||||
exit_spread = future_spreads.iloc[-1]
|
||||
pnl = (exit_spread - entry_spread) / entry_spread
|
||||
pnl = 0.0
|
||||
trade_duration = len(future_spreads)
|
||||
|
||||
if z > 1.0: # Short trade
|
||||
tp_price = entry_spread * (1 - PROFIT_THRESHOLD)
|
||||
sl_price = entry_spread * (1 + STOP_LOSS_PCT)
|
||||
|
||||
hit_tp = future_spreads <= tp_price
|
||||
hit_sl = future_spreads >= sl_price
|
||||
|
||||
# Check what happened first
|
||||
first_tp = np.argmax(hit_tp.values) if hit_tp.any() else 99999
|
||||
first_sl = np.argmax(hit_sl.values) if hit_sl.any() else 99999
|
||||
|
||||
if first_sl < first_tp and first_sl < 99999:
|
||||
# Stopped out
|
||||
exit_price = future_spreads.iloc[first_sl] # Approx SL price
|
||||
# Use exact SL price for realistic simulation? Or close
|
||||
# Let's use the close price of the bar where it crossed
|
||||
pnl = (entry_spread - exit_price) / entry_spread
|
||||
trade_duration = first_sl + 1
|
||||
elif first_tp < first_sl and first_tp < 99999:
|
||||
# Take profit
|
||||
exit_price = future_spreads.iloc[first_tp]
|
||||
pnl = (entry_spread - exit_price) / entry_spread
|
||||
trade_duration = first_tp + 1
|
||||
else:
|
||||
# Held to horizon
|
||||
exit_price = future_spreads.iloc[-1]
|
||||
pnl = (entry_spread - exit_price) / entry_spread
|
||||
|
||||
else: # Long trade
|
||||
tp_price = entry_spread * (1 + PROFIT_THRESHOLD)
|
||||
sl_price = entry_spread * (1 - STOP_LOSS_PCT)
|
||||
|
||||
hit_tp = future_spreads >= tp_price
|
||||
hit_sl = future_spreads <= sl_price
|
||||
|
||||
first_tp = np.argmax(hit_tp.values) if hit_tp.any() else 99999
|
||||
first_sl = np.argmax(hit_sl.values) if hit_sl.any() else 99999
|
||||
|
||||
if first_sl < first_tp and first_sl < 99999:
|
||||
# Stopped out
|
||||
exit_price = future_spreads.iloc[first_sl]
|
||||
pnl = (exit_price - entry_spread) / entry_spread
|
||||
trade_duration = first_sl + 1
|
||||
elif first_tp < first_sl and first_tp < 99999:
|
||||
# Take profit
|
||||
exit_price = future_spreads.iloc[first_tp]
|
||||
pnl = (exit_price - entry_spread) / entry_spread
|
||||
trade_duration = first_tp + 1
|
||||
else:
|
||||
# Held to horizon
|
||||
exit_price = future_spreads.iloc[-1]
|
||||
pnl = (exit_price - entry_spread) / entry_spread
|
||||
|
||||
# Subtract fees
|
||||
net_pnl = pnl - FEE_RATE
|
||||
total_pnl += net_pnl
|
||||
n_trades += 1
|
||||
|
||||
# Set next available trade index (simple non-overlapping logic)
|
||||
# We assume we hold for 'horizon' bars
|
||||
next_trade_idx = i + horizon
|
||||
# Set next available trade index
|
||||
next_trade_idx = i + trade_duration
|
||||
|
||||
return total_pnl, n_trades
|
||||
|
||||
@@ -321,7 +420,7 @@ def test_horizons(features, horizons):
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("WALK-FORWARD HORIZON OPTIMIZATION")
|
||||
print(f"Train Ratio: {TRAIN_RATIO*100:.0f}% | Profit Target: {PROFIT_THRESHOLD*100:.1f}% | Fee Rate: {FEE_RATE*100:.2f}%")
|
||||
print(f"Train Ratio: {TRAIN_RATIO*100:.0f}% | Profit Target: {PROFIT_THRESHOLD*100:.1f}% | Stop Loss: {STOP_LOSS_PCT*100:.1f}% | Fee Rate: {FEE_RATE*100:.2f}%")
|
||||
print("=" * 80)
|
||||
|
||||
for h in horizons:
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Setup script for Systemd Timer (Daily Training)
|
||||
|
||||
SERVICE_FILE="tasks/lowkey-training.service"
|
||||
TIMER_FILE="tasks/lowkey-training.timer"
|
||||
SYSTEMD_DIR="/etc/systemd/system"
|
||||
|
||||
echo "To install the daily training schedule, please run the following commands:"
|
||||
echo ""
|
||||
echo "sudo cp $SERVICE_FILE $SYSTEMD_DIR/"
|
||||
echo "sudo cp $TIMER_FILE $SYSTEMD_DIR/"
|
||||
echo "sudo systemctl daemon-reload"
|
||||
echo "sudo systemctl enable --now lowkey-training.timer"
|
||||
echo ""
|
||||
echo "To check the status:"
|
||||
echo "systemctl list-timers --all | grep lowkey"
|
||||
@@ -30,7 +30,7 @@ class RegimeReversionStrategy(BaseStrategy):
|
||||
|
||||
# Optimal parameters from walk-forward research (2025-10 to 2025-12)
|
||||
# Research: research/horizon_optimization_results.csv
|
||||
OPTIMAL_HORIZON = 102 # 4.25 days - best Net PnL (+232%)
|
||||
OPTIMAL_HORIZON = 54 # Updated from 102h based on corrected labeling
|
||||
OPTIMAL_Z_WINDOW = 24 # 24h rolling window for spread Z-score
|
||||
OPTIMAL_TRAIN_RATIO = 0.7 # 70% train / 30% test split
|
||||
OPTIMAL_PROFIT_TARGET = 0.005 # 0.5% profit threshold for target definition
|
||||
@@ -321,21 +321,64 @@ class RegimeReversionStrategy(BaseStrategy):
|
||||
train_features: DataFrame containing features for training period only
|
||||
"""
|
||||
threshold = self.profit_target
|
||||
stop_loss_pct = self.stop_loss
|
||||
horizon = self.horizon
|
||||
z_thresh = self.z_entry_threshold
|
||||
|
||||
# Define targets using ONLY training data
|
||||
# For Short Spread (Z > threshold): Did spread drop below target within horizon?
|
||||
future_min = train_features['spread'].rolling(window=horizon).min().shift(-horizon)
|
||||
target_short = train_features['spread'] * (1 - threshold)
|
||||
success_short = (train_features['z_score'] > z_thresh) & (future_min < target_short)
|
||||
# Calculate targets path-dependently (checking SL before TP)
|
||||
spread = train_features['spread'].values
|
||||
z_score = train_features['z_score'].values
|
||||
n = len(spread)
|
||||
|
||||
# For Long Spread (Z < -threshold): Did spread rise above target within horizon?
|
||||
future_max = train_features['spread'].rolling(window=horizon).max().shift(-horizon)
|
||||
target_long = train_features['spread'] * (1 + threshold)
|
||||
success_long = (train_features['z_score'] < -z_thresh) & (future_max > target_long)
|
||||
targets = np.zeros(n, dtype=int)
|
||||
|
||||
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||
# Only iterate relevant rows for efficiency
|
||||
candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0]
|
||||
|
||||
for i in candidates:
|
||||
if i + horizon >= n:
|
||||
continue
|
||||
|
||||
entry_price = spread[i]
|
||||
future_prices = spread[i+1 : i+1+horizon]
|
||||
|
||||
if z_score[i] > z_thresh: # Short
|
||||
target_price = entry_price * (1 - threshold)
|
||||
stop_price = entry_price * (1 + stop_loss_pct)
|
||||
|
||||
hit_tp = future_prices <= target_price
|
||||
hit_sl = future_prices >= stop_price
|
||||
|
||||
if not np.any(hit_tp):
|
||||
targets[i] = 0
|
||||
elif not np.any(hit_sl):
|
||||
targets[i] = 1
|
||||
else:
|
||||
first_tp_idx = np.argmax(hit_tp)
|
||||
first_sl_idx = np.argmax(hit_sl)
|
||||
if first_tp_idx < first_sl_idx:
|
||||
targets[i] = 1
|
||||
else:
|
||||
targets[i] = 0
|
||||
|
||||
else: # Long
|
||||
target_price = entry_price * (1 + threshold)
|
||||
stop_price = entry_price * (1 - stop_loss_pct)
|
||||
|
||||
hit_tp = future_prices >= target_price
|
||||
hit_sl = future_prices <= stop_price
|
||||
|
||||
if not np.any(hit_tp):
|
||||
targets[i] = 0
|
||||
elif not np.any(hit_sl):
|
||||
targets[i] = 1
|
||||
else:
|
||||
first_tp_idx = np.argmax(hit_tp)
|
||||
first_sl_idx = np.argmax(hit_sl)
|
||||
if first_tp_idx < first_sl_idx:
|
||||
targets[i] = 1
|
||||
else:
|
||||
targets[i] = 0
|
||||
|
||||
# Build model
|
||||
model = RandomForestClassifier(
|
||||
@@ -351,10 +394,9 @@ class RegimeReversionStrategy(BaseStrategy):
|
||||
X_train = train_features[cols].fillna(0)
|
||||
X_train = X_train.replace([np.inf, -np.inf], 0)
|
||||
|
||||
# Remove rows with NaN targets (from rolling window at end of training period)
|
||||
valid_mask = ~np.isnan(targets) & ~np.isinf(targets)
|
||||
# Also check for rows where future data doesn't exist (shift created NaNs)
|
||||
valid_mask = valid_mask & (future_min.notna().values) & (future_max.notna().values)
|
||||
# Use rows where we had enough data to look ahead
|
||||
valid_mask = np.zeros(n, dtype=bool)
|
||||
valid_mask[:n-horizon] = True
|
||||
|
||||
X_train_clean = X_train[valid_mask]
|
||||
targets_clean = targets[valid_mask]
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
[Unit]
|
||||
Description=Lowkey Backtest Daily Model Training
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=oneshot
|
||||
WorkingDirectory=/home/tamaya/Documents/Work/TCP/lowkey_backtest_live
|
||||
ExecStart=/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/train_daily.sh
|
||||
User=tamaya
|
||||
Group=tamaya
|
||||
StandardOutput=append:/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/logs/training.log
|
||||
StandardError=append:/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/logs/training.log
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
@@ -1,10 +0,0 @@
|
||||
[Unit]
|
||||
Description=Run Lowkey Backtest Training Daily
|
||||
|
||||
[Timer]
|
||||
OnCalendar=*-*-* 00:30:00
|
||||
Persistent=true
|
||||
Unit=lowkey-training.service
|
||||
|
||||
[Install]
|
||||
WantedBy=timers.target
|
||||
@@ -1,50 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Daily ML Model Training Script
|
||||
#
|
||||
# Downloads fresh data and retrains the regime detection model.
|
||||
# Can be run manually or scheduled via cron.
|
||||
#
|
||||
# Usage:
|
||||
# ./train_daily.sh # Full workflow
|
||||
# ./train_daily.sh --skip-research # Skip research validation
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
LOG_DIR="logs"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S")
|
||||
echo "[$TIMESTAMP] Starting daily training..."
|
||||
|
||||
# 1. Download fresh data
|
||||
echo "Downloading BTC-USDT 1h data..."
|
||||
uv run python main.py download -p BTC-USDT -t 1h
|
||||
|
||||
echo "Downloading ETH-USDT 1h data..."
|
||||
uv run python main.py download -p ETH-USDT -t 1h
|
||||
|
||||
# 2. Research optimization (find best horizon)
|
||||
echo "Running research optimization..."
|
||||
uv run python research/regime_detection.py --output-horizon data/optimal_horizon.txt
|
||||
|
||||
# 3. Read best horizon
|
||||
if [[ -f "data/optimal_horizon.txt" ]]; then
|
||||
BEST_HORIZON=$(cat data/optimal_horizon.txt)
|
||||
echo "Found optimal horizon: ${BEST_HORIZON} bars"
|
||||
else
|
||||
BEST_HORIZON=102
|
||||
echo "Warning: Could not find optimal horizon file. Using default: ${BEST_HORIZON}"
|
||||
fi
|
||||
|
||||
# 4. Train model
|
||||
echo "Training ML model with horizon ${BEST_HORIZON}..."
|
||||
uv run python train_model.py --horizon "$BEST_HORIZON"
|
||||
|
||||
# 5. Cleanup
|
||||
rm -f data/optimal_horizon.txt
|
||||
|
||||
TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S")
|
||||
echo "[$TIMESTAMP] Daily training complete."
|
||||
451
train_model.py
451
train_model.py
@@ -1,451 +0,0 @@
|
||||
"""
|
||||
ML Model Training Script.
|
||||
|
||||
Trains the regime detection Random Forest model on historical data.
|
||||
Can be run manually or scheduled via cron for daily retraining.
|
||||
|
||||
Usage:
|
||||
uv run python train_model.py [options]
|
||||
|
||||
Options:
|
||||
--days DAYS Number of days of historical data to use (default: 90)
|
||||
--pair PAIR Trading pair for context (default: BTC-USDT)
|
||||
--timeframe TF Timeframe (default: 1h)
|
||||
--output PATH Output model path (default: data/regime_model.pkl)
|
||||
--train-ratio R Train/test split ratio (default: 0.7)
|
||||
--dry-run Run without saving model
|
||||
"""
|
||||
import argparse
|
||||
import pickle
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import ta
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import classification_report, f1_score
|
||||
|
||||
from engine.data_manager import DataManager
|
||||
from engine.market import MarketType
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Default configuration (from research optimization)
|
||||
DEFAULT_HORIZON = 102 # 4.25 days - optimal from research
|
||||
DEFAULT_Z_WINDOW = 24 # 24h rolling window
|
||||
DEFAULT_PROFIT_TARGET = 0.005 # 0.5% profit threshold
|
||||
DEFAULT_Z_THRESHOLD = 1.0 # Z-score entry threshold
|
||||
DEFAULT_TRAIN_RATIO = 0.7 # 70% train / 30% test
|
||||
FEE_RATE = 0.001 # 0.1% round-trip fee
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train the regime detection ML model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--days",
|
||||
type=int,
|
||||
default=90,
|
||||
help="Number of days of historical data to use (default: 90)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pair",
|
||||
type=str,
|
||||
default="BTC-USDT",
|
||||
help="Base pair for context data (default: BTC-USDT)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spread-pair",
|
||||
type=str,
|
||||
default="ETH-USDT",
|
||||
help="Spread pair to trade (default: ETH-USDT)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeframe",
|
||||
type=str,
|
||||
default="1h",
|
||||
help="Timeframe (default: 1h)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--market",
|
||||
type=str,
|
||||
choices=["spot", "perpetual"],
|
||||
default="perpetual",
|
||||
help="Market type (default: perpetual)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="data/regime_model.pkl",
|
||||
help="Output model path (default: data/regime_model.pkl)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-ratio",
|
||||
type=float,
|
||||
default=DEFAULT_TRAIN_RATIO,
|
||||
help="Train/test split ratio (default: 0.7)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon",
|
||||
type=int,
|
||||
default=DEFAULT_HORIZON,
|
||||
help="Prediction horizon in bars (default: 102)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Run without saving model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--download",
|
||||
action="store_true",
|
||||
help="Download latest data before training"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def download_data(dm: DataManager, pair: str, timeframe: str, market_type: MarketType):
|
||||
"""Download latest data for a pair."""
|
||||
logger.info(f"Downloading latest data for {pair}...")
|
||||
try:
|
||||
dm.download_data("okx", pair, timeframe, market_type)
|
||||
logger.info(f"Downloaded {pair} data")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download {pair}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def load_data(
|
||||
dm: DataManager,
|
||||
base_pair: str,
|
||||
spread_pair: str,
|
||||
timeframe: str,
|
||||
market_type: MarketType,
|
||||
days: int
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Load and align historical data for both pairs."""
|
||||
df_base = dm.load_data("okx", base_pair, timeframe, market_type)
|
||||
df_spread = dm.load_data("okx", spread_pair, timeframe, market_type)
|
||||
|
||||
# Filter to last N days
|
||||
end_date = pd.Timestamp.now(tz="UTC")
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
df_base = df_base[(df_base.index >= start_date) & (df_base.index <= end_date)]
|
||||
df_spread = df_spread[(df_spread.index >= start_date) & (df_spread.index <= end_date)]
|
||||
|
||||
# Align indices
|
||||
common = df_base.index.intersection(df_spread.index)
|
||||
df_base = df_base.loc[common]
|
||||
df_spread = df_spread.loc[common]
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(common)} bars from {common.min()} to {common.max()}"
|
||||
)
|
||||
return df_base, df_spread
|
||||
|
||||
|
||||
def load_cryptoquant_data() -> pd.DataFrame | None:
|
||||
"""Load CryptoQuant on-chain data if available."""
|
||||
try:
|
||||
cq_path = Path("data/cq_training_data.csv")
|
||||
if not cq_path.exists():
|
||||
logger.info("CryptoQuant data not found, skipping on-chain features")
|
||||
return None
|
||||
|
||||
cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True)
|
||||
if cq_df.index.tz is None:
|
||||
cq_df.index = cq_df.index.tz_localize('UTC')
|
||||
logger.info(f"Loaded CryptoQuant data: {len(cq_df)} rows")
|
||||
return cq_df
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load CryptoQuant data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def calculate_features(
|
||||
df_base: pd.DataFrame,
|
||||
df_spread: pd.DataFrame,
|
||||
cq_df: pd.DataFrame | None = None,
|
||||
z_window: int = DEFAULT_Z_WINDOW
|
||||
) -> pd.DataFrame:
|
||||
"""Calculate all features for the model."""
|
||||
spread = df_spread['close'] / df_base['close']
|
||||
|
||||
# Z-Score
|
||||
rolling_mean = spread.rolling(window=z_window).mean()
|
||||
rolling_std = spread.rolling(window=z_window).std()
|
||||
z_score = (spread - rolling_mean) / rolling_std
|
||||
|
||||
# Technicals
|
||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||
spread_roc = spread.pct_change(periods=5) * 100
|
||||
spread_change_1h = spread.pct_change(periods=1)
|
||||
|
||||
# Volume
|
||||
vol_ratio = df_spread['volume'] / df_base['volume']
|
||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||
|
||||
# Volatility
|
||||
ret_base = df_base['close'].pct_change()
|
||||
ret_spread = df_spread['close'].pct_change()
|
||||
vol_base = ret_base.rolling(window=z_window).std()
|
||||
vol_spread = ret_spread.rolling(window=z_window).std()
|
||||
vol_spread_ratio = vol_spread / vol_base
|
||||
|
||||
features = pd.DataFrame(index=spread.index)
|
||||
features['spread'] = spread
|
||||
features['z_score'] = z_score
|
||||
features['spread_rsi'] = spread_rsi
|
||||
features['spread_roc'] = spread_roc
|
||||
features['spread_change_1h'] = spread_change_1h
|
||||
features['vol_ratio'] = vol_ratio
|
||||
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
|
||||
features['vol_diff_ratio'] = vol_spread_ratio
|
||||
|
||||
# Add CQ features if available
|
||||
if cq_df is not None:
|
||||
cq_aligned = cq_df.reindex(features.index, method='ffill')
|
||||
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
|
||||
cq_aligned['funding_diff'] = (
|
||||
cq_aligned['eth_funding'] - cq_aligned['btc_funding']
|
||||
)
|
||||
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
|
||||
cq_aligned['inflow_ratio'] = (
|
||||
cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
|
||||
)
|
||||
features = features.join(cq_aligned)
|
||||
|
||||
return features.dropna()
|
||||
|
||||
|
||||
def calculate_targets(
|
||||
features: pd.DataFrame,
|
||||
horizon: int,
|
||||
profit_target: float = DEFAULT_PROFIT_TARGET,
|
||||
z_threshold: float = DEFAULT_Z_THRESHOLD
|
||||
) -> tuple[np.ndarray, pd.Series]:
|
||||
"""Calculate target labels for training."""
|
||||
spread = features['spread']
|
||||
z_score = features['z_score']
|
||||
|
||||
# For Short (Z > threshold): Did spread drop below target?
|
||||
future_min = spread.rolling(window=horizon).min().shift(-horizon)
|
||||
target_short = spread * (1 - profit_target)
|
||||
success_short = (z_score > z_threshold) & (future_min < target_short)
|
||||
|
||||
# For Long (Z < -threshold): Did spread rise above target?
|
||||
future_max = spread.rolling(window=horizon).max().shift(-horizon)
|
||||
target_long = spread * (1 + profit_target)
|
||||
success_long = (z_score < -z_threshold) & (future_max > target_long)
|
||||
|
||||
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||
|
||||
# Create valid mask (rows with complete future data)
|
||||
valid_mask = future_min.notna() & future_max.notna()
|
||||
|
||||
return targets, valid_mask
|
||||
|
||||
|
||||
def train_model(
|
||||
features: pd.DataFrame,
|
||||
train_ratio: float = DEFAULT_TRAIN_RATIO,
|
||||
horizon: int = DEFAULT_HORIZON
|
||||
) -> tuple[RandomForestClassifier, list[str], dict]:
|
||||
"""
|
||||
Train Random Forest model with walk-forward split.
|
||||
|
||||
Args:
|
||||
features: DataFrame with calculated features
|
||||
train_ratio: Fraction of data to use for training
|
||||
horizon: Prediction horizon in bars
|
||||
|
||||
Returns:
|
||||
Tuple of (trained model, feature columns, metrics dict)
|
||||
"""
|
||||
# Calculate targets
|
||||
targets, valid_mask = calculate_targets(features, horizon)
|
||||
|
||||
# Walk-forward split
|
||||
n_samples = len(features)
|
||||
train_size = int(n_samples * train_ratio)
|
||||
|
||||
train_features = features.iloc[:train_size]
|
||||
test_features = features.iloc[train_size:]
|
||||
|
||||
train_targets = targets[:train_size]
|
||||
test_targets = targets[train_size:]
|
||||
|
||||
train_valid = valid_mask.iloc[:train_size]
|
||||
test_valid = valid_mask.iloc[train_size:]
|
||||
|
||||
# Prepare training data
|
||||
exclude = ['spread']
|
||||
feature_cols = [c for c in features.columns if c not in exclude]
|
||||
|
||||
X_train = train_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0)
|
||||
X_train_valid = X_train[train_valid]
|
||||
y_train_valid = train_targets[train_valid]
|
||||
|
||||
if len(X_train_valid) < 100:
|
||||
raise ValueError(
|
||||
f"Not enough training data: {len(X_train_valid)} samples (need >= 100)"
|
||||
)
|
||||
|
||||
logger.info(f"Training on {len(X_train_valid)} samples...")
|
||||
|
||||
# Train model
|
||||
model = RandomForestClassifier(
|
||||
n_estimators=300,
|
||||
max_depth=5,
|
||||
min_samples_leaf=30,
|
||||
class_weight={0: 1, 1: 3},
|
||||
random_state=42
|
||||
)
|
||||
model.fit(X_train_valid, y_train_valid)
|
||||
|
||||
# Evaluate on test set
|
||||
X_test = test_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0)
|
||||
predictions = model.predict(X_test)
|
||||
|
||||
# Only evaluate on valid test rows
|
||||
test_valid_mask = test_valid.values
|
||||
y_test_valid = test_targets[test_valid_mask]
|
||||
pred_valid = predictions[test_valid_mask]
|
||||
|
||||
# Calculate metrics
|
||||
f1 = f1_score(y_test_valid, pred_valid, zero_division=0)
|
||||
|
||||
metrics = {
|
||||
'train_samples': len(X_train_valid),
|
||||
'test_samples': len(X_test),
|
||||
'f1_score': f1,
|
||||
'train_end': train_features.index[-1].isoformat(),
|
||||
'test_start': test_features.index[0].isoformat(),
|
||||
'horizon': horizon,
|
||||
'feature_cols': feature_cols,
|
||||
}
|
||||
|
||||
logger.info(f"Model trained. F1 Score: {f1:.3f}")
|
||||
logger.info(
|
||||
f"Train period: {train_features.index[0]} to {train_features.index[-1]}"
|
||||
)
|
||||
logger.info(
|
||||
f"Test period: {test_features.index[0]} to {test_features.index[-1]}"
|
||||
)
|
||||
|
||||
return model, feature_cols, metrics
|
||||
|
||||
|
||||
def save_model(
|
||||
model: RandomForestClassifier,
|
||||
feature_cols: list[str],
|
||||
metrics: dict,
|
||||
output_path: str,
|
||||
versioned: bool = True
|
||||
):
|
||||
"""
|
||||
Save trained model to file.
|
||||
|
||||
Args:
|
||||
model: Trained model
|
||||
feature_cols: List of feature column names
|
||||
metrics: Training metrics
|
||||
output_path: Output file path
|
||||
versioned: If True, also save a timestamped version
|
||||
"""
|
||||
output = Path(output_path)
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = {
|
||||
'model': model,
|
||||
'feature_cols': feature_cols,
|
||||
'metrics': metrics,
|
||||
'trained_at': datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# Save main model file
|
||||
with open(output, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
logger.info(f"Saved model to {output}")
|
||||
|
||||
# Save versioned copy
|
||||
if versioned:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
versioned_path = output.parent / f"regime_model_{timestamp}.pkl"
|
||||
with open(versioned_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
logger.info(f"Saved versioned model to {versioned_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
args = parse_args()
|
||||
|
||||
market_type = MarketType.PERPETUAL if args.market == "perpetual" else MarketType.SPOT
|
||||
dm = DataManager()
|
||||
|
||||
# Download latest data if requested
|
||||
if args.download:
|
||||
download_data(dm, args.pair, args.timeframe, market_type)
|
||||
download_data(dm, args.spread_pair, args.timeframe, market_type)
|
||||
|
||||
# Load data
|
||||
try:
|
||||
df_base, df_spread = load_data(
|
||||
dm, args.pair, args.spread_pair, args.timeframe, market_type, args.days
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load data: {e}")
|
||||
logger.info("Try running with --download flag to fetch latest data")
|
||||
sys.exit(1)
|
||||
|
||||
# Load on-chain data
|
||||
cq_df = load_cryptoquant_data()
|
||||
|
||||
# Calculate features
|
||||
features = calculate_features(df_base, df_spread, cq_df)
|
||||
logger.info(
|
||||
f"Calculated {len(features)} feature rows with {len(features.columns)} columns"
|
||||
)
|
||||
|
||||
if len(features) < 200:
|
||||
logger.error(f"Not enough data: {len(features)} rows (need >= 200)")
|
||||
sys.exit(1)
|
||||
|
||||
# Train model
|
||||
try:
|
||||
model, feature_cols, metrics = train_model(
|
||||
features, args.train_ratio, args.horizon
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Print metrics summary
|
||||
print("\n" + "=" * 60)
|
||||
print("TRAINING COMPLETE")
|
||||
print("=" * 60)
|
||||
print(f"Train samples: {metrics['train_samples']}")
|
||||
print(f"Test samples: {metrics['test_samples']}")
|
||||
print(f"F1 Score: {metrics['f1_score']:.3f}")
|
||||
print(f"Horizon: {metrics['horizon']} bars")
|
||||
print(f"Features: {len(feature_cols)}")
|
||||
print("=" * 60)
|
||||
|
||||
# Save model
|
||||
if not args.dry_run:
|
||||
save_model(model, feature_cols, metrics, args.output)
|
||||
else:
|
||||
logger.info("Dry run - model not saved")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user