import argparse from pathlib import Path from .db import init_aggregated_db, ingest_source_db, get_instrument_time_bounds from .events import stream_book_events from .engine import PortfolioState, process_event from .strategy import DemoMomentumConfig, DemoMomentumStrategy from .metrics import compute_metrics from .io import write_trade_log, append_summary_row import datetime as _dt import re as _re import time as _time def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Granular order-book backtester (book-only)" ) sub = parser.add_subparsers(dest="command", required=True) # init-db init_p = sub.add_parser("init-db", help="Create aggregated SQLite database for book snapshots") init_p.add_argument("--agg-db", required=True, help="Path to aggregated database (e.g., ./okx_agg.db)") # ingest ingest_p = sub.add_parser("ingest", help="Ingest one or more OKX book SQLite file(s) into the aggregated DB") ingest_p.add_argument("--agg-db", required=True, help="Path to aggregated database") ingest_p.add_argument( "--db", action="append", required=True, help="Path to a source OKX SQLite file containing a book table. Repeat flag to add multiple files.", ) # run run_p = sub.add_parser("run", help="Run a backtest over the aggregated DB (book-only)") run_p.add_argument("--agg-db", required=True, help="Path to aggregated database") run_p.add_argument("--instrument", required=True, help="Instrument symbol, e.g., BTC-USDT") run_p.add_argument("--since", required=False, help="ISO8601 or epoch milliseconds start time (defaults to earliest)") run_p.add_argument("--until", required=False, help="ISO8601 or epoch milliseconds end time (defaults to latest)") run_p.add_argument("--stoploss", type=float, default=0.05, help="Stop-loss percentage (e.g., 0.05)") run_p.add_argument("--maker", action="store_true", help="Use maker fees (default taker)") run_p.add_argument("--init-usd", type=float, default=1000.0, help="Initial USD balance") run_p.add_argument("--strategy", default="demo", help="Strategy to run (demo for now)") return parser def cmd_init_db(agg_db: Path) -> None: init_aggregated_db(agg_db) print(f"[init-db] Aggregated DB ready at: {agg_db}") def cmd_ingest(agg_db: Path, src_dbs: list[Path]) -> None: total_inserted = 0 for src in src_dbs: inserted = ingest_source_db(agg_db, src) total_inserted += inserted print(f"[ingest] {inserted} rows ingested from {src}") print(f"[ingest] Total inserted: {total_inserted} into {agg_db}") def cmd_run( agg_db: Path, instrument: str, since: str | None, until: str | None, stoploss: float, is_maker: bool, init_usd: float, strategy_name: str, ) -> None: def _iso(ms: int) -> str: return _dt.datetime.fromtimestamp(ms / 1000.0, tz=_dt.timezone.utc).isoformat().replace("+00:00", "Z") def _parse_time(s: str) -> int: # epoch ms if s.isdigit(): return int(s) # basic ISO8601 without timezone -> assume Z # Accept formats like 2024-06-09T12:00:00Z or 2024-06-09 12:00:00 s2 = s.replace(" ", "T") if s2.endswith("Z"): s2 = s2[:-1] try: dt = _dt.datetime.fromisoformat(s2) except ValueError: raise SystemExit(f"Cannot parse time: {s}") # Treat naive datetime as UTC return int(dt.timestamp() * 1000) if since is None or until is None: bounds = get_instrument_time_bounds(agg_db, instrument) if not bounds: raise SystemExit("No data for instrument in aggregated DB") min_ms, max_ms = bounds if since is None: since_ms = min_ms else: since_ms = _parse_time(since) if until is None: until_ms = max_ms else: until_ms = _parse_time(until) print(f"[run] Using time range: {_iso(since_ms)} to {_iso(until_ms)}") else: since_ms = _parse_time(since) until_ms = _parse_time(until) state = PortfolioState(cash_usd=init_usd) if strategy_name == "demo": strat = DemoMomentumStrategy(DemoMomentumConfig()) else: raise SystemExit(f"Unknown strategy: {strategy_name}") logs = [] trade_results = [] last_event = None total_span = max(1, until_ms - since_ms) started_at = _time.time() count = 0 next_report_pct = 1 for ev in stream_book_events(agg_db, instrument, since_ms, until_ms): last_event = ev orders = strat.on_book_event(ev) exec_logs = process_event(state, ev, orders, stoploss_pct=stoploss, is_maker=is_maker) for le in exec_logs: logs.append(le) if le.get("type") in {"sell", "stop_loss"} and "pnl" in le: try: trade_results.append(float(le["pnl"])) except Exception: pass count += 1 pct = int(min(100.0, max(0.0, (ev.timestamp_ms - since_ms) * 100.0 / total_span))) if pct >= next_report_pct: elapsed = max(1e-6, _time.time() - started_at) speed_k = (count / elapsed) / 1000.0 eq = state.equity_curve[-1][1] if state.equity_curve else state.cash_usd print(f"\r[run] {pct:3d}% events={count} speed={speed_k:6.1f}k ev/s equity=${eq:,.2f}", end="", flush=True) next_report_pct = pct + 1 if count: print() # Forced close if holding if last_event and state.coin > 0: price = last_event.best_bid gross = state.coin * price from .fees import calculate_okx_fee fee = calculate_okx_fee(gross, is_maker=is_maker) state.cash_usd += (gross - fee) state.fees_paid += fee pnl = 0.0 if state.entry_price: pnl = (price - state.entry_price) / state.entry_price logs.append({ "type": "forced_close", "time": last_event.timestamp_ms, "price": price, "usd": state.cash_usd, "coin": 0.0, "pnl": pnl, "fee": fee, }) trade_results.append(pnl) state.coin = 0.0 state.entry_price = None state.entry_time_ms = None # Outputs stop_str = f"{stoploss:.2f}".replace(".", "p") logfile = f"trade_log_book_sl{stop_str}.csv" write_trade_log(logs, logfile) metrics = compute_metrics(state.equity_curve, trade_results) summary_row = { "timeframe": "book", "stop_loss": stoploss, "total_return": f"{metrics['total_return']*100:.2f}%", "max_drawdown": f"{metrics['max_drawdown']*100:.2f}%", "sharpe_ratio": f"{metrics['sharpe_ratio']:.2f}", "win_rate": f"{metrics['win_rate']*100:.2f}%", "num_trades": int(metrics["num_trades"]), "final_equity": f"${state.equity_curve[-1][1]:.2f}" if state.equity_curve else "$0.00", "initial_equity": f"${state.equity_curve[0][1]:.2f}" if state.equity_curve else "$0.00", "num_stop_losses": sum(1 for l in logs if l.get("type") == "stop_loss"), "total_fees": f"${state.fees_paid:.4f}", } append_summary_row(summary_row) def main() -> None: parser = build_parser() args = parser.parse_args() cmd = args.command if cmd == "init-db": cmd_init_db(Path(args.agg_db)) elif cmd == "ingest": srcs = [Path(p) for p in (args.db or [])] cmd_ingest(Path(args.agg_db), srcs) elif cmd == "run": cmd_run( Path(args.agg_db), args.instrument, getattr(args, "since", None), getattr(args, "until", None), args.stoploss, args.maker, args.init_usd, args.strategy, ) else: parser.error("Unknown command") if __name__ == "__main__": main()