74 lines
2.8 KiB
Python
74 lines
2.8 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Dict, Iterator, List, Tuple
|
||
|
|
import sqlite3
|
||
|
|
import logging
|
||
|
|
|
||
|
|
from models import Trade
|
||
|
|
|
||
|
|
|
||
|
|
class SQLiteOrderflowRepository:
|
||
|
|
"""Read-only repository for loading orderflow data from a SQLite database."""
|
||
|
|
|
||
|
|
def __init__(self, db_path: Path) -> None:
|
||
|
|
self.db_path = db_path
|
||
|
|
|
||
|
|
def connect(self) -> sqlite3.Connection:
|
||
|
|
conn = sqlite3.connect(str(self.db_path))
|
||
|
|
conn.execute("PRAGMA journal_mode = OFF")
|
||
|
|
conn.execute("PRAGMA synchronous = OFF")
|
||
|
|
conn.execute("PRAGMA cache_size = 100000")
|
||
|
|
conn.execute("PRAGMA temp_store = MEMORY")
|
||
|
|
conn.execute("PRAGMA mmap_size = 30000000000")
|
||
|
|
return conn
|
||
|
|
|
||
|
|
def count_rows(self, conn: sqlite3.Connection, table: str) -> int:
|
||
|
|
allowed_tables = {"book", "trades"}
|
||
|
|
if table not in allowed_tables:
|
||
|
|
raise ValueError(f"Unsupported table name: {table}")
|
||
|
|
try:
|
||
|
|
row = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()
|
||
|
|
return int(row[0]) if row and row[0] is not None else 0
|
||
|
|
except sqlite3.Error as e:
|
||
|
|
logging.error(f"Error counting rows in table {table}: {e}")
|
||
|
|
return 0
|
||
|
|
|
||
|
|
def load_trades_by_timestamp(self, conn: sqlite3.Connection) -> Dict[int, List[Trade]]:
|
||
|
|
trades_by_timestamp: Dict[int, List[Trade]] = {}
|
||
|
|
try:
|
||
|
|
cursor = conn.cursor()
|
||
|
|
cursor.execute(
|
||
|
|
"SELECT id, trade_id, price, size, side, timestamp FROM trades ORDER BY timestamp ASC"
|
||
|
|
)
|
||
|
|
for batch in iter(lambda: cursor.fetchmany(5000), []):
|
||
|
|
for id_, trade_id, price, size, side, ts in batch:
|
||
|
|
timestamp_int = int(ts)
|
||
|
|
trade = Trade(
|
||
|
|
id=id_,
|
||
|
|
trade_id=float(trade_id),
|
||
|
|
price=float(price),
|
||
|
|
size=float(size),
|
||
|
|
side=str(side),
|
||
|
|
timestamp=timestamp_int,
|
||
|
|
)
|
||
|
|
if timestamp_int not in trades_by_timestamp:
|
||
|
|
trades_by_timestamp[timestamp_int] = []
|
||
|
|
trades_by_timestamp[timestamp_int].append(trade)
|
||
|
|
return trades_by_timestamp
|
||
|
|
except sqlite3.Error as e:
|
||
|
|
logging.error(f"Error loading trades: {e}")
|
||
|
|
return {}
|
||
|
|
|
||
|
|
def iterate_book_rows(self, conn: sqlite3.Connection) -> Iterator[Tuple[int, str, str, int]]:
|
||
|
|
cursor = conn.cursor()
|
||
|
|
cursor.execute("SELECT id, bids, asks, timestamp FROM book ORDER BY timestamp ASC")
|
||
|
|
while True:
|
||
|
|
rows = cursor.fetchmany(5000)
|
||
|
|
if not rows:
|
||
|
|
break
|
||
|
|
for row in rows:
|
||
|
|
yield row # (id, bids, asks, timestamp)
|
||
|
|
|
||
|
|
|