2025-08-26 17:22:07 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Dict, Iterator, List, Tuple
|
|
|
|
|
import sqlite3
|
|
|
|
|
import logging
|
|
|
|
|
|
2025-09-01 11:17:10 +08:00
|
|
|
from models import Trade, Metric
|
2025-08-26 17:22:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class SQLiteOrderflowRepository:
|
|
|
|
|
"""Read-only repository for loading orderflow data from a SQLite database."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, db_path: Path) -> None:
|
|
|
|
|
self.db_path = db_path
|
2025-09-01 11:17:10 +08:00
|
|
|
self.conn = None
|
2025-08-26 17:22:07 +08:00
|
|
|
|
2025-09-01 11:17:10 +08:00
|
|
|
def connect(self) -> None:
|
|
|
|
|
self.conn = sqlite3.connect(str(self.db_path))
|
|
|
|
|
self.conn.execute("PRAGMA journal_mode = OFF")
|
|
|
|
|
self.conn.execute("PRAGMA synchronous = OFF")
|
|
|
|
|
self.conn.execute("PRAGMA cache_size = 100000")
|
|
|
|
|
self.conn.execute("PRAGMA temp_store = MEMORY")
|
|
|
|
|
self.conn.execute("PRAGMA mmap_size = 30000000000")
|
2025-08-26 17:22:07 +08:00
|
|
|
|
2025-09-01 11:17:10 +08:00
|
|
|
def count_rows(self, table: str) -> int:
|
2025-08-26 17:22:07 +08:00
|
|
|
allowed_tables = {"book", "trades"}
|
|
|
|
|
if table not in allowed_tables:
|
|
|
|
|
raise ValueError(f"Unsupported table name: {table}")
|
|
|
|
|
try:
|
2025-09-01 11:17:10 +08:00
|
|
|
row = self.conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()
|
2025-08-26 17:22:07 +08:00
|
|
|
return int(row[0]) if row and row[0] is not None else 0
|
|
|
|
|
except sqlite3.Error as e:
|
|
|
|
|
logging.error(f"Error counting rows in table {table}: {e}")
|
|
|
|
|
return 0
|
|
|
|
|
|
2025-09-01 11:17:10 +08:00
|
|
|
def load_trades(self) -> Dict[int, List[Trade]]:
|
|
|
|
|
trades: List[Trade] = []
|
2025-08-26 17:22:07 +08:00
|
|
|
try:
|
2025-09-01 11:17:10 +08:00
|
|
|
cursor = self.conn.cursor()
|
2025-08-26 17:22:07 +08:00
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT id, trade_id, price, size, side, timestamp FROM trades ORDER BY timestamp ASC"
|
|
|
|
|
)
|
|
|
|
|
for batch in iter(lambda: cursor.fetchmany(5000), []):
|
|
|
|
|
for id_, trade_id, price, size, side, ts in batch:
|
|
|
|
|
timestamp_int = int(ts)
|
|
|
|
|
trade = Trade(
|
|
|
|
|
id=id_,
|
|
|
|
|
trade_id=float(trade_id),
|
|
|
|
|
price=float(price),
|
|
|
|
|
size=float(size),
|
|
|
|
|
side=str(side),
|
|
|
|
|
timestamp=timestamp_int,
|
|
|
|
|
)
|
2025-09-01 11:17:10 +08:00
|
|
|
trades.append(trade)
|
|
|
|
|
return trades
|
2025-08-26 17:22:07 +08:00
|
|
|
except sqlite3.Error as e:
|
|
|
|
|
logging.error(f"Error loading trades: {e}")
|
|
|
|
|
return {}
|
|
|
|
|
|
2025-09-01 11:17:10 +08:00
|
|
|
def iterate_book_rows(self) -> Iterator[Tuple[int, str, str, int]]:
|
|
|
|
|
cursor = self.conn.cursor()
|
2025-08-26 17:22:07 +08:00
|
|
|
cursor.execute("SELECT id, bids, asks, timestamp FROM book ORDER BY timestamp ASC")
|
|
|
|
|
while True:
|
|
|
|
|
rows = cursor.fetchmany(5000)
|
|
|
|
|
if not rows:
|
|
|
|
|
break
|
|
|
|
|
for row in rows:
|
|
|
|
|
yield row # (id, bids, asks, timestamp)
|
|
|
|
|
|
2025-09-01 11:17:10 +08:00
|
|
|
def create_metrics_table(self) -> None:
|
|
|
|
|
"""Create the metrics table with proper indexes and foreign key constraints.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
conn: Active SQLite database connection.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# Create metrics table following PRD schema
|
|
|
|
|
self.conn.execute("""
|
|
|
|
|
CREATE TABLE IF NOT EXISTS metrics (
|
|
|
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
|
|
snapshot_id INTEGER NOT NULL,
|
|
|
|
|
timestamp TEXT NOT NULL,
|
|
|
|
|
obi REAL NOT NULL,
|
|
|
|
|
cvd REAL NOT NULL,
|
|
|
|
|
best_bid REAL,
|
|
|
|
|
best_ask REAL,
|
|
|
|
|
FOREIGN KEY (snapshot_id) REFERENCES book(id)
|
|
|
|
|
)
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
# Create indexes for efficient querying
|
|
|
|
|
self.conn.execute("CREATE INDEX IF NOT EXISTS idx_metrics_timestamp ON metrics(timestamp)")
|
|
|
|
|
self.conn.execute("CREATE INDEX IF NOT EXISTS idx_metrics_snapshot_id ON metrics(snapshot_id)")
|
|
|
|
|
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
logging.info("Metrics table and indexes created successfully")
|
|
|
|
|
|
|
|
|
|
except sqlite3.Error as e:
|
|
|
|
|
logging.error(f"Error creating metrics table: {e}")
|
|
|
|
|
raise
|
2025-08-26 17:22:07 +08:00
|
|
|
|
2025-09-01 11:17:10 +08:00
|
|
|
def table_exists(self, table_name: str) -> bool:
|
|
|
|
|
"""Check if a table exists in the database.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
conn: Active SQLite database connection.
|
|
|
|
|
table_name: Name of the table to check.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
True if table exists, False otherwise.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
|
|
|
|
(table_name,)
|
|
|
|
|
)
|
|
|
|
|
return cursor.fetchone() is not None
|
|
|
|
|
except sqlite3.Error as e:
|
|
|
|
|
logging.error(f"Error checking if table {table_name} exists: {e}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def insert_metrics_batch(self, metrics: List[Metric]) -> None:
|
|
|
|
|
"""Insert multiple metrics in a single batch operation for performance.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
conn: Active SQLite database connection.
|
|
|
|
|
metrics: List of Metric objects to insert.
|
|
|
|
|
"""
|
|
|
|
|
if not metrics:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Prepare batch data following existing batch pattern
|
|
|
|
|
batch_data = [
|
|
|
|
|
(m.snapshot_id, m.timestamp, m.obi, m.cvd, m.best_bid, m.best_ask)
|
|
|
|
|
for m in metrics
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Use executemany for batch insertion
|
|
|
|
|
self.conn.executemany(
|
|
|
|
|
"INSERT INTO metrics (snapshot_id, timestamp, obi, cvd, best_bid, best_ask) VALUES (?, ?, ?, ?, ?, ?)",
|
|
|
|
|
batch_data
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logging.debug(f"Inserted {len(metrics)} metrics records")
|
|
|
|
|
|
|
|
|
|
except sqlite3.Error as e:
|
|
|
|
|
logging.error(f"Error inserting metrics batch: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def load_metrics_by_timerange(self, start_timestamp: int, end_timestamp: int) -> List[Metric]:
|
|
|
|
|
"""Load metrics within a specified timestamp range.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
conn: Active SQLite database connection.
|
|
|
|
|
start_timestamp: Start of the time range (inclusive).
|
|
|
|
|
end_timestamp: End of the time range (inclusive).
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List of Metric objects ordered by timestamp.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT snapshot_id, timestamp, obi, cvd, best_bid, best_ask FROM metrics WHERE timestamp >= ? AND timestamp <= ? ORDER BY timestamp ASC",
|
|
|
|
|
(start_timestamp, end_timestamp)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
metrics = []
|
|
|
|
|
for batch in iter(lambda: cursor.fetchmany(5000), []):
|
|
|
|
|
for snapshot_id, timestamp, obi, cvd, best_bid, best_ask in batch:
|
|
|
|
|
metric = Metric(
|
|
|
|
|
snapshot_id=int(snapshot_id),
|
|
|
|
|
timestamp=int(timestamp),
|
|
|
|
|
obi=float(obi),
|
|
|
|
|
cvd=float(cvd),
|
|
|
|
|
best_bid=float(best_bid) if best_bid is not None else None,
|
|
|
|
|
best_ask=float(best_ask) if best_ask is not None else None,
|
|
|
|
|
)
|
|
|
|
|
metrics.append(metric)
|
|
|
|
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
except sqlite3.Error as e:
|
|
|
|
|
logging.error(f"Error loading metrics by timerange: {e}")
|
|
|
|
|
return []
|