from __future__ import annotations from pathlib import Path from typing import Dict, Iterator, List, Tuple import sqlite3 import logging from models import Trade class SQLiteOrderflowRepository: """Read-only repository for loading orderflow data from a SQLite database.""" def __init__(self, db_path: Path) -> None: self.db_path = db_path def connect(self) -> sqlite3.Connection: conn = sqlite3.connect(str(self.db_path)) conn.execute("PRAGMA journal_mode = OFF") conn.execute("PRAGMA synchronous = OFF") conn.execute("PRAGMA cache_size = 100000") conn.execute("PRAGMA temp_store = MEMORY") conn.execute("PRAGMA mmap_size = 30000000000") return conn def count_rows(self, conn: sqlite3.Connection, table: str) -> int: allowed_tables = {"book", "trades"} if table not in allowed_tables: raise ValueError(f"Unsupported table name: {table}") try: row = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() return int(row[0]) if row and row[0] is not None else 0 except sqlite3.Error as e: logging.error(f"Error counting rows in table {table}: {e}") return 0 def load_trades_by_timestamp(self, conn: sqlite3.Connection) -> Dict[int, List[Trade]]: trades_by_timestamp: Dict[int, List[Trade]] = {} try: cursor = conn.cursor() cursor.execute( "SELECT id, trade_id, price, size, side, timestamp FROM trades ORDER BY timestamp ASC" ) for batch in iter(lambda: cursor.fetchmany(5000), []): for id_, trade_id, price, size, side, ts in batch: timestamp_int = int(ts) trade = Trade( id=id_, trade_id=float(trade_id), price=float(price), size=float(size), side=str(side), timestamp=timestamp_int, ) if timestamp_int not in trades_by_timestamp: trades_by_timestamp[timestamp_int] = [] trades_by_timestamp[timestamp_int].append(trade) return trades_by_timestamp except sqlite3.Error as e: logging.error(f"Error loading trades: {e}") return {} def iterate_book_rows(self, conn: sqlite3.Connection) -> Iterator[Tuple[int, str, str, int]]: cursor = conn.cursor() cursor.execute("SELECT id, bids, asks, timestamp FROM book ORDER BY timestamp ASC") while True: rows = cursor.fetchmany(5000) if not rows: break for row in rows: yield row # (id, bids, asks, timestamp)