219 lines
7.6 KiB
Python
219 lines
7.6 KiB
Python
|
|
import argparse
|
||
|
|
from pathlib import Path
|
||
|
|
from .db import init_aggregated_db, ingest_source_db, get_instrument_time_bounds
|
||
|
|
from .events import stream_book_events
|
||
|
|
from .engine import PortfolioState, process_event
|
||
|
|
from .strategy import DemoMomentumConfig, DemoMomentumStrategy
|
||
|
|
from .metrics import compute_metrics
|
||
|
|
from .io import write_trade_log, append_summary_row
|
||
|
|
import datetime as _dt
|
||
|
|
import re as _re
|
||
|
|
import time as _time
|
||
|
|
|
||
|
|
|
||
|
|
def build_parser() -> argparse.ArgumentParser:
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description="Granular order-book backtester (book-only)"
|
||
|
|
)
|
||
|
|
sub = parser.add_subparsers(dest="command", required=True)
|
||
|
|
|
||
|
|
# init-db
|
||
|
|
init_p = sub.add_parser("init-db", help="Create aggregated SQLite database for book snapshots")
|
||
|
|
init_p.add_argument("--agg-db", required=True, help="Path to aggregated database (e.g., ./okx_agg.db)")
|
||
|
|
|
||
|
|
# ingest
|
||
|
|
ingest_p = sub.add_parser("ingest", help="Ingest one or more OKX book SQLite file(s) into the aggregated DB")
|
||
|
|
ingest_p.add_argument("--agg-db", required=True, help="Path to aggregated database")
|
||
|
|
ingest_p.add_argument(
|
||
|
|
"--db",
|
||
|
|
action="append",
|
||
|
|
required=True,
|
||
|
|
help="Path to a source OKX SQLite file containing a book table. Repeat flag to add multiple files.",
|
||
|
|
)
|
||
|
|
|
||
|
|
# run
|
||
|
|
run_p = sub.add_parser("run", help="Run a backtest over the aggregated DB (book-only)")
|
||
|
|
run_p.add_argument("--agg-db", required=True, help="Path to aggregated database")
|
||
|
|
run_p.add_argument("--instrument", required=True, help="Instrument symbol, e.g., BTC-USDT")
|
||
|
|
run_p.add_argument("--since", required=False, help="ISO8601 or epoch milliseconds start time (defaults to earliest)")
|
||
|
|
run_p.add_argument("--until", required=False, help="ISO8601 or epoch milliseconds end time (defaults to latest)")
|
||
|
|
run_p.add_argument("--stoploss", type=float, default=0.05, help="Stop-loss percentage (e.g., 0.05)")
|
||
|
|
run_p.add_argument("--maker", action="store_true", help="Use maker fees (default taker)")
|
||
|
|
run_p.add_argument("--init-usd", type=float, default=1000.0, help="Initial USD balance")
|
||
|
|
run_p.add_argument("--strategy", default="demo", help="Strategy to run (demo for now)")
|
||
|
|
|
||
|
|
return parser
|
||
|
|
|
||
|
|
|
||
|
|
def cmd_init_db(agg_db: Path) -> None:
|
||
|
|
init_aggregated_db(agg_db)
|
||
|
|
print(f"[init-db] Aggregated DB ready at: {agg_db}")
|
||
|
|
|
||
|
|
|
||
|
|
def cmd_ingest(agg_db: Path, src_dbs: list[Path]) -> None:
|
||
|
|
total_inserted = 0
|
||
|
|
for src in src_dbs:
|
||
|
|
inserted = ingest_source_db(agg_db, src)
|
||
|
|
total_inserted += inserted
|
||
|
|
print(f"[ingest] {inserted} rows ingested from {src}")
|
||
|
|
print(f"[ingest] Total inserted: {total_inserted} into {agg_db}")
|
||
|
|
|
||
|
|
|
||
|
|
def cmd_run(
|
||
|
|
agg_db: Path,
|
||
|
|
instrument: str,
|
||
|
|
since: str | None,
|
||
|
|
until: str | None,
|
||
|
|
stoploss: float,
|
||
|
|
is_maker: bool,
|
||
|
|
init_usd: float,
|
||
|
|
strategy_name: str,
|
||
|
|
) -> None:
|
||
|
|
def _iso(ms: int) -> str:
|
||
|
|
return _dt.datetime.fromtimestamp(ms / 1000.0, tz=_dt.timezone.utc).isoformat().replace("+00:00", "Z")
|
||
|
|
def _parse_time(s: str) -> int:
|
||
|
|
# epoch ms
|
||
|
|
if s.isdigit():
|
||
|
|
return int(s)
|
||
|
|
# basic ISO8601 without timezone -> assume Z
|
||
|
|
# Accept formats like 2024-06-09T12:00:00Z or 2024-06-09 12:00:00
|
||
|
|
s2 = s.replace(" ", "T")
|
||
|
|
if s2.endswith("Z"):
|
||
|
|
s2 = s2[:-1]
|
||
|
|
try:
|
||
|
|
dt = _dt.datetime.fromisoformat(s2)
|
||
|
|
except ValueError:
|
||
|
|
raise SystemExit(f"Cannot parse time: {s}")
|
||
|
|
# Treat naive datetime as UTC
|
||
|
|
return int(dt.timestamp() * 1000)
|
||
|
|
|
||
|
|
if since is None or until is None:
|
||
|
|
bounds = get_instrument_time_bounds(agg_db, instrument)
|
||
|
|
if not bounds:
|
||
|
|
raise SystemExit("No data for instrument in aggregated DB")
|
||
|
|
min_ms, max_ms = bounds
|
||
|
|
if since is None:
|
||
|
|
since_ms = min_ms
|
||
|
|
else:
|
||
|
|
since_ms = _parse_time(since)
|
||
|
|
if until is None:
|
||
|
|
until_ms = max_ms
|
||
|
|
else:
|
||
|
|
until_ms = _parse_time(until)
|
||
|
|
print(f"[run] Using time range: {_iso(since_ms)} to {_iso(until_ms)}")
|
||
|
|
else:
|
||
|
|
since_ms = _parse_time(since)
|
||
|
|
until_ms = _parse_time(until)
|
||
|
|
|
||
|
|
state = PortfolioState(cash_usd=init_usd)
|
||
|
|
if strategy_name == "demo":
|
||
|
|
strat = DemoMomentumStrategy(DemoMomentumConfig())
|
||
|
|
else:
|
||
|
|
raise SystemExit(f"Unknown strategy: {strategy_name}")
|
||
|
|
|
||
|
|
logs = []
|
||
|
|
trade_results = []
|
||
|
|
last_event = None
|
||
|
|
total_span = max(1, until_ms - since_ms)
|
||
|
|
started_at = _time.time()
|
||
|
|
count = 0
|
||
|
|
next_report_pct = 1
|
||
|
|
for ev in stream_book_events(agg_db, instrument, since_ms, until_ms):
|
||
|
|
last_event = ev
|
||
|
|
orders = strat.on_book_event(ev)
|
||
|
|
exec_logs = process_event(state, ev, orders, stoploss_pct=stoploss, is_maker=is_maker)
|
||
|
|
for le in exec_logs:
|
||
|
|
logs.append(le)
|
||
|
|
if le.get("type") in {"sell", "stop_loss"} and "pnl" in le:
|
||
|
|
try:
|
||
|
|
trade_results.append(float(le["pnl"]))
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
count += 1
|
||
|
|
pct = int(min(100.0, max(0.0, (ev.timestamp_ms - since_ms) * 100.0 / total_span)))
|
||
|
|
if pct >= next_report_pct:
|
||
|
|
elapsed = max(1e-6, _time.time() - started_at)
|
||
|
|
speed_k = (count / elapsed) / 1000.0
|
||
|
|
eq = state.equity_curve[-1][1] if state.equity_curve else state.cash_usd
|
||
|
|
print(f"\r[run] {pct:3d}% events={count} speed={speed_k:6.1f}k ev/s equity=${eq:,.2f}", end="", flush=True)
|
||
|
|
next_report_pct = pct + 1
|
||
|
|
if count:
|
||
|
|
print()
|
||
|
|
|
||
|
|
# Forced close if holding
|
||
|
|
if last_event and state.coin > 0:
|
||
|
|
price = last_event.best_bid
|
||
|
|
gross = state.coin * price
|
||
|
|
from .fees import calculate_okx_fee
|
||
|
|
fee = calculate_okx_fee(gross, is_maker=is_maker)
|
||
|
|
state.cash_usd += (gross - fee)
|
||
|
|
state.fees_paid += fee
|
||
|
|
pnl = 0.0
|
||
|
|
if state.entry_price:
|
||
|
|
pnl = (price - state.entry_price) / state.entry_price
|
||
|
|
logs.append({
|
||
|
|
"type": "forced_close",
|
||
|
|
"time": last_event.timestamp_ms,
|
||
|
|
"price": price,
|
||
|
|
"usd": state.cash_usd,
|
||
|
|
"coin": 0.0,
|
||
|
|
"pnl": pnl,
|
||
|
|
"fee": fee,
|
||
|
|
})
|
||
|
|
trade_results.append(pnl)
|
||
|
|
state.coin = 0.0
|
||
|
|
state.entry_price = None
|
||
|
|
state.entry_time_ms = None
|
||
|
|
|
||
|
|
# Outputs
|
||
|
|
stop_str = f"{stoploss:.2f}".replace(".", "p")
|
||
|
|
logfile = f"trade_log_book_sl{stop_str}.csv"
|
||
|
|
write_trade_log(logs, logfile)
|
||
|
|
|
||
|
|
metrics = compute_metrics(state.equity_curve, trade_results)
|
||
|
|
summary_row = {
|
||
|
|
"timeframe": "book",
|
||
|
|
"stop_loss": stoploss,
|
||
|
|
"total_return": f"{metrics['total_return']*100:.2f}%",
|
||
|
|
"max_drawdown": f"{metrics['max_drawdown']*100:.2f}%",
|
||
|
|
"sharpe_ratio": f"{metrics['sharpe_ratio']:.2f}",
|
||
|
|
"win_rate": f"{metrics['win_rate']*100:.2f}%",
|
||
|
|
"num_trades": int(metrics["num_trades"]),
|
||
|
|
"final_equity": f"${state.equity_curve[-1][1]:.2f}" if state.equity_curve else "$0.00",
|
||
|
|
"initial_equity": f"${state.equity_curve[0][1]:.2f}" if state.equity_curve else "$0.00",
|
||
|
|
"num_stop_losses": sum(1 for l in logs if l.get("type") == "stop_loss"),
|
||
|
|
"total_fees": f"${state.fees_paid:.4f}",
|
||
|
|
}
|
||
|
|
append_summary_row(summary_row)
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
parser = build_parser()
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
cmd = args.command
|
||
|
|
if cmd == "init-db":
|
||
|
|
cmd_init_db(Path(args.agg_db))
|
||
|
|
elif cmd == "ingest":
|
||
|
|
srcs = [Path(p) for p in (args.db or [])]
|
||
|
|
cmd_ingest(Path(args.agg_db), srcs)
|
||
|
|
elif cmd == "run":
|
||
|
|
cmd_run(
|
||
|
|
Path(args.agg_db),
|
||
|
|
args.instrument,
|
||
|
|
getattr(args, "since", None),
|
||
|
|
getattr(args, "until", None),
|
||
|
|
args.stoploss,
|
||
|
|
args.maker,
|
||
|
|
args.init_usd,
|
||
|
|
args.strategy,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
parser.error("Unknown command")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|
||
|
|
|
||
|
|
|