56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
|
import argparse, numpy as np, pandas as pd
|
|
from pathlib import Path
|
|
from hmmlearn.hmm import GaussianHMM
|
|
from sklearn.preprocessing import StandardScaler
|
|
from main import _load_bitstamp_csv, _align_minutely, build_features, feature_matrix # reuse helpers
|
|
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--btc", type=Path)
|
|
p.add_argument("--eth", type=Path)
|
|
p.add_argument("--rules")
|
|
p.add_argument("--states", type=int)
|
|
p.add_argument("--split")
|
|
p.add_argument("--horizon", type=int)
|
|
p.add_argument("--conf", type=float)
|
|
a = p.parse_args()
|
|
|
|
btc = _load_bitstamp_csv(a.btc, "btc")
|
|
eth = _load_bitstamp_csv(a.eth, "eth")
|
|
minute = _align_minutely(btc, eth)
|
|
g = build_features(minute, a.rules, a.horizon)
|
|
train, test = g.loc[:a.split], g.loc[a.split:]
|
|
|
|
Xtr, ytr, _ = feature_matrix(train)
|
|
Xte, yte, _ = feature_matrix(test)
|
|
|
|
scaler = StandardScaler()
|
|
Xtr_s, Xte_s = scaler.fit_transform(Xtr), scaler.transform(Xte)
|
|
|
|
hmm = GaussianHMM(n_components=a.states, covariance_type="diag", n_iter=300, random_state=7)
|
|
hmm.fit(Xtr_s)
|
|
st_tr, st_te = hmm.predict(Xtr_s), hmm.predict(Xte_s)
|
|
|
|
means = {s: float(np.nanmean(ytr[st_tr == s])) for s in range(a.states)}
|
|
thr = np.nanpercentile(np.abs(list(means.values())), 30)
|
|
state_to_stance = {s: (1 if m > +thr else (-1 if m < -thr else 0)) for s, m in means.items()}
|
|
|
|
post_te = hmm.predict_proba(Xte_s)
|
|
maxp = post_te.max(axis=1)
|
|
raw_pred = np.vectorize(state_to_stance.get)(st_te)
|
|
preds = np.where(maxp >= a.conf, raw_pred, 0).astype(np.int8)
|
|
|
|
# Metrics
|
|
y, preds = yte[:len(preds)], preds
|
|
mask = preds != 0
|
|
coverage = mask.mean()
|
|
hit_rate = (np.sign(preds) == np.sign(y)).mean()
|
|
hit_trades = (np.sign(preds[mask]) == np.sign(y[mask])).mean() if mask.any() else np.nan
|
|
pnl = preds * y
|
|
bars_day = int(round(24*60 / max(1,int(pd.Timedelta(a.rules).total_seconds()/60))))
|
|
ann = np.sqrt(365 * bars_day)
|
|
sharpe = np.nanmean(pnl) / (np.nanstd(pnl)+1e-12) * ann
|
|
|
|
print(f"{a.rules:>5} conf={a.conf:.2f} cov={coverage:.3f} hit={hit_rate:.3f} "
|
|
f"hit_trades={hit_trades:.3f} Sharpe={sharpe:.3f}")
|