orderflow_backtest/repositories/sqlite_repository.py

74 lines
2.8 KiB
Python
Raw Normal View History

from __future__ import annotations
from pathlib import Path
from typing import Dict, Iterator, List, Tuple
import sqlite3
import logging
from models import Trade
class SQLiteOrderflowRepository:
"""Read-only repository for loading orderflow data from a SQLite database."""
def __init__(self, db_path: Path) -> None:
self.db_path = db_path
def connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(str(self.db_path))
conn.execute("PRAGMA journal_mode = OFF")
conn.execute("PRAGMA synchronous = OFF")
conn.execute("PRAGMA cache_size = 100000")
conn.execute("PRAGMA temp_store = MEMORY")
conn.execute("PRAGMA mmap_size = 30000000000")
return conn
def count_rows(self, conn: sqlite3.Connection, table: str) -> int:
allowed_tables = {"book", "trades"}
if table not in allowed_tables:
raise ValueError(f"Unsupported table name: {table}")
try:
row = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()
return int(row[0]) if row and row[0] is not None else 0
except sqlite3.Error as e:
logging.error(f"Error counting rows in table {table}: {e}")
return 0
def load_trades_by_timestamp(self, conn: sqlite3.Connection) -> Dict[int, List[Trade]]:
trades_by_timestamp: Dict[int, List[Trade]] = {}
try:
cursor = conn.cursor()
cursor.execute(
"SELECT id, trade_id, price, size, side, timestamp FROM trades ORDER BY timestamp ASC"
)
for batch in iter(lambda: cursor.fetchmany(5000), []):
for id_, trade_id, price, size, side, ts in batch:
timestamp_int = int(ts)
trade = Trade(
id=id_,
trade_id=float(trade_id),
price=float(price),
size=float(size),
side=str(side),
timestamp=timestamp_int,
)
if timestamp_int not in trades_by_timestamp:
trades_by_timestamp[timestamp_int] = []
trades_by_timestamp[timestamp_int].append(trade)
return trades_by_timestamp
except sqlite3.Error as e:
logging.error(f"Error loading trades: {e}")
return {}
def iterate_book_rows(self, conn: sqlite3.Connection) -> Iterator[Tuple[int, str, str, int]]:
cursor = conn.cursor()
cursor.execute("SELECT id, bids, asks, timestamp FROM book ORDER BY timestamp ASC")
while True:
rows = cursor.fetchmany(5000)
if not rows:
break
for row in rows:
yield row # (id, bids, asks, timestamp)