""" Backtest execution and history endpoints. """ from typing import Any from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session from api.models.database import get_db from api.models.schemas import ( BacktestListResponse, BacktestRequest, BacktestResult, CompareRequest, CompareResult, ) from api.services.runner import get_runner from api.services.storage import get_storage from engine.logging_config import get_logger router = APIRouter() logger = get_logger(__name__) @router.post("/backtest", response_model=BacktestResult) async def run_backtest( request: BacktestRequest, db: Session = Depends(get_db), ): """ Execute a backtest with the specified configuration. Runs the strategy on historical data and returns metrics, equity curve, and trade records. Results are automatically saved. """ runner = get_runner() storage = get_storage() try: # Execute backtest result = runner.run(request) # Save to database storage.save_run(db, result) logger.info( "Backtest completed and saved: %s (return=%.2f%%, sharpe=%.2f)", result.run_id, result.metrics.total_return, result.metrics.sharpe_ratio, ) return result except KeyError as e: raise HTTPException(status_code=400, detail=f"Invalid strategy: {e}") except FileNotFoundError as e: raise HTTPException(status_code=404, detail=f"Data not found: {e}") except Exception as e: logger.error("Backtest failed: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.get("/backtests", response_model=BacktestListResponse) async def list_backtests( limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0), strategy: str | None = None, symbol: str | None = None, db: Session = Depends(get_db), ): """ List saved backtest runs with optional filtering. Returns summaries for quick display in the history sidebar. """ storage = get_storage() runs, total = storage.list_runs( db, limit=limit, offset=offset, strategy=strategy, symbol=symbol, ) return BacktestListResponse(runs=runs, total=total) @router.get("/backtest/{run_id}", response_model=BacktestResult) async def get_backtest( run_id: str, db: Session = Depends(get_db), ): """ Retrieve a specific backtest run by ID. Returns full result including equity curve and trades. """ storage = get_storage() result = storage.get_run(db, run_id) if not result: raise HTTPException(status_code=404, detail=f"Run not found: {run_id}") return result @router.delete("/backtest/{run_id}") async def delete_backtest( run_id: str, db: Session = Depends(get_db), ): """ Delete a backtest run. """ storage = get_storage() deleted = storage.delete_run(db, run_id) if not deleted: raise HTTPException(status_code=404, detail=f"Run not found: {run_id}") return {"status": "deleted", "run_id": run_id} @router.post("/compare", response_model=CompareResult) async def compare_runs( request: CompareRequest, db: Session = Depends(get_db), ): """ Compare multiple backtest runs (2-5 runs). Returns full results for each run plus parameter differences. """ storage = get_storage() runs = storage.get_runs_by_ids(db, request.run_ids) if len(runs) != len(request.run_ids): found_ids = {r.run_id for r in runs} missing = [rid for rid in request.run_ids if rid not in found_ids] raise HTTPException( status_code=404, detail=f"Runs not found: {missing}" ) # Calculate parameter differences param_diff = _calculate_param_diff(runs) return CompareResult(runs=runs, param_diff=param_diff) def _calculate_param_diff(runs: list[BacktestResult]) -> dict[str, list[Any]]: """ Find parameters that differ between runs. Returns dict mapping param name to list of values (one per run). """ if not runs: return {} # Collect all param keys all_keys: set[str] = set() for run in runs: all_keys.update(run.params.keys()) # Also include strategy and key config all_keys.update(['strategy', 'symbol', 'leverage', 'timeframe']) diff: dict[str, list[Any]] = {} for key in sorted(all_keys): values = [] for run in runs: if key == 'strategy': values.append(run.strategy) elif key == 'symbol': values.append(run.symbol) elif key == 'leverage': values.append(run.leverage) elif key == 'timeframe': values.append(run.timeframe) else: values.append(run.params.get(key)) # Only include if values differ if len(set(str(v) for v in values)) > 1: diff[key] = values return diff