orderflow_backtest/tests/test_metrics_repository.py

127 lines
4.9 KiB
Python
Raw Normal View History

"""Tests for SQLiteMetricsRepository table creation and schema validation."""
import sys
import sqlite3
import tempfile
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[1]))
from repositories.sqlite_metrics_repository import SQLiteMetricsRepository
from models import Metric
def test_create_metrics_table():
"""Test that metrics table is created with proper schema and indexes."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
db_path = Path(tmp_file.name)
try:
repo = SQLiteMetricsRepository(db_path)
with repo.connect() as conn:
# Create metrics table
repo.create_metrics_table(conn)
# Verify table exists
assert repo.table_exists(conn, "metrics")
# Verify table schema
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(metrics)")
columns = cursor.fetchall()
# Check expected columns exist
column_names = [col[1] for col in columns]
expected_columns = ["id", "snapshot_id", "timestamp", "obi", "cvd", "best_bid", "best_ask"]
for col in expected_columns:
assert col in column_names, f"Column {col} missing from metrics table"
# Verify indexes exist
cursor.execute("PRAGMA index_list(metrics)")
indexes = cursor.fetchall()
index_names = [idx[1] for idx in indexes]
assert "idx_metrics_timestamp" in index_names
assert "idx_metrics_snapshot_id" in index_names
finally:
db_path.unlink(missing_ok=True)
def test_insert_metrics_batch():
"""Test batch insertion of metrics data."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
db_path = Path(tmp_file.name)
try:
repo = SQLiteMetricsRepository(db_path)
with repo.connect() as conn:
# Create metrics table
repo.create_metrics_table(conn)
# Create test metrics
metrics = [
Metric(snapshot_id=1, timestamp=1000, obi=0.5, cvd=100.0, best_bid=50000.0, best_ask=50001.0),
Metric(snapshot_id=2, timestamp=1001, obi=-0.2, cvd=150.0, best_bid=50002.0, best_ask=50003.0),
Metric(snapshot_id=3, timestamp=1002, obi=0.0, cvd=125.0), # No best_bid/ask
]
# Insert batch
repo.insert_metrics_batch(conn, metrics)
conn.commit()
# Verify insertion
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM metrics")
count = cursor.fetchone()[0]
assert count == 3
# Verify data integrity
cursor.execute("SELECT snapshot_id, timestamp, obi, cvd, best_bid, best_ask FROM metrics ORDER BY timestamp")
rows = cursor.fetchall()
assert rows[0] == (1, "1000", 0.5, 100.0, 50000.0, 50001.0)
assert rows[1] == (2, "1001", -0.2, 150.0, 50002.0, 50003.0)
assert rows[2] == (3, "1002", 0.0, 125.0, None, None)
finally:
db_path.unlink(missing_ok=True)
def test_load_metrics_by_timerange():
"""Test loading metrics within a timestamp range."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
db_path = Path(tmp_file.name)
try:
repo = SQLiteMetricsRepository(db_path)
with repo.connect() as conn:
# Create metrics table and insert test data
repo.create_metrics_table(conn)
metrics = [
Metric(snapshot_id=1, timestamp=1000, obi=0.1, cvd=10.0, best_bid=50000.0, best_ask=50001.0),
Metric(snapshot_id=2, timestamp=1005, obi=0.2, cvd=20.0, best_bid=50002.0, best_ask=50003.0),
Metric(snapshot_id=3, timestamp=1010, obi=0.3, cvd=30.0, best_bid=50004.0, best_ask=50005.0),
Metric(snapshot_id=4, timestamp=1015, obi=0.4, cvd=40.0, best_bid=50006.0, best_ask=50007.0),
]
repo.insert_metrics_batch(conn, metrics)
conn.commit()
# Test timerange query - should get middle 2 records
loaded_metrics = repo.load_metrics_by_timerange(conn, 1003, 1012)
assert len(loaded_metrics) == 2
assert loaded_metrics[0].timestamp == 1005
assert loaded_metrics[0].obi == 0.2
assert loaded_metrics[1].timestamp == 1010
assert loaded_metrics[1].obi == 0.3
# Test edge cases
assert len(repo.load_metrics_by_timerange(conn, 2000, 3000)) == 0 # No data
assert len(repo.load_metrics_by_timerange(conn, 1000, 1000)) == 1 # Single record
finally:
db_path.unlink(missing_ok=True)