orderflow_backtest/repositories/sqlite_repository.py

189 lines
7.1 KiB
Python
Raw Normal View History

from __future__ import annotations
from pathlib import Path
from typing import Dict, Iterator, List, Tuple
import sqlite3
import logging
from models import Trade, Metric
class SQLiteOrderflowRepository:
"""Read-only repository for loading orderflow data from a SQLite database."""
def __init__(self, db_path: Path) -> None:
self.db_path = db_path
self.conn = None
def connect(self) -> None:
self.conn = sqlite3.connect(str(self.db_path))
self.conn.execute("PRAGMA journal_mode = OFF")
self.conn.execute("PRAGMA synchronous = OFF")
self.conn.execute("PRAGMA cache_size = 100000")
self.conn.execute("PRAGMA temp_store = MEMORY")
self.conn.execute("PRAGMA mmap_size = 30000000000")
def count_rows(self, table: str) -> int:
allowed_tables = {"book", "trades"}
if table not in allowed_tables:
raise ValueError(f"Unsupported table name: {table}")
try:
row = self.conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()
return int(row[0]) if row and row[0] is not None else 0
except sqlite3.Error as e:
logging.error(f"Error counting rows in table {table}: {e}")
return 0
def load_trades(self) -> Dict[int, List[Trade]]:
trades: List[Trade] = []
try:
cursor = self.conn.cursor()
cursor.execute(
"SELECT id, trade_id, price, size, side, timestamp FROM trades ORDER BY timestamp ASC"
)
for batch in iter(lambda: cursor.fetchmany(5000), []):
for id_, trade_id, price, size, side, ts in batch:
timestamp_int = int(ts)
trade = Trade(
id=id_,
trade_id=float(trade_id),
price=float(price),
size=float(size),
side=str(side),
timestamp=timestamp_int,
)
trades.append(trade)
return trades
except sqlite3.Error as e:
logging.error(f"Error loading trades: {e}")
return {}
def iterate_book_rows(self) -> Iterator[Tuple[int, str, str, int]]:
cursor = self.conn.cursor()
cursor.execute("SELECT id, bids, asks, timestamp FROM book ORDER BY timestamp ASC")
while True:
rows = cursor.fetchmany(5000)
if not rows:
break
for row in rows:
yield row # (id, bids, asks, timestamp)
def create_metrics_table(self) -> None:
"""Create the metrics table with proper indexes and foreign key constraints.
Args:
conn: Active SQLite database connection.
"""
try:
# Create metrics table following PRD schema
self.conn.execute("""
CREATE TABLE IF NOT EXISTS metrics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
snapshot_id INTEGER NOT NULL,
timestamp TEXT NOT NULL,
obi REAL NOT NULL,
cvd REAL NOT NULL,
best_bid REAL,
best_ask REAL,
FOREIGN KEY (snapshot_id) REFERENCES book(id)
)
""")
# Create indexes for efficient querying
self.conn.execute("CREATE INDEX IF NOT EXISTS idx_metrics_timestamp ON metrics(timestamp)")
self.conn.execute("CREATE INDEX IF NOT EXISTS idx_metrics_snapshot_id ON metrics(snapshot_id)")
self.conn.commit()
logging.info("Metrics table and indexes created successfully")
except sqlite3.Error as e:
logging.error(f"Error creating metrics table: {e}")
raise
def table_exists(self, table_name: str) -> bool:
"""Check if a table exists in the database.
Args:
conn: Active SQLite database connection.
table_name: Name of the table to check.
Returns:
True if table exists, False otherwise.
"""
try:
cursor = self.conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table_name,)
)
return cursor.fetchone() is not None
except sqlite3.Error as e:
logging.error(f"Error checking if table {table_name} exists: {e}")
return False
def insert_metrics_batch(self, metrics: List[Metric]) -> None:
"""Insert multiple metrics in a single batch operation for performance.
Args:
conn: Active SQLite database connection.
metrics: List of Metric objects to insert.
"""
if not metrics:
return
try:
# Prepare batch data following existing batch pattern
batch_data = [
(m.snapshot_id, m.timestamp, m.obi, m.cvd, m.best_bid, m.best_ask)
for m in metrics
]
# Use executemany for batch insertion
self.conn.executemany(
"INSERT INTO metrics (snapshot_id, timestamp, obi, cvd, best_bid, best_ask) VALUES (?, ?, ?, ?, ?, ?)",
batch_data
)
logging.debug(f"Inserted {len(metrics)} metrics records")
except sqlite3.Error as e:
logging.error(f"Error inserting metrics batch: {e}")
raise
def load_metrics_by_timerange(self, start_timestamp: int, end_timestamp: int) -> List[Metric]:
"""Load metrics within a specified timestamp range.
Args:
conn: Active SQLite database connection.
start_timestamp: Start of the time range (inclusive).
end_timestamp: End of the time range (inclusive).
Returns:
List of Metric objects ordered by timestamp.
"""
try:
cursor = self.conn.cursor()
cursor.execute(
"SELECT snapshot_id, timestamp, obi, cvd, best_bid, best_ask FROM metrics WHERE timestamp >= ? AND timestamp <= ? ORDER BY timestamp ASC",
(start_timestamp, end_timestamp)
)
metrics = []
for batch in iter(lambda: cursor.fetchmany(5000), []):
for snapshot_id, timestamp, obi, cvd, best_bid, best_ask in batch:
metric = Metric(
snapshot_id=int(snapshot_id),
timestamp=int(timestamp),
obi=float(obi),
cvd=float(cvd),
best_bid=float(best_bid) if best_bid is not None else None,
best_ask=float(best_ask) if best_ask is not None else None,
)
metrics.append(metric)
return metrics
except sqlite3.Error as e:
logging.error(f"Error loading metrics by timerange: {e}")
return []