""" Storage service for persisting and retrieving backtest runs. """ from sqlalchemy.orm import Session from api.models.database import BacktestRun from api.models.schemas import ( BacktestResult, BacktestSummary, EquityPoint, BacktestMetrics, TradeRecord, ) class StorageService: """ Service for saving and loading backtest runs from SQLite. """ def save_run(self, db: Session, result: BacktestResult) -> BacktestRun: """ Persist a backtest result to the database. Args: db: Database session result: BacktestResult to save Returns: Created BacktestRun record """ run = BacktestRun( run_id=result.run_id, strategy=result.strategy, symbol=result.symbol, market_type=result.market_type, timeframe=result.timeframe, leverage=result.leverage, params=result.params, start_date=result.start_date, end_date=result.end_date, total_return=result.metrics.total_return, benchmark_return=result.metrics.benchmark_return, alpha=result.metrics.alpha, sharpe_ratio=result.metrics.sharpe_ratio, max_drawdown=result.metrics.max_drawdown, win_rate=result.metrics.win_rate, total_trades=result.metrics.total_trades, profit_factor=result.metrics.profit_factor, total_fees=result.metrics.total_fees, total_funding=result.metrics.total_funding, liquidation_count=result.metrics.liquidation_count, liquidation_loss=result.metrics.liquidation_loss, adjusted_return=result.metrics.adjusted_return, ) # Serialize complex data run.set_equity_curve([p.model_dump() for p in result.equity_curve]) run.set_trades([t.model_dump() for t in result.trades]) db.add(run) db.commit() db.refresh(run) return run def get_run(self, db: Session, run_id: str) -> BacktestResult | None: """ Retrieve a backtest run by ID. Args: db: Database session run_id: UUID of the run Returns: BacktestResult or None if not found """ run = db.query(BacktestRun).filter(BacktestRun.run_id == run_id).first() if not run: return None return self._to_result(run) def list_runs( self, db: Session, limit: int = 50, offset: int = 0, strategy: str | None = None, symbol: str | None = None, ) -> tuple[list[BacktestSummary], int]: """ List backtest runs with optional filtering. Args: db: Database session limit: Maximum number of runs to return offset: Offset for pagination strategy: Filter by strategy name symbol: Filter by symbol Returns: Tuple of (list of summaries, total count) """ query = db.query(BacktestRun) if strategy: query = query.filter(BacktestRun.strategy == strategy) if symbol: query = query.filter(BacktestRun.symbol == symbol) total = query.count() runs = query.order_by(BacktestRun.created_at.desc()).offset(offset).limit(limit).all() summaries = [self._to_summary(run) for run in runs] return summaries, total def get_runs_by_ids(self, db: Session, run_ids: list[str]) -> list[BacktestResult]: """ Retrieve multiple runs by their IDs. Args: db: Database session run_ids: List of run UUIDs Returns: List of BacktestResults (preserves order) """ runs = db.query(BacktestRun).filter(BacktestRun.run_id.in_(run_ids)).all() # Create lookup and preserve order run_map = {run.run_id: run for run in runs} results = [] for run_id in run_ids: if run_id in run_map: results.append(self._to_result(run_map[run_id])) return results def delete_run(self, db: Session, run_id: str) -> bool: """ Delete a backtest run. Args: db: Database session run_id: UUID of the run to delete Returns: True if deleted, False if not found """ run = db.query(BacktestRun).filter(BacktestRun.run_id == run_id).first() if not run: return False db.delete(run) db.commit() return True def _to_result(self, run: BacktestRun) -> BacktestResult: """Convert database record to BacktestResult schema.""" equity_data = run.get_equity_curve() trades_data = run.get_trades() return BacktestResult( run_id=run.run_id, strategy=run.strategy, symbol=run.symbol, market_type=run.market_type, timeframe=run.timeframe, start_date=run.start_date or "", end_date=run.end_date or "", leverage=run.leverage, params=run.params or {}, metrics=BacktestMetrics( total_return=run.total_return, benchmark_return=run.benchmark_return or 0.0, alpha=run.alpha or 0.0, sharpe_ratio=run.sharpe_ratio, max_drawdown=run.max_drawdown, win_rate=run.win_rate, total_trades=run.total_trades, profit_factor=run.profit_factor, total_fees=run.total_fees, total_funding=run.total_funding, liquidation_count=run.liquidation_count, liquidation_loss=run.liquidation_loss, adjusted_return=run.adjusted_return, ), equity_curve=[EquityPoint(**p) for p in equity_data], trades=[TradeRecord(**t) for t in trades_data], created_at=run.created_at.isoformat() if run.created_at else "", ) def _to_summary(self, run: BacktestRun) -> BacktestSummary: """Convert database record to BacktestSummary schema.""" return BacktestSummary( run_id=run.run_id, strategy=run.strategy, symbol=run.symbol, market_type=run.market_type, timeframe=run.timeframe, total_return=run.total_return, sharpe_ratio=run.sharpe_ratio, max_drawdown=run.max_drawdown, total_trades=run.total_trades, created_at=run.created_at.isoformat() if run.created_at else "", params=run.params or {}, ) # Singleton instance _storage: StorageService | None = None def get_storage() -> StorageService: """Get or create the storage service instance.""" global _storage if _storage is None: _storage = StorageService() return _storage