130 lines
4.7 KiB
Python
130 lines
4.7 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
import xgboost as xgb
|
|
import pickle
|
|
import sys
|
|
from pathlib import Path
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.metrics import accuracy_score, classification_report, precision_recall_fscore_support
|
|
import strategy_config as config
|
|
|
|
def train_model():
|
|
print(f"--- Starting Model Training Pipeline ---")
|
|
|
|
try:
|
|
if not os.path.exists(config.FEATURES_PATH):
|
|
print(f"Error: {config.FEATURES_PATH} not found. Run prepare_data.py first.")
|
|
return
|
|
|
|
df = pd.read_csv(config.FEATURES_PATH)
|
|
# Ensure index if needed, but read_csv usually reads generic index unless specified
|
|
# prepare_data saved with index (timestamp)
|
|
if 'timestamp' in df.columns:
|
|
df = df.set_index('timestamp')
|
|
|
|
print(f"Loaded {len(df)} data points from {config.FEATURES_PATH}")
|
|
|
|
y = df['target']
|
|
print(f"Buy signals rate: {y.mean():.1%}")
|
|
|
|
# Use the dynamic feature list directly from config.py
|
|
# Check if all features exist
|
|
available_feats = [f for f in config.FEATURE_NAMES if f in df.columns]
|
|
missing_feats = [f for f in config.FEATURE_NAMES if f not in df.columns]
|
|
|
|
if missing_feats:
|
|
print(f"⚠️ Missing features: {missing_feats}")
|
|
print(f"Proceeding with {len(available_feats)} features.")
|
|
|
|
X = df[available_feats]
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
|
X, y, test_size=0.3, shuffle=False
|
|
)
|
|
|
|
# Save the test set start index for the backtester to use
|
|
# This prevents train/test leakage during backtesting
|
|
test_start_idx = len(X_train)
|
|
test_start_timestamp = df.index[test_start_idx] if hasattr(df.index, '__getitem__') else test_start_idx
|
|
|
|
# Save split info
|
|
split_info = {
|
|
'test_start_idx': test_start_idx,
|
|
'test_start_timestamp': str(test_start_timestamp),
|
|
'train_size': len(X_train),
|
|
'test_size': len(X_test)
|
|
}
|
|
split_info_path = config.MODEL_PATH.replace('.pkl', '_split.pkl')
|
|
with open(split_info_path, 'wb') as f:
|
|
pickle.dump(split_info, f)
|
|
print(f"Split info saved: Test starts at index {test_start_idx} ({test_start_timestamp})")
|
|
|
|
print(f"Training set size: {len(X_train)}")
|
|
print(f"Test set size: {len(X_test)}")
|
|
|
|
print("\nTraining XGBoost model...")
|
|
model = xgb.XGBClassifier(
|
|
objective='binary:logistic',
|
|
eval_metric='logloss',
|
|
n_estimators=200,
|
|
learning_rate=0.05,
|
|
scale_pos_weight=8.0,
|
|
max_depth=5,
|
|
subsample=0.8,
|
|
random_state=42,
|
|
early_stopping_rounds=10
|
|
)
|
|
|
|
model.fit(
|
|
X_train, y_train,
|
|
eval_set=[(X_test, y_test)],
|
|
verbose=False
|
|
)
|
|
print("Model training complete.")
|
|
|
|
y_pred = model.predict(X_test)
|
|
y_proba = model.predict_proba(X_test)[:, 1]
|
|
|
|
accuracy = accuracy_score(y_test, y_pred)
|
|
print(f"\n--- Model Evaluation ---")
|
|
print(f"Accuracy on Test Set: {accuracy * 100:.2f}%")
|
|
print("\nClassification Report:")
|
|
print(classification_report(y_test, y_pred, target_names=['Hold/Sell (0)', 'Buy (1)']))
|
|
|
|
print("\n--- Probability Threshold Analysis ---")
|
|
thresholds = [0.35, 0.40, 0.45, 0.50, 0.55, 0.60]
|
|
for thresh in thresholds:
|
|
pred_at_thresh = (y_proba >= thresh).astype(int)
|
|
if pred_at_thresh.sum() > 0:
|
|
precision, recall, f1, _ = precision_recall_fscore_support(
|
|
y_test, pred_at_thresh, average='binary', zero_division=0
|
|
)
|
|
signal_rate = pred_at_thresh.mean() * 100
|
|
print(f" Thresh {thresh:.2f}: Precision={precision:.2f}, Recall={recall:.2f}, "
|
|
f"F1={f1:.2f}, Signals={signal_rate:.1f}%")
|
|
else:
|
|
print(f" Thresh {thresh:.2f}: No signals generated")
|
|
|
|
with open(config.MODEL_PATH, 'wb') as f:
|
|
pickle.dump(model, f)
|
|
|
|
print(f"\nSUCCESS: Model saved to {config.MODEL_PATH}")
|
|
|
|
# Feature Importance
|
|
importance = pd.DataFrame({
|
|
'feature': X.columns,
|
|
'importance': model.feature_importances_
|
|
}).sort_values('importance', ascending=False)
|
|
|
|
print("\nTop 10 Features:")
|
|
print(importance.head(10))
|
|
|
|
except Exception as e:
|
|
print(f"AN ERROR OCCURRED: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
import os
|
|
if __name__ == "__main__":
|
|
train_model()
|