126 lines
4.4 KiB
Python
126 lines
4.4 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Dict, List, Tuple
|
||
|
|
import os
|
||
|
|
import csv
|
||
|
|
import json
|
||
|
|
import numpy as np
|
||
|
|
import pandas as pd
|
||
|
|
|
||
|
|
from .config import RunConfig
|
||
|
|
from .data import load_and_filter_data
|
||
|
|
from .preprocess import add_basic_time_features, downcast_numeric_columns, handle_nans
|
||
|
|
from .selection import build_feature_list, prune_features
|
||
|
|
from .model import train_model, predict, get_feature_importance
|
||
|
|
from .metrics import compute_price_series_from_log_returns, compute_metrics_from_prices
|
||
|
|
from evaluation import walk_forward_cv
|
||
|
|
from feature_engineering import feature_engineering
|
||
|
|
from plot_results import plot_prediction_error_distribution
|
||
|
|
|
||
|
|
|
||
|
|
def ensure_charts_dir(path: str) -> None:
|
||
|
|
if not os.path.exists(path):
|
||
|
|
os.makedirs(path, exist_ok=True)
|
||
|
|
|
||
|
|
|
||
|
|
def run_pipeline(cfg: RunConfig) -> Dict[str, float]:
|
||
|
|
# Setup outputs
|
||
|
|
ensure_charts_dir(cfg.output.charts_dir)
|
||
|
|
|
||
|
|
# Load and target
|
||
|
|
df = load_and_filter_data(cfg.data)
|
||
|
|
|
||
|
|
# Features
|
||
|
|
features_dict = feature_engineering(
|
||
|
|
df,
|
||
|
|
os.path.splitext(os.path.basename(cfg.data.csv_path))[0],
|
||
|
|
cfg.features.ohlcv_cols,
|
||
|
|
cfg.features.lags,
|
||
|
|
cfg.features.window_sizes,
|
||
|
|
)
|
||
|
|
features_df = pd.DataFrame(features_dict)
|
||
|
|
df = pd.concat([df, features_df], axis=1)
|
||
|
|
|
||
|
|
# Preprocess
|
||
|
|
df = downcast_numeric_columns(df)
|
||
|
|
df = add_basic_time_features(df)
|
||
|
|
df = handle_nans(df, cfg.preprocess)
|
||
|
|
|
||
|
|
# Feature selection and pruning
|
||
|
|
feature_cols = build_feature_list(df.columns)
|
||
|
|
|
||
|
|
X = df[feature_cols].values.astype(np.float32)
|
||
|
|
y = df["log_return"].values.astype(np.float32)
|
||
|
|
split_idx = int(len(X) * 0.8)
|
||
|
|
X_train, X_test = X[:split_idx], X[split_idx:]
|
||
|
|
y_train, y_test = y[:split_idx], y[split_idx:]
|
||
|
|
|
||
|
|
importance_avg = None
|
||
|
|
if cfg.pruning.do_walk_forward_cv:
|
||
|
|
metrics_avg, importance_avg = walk_forward_cv(X, y, feature_cols, n_splits=cfg.pruning.n_splits)
|
||
|
|
# Optional: you may log or return metrics_avg
|
||
|
|
|
||
|
|
kept_feature_cols = prune_features(feature_cols, importance_avg, cfg.pruning) if cfg.pruning.auto_prune else feature_cols
|
||
|
|
|
||
|
|
# Train model
|
||
|
|
model = train_model(
|
||
|
|
df[kept_feature_cols].values.astype(np.float32)[:split_idx],
|
||
|
|
df[kept_feature_cols].values.astype(np.float32)[split_idx:],
|
||
|
|
y[:split_idx],
|
||
|
|
y[split_idx:],
|
||
|
|
eval_metric='rmse',
|
||
|
|
)
|
||
|
|
|
||
|
|
# Save model
|
||
|
|
model.save_model(cfg.output.model_output_path)
|
||
|
|
|
||
|
|
# Persist the exact feature list used for training next to the model
|
||
|
|
try:
|
||
|
|
features_path = os.path.splitext(cfg.output.model_output_path)[0] + "_features.json"
|
||
|
|
with open(features_path, "w") as f:
|
||
|
|
json.dump({"feature_names": kept_feature_cols}, f)
|
||
|
|
except Exception:
|
||
|
|
# Feature list persistence is optional; avoid breaking the run on failure
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Predict
|
||
|
|
X_test_kept = df[kept_feature_cols].values.astype(np.float32)[split_idx:]
|
||
|
|
test_preds = predict(model, X_test_kept)
|
||
|
|
|
||
|
|
# Reconstruct price series
|
||
|
|
close_prices = df['Close'].values
|
||
|
|
start_price = close_prices[split_idx]
|
||
|
|
actual_prices = compute_price_series_from_log_returns(start_price, y_test)
|
||
|
|
predicted_prices = compute_price_series_from_log_returns(start_price, test_preds)
|
||
|
|
|
||
|
|
# Metrics
|
||
|
|
metrics = compute_metrics_from_prices(actual_prices, predicted_prices)
|
||
|
|
|
||
|
|
# Plot prediction error distribution to charts dir (parity with previous behavior)
|
||
|
|
try:
|
||
|
|
plot_prediction_error_distribution(predicted_prices, actual_prices, prefix="all_features")
|
||
|
|
except Exception:
|
||
|
|
# plotting is optional; ignore failures in headless environments
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Persist per-feature metrics and importances
|
||
|
|
feat_importance = get_feature_importance(model, kept_feature_cols)
|
||
|
|
if not os.path.exists(cfg.output.results_csv):
|
||
|
|
with open(cfg.output.results_csv, 'w', newline='') as f:
|
||
|
|
writer = csv.writer(f)
|
||
|
|
writer.writerow(['feature', 'rmse', 'mape', 'r2', 'directional_accuracy', 'feature_importance'])
|
||
|
|
with open(cfg.output.results_csv, 'a', newline='') as f:
|
||
|
|
writer = csv.writer(f)
|
||
|
|
for feature in kept_feature_cols:
|
||
|
|
importance = feat_importance.get(feature, 0.0)
|
||
|
|
row = [feature]
|
||
|
|
for key in ['rmse', 'mape', 'r2', 'directional_accuracy']:
|
||
|
|
val = metrics[key]
|
||
|
|
row.append(f"{val:.10f}")
|
||
|
|
row.append(f"{importance:.6f}")
|
||
|
|
writer.writerow(row)
|
||
|
|
|
||
|
|
return metrics
|
||
|
|
|
||
|
|
|