219 lines
7.6 KiB
Python

import argparse
from pathlib import Path
from .db import init_aggregated_db, ingest_source_db, get_instrument_time_bounds
from .events import stream_book_events
from .engine import PortfolioState, process_event
from .strategy import DemoMomentumConfig, DemoMomentumStrategy
from .metrics import compute_metrics
from .io import write_trade_log, append_summary_row
import datetime as _dt
import re as _re
import time as _time
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Granular order-book backtester (book-only)"
)
sub = parser.add_subparsers(dest="command", required=True)
# init-db
init_p = sub.add_parser("init-db", help="Create aggregated SQLite database for book snapshots")
init_p.add_argument("--agg-db", required=True, help="Path to aggregated database (e.g., ./okx_agg.db)")
# ingest
ingest_p = sub.add_parser("ingest", help="Ingest one or more OKX book SQLite file(s) into the aggregated DB")
ingest_p.add_argument("--agg-db", required=True, help="Path to aggregated database")
ingest_p.add_argument(
"--db",
action="append",
required=True,
help="Path to a source OKX SQLite file containing a book table. Repeat flag to add multiple files.",
)
# run
run_p = sub.add_parser("run", help="Run a backtest over the aggregated DB (book-only)")
run_p.add_argument("--agg-db", required=True, help="Path to aggregated database")
run_p.add_argument("--instrument", required=True, help="Instrument symbol, e.g., BTC-USDT")
run_p.add_argument("--since", required=False, help="ISO8601 or epoch milliseconds start time (defaults to earliest)")
run_p.add_argument("--until", required=False, help="ISO8601 or epoch milliseconds end time (defaults to latest)")
run_p.add_argument("--stoploss", type=float, default=0.05, help="Stop-loss percentage (e.g., 0.05)")
run_p.add_argument("--maker", action="store_true", help="Use maker fees (default taker)")
run_p.add_argument("--init-usd", type=float, default=1000.0, help="Initial USD balance")
run_p.add_argument("--strategy", default="demo", help="Strategy to run (demo for now)")
return parser
def cmd_init_db(agg_db: Path) -> None:
init_aggregated_db(agg_db)
print(f"[init-db] Aggregated DB ready at: {agg_db}")
def cmd_ingest(agg_db: Path, src_dbs: list[Path]) -> None:
total_inserted = 0
for src in src_dbs:
inserted = ingest_source_db(agg_db, src)
total_inserted += inserted
print(f"[ingest] {inserted} rows ingested from {src}")
print(f"[ingest] Total inserted: {total_inserted} into {agg_db}")
def cmd_run(
agg_db: Path,
instrument: str,
since: str | None,
until: str | None,
stoploss: float,
is_maker: bool,
init_usd: float,
strategy_name: str,
) -> None:
def _iso(ms: int) -> str:
return _dt.datetime.fromtimestamp(ms / 1000.0, tz=_dt.timezone.utc).isoformat().replace("+00:00", "Z")
def _parse_time(s: str) -> int:
# epoch ms
if s.isdigit():
return int(s)
# basic ISO8601 without timezone -> assume Z
# Accept formats like 2024-06-09T12:00:00Z or 2024-06-09 12:00:00
s2 = s.replace(" ", "T")
if s2.endswith("Z"):
s2 = s2[:-1]
try:
dt = _dt.datetime.fromisoformat(s2)
except ValueError:
raise SystemExit(f"Cannot parse time: {s}")
# Treat naive datetime as UTC
return int(dt.timestamp() * 1000)
if since is None or until is None:
bounds = get_instrument_time_bounds(agg_db, instrument)
if not bounds:
raise SystemExit("No data for instrument in aggregated DB")
min_ms, max_ms = bounds
if since is None:
since_ms = min_ms
else:
since_ms = _parse_time(since)
if until is None:
until_ms = max_ms
else:
until_ms = _parse_time(until)
print(f"[run] Using time range: {_iso(since_ms)} to {_iso(until_ms)}")
else:
since_ms = _parse_time(since)
until_ms = _parse_time(until)
state = PortfolioState(cash_usd=init_usd)
if strategy_name == "demo":
strat = DemoMomentumStrategy(DemoMomentumConfig())
else:
raise SystemExit(f"Unknown strategy: {strategy_name}")
logs = []
trade_results = []
last_event = None
total_span = max(1, until_ms - since_ms)
started_at = _time.time()
count = 0
next_report_pct = 1
for ev in stream_book_events(agg_db, instrument, since_ms, until_ms):
last_event = ev
orders = strat.on_book_event(ev)
exec_logs = process_event(state, ev, orders, stoploss_pct=stoploss, is_maker=is_maker)
for le in exec_logs:
logs.append(le)
if le.get("type") in {"sell", "stop_loss"} and "pnl" in le:
try:
trade_results.append(float(le["pnl"]))
except Exception:
pass
count += 1
pct = int(min(100.0, max(0.0, (ev.timestamp_ms - since_ms) * 100.0 / total_span)))
if pct >= next_report_pct:
elapsed = max(1e-6, _time.time() - started_at)
speed_k = (count / elapsed) / 1000.0
eq = state.equity_curve[-1][1] if state.equity_curve else state.cash_usd
print(f"\r[run] {pct:3d}% events={count} speed={speed_k:6.1f}k ev/s equity=${eq:,.2f}", end="", flush=True)
next_report_pct = pct + 1
if count:
print()
# Forced close if holding
if last_event and state.coin > 0:
price = last_event.best_bid
gross = state.coin * price
from .fees import calculate_okx_fee
fee = calculate_okx_fee(gross, is_maker=is_maker)
state.cash_usd += (gross - fee)
state.fees_paid += fee
pnl = 0.0
if state.entry_price:
pnl = (price - state.entry_price) / state.entry_price
logs.append({
"type": "forced_close",
"time": last_event.timestamp_ms,
"price": price,
"usd": state.cash_usd,
"coin": 0.0,
"pnl": pnl,
"fee": fee,
})
trade_results.append(pnl)
state.coin = 0.0
state.entry_price = None
state.entry_time_ms = None
# Outputs
stop_str = f"{stoploss:.2f}".replace(".", "p")
logfile = f"trade_log_book_sl{stop_str}.csv"
write_trade_log(logs, logfile)
metrics = compute_metrics(state.equity_curve, trade_results)
summary_row = {
"timeframe": "book",
"stop_loss": stoploss,
"total_return": f"{metrics['total_return']*100:.2f}%",
"max_drawdown": f"{metrics['max_drawdown']*100:.2f}%",
"sharpe_ratio": f"{metrics['sharpe_ratio']:.2f}",
"win_rate": f"{metrics['win_rate']*100:.2f}%",
"num_trades": int(metrics["num_trades"]),
"final_equity": f"${state.equity_curve[-1][1]:.2f}" if state.equity_curve else "$0.00",
"initial_equity": f"${state.equity_curve[0][1]:.2f}" if state.equity_curve else "$0.00",
"num_stop_losses": sum(1 for l in logs if l.get("type") == "stop_loss"),
"total_fees": f"${state.fees_paid:.4f}",
}
append_summary_row(summary_row)
def main() -> None:
parser = build_parser()
args = parser.parse_args()
cmd = args.command
if cmd == "init-db":
cmd_init_db(Path(args.agg_db))
elif cmd == "ingest":
srcs = [Path(p) for p in (args.db or [])]
cmd_ingest(Path(args.agg_db), srcs)
elif cmd == "run":
cmd_run(
Path(args.agg_db),
args.instrument,
getattr(args, "since", None),
getattr(args, "until", None),
args.stoploss,
args.maker,
args.init_usd,
args.strategy,
)
else:
parser.error("Unknown command")
if __name__ == "__main__":
main()