from __future__ import annotations from pathlib import Path from typing import Dict, Iterator, List, Tuple import sqlite3 import logging from models import Trade, Metric class SQLiteOrderflowRepository: """Read-only repository for loading orderflow data from a SQLite database.""" def __init__(self, db_path: Path) -> None: self.db_path = db_path self.conn = None def connect(self) -> None: self.conn = sqlite3.connect(str(self.db_path)) self.conn.execute("PRAGMA journal_mode = OFF") self.conn.execute("PRAGMA synchronous = OFF") self.conn.execute("PRAGMA cache_size = 100000") self.conn.execute("PRAGMA temp_store = MEMORY") self.conn.execute("PRAGMA mmap_size = 30000000000") def count_rows(self, table: str) -> int: allowed_tables = {"book", "trades"} if table not in allowed_tables: raise ValueError(f"Unsupported table name: {table}") try: row = self.conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() return int(row[0]) if row and row[0] is not None else 0 except sqlite3.Error as e: logging.error(f"Error counting rows in table {table}: {e}") return 0 def load_trades(self) -> Dict[int, List[Trade]]: trades: List[Trade] = [] try: cursor = self.conn.cursor() cursor.execute( "SELECT id, trade_id, price, size, side, timestamp FROM trades ORDER BY timestamp ASC" ) for batch in iter(lambda: cursor.fetchmany(5000), []): for id_, trade_id, price, size, side, ts in batch: timestamp_int = int(ts) trade = Trade( id=id_, trade_id=float(trade_id), price=float(price), size=float(size), side=str(side), timestamp=timestamp_int, ) trades.append(trade) return trades except sqlite3.Error as e: logging.error(f"Error loading trades: {e}") return {} def iterate_book_rows(self) -> Iterator[Tuple[int, str, str, int]]: cursor = self.conn.cursor() cursor.execute("SELECT id, bids, asks, timestamp FROM book ORDER BY timestamp ASC") while True: rows = cursor.fetchmany(5000) if not rows: break for row in rows: yield row # (id, bids, asks, timestamp) def create_metrics_table(self) -> None: """Create the metrics table with proper indexes and foreign key constraints. Args: conn: Active SQLite database connection. """ try: # Create metrics table following PRD schema self.conn.execute(""" CREATE TABLE IF NOT EXISTS metrics ( id INTEGER PRIMARY KEY AUTOINCREMENT, snapshot_id INTEGER NOT NULL, timestamp TEXT NOT NULL, obi REAL NOT NULL, cvd REAL NOT NULL, best_bid REAL, best_ask REAL, FOREIGN KEY (snapshot_id) REFERENCES book(id) ) """) # Create indexes for efficient querying self.conn.execute("CREATE INDEX IF NOT EXISTS idx_metrics_timestamp ON metrics(timestamp)") self.conn.execute("CREATE INDEX IF NOT EXISTS idx_metrics_snapshot_id ON metrics(snapshot_id)") self.conn.commit() logging.info("Metrics table and indexes created successfully") except sqlite3.Error as e: logging.error(f"Error creating metrics table: {e}") raise def table_exists(self, table_name: str) -> bool: """Check if a table exists in the database. Args: conn: Active SQLite database connection. table_name: Name of the table to check. Returns: True if table exists, False otherwise. """ try: cursor = self.conn.cursor() cursor.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,) ) return cursor.fetchone() is not None except sqlite3.Error as e: logging.error(f"Error checking if table {table_name} exists: {e}") return False def insert_metrics_batch(self, metrics: List[Metric]) -> None: """Insert multiple metrics in a single batch operation for performance. Args: conn: Active SQLite database connection. metrics: List of Metric objects to insert. """ if not metrics: return try: # Prepare batch data following existing batch pattern batch_data = [ (m.snapshot_id, m.timestamp, m.obi, m.cvd, m.best_bid, m.best_ask) for m in metrics ] # Use executemany for batch insertion self.conn.executemany( "INSERT INTO metrics (snapshot_id, timestamp, obi, cvd, best_bid, best_ask) VALUES (?, ?, ?, ?, ?, ?)", batch_data ) logging.debug(f"Inserted {len(metrics)} metrics records") except sqlite3.Error as e: logging.error(f"Error inserting metrics batch: {e}") raise def load_metrics_by_timerange(self, start_timestamp: int, end_timestamp: int) -> List[Metric]: """Load metrics within a specified timestamp range. Args: conn: Active SQLite database connection. start_timestamp: Start of the time range (inclusive). end_timestamp: End of the time range (inclusive). Returns: List of Metric objects ordered by timestamp. """ try: cursor = self.conn.cursor() cursor.execute( "SELECT snapshot_id, timestamp, obi, cvd, best_bid, best_ask FROM metrics WHERE timestamp >= ? AND timestamp <= ? ORDER BY timestamp ASC", (start_timestamp, end_timestamp) ) metrics = [] for batch in iter(lambda: cursor.fetchmany(5000), []): for snapshot_id, timestamp, obi, cvd, best_bid, best_ask in batch: metric = Metric( snapshot_id=int(snapshot_id), timestamp=int(timestamp), obi=float(obi), cvd=float(cvd), best_bid=float(best_bid) if best_bid is not None else None, best_ask=float(best_ask) if best_ask is not None else None, ) metrics.append(metric) return metrics except sqlite3.Error as e: logging.error(f"Error loading metrics by timerange: {e}") return []