""" SQLAlchemy database models and session management for backtest run persistence. """ import json from datetime import datetime, timezone from pathlib import Path from sqlalchemy import JSON, Column, DateTime, Float, Integer, String, Text, create_engine from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker # Database file location DB_PATH = Path(__file__).parent.parent.parent / "data" / "backtest_runs.db" DATABASE_URL = f"sqlite:///{DB_PATH}" engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) class Base(DeclarativeBase): """Base class for SQLAlchemy models.""" pass class BacktestRun(Base): """ Persisted backtest run record. Stores all information needed to display and compare runs. """ __tablename__ = "backtest_runs" id = Column(Integer, primary_key=True, autoincrement=True) run_id = Column(String(36), unique=True, nullable=False, index=True) # Configuration strategy = Column(String(50), nullable=False, index=True) symbol = Column(String(20), nullable=False, index=True) exchange = Column(String(20), nullable=False, default="okx") market_type = Column(String(20), nullable=False) timeframe = Column(String(10), nullable=False) leverage = Column(Integer, nullable=False, default=1) params = Column(JSON, nullable=False, default=dict) # Date range start_date = Column(String(20), nullable=True) end_date = Column(String(20), nullable=True) # Metrics (denormalized for quick listing) total_return = Column(Float, nullable=False) benchmark_return = Column(Float, nullable=False, default=0.0) alpha = Column(Float, nullable=False, default=0.0) sharpe_ratio = Column(Float, nullable=False) max_drawdown = Column(Float, nullable=False) win_rate = Column(Float, nullable=False) total_trades = Column(Integer, nullable=False) profit_factor = Column(Float, nullable=True) total_fees = Column(Float, nullable=False, default=0.0) total_funding = Column(Float, nullable=False, default=0.0) liquidation_count = Column(Integer, nullable=False, default=0) liquidation_loss = Column(Float, nullable=False, default=0.0) adjusted_return = Column(Float, nullable=True) # Full data (JSON serialized) equity_curve = Column(Text, nullable=False) # JSON array trades = Column(Text, nullable=False) # JSON array # Metadata created_at = Column(DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)) def set_equity_curve(self, data: list[dict]): """Serialize equity curve to JSON string.""" self.equity_curve = json.dumps(data) def get_equity_curve(self) -> list[dict]: """Deserialize equity curve from JSON string.""" return json.loads(self.equity_curve) if self.equity_curve else [] def set_trades(self, data: list[dict]): """Serialize trades to JSON string.""" self.trades = json.dumps(data) def get_trades(self) -> list[dict]: """Deserialize trades from JSON string.""" return json.loads(self.trades) if self.trades else [] def init_db(): """Create database tables if they don't exist.""" DB_PATH.parent.mkdir(parents=True, exist_ok=True) Base.metadata.create_all(bind=engine) def get_db() -> Session: """Get database session (dependency injection).""" db = SessionLocal() try: yield db finally: db.close()