BTC_ETH_regime_predictor/main_conf_metrics.py

56 lines
2.0 KiB
Python

#!/usr/bin/env python3
import argparse, numpy as np, pandas as pd
from pathlib import Path
from hmmlearn.hmm import GaussianHMM
from sklearn.preprocessing import StandardScaler
from main import _load_bitstamp_csv, _align_minutely, build_features, feature_matrix # reuse helpers
p = argparse.ArgumentParser()
p.add_argument("--btc", type=Path)
p.add_argument("--eth", type=Path)
p.add_argument("--rules")
p.add_argument("--states", type=int)
p.add_argument("--split")
p.add_argument("--horizon", type=int)
p.add_argument("--conf", type=float)
a = p.parse_args()
btc = _load_bitstamp_csv(a.btc, "btc")
eth = _load_bitstamp_csv(a.eth, "eth")
minute = _align_minutely(btc, eth)
g = build_features(minute, a.rules, a.horizon)
train, test = g.loc[:a.split], g.loc[a.split:]
Xtr, ytr, _ = feature_matrix(train)
Xte, yte, _ = feature_matrix(test)
scaler = StandardScaler()
Xtr_s, Xte_s = scaler.fit_transform(Xtr), scaler.transform(Xte)
hmm = GaussianHMM(n_components=a.states, covariance_type="diag", n_iter=300, random_state=7)
hmm.fit(Xtr_s)
st_tr, st_te = hmm.predict(Xtr_s), hmm.predict(Xte_s)
means = {s: float(np.nanmean(ytr[st_tr == s])) for s in range(a.states)}
thr = np.nanpercentile(np.abs(list(means.values())), 30)
state_to_stance = {s: (1 if m > +thr else (-1 if m < -thr else 0)) for s, m in means.items()}
post_te = hmm.predict_proba(Xte_s)
maxp = post_te.max(axis=1)
raw_pred = np.vectorize(state_to_stance.get)(st_te)
preds = np.where(maxp >= a.conf, raw_pred, 0).astype(np.int8)
# Metrics
y, preds = yte[:len(preds)], preds
mask = preds != 0
coverage = mask.mean()
hit_rate = (np.sign(preds) == np.sign(y)).mean()
hit_trades = (np.sign(preds[mask]) == np.sign(y[mask])).mean() if mask.any() else np.nan
pnl = preds * y
bars_day = int(round(24*60 / max(1,int(pd.Timedelta(a.rules).total_seconds()/60))))
ann = np.sqrt(365 * bars_day)
sharpe = np.nanmean(pnl) / (np.nanstd(pnl)+1e-12) * ann
print(f"{a.rules:>5} conf={a.conf:.2f} cov={coverage:.3f} hit={hit_rate:.3f} "
f"hit_trades={hit_trades:.3f} Sharpe={sharpe:.3f}")