Files
lowkey_backtest/api/models/database.py

100 lines
3.5 KiB
Python
Raw Normal View History

"""
SQLAlchemy database models and session management for backtest run persistence.
"""
import json
from datetime import datetime, timezone
from pathlib import Path
from sqlalchemy import JSON, Column, DateTime, Float, Integer, String, Text, create_engine
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
# Database file location
DB_PATH = Path(__file__).parent.parent.parent / "data" / "backtest_runs.db"
DATABASE_URL = f"sqlite:///{DB_PATH}"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
class Base(DeclarativeBase):
"""Base class for SQLAlchemy models."""
pass
class BacktestRun(Base):
"""
Persisted backtest run record.
Stores all information needed to display and compare runs.
"""
__tablename__ = "backtest_runs"
id = Column(Integer, primary_key=True, autoincrement=True)
run_id = Column(String(36), unique=True, nullable=False, index=True)
# Configuration
strategy = Column(String(50), nullable=False, index=True)
symbol = Column(String(20), nullable=False, index=True)
exchange = Column(String(20), nullable=False, default="okx")
market_type = Column(String(20), nullable=False)
timeframe = Column(String(10), nullable=False)
leverage = Column(Integer, nullable=False, default=1)
params = Column(JSON, nullable=False, default=dict)
# Date range
start_date = Column(String(20), nullable=True)
end_date = Column(String(20), nullable=True)
# Metrics (denormalized for quick listing)
total_return = Column(Float, nullable=False)
benchmark_return = Column(Float, nullable=False, default=0.0)
alpha = Column(Float, nullable=False, default=0.0)
sharpe_ratio = Column(Float, nullable=False)
max_drawdown = Column(Float, nullable=False)
win_rate = Column(Float, nullable=False)
total_trades = Column(Integer, nullable=False)
profit_factor = Column(Float, nullable=True)
total_fees = Column(Float, nullable=False, default=0.0)
total_funding = Column(Float, nullable=False, default=0.0)
liquidation_count = Column(Integer, nullable=False, default=0)
liquidation_loss = Column(Float, nullable=False, default=0.0)
adjusted_return = Column(Float, nullable=True)
# Full data (JSON serialized)
equity_curve = Column(Text, nullable=False) # JSON array
trades = Column(Text, nullable=False) # JSON array
# Metadata
created_at = Column(DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
def set_equity_curve(self, data: list[dict]):
"""Serialize equity curve to JSON string."""
self.equity_curve = json.dumps(data)
def get_equity_curve(self) -> list[dict]:
"""Deserialize equity curve from JSON string."""
return json.loads(self.equity_curve) if self.equity_curve else []
def set_trades(self, data: list[dict]):
"""Serialize trades to JSON string."""
self.trades = json.dumps(data)
def get_trades(self) -> list[dict]:
"""Deserialize trades from JSON string."""
return json.loads(self.trades) if self.trades else []
def init_db():
"""Create database tables if they don't exist."""
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
Base.metadata.create_all(bind=engine)
def get_db() -> Session:
"""Get database session (dependency injection)."""
db = SessionLocal()
try:
yield db
finally:
db.close()