Initialize granular-backtest package with core modules and CLI functionality

This commit is contained in:
Simon Moisy
2025-08-14 10:34:25 +08:00
parent b80a8f881b
commit 69c5cd6236
18 changed files with 2117 additions and 0 deletions

24
book_backtest/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
"""Book-only granular backtester package.
This package provides an event-driven backtesting engine operating solely on
order book snapshots streamed from an aggregated SQLite database. It is scoped
to long-only strategies for the initial implementation.
Modules are intentionally small and composable. The CLI entrypoint is
`book_backtest.cli` with subcommands for initializing the aggregated database,
ingesting source OKX book databases, and running a backtest.
"""
__all__ = [
"cli",
"db",
"events",
"engine",
"strategy",
"orders",
"fees",
"metrics",
"io",
]

218
book_backtest/cli.py Normal file
View File

@@ -0,0 +1,218 @@
import argparse
from pathlib import Path
from .db import init_aggregated_db, ingest_source_db, get_instrument_time_bounds
from .events import stream_book_events
from .engine import PortfolioState, process_event
from .strategy import DemoMomentumConfig, DemoMomentumStrategy
from .metrics import compute_metrics
from .io import write_trade_log, append_summary_row
import datetime as _dt
import re as _re
import time as _time
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Granular order-book backtester (book-only)"
)
sub = parser.add_subparsers(dest="command", required=True)
# init-db
init_p = sub.add_parser("init-db", help="Create aggregated SQLite database for book snapshots")
init_p.add_argument("--agg-db", required=True, help="Path to aggregated database (e.g., ./okx_agg.db)")
# ingest
ingest_p = sub.add_parser("ingest", help="Ingest one or more OKX book SQLite file(s) into the aggregated DB")
ingest_p.add_argument("--agg-db", required=True, help="Path to aggregated database")
ingest_p.add_argument(
"--db",
action="append",
required=True,
help="Path to a source OKX SQLite file containing a book table. Repeat flag to add multiple files.",
)
# run
run_p = sub.add_parser("run", help="Run a backtest over the aggregated DB (book-only)")
run_p.add_argument("--agg-db", required=True, help="Path to aggregated database")
run_p.add_argument("--instrument", required=True, help="Instrument symbol, e.g., BTC-USDT")
run_p.add_argument("--since", required=False, help="ISO8601 or epoch milliseconds start time (defaults to earliest)")
run_p.add_argument("--until", required=False, help="ISO8601 or epoch milliseconds end time (defaults to latest)")
run_p.add_argument("--stoploss", type=float, default=0.05, help="Stop-loss percentage (e.g., 0.05)")
run_p.add_argument("--maker", action="store_true", help="Use maker fees (default taker)")
run_p.add_argument("--init-usd", type=float, default=1000.0, help="Initial USD balance")
run_p.add_argument("--strategy", default="demo", help="Strategy to run (demo for now)")
return parser
def cmd_init_db(agg_db: Path) -> None:
init_aggregated_db(agg_db)
print(f"[init-db] Aggregated DB ready at: {agg_db}")
def cmd_ingest(agg_db: Path, src_dbs: list[Path]) -> None:
total_inserted = 0
for src in src_dbs:
inserted = ingest_source_db(agg_db, src)
total_inserted += inserted
print(f"[ingest] {inserted} rows ingested from {src}")
print(f"[ingest] Total inserted: {total_inserted} into {agg_db}")
def cmd_run(
agg_db: Path,
instrument: str,
since: str | None,
until: str | None,
stoploss: float,
is_maker: bool,
init_usd: float,
strategy_name: str,
) -> None:
def _iso(ms: int) -> str:
return _dt.datetime.fromtimestamp(ms / 1000.0, tz=_dt.timezone.utc).isoformat().replace("+00:00", "Z")
def _parse_time(s: str) -> int:
# epoch ms
if s.isdigit():
return int(s)
# basic ISO8601 without timezone -> assume Z
# Accept formats like 2024-06-09T12:00:00Z or 2024-06-09 12:00:00
s2 = s.replace(" ", "T")
if s2.endswith("Z"):
s2 = s2[:-1]
try:
dt = _dt.datetime.fromisoformat(s2)
except ValueError:
raise SystemExit(f"Cannot parse time: {s}")
# Treat naive datetime as UTC
return int(dt.timestamp() * 1000)
if since is None or until is None:
bounds = get_instrument_time_bounds(agg_db, instrument)
if not bounds:
raise SystemExit("No data for instrument in aggregated DB")
min_ms, max_ms = bounds
if since is None:
since_ms = min_ms
else:
since_ms = _parse_time(since)
if until is None:
until_ms = max_ms
else:
until_ms = _parse_time(until)
print(f"[run] Using time range: {_iso(since_ms)} to {_iso(until_ms)}")
else:
since_ms = _parse_time(since)
until_ms = _parse_time(until)
state = PortfolioState(cash_usd=init_usd)
if strategy_name == "demo":
strat = DemoMomentumStrategy(DemoMomentumConfig())
else:
raise SystemExit(f"Unknown strategy: {strategy_name}")
logs = []
trade_results = []
last_event = None
total_span = max(1, until_ms - since_ms)
started_at = _time.time()
count = 0
next_report_pct = 1
for ev in stream_book_events(agg_db, instrument, since_ms, until_ms):
last_event = ev
orders = strat.on_book_event(ev)
exec_logs = process_event(state, ev, orders, stoploss_pct=stoploss, is_maker=is_maker)
for le in exec_logs:
logs.append(le)
if le.get("type") in {"sell", "stop_loss"} and "pnl" in le:
try:
trade_results.append(float(le["pnl"]))
except Exception:
pass
count += 1
pct = int(min(100.0, max(0.0, (ev.timestamp_ms - since_ms) * 100.0 / total_span)))
if pct >= next_report_pct:
elapsed = max(1e-6, _time.time() - started_at)
speed_k = (count / elapsed) / 1000.0
eq = state.equity_curve[-1][1] if state.equity_curve else state.cash_usd
print(f"\r[run] {pct:3d}% events={count} speed={speed_k:6.1f}k ev/s equity=${eq:,.2f}", end="", flush=True)
next_report_pct = pct + 1
if count:
print()
# Forced close if holding
if last_event and state.coin > 0:
price = last_event.best_bid
gross = state.coin * price
from .fees import calculate_okx_fee
fee = calculate_okx_fee(gross, is_maker=is_maker)
state.cash_usd += (gross - fee)
state.fees_paid += fee
pnl = 0.0
if state.entry_price:
pnl = (price - state.entry_price) / state.entry_price
logs.append({
"type": "forced_close",
"time": last_event.timestamp_ms,
"price": price,
"usd": state.cash_usd,
"coin": 0.0,
"pnl": pnl,
"fee": fee,
})
trade_results.append(pnl)
state.coin = 0.0
state.entry_price = None
state.entry_time_ms = None
# Outputs
stop_str = f"{stoploss:.2f}".replace(".", "p")
logfile = f"trade_log_book_sl{stop_str}.csv"
write_trade_log(logs, logfile)
metrics = compute_metrics(state.equity_curve, trade_results)
summary_row = {
"timeframe": "book",
"stop_loss": stoploss,
"total_return": f"{metrics['total_return']*100:.2f}%",
"max_drawdown": f"{metrics['max_drawdown']*100:.2f}%",
"sharpe_ratio": f"{metrics['sharpe_ratio']:.2f}",
"win_rate": f"{metrics['win_rate']*100:.2f}%",
"num_trades": int(metrics["num_trades"]),
"final_equity": f"${state.equity_curve[-1][1]:.2f}" if state.equity_curve else "$0.00",
"initial_equity": f"${state.equity_curve[0][1]:.2f}" if state.equity_curve else "$0.00",
"num_stop_losses": sum(1 for l in logs if l.get("type") == "stop_loss"),
"total_fees": f"${state.fees_paid:.4f}",
}
append_summary_row(summary_row)
def main() -> None:
parser = build_parser()
args = parser.parse_args()
cmd = args.command
if cmd == "init-db":
cmd_init_db(Path(args.agg_db))
elif cmd == "ingest":
srcs = [Path(p) for p in (args.db or [])]
cmd_ingest(Path(args.agg_db), srcs)
elif cmd == "run":
cmd_run(
Path(args.agg_db),
args.instrument,
getattr(args, "since", None),
getattr(args, "until", None),
args.stoploss,
args.maker,
args.init_usd,
args.strategy,
)
else:
parser.error("Unknown command")
if __name__ == "__main__":
main()

89
book_backtest/db.py Normal file
View File

@@ -0,0 +1,89 @@
from __future__ import annotations
import sqlite3
from pathlib import Path
from typing import Iterator, Optional, Tuple
AGG_SCHEMA = """
CREATE TABLE IF NOT EXISTS book (
id INTEGER PRIMARY KEY AUTOINCREMENT,
instrument TEXT NOT NULL,
bids TEXT NOT NULL,
asks TEXT NOT NULL,
timestamp TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_book_instrument_ts ON book (instrument, timestamp);
CREATE UNIQUE INDEX IF NOT EXISTS ux_book_inst_ts ON book(instrument, timestamp);
"""
def init_aggregated_db(db_path: Path) -> None:
conn = sqlite3.connect(str(db_path))
try:
cur = conn.cursor()
cur.executescript(AGG_SCHEMA)
conn.commit()
finally:
conn.close()
def stream_source_book_rows(src_db: Path) -> Iterator[Tuple[str, str, str, str]]:
conn = sqlite3.connect(str(src_db))
try:
cur = conn.cursor()
cur.execute(
"SELECT instrument, bids, asks, timestamp FROM book ORDER BY timestamp ASC"
)
for row in cur:
yield row # (instrument, bids, asks, timestamp)
finally:
conn.close()
def ingest_source_db(agg_db: Path, src_db: Path) -> int:
conn = sqlite3.connect(str(agg_db))
inserted = 0
try:
cur = conn.cursor()
cur.execute("PRAGMA journal_mode = WAL;")
cur.execute("PRAGMA synchronous = NORMAL;")
cur.execute("PRAGMA temp_store = MEMORY;")
cur.execute("PRAGMA cache_size = -20000;") # ~20MB
for instrument, bids, asks, ts in stream_source_book_rows(src_db):
try:
cur.execute(
"INSERT OR IGNORE INTO book (instrument, bids, asks, timestamp) VALUES (?, ?, ?, ?)",
(instrument, bids, asks, ts),
)
if cur.rowcount:
inserted += 1
except sqlite3.Error:
# Skip malformed rows silently during ingestion
pass
if inserted % 10000 == 0:
conn.commit()
conn.commit()
finally:
conn.close()
return inserted
def get_instrument_time_bounds(agg_db: Path, instrument: str) -> Optional[Tuple[int, int]]:
"""Return (min_timestamp_ms, max_timestamp_ms) for an instrument, or None if empty."""
conn = sqlite3.connect(str(agg_db))
try:
cur = conn.cursor()
cur.execute(
"SELECT MIN(CAST(timestamp AS INTEGER)), MAX(CAST(timestamp AS INTEGER)) FROM book WHERE instrument = ?",
(instrument,),
)
row = cur.fetchone()
if not row or row[0] is None or row[1] is None:
return None
return int(row[0]), int(row[1])
finally:
conn.close()

114
book_backtest/engine.py Normal file
View File

@@ -0,0 +1,114 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Dict, Any
from .events import BookEvent
from .orders import MarketBuy, MarketSell
from .fees import calculate_okx_fee
@dataclass
class PortfolioState:
cash_usd: float
coin: float = 0.0
entry_price: float | None = None
entry_time_ms: int | None = None
fees_paid: float = 0.0
equity_curve: list[tuple[int, float]] = field(default_factory=list)
def process_event(
state: PortfolioState,
event: BookEvent,
orders: list,
stoploss_pct: float,
is_maker: bool,
) -> list[Dict[str, Any]]:
"""Apply orders at top-of-book and update stop-loss logic.
Returns log entries for any executions that occurred.
"""
logs: list[Dict[str, Any]] = []
# Evaluate stop-loss for an open long position
if state.coin > 0 and state.entry_price is not None:
threshold = state.entry_price * (1.0 - stoploss_pct)
if event.best_bid <= threshold:
# Exit at the worse of best_bid and threshold to emulate slippage
exit_price = min(event.best_bid, threshold)
gross = state.coin * exit_price
fee = calculate_okx_fee(gross, is_maker=is_maker)
state.cash_usd = gross - fee
state.fees_paid += fee
pnl = (exit_price - state.entry_price) / state.entry_price
logs.append({
"type": "stop_loss",
"time": event.timestamp_ms,
"price": exit_price,
"usd": state.cash_usd,
"coin": 0.0,
"pnl": pnl,
"fee": fee,
})
state.coin = 0.0
state.entry_price = None
state.entry_time_ms = None
# Apply incoming orders
for order in orders:
if isinstance(order, MarketBuy):
if state.cash_usd <= 0:
continue
# Execute at best ask
price = event.best_ask
notional = min(order.usd_notional, state.cash_usd)
qty = notional / price if price > 0 else 0.0
fee = calculate_okx_fee(notional, is_maker=is_maker)
# Deduct both notional and fee from cash
state.cash_usd -= (notional + fee)
state.fees_paid += fee
state.coin += qty
state.entry_price = price
state.entry_time_ms = event.timestamp_ms
logs.append({
"type": "buy",
"time": event.timestamp_ms,
"price": price,
"usd": state.cash_usd,
"coin": state.coin,
"fee": fee,
})
elif isinstance(order, MarketSell):
if state.coin <= 0:
continue
price = event.best_bid
qty = min(order.amount, state.coin)
gross = qty * price
fee = calculate_okx_fee(gross, is_maker=is_maker)
state.cash_usd += (gross - fee)
state.fees_paid += fee
state.coin -= qty
pnl = 0.0
if state.entry_price:
pnl = (price - state.entry_price) / state.entry_price
logs.append({
"type": "sell",
"time": event.timestamp_ms,
"price": price,
"usd": state.cash_usd,
"coin": state.coin,
"pnl": pnl,
"fee": fee,
})
if state.coin <= 0:
state.entry_price = None
state.entry_time_ms = None
# Track equity at each event using best bid for mark-to-market
mark_price = event.best_bid
equity = state.cash_usd + state.coin * mark_price
state.equity_curve.append((event.timestamp_ms, equity))
return logs