Implement FastAPI backend and Vue 3 frontend for Lowkey Backtest UI
- Added FastAPI backend with core API endpoints for strategies, backtests, and data management. - Introduced Vue 3 frontend with a dark theme, enabling users to run backtests, adjust parameters, and compare results. - Implemented Pydantic schemas for request/response validation and SQLAlchemy models for database interactions. - Enhanced project structure with dedicated modules for services, routers, and components. - Updated dependencies in `pyproject.toml` and `frontend/package.json` to include FastAPI, SQLAlchemy, and Vue-related packages. - Improved `.gitignore` to exclude unnecessary files and directories.
This commit is contained in:
3
api/__init__.py
Normal file
3
api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
FastAPI backend for Lowkey Backtest UI.
|
||||
"""
|
||||
47
api/main.py
Normal file
47
api/main.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
FastAPI application entry point for Lowkey Backtest UI.
|
||||
|
||||
Run with: uvicorn api.main:app --reload
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from api.models.database import init_db
|
||||
from api.routers import backtest, data, strategies
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize database on startup."""
|
||||
init_db()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Lowkey Backtest API",
|
||||
description="API for running and analyzing trading strategy backtests",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS configuration for local development
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Register routers
|
||||
app.include_router(strategies.router, prefix="/api", tags=["strategies"])
|
||||
app.include_router(data.router, prefix="/api", tags=["data"])
|
||||
app.include_router(backtest.router, prefix="/api", tags=["backtest"])
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok", "service": "lowkey-backtest-api"}
|
||||
3
api/models/__init__.py
Normal file
3
api/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Pydantic schemas and database models.
|
||||
"""
|
||||
99
api/models/database.py
Normal file
99
api/models/database.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
SQLAlchemy database models and session management for backtest run persistence.
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime, Float, Integer, String, Text, create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
# Database file location
|
||||
DB_PATH = Path(__file__).parent.parent.parent / "data" / "backtest_runs.db"
|
||||
DATABASE_URL = f"sqlite:///{DB_PATH}"
|
||||
|
||||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for SQLAlchemy models."""
|
||||
pass
|
||||
|
||||
|
||||
class BacktestRun(Base):
|
||||
"""
|
||||
Persisted backtest run record.
|
||||
|
||||
Stores all information needed to display and compare runs.
|
||||
"""
|
||||
__tablename__ = "backtest_runs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
run_id = Column(String(36), unique=True, nullable=False, index=True)
|
||||
|
||||
# Configuration
|
||||
strategy = Column(String(50), nullable=False, index=True)
|
||||
symbol = Column(String(20), nullable=False, index=True)
|
||||
exchange = Column(String(20), nullable=False, default="okx")
|
||||
market_type = Column(String(20), nullable=False)
|
||||
timeframe = Column(String(10), nullable=False)
|
||||
leverage = Column(Integer, nullable=False, default=1)
|
||||
params = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Date range
|
||||
start_date = Column(String(20), nullable=True)
|
||||
end_date = Column(String(20), nullable=True)
|
||||
|
||||
# Metrics (denormalized for quick listing)
|
||||
total_return = Column(Float, nullable=False)
|
||||
benchmark_return = Column(Float, nullable=False, default=0.0)
|
||||
alpha = Column(Float, nullable=False, default=0.0)
|
||||
sharpe_ratio = Column(Float, nullable=False)
|
||||
max_drawdown = Column(Float, nullable=False)
|
||||
win_rate = Column(Float, nullable=False)
|
||||
total_trades = Column(Integer, nullable=False)
|
||||
profit_factor = Column(Float, nullable=True)
|
||||
total_fees = Column(Float, nullable=False, default=0.0)
|
||||
total_funding = Column(Float, nullable=False, default=0.0)
|
||||
liquidation_count = Column(Integer, nullable=False, default=0)
|
||||
liquidation_loss = Column(Float, nullable=False, default=0.0)
|
||||
adjusted_return = Column(Float, nullable=True)
|
||||
|
||||
# Full data (JSON serialized)
|
||||
equity_curve = Column(Text, nullable=False) # JSON array
|
||||
trades = Column(Text, nullable=False) # JSON array
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def set_equity_curve(self, data: list[dict]):
|
||||
"""Serialize equity curve to JSON string."""
|
||||
self.equity_curve = json.dumps(data)
|
||||
|
||||
def get_equity_curve(self) -> list[dict]:
|
||||
"""Deserialize equity curve from JSON string."""
|
||||
return json.loads(self.equity_curve) if self.equity_curve else []
|
||||
|
||||
def set_trades(self, data: list[dict]):
|
||||
"""Serialize trades to JSON string."""
|
||||
self.trades = json.dumps(data)
|
||||
|
||||
def get_trades(self) -> list[dict]:
|
||||
"""Deserialize trades from JSON string."""
|
||||
return json.loads(self.trades) if self.trades else []
|
||||
|
||||
|
||||
def init_db():
|
||||
"""Create database tables if they don't exist."""
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_db() -> Session:
|
||||
"""Get database session (dependency injection)."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
162
api/models/schemas.py
Normal file
162
api/models/schemas.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Pydantic schemas for API request/response models.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# --- Strategy Schemas ---
|
||||
|
||||
class StrategyParam(BaseModel):
|
||||
"""Single strategy parameter definition."""
|
||||
name: str
|
||||
value: Any
|
||||
param_type: str = Field(description="Type: int, float, bool, list")
|
||||
min_value: float | None = None
|
||||
max_value: float | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class StrategyInfo(BaseModel):
|
||||
"""Strategy information with parameters."""
|
||||
name: str
|
||||
display_name: str
|
||||
market_type: str
|
||||
default_leverage: int
|
||||
default_params: dict[str, Any]
|
||||
grid_params: dict[str, Any]
|
||||
|
||||
|
||||
class StrategiesResponse(BaseModel):
|
||||
"""Response for GET /api/strategies."""
|
||||
strategies: list[StrategyInfo]
|
||||
|
||||
|
||||
# --- Symbol/Data Schemas ---
|
||||
|
||||
class SymbolInfo(BaseModel):
|
||||
"""Available symbol information."""
|
||||
symbol: str
|
||||
exchange: str
|
||||
market_type: str
|
||||
timeframes: list[str]
|
||||
start_date: str | None = None
|
||||
end_date: str | None = None
|
||||
row_count: int = 0
|
||||
|
||||
|
||||
class DataStatusResponse(BaseModel):
|
||||
"""Response for GET /api/data/status."""
|
||||
symbols: list[SymbolInfo]
|
||||
|
||||
|
||||
# --- Backtest Schemas ---
|
||||
|
||||
class BacktestRequest(BaseModel):
|
||||
"""Request body for POST /api/backtest."""
|
||||
strategy: str
|
||||
symbol: str
|
||||
exchange: str = "okx"
|
||||
timeframe: str = "1h"
|
||||
market_type: str = "perpetual"
|
||||
start_date: str | None = None
|
||||
end_date: str | None = None
|
||||
init_cash: float = 10000.0
|
||||
leverage: int | None = None
|
||||
fees: float | None = None
|
||||
slippage: float = 0.001
|
||||
sl_stop: float | None = None
|
||||
tp_stop: float | None = None
|
||||
sl_trail: bool = False
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TradeRecord(BaseModel):
|
||||
"""Single trade record."""
|
||||
entry_time: str
|
||||
exit_time: str | None = None
|
||||
entry_price: float
|
||||
exit_price: float | None = None
|
||||
size: float
|
||||
direction: str
|
||||
pnl: float | None = None
|
||||
return_pct: float | None = None
|
||||
status: str = "closed"
|
||||
|
||||
|
||||
class EquityPoint(BaseModel):
|
||||
"""Single point on equity curve."""
|
||||
timestamp: str
|
||||
value: float
|
||||
drawdown: float = 0.0
|
||||
|
||||
|
||||
class BacktestMetrics(BaseModel):
|
||||
"""Backtest performance metrics."""
|
||||
total_return: float
|
||||
benchmark_return: float = 0.0
|
||||
alpha: float = 0.0
|
||||
sharpe_ratio: float
|
||||
max_drawdown: float
|
||||
win_rate: float
|
||||
total_trades: int
|
||||
profit_factor: float | None = None
|
||||
avg_trade_return: float | None = None
|
||||
total_fees: float = 0.0
|
||||
total_funding: float = 0.0
|
||||
liquidation_count: int = 0
|
||||
liquidation_loss: float = 0.0
|
||||
adjusted_return: float | None = None
|
||||
|
||||
|
||||
class BacktestResult(BaseModel):
|
||||
"""Complete backtest result."""
|
||||
run_id: str
|
||||
strategy: str
|
||||
symbol: str
|
||||
market_type: str
|
||||
timeframe: str
|
||||
start_date: str
|
||||
end_date: str
|
||||
leverage: int
|
||||
params: dict[str, Any]
|
||||
metrics: BacktestMetrics
|
||||
equity_curve: list[EquityPoint]
|
||||
trades: list[TradeRecord]
|
||||
created_at: str
|
||||
|
||||
|
||||
class BacktestSummary(BaseModel):
|
||||
"""Summary for backtest list view."""
|
||||
run_id: str
|
||||
strategy: str
|
||||
symbol: str
|
||||
market_type: str
|
||||
timeframe: str
|
||||
total_return: float
|
||||
sharpe_ratio: float
|
||||
max_drawdown: float
|
||||
total_trades: int
|
||||
created_at: str
|
||||
params: dict[str, Any]
|
||||
|
||||
|
||||
class BacktestListResponse(BaseModel):
|
||||
"""Response for GET /api/backtests."""
|
||||
runs: list[BacktestSummary]
|
||||
total: int
|
||||
|
||||
|
||||
# --- Comparison Schemas ---
|
||||
|
||||
class CompareRequest(BaseModel):
|
||||
"""Request body for POST /api/compare."""
|
||||
run_ids: list[str] = Field(min_length=2, max_length=5)
|
||||
|
||||
|
||||
class CompareResult(BaseModel):
|
||||
"""Comparison of multiple backtest runs."""
|
||||
runs: list[BacktestResult]
|
||||
param_diff: dict[str, list[Any]]
|
||||
3
api/routers/__init__.py
Normal file
3
api/routers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
API routers for backtest, strategies, and data endpoints.
|
||||
"""
|
||||
193
api/routers/backtest.py
Normal file
193
api/routers/backtest.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Backtest execution and history endpoints.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.models.database import get_db
|
||||
from api.models.schemas import (
|
||||
BacktestListResponse,
|
||||
BacktestRequest,
|
||||
BacktestResult,
|
||||
CompareRequest,
|
||||
CompareResult,
|
||||
)
|
||||
from api.services.runner import get_runner
|
||||
from api.services.storage import get_storage
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post("/backtest", response_model=BacktestResult)
|
||||
async def run_backtest(
|
||||
request: BacktestRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Execute a backtest with the specified configuration.
|
||||
|
||||
Runs the strategy on historical data and returns metrics,
|
||||
equity curve, and trade records. Results are automatically saved.
|
||||
"""
|
||||
runner = get_runner()
|
||||
storage = get_storage()
|
||||
|
||||
try:
|
||||
# Execute backtest
|
||||
result = runner.run(request)
|
||||
|
||||
# Save to database
|
||||
storage.save_run(db, result)
|
||||
|
||||
logger.info(
|
||||
"Backtest completed and saved: %s (return=%.2f%%, sharpe=%.2f)",
|
||||
result.run_id,
|
||||
result.metrics.total_return,
|
||||
result.metrics.sharpe_ratio,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid strategy: {e}")
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=f"Data not found: {e}")
|
||||
except Exception as e:
|
||||
logger.error("Backtest failed: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/backtests", response_model=BacktestListResponse)
|
||||
async def list_backtests(
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
strategy: str | None = None,
|
||||
symbol: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List saved backtest runs with optional filtering.
|
||||
|
||||
Returns summaries for quick display in the history sidebar.
|
||||
"""
|
||||
storage = get_storage()
|
||||
|
||||
runs, total = storage.list_runs(
|
||||
db,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
strategy=strategy,
|
||||
symbol=symbol,
|
||||
)
|
||||
|
||||
return BacktestListResponse(runs=runs, total=total)
|
||||
|
||||
|
||||
@router.get("/backtest/{run_id}", response_model=BacktestResult)
|
||||
async def get_backtest(
|
||||
run_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve a specific backtest run by ID.
|
||||
|
||||
Returns full result including equity curve and trades.
|
||||
"""
|
||||
storage = get_storage()
|
||||
|
||||
result = storage.get_run(db, run_id)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/backtest/{run_id}")
|
||||
async def delete_backtest(
|
||||
run_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete a backtest run.
|
||||
"""
|
||||
storage = get_storage()
|
||||
|
||||
deleted = storage.delete_run(db, run_id)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
|
||||
|
||||
return {"status": "deleted", "run_id": run_id}
|
||||
|
||||
|
||||
@router.post("/compare", response_model=CompareResult)
|
||||
async def compare_runs(
|
||||
request: CompareRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Compare multiple backtest runs (2-5 runs).
|
||||
|
||||
Returns full results for each run plus parameter differences.
|
||||
"""
|
||||
storage = get_storage()
|
||||
|
||||
runs = storage.get_runs_by_ids(db, request.run_ids)
|
||||
|
||||
if len(runs) != len(request.run_ids):
|
||||
found_ids = {r.run_id for r in runs}
|
||||
missing = [rid for rid in request.run_ids if rid not in found_ids]
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Runs not found: {missing}"
|
||||
)
|
||||
|
||||
# Calculate parameter differences
|
||||
param_diff = _calculate_param_diff(runs)
|
||||
|
||||
return CompareResult(runs=runs, param_diff=param_diff)
|
||||
|
||||
|
||||
def _calculate_param_diff(runs: list[BacktestResult]) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Find parameters that differ between runs.
|
||||
|
||||
Returns dict mapping param name to list of values (one per run).
|
||||
"""
|
||||
if not runs:
|
||||
return {}
|
||||
|
||||
# Collect all param keys
|
||||
all_keys: set[str] = set()
|
||||
for run in runs:
|
||||
all_keys.update(run.params.keys())
|
||||
|
||||
# Also include strategy and key config
|
||||
all_keys.update(['strategy', 'symbol', 'leverage', 'timeframe'])
|
||||
|
||||
diff: dict[str, list[Any]] = {}
|
||||
|
||||
for key in sorted(all_keys):
|
||||
values = []
|
||||
for run in runs:
|
||||
if key == 'strategy':
|
||||
values.append(run.strategy)
|
||||
elif key == 'symbol':
|
||||
values.append(run.symbol)
|
||||
elif key == 'leverage':
|
||||
values.append(run.leverage)
|
||||
elif key == 'timeframe':
|
||||
values.append(run.timeframe)
|
||||
else:
|
||||
values.append(run.params.get(key))
|
||||
|
||||
# Only include if values differ
|
||||
if len(set(str(v) for v in values)) > 1:
|
||||
diff[key] = values
|
||||
|
||||
return diff
|
||||
97
api/routers/data.py
Normal file
97
api/routers/data.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Data status and symbol information endpoints.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from fastapi import APIRouter
|
||||
|
||||
from api.models.schemas import DataStatusResponse, SymbolInfo
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Base path for CCXT data
|
||||
DATA_BASE = Path(__file__).parent.parent.parent / "data" / "ccxt"
|
||||
|
||||
|
||||
def _scan_available_data() -> list[SymbolInfo]:
|
||||
"""
|
||||
Scan the data directory for available symbols and timeframes.
|
||||
|
||||
Returns list of SymbolInfo with date ranges and row counts.
|
||||
"""
|
||||
symbols = []
|
||||
|
||||
if not DATA_BASE.exists():
|
||||
return symbols
|
||||
|
||||
# Structure: data/ccxt/{exchange}/{market_type}/{symbol}/{timeframe}.csv
|
||||
for exchange_dir in DATA_BASE.iterdir():
|
||||
if not exchange_dir.is_dir():
|
||||
continue
|
||||
exchange = exchange_dir.name
|
||||
|
||||
for market_dir in exchange_dir.iterdir():
|
||||
if not market_dir.is_dir():
|
||||
continue
|
||||
market_type = market_dir.name
|
||||
|
||||
for symbol_dir in market_dir.iterdir():
|
||||
if not symbol_dir.is_dir():
|
||||
continue
|
||||
symbol = symbol_dir.name
|
||||
|
||||
# Find all timeframes
|
||||
timeframes = []
|
||||
start_date = None
|
||||
end_date = None
|
||||
row_count = 0
|
||||
|
||||
for csv_file in symbol_dir.glob("*.csv"):
|
||||
tf = csv_file.stem
|
||||
timeframes.append(tf)
|
||||
|
||||
# Read first and last rows for date range
|
||||
try:
|
||||
df = pd.read_csv(csv_file, parse_dates=['timestamp'])
|
||||
if not df.empty:
|
||||
row_count = len(df)
|
||||
start_date = df['timestamp'].min().strftime("%Y-%m-%d")
|
||||
end_date = df['timestamp'].max().strftime("%Y-%m-%d")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if timeframes:
|
||||
symbols.append(SymbolInfo(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
market_type=market_type,
|
||||
timeframes=sorted(timeframes),
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
row_count=row_count,
|
||||
))
|
||||
|
||||
return symbols
|
||||
|
||||
|
||||
@router.get("/symbols", response_model=DataStatusResponse)
|
||||
async def get_symbols():
|
||||
"""
|
||||
Get list of available symbols with their data ranges.
|
||||
|
||||
Scans the local data directory for downloaded OHLCV data.
|
||||
"""
|
||||
symbols = _scan_available_data()
|
||||
return DataStatusResponse(symbols=symbols)
|
||||
|
||||
|
||||
@router.get("/data/status", response_model=DataStatusResponse)
|
||||
async def get_data_status():
|
||||
"""
|
||||
Get detailed data inventory status.
|
||||
|
||||
Alias for /symbols with additional metadata.
|
||||
"""
|
||||
symbols = _scan_available_data()
|
||||
return DataStatusResponse(symbols=symbols)
|
||||
67
api/routers/strategies.py
Normal file
67
api/routers/strategies.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Strategy information endpoints.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter
|
||||
|
||||
from api.models.schemas import StrategiesResponse, StrategyInfo
|
||||
from strategies.factory import get_registry
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _serialize_param_value(value: Any) -> Any:
|
||||
"""Convert numpy arrays and other types to JSON-serializable format."""
|
||||
if isinstance(value, np.ndarray):
|
||||
return value.tolist()
|
||||
if isinstance(value, (np.integer, np.floating)):
|
||||
return value.item()
|
||||
return value
|
||||
|
||||
|
||||
def _get_display_name(name: str) -> str:
|
||||
"""Convert strategy key to display name."""
|
||||
display_names = {
|
||||
"rsi": "RSI Strategy",
|
||||
"macross": "MA Crossover",
|
||||
"meta_st": "Meta Supertrend",
|
||||
"regime": "Regime Reversion (ML)",
|
||||
}
|
||||
return display_names.get(name, name.replace("_", " ").title())
|
||||
|
||||
|
||||
@router.get("/strategies", response_model=StrategiesResponse)
|
||||
async def get_strategies():
|
||||
"""
|
||||
Get list of available strategies with their parameters.
|
||||
|
||||
Returns strategy names, default parameters, and grid search ranges.
|
||||
"""
|
||||
registry = get_registry()
|
||||
strategies = []
|
||||
|
||||
for name, config in registry.items():
|
||||
strategy_instance = config.strategy_class()
|
||||
|
||||
# Serialize parameters (convert numpy arrays to lists)
|
||||
default_params = {
|
||||
k: _serialize_param_value(v)
|
||||
for k, v in config.default_params.items()
|
||||
}
|
||||
grid_params = {
|
||||
k: _serialize_param_value(v)
|
||||
for k, v in config.grid_params.items()
|
||||
}
|
||||
|
||||
strategies.append(StrategyInfo(
|
||||
name=name,
|
||||
display_name=_get_display_name(name),
|
||||
market_type=strategy_instance.default_market_type.value,
|
||||
default_leverage=strategy_instance.default_leverage,
|
||||
default_params=default_params,
|
||||
grid_params=grid_params,
|
||||
))
|
||||
|
||||
return StrategiesResponse(strategies=strategies)
|
||||
3
api/services/__init__.py
Normal file
3
api/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Business logic services for backtest execution and storage.
|
||||
"""
|
||||
300
api/services/runner.py
Normal file
300
api/services/runner.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Backtest runner service that wraps the existing Backtester engine.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from api.models.schemas import (
|
||||
BacktestMetrics,
|
||||
BacktestRequest,
|
||||
BacktestResult,
|
||||
EquityPoint,
|
||||
TradeRecord,
|
||||
)
|
||||
from engine.backtester import Backtester
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketType
|
||||
from strategies.factory import get_strategy
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BacktestRunner:
|
||||
"""
|
||||
Service for executing backtests via the API.
|
||||
|
||||
Wraps the existing Backtester engine and converts results
|
||||
to API response format.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.dm = DataManager()
|
||||
self.bt = Backtester(self.dm)
|
||||
|
||||
def run(self, request: BacktestRequest) -> BacktestResult:
|
||||
"""
|
||||
Execute a backtest and return structured results.
|
||||
|
||||
Args:
|
||||
request: BacktestRequest with strategy, symbol, and parameters
|
||||
|
||||
Returns:
|
||||
BacktestResult with metrics, equity curve, and trades
|
||||
"""
|
||||
# Get strategy instance
|
||||
strategy, default_params = get_strategy(request.strategy, is_grid=False)
|
||||
|
||||
# Merge default params with request params
|
||||
params = {**default_params, **request.params}
|
||||
|
||||
# Convert market type string to enum
|
||||
market_type = MarketType(request.market_type)
|
||||
|
||||
# Override strategy market type if specified
|
||||
strategy.default_market_type = market_type
|
||||
|
||||
logger.info(
|
||||
"Running backtest: %s on %s (%s), params=%s",
|
||||
request.strategy, request.symbol, request.timeframe, params
|
||||
)
|
||||
|
||||
# Execute backtest
|
||||
result = self.bt.run_strategy(
|
||||
strategy=strategy,
|
||||
exchange_id=request.exchange,
|
||||
symbol=request.symbol,
|
||||
timeframe=request.timeframe,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
init_cash=request.init_cash,
|
||||
fees=request.fees,
|
||||
slippage=request.slippage,
|
||||
sl_stop=request.sl_stop,
|
||||
tp_stop=request.tp_stop,
|
||||
sl_trail=request.sl_trail,
|
||||
leverage=request.leverage,
|
||||
**params
|
||||
)
|
||||
|
||||
# Extract data from portfolio
|
||||
portfolio = result.portfolio
|
||||
|
||||
# Build trade records
|
||||
trades = self._build_trade_records(portfolio)
|
||||
|
||||
# Build equity curve (trimmed to trading period)
|
||||
equity_curve = self._build_equity_curve(portfolio)
|
||||
|
||||
# Build metrics
|
||||
metrics = self._build_metrics(result, portfolio)
|
||||
|
||||
# Get date range from actual trading period (first trade to end)
|
||||
idx = portfolio.wrapper.index
|
||||
end_date = idx[-1].strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
# Use first trade time as start if trades exist
|
||||
trades_df = portfolio.trades.records_readable
|
||||
if not trades_df.empty:
|
||||
first_entry_col = 'Entry Timestamp' if 'Entry Timestamp' in trades_df.columns else 'Entry Time'
|
||||
if first_entry_col in trades_df.columns:
|
||||
first_trade_time = pd.to_datetime(trades_df[first_entry_col].iloc[0])
|
||||
start_date = first_trade_time.strftime("%Y-%m-%d %H:%M")
|
||||
else:
|
||||
start_date = idx[0].strftime("%Y-%m-%d %H:%M")
|
||||
else:
|
||||
start_date = idx[0].strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
return BacktestResult(
|
||||
run_id=str(uuid.uuid4()),
|
||||
strategy=request.strategy,
|
||||
symbol=request.symbol,
|
||||
market_type=result.market_type.value,
|
||||
timeframe=request.timeframe,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
leverage=result.leverage,
|
||||
params=params,
|
||||
metrics=metrics,
|
||||
equity_curve=equity_curve,
|
||||
trades=trades,
|
||||
created_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
def _build_equity_curve(self, portfolio) -> list[EquityPoint]:
|
||||
"""Extract equity curve with drawdown from portfolio, starting from first trade."""
|
||||
value_series = portfolio.value()
|
||||
drawdown_series = portfolio.drawdown()
|
||||
|
||||
# Handle multi-column case (from grid search)
|
||||
if hasattr(value_series, 'columns') and len(value_series.columns) > 1:
|
||||
value_series = value_series.iloc[:, 0]
|
||||
drawdown_series = drawdown_series.iloc[:, 0]
|
||||
elif hasattr(value_series, 'columns'):
|
||||
value_series = value_series.iloc[:, 0]
|
||||
drawdown_series = drawdown_series.iloc[:, 0]
|
||||
|
||||
# Find first trade time to trim equity curve
|
||||
first_trade_idx = 0
|
||||
trades_df = portfolio.trades.records_readable
|
||||
if not trades_df.empty:
|
||||
first_entry_col = 'Entry Timestamp' if 'Entry Timestamp' in trades_df.columns else 'Entry Time'
|
||||
if first_entry_col in trades_df.columns:
|
||||
first_trade_time = pd.to_datetime(trades_df[first_entry_col].iloc[0])
|
||||
# Find index in value_series closest to first trade
|
||||
if hasattr(value_series.index, 'get_indexer'):
|
||||
first_trade_idx = value_series.index.get_indexer([first_trade_time], method='nearest')[0]
|
||||
# Start a few bars before first trade for context
|
||||
first_trade_idx = max(0, first_trade_idx - 5)
|
||||
|
||||
# Slice from first trade onwards
|
||||
value_series = value_series.iloc[first_trade_idx:]
|
||||
drawdown_series = drawdown_series.iloc[first_trade_idx:]
|
||||
|
||||
points = []
|
||||
for i, (ts, val) in enumerate(value_series.items()):
|
||||
dd = drawdown_series.iloc[i] if i < len(drawdown_series) else 0.0
|
||||
points.append(EquityPoint(
|
||||
timestamp=ts.isoformat(),
|
||||
value=float(val),
|
||||
drawdown=float(dd) * 100, # Convert to percentage
|
||||
))
|
||||
|
||||
return points
|
||||
|
||||
def _build_trade_records(self, portfolio) -> list[TradeRecord]:
|
||||
"""Extract trade records from portfolio."""
|
||||
trades_df = portfolio.trades.records_readable
|
||||
|
||||
if trades_df.empty:
|
||||
return []
|
||||
|
||||
records = []
|
||||
for _, row in trades_df.iterrows():
|
||||
# Handle different column names in vectorbt
|
||||
entry_time = row.get('Entry Timestamp', row.get('Entry Time', ''))
|
||||
exit_time = row.get('Exit Timestamp', row.get('Exit Time', ''))
|
||||
|
||||
records.append(TradeRecord(
|
||||
entry_time=str(entry_time) if pd.notna(entry_time) else "",
|
||||
exit_time=str(exit_time) if pd.notna(exit_time) else None,
|
||||
entry_price=float(row.get('Avg Entry Price', row.get('Entry Price', 0))),
|
||||
exit_price=float(row.get('Avg Exit Price', row.get('Exit Price', 0)))
|
||||
if pd.notna(row.get('Avg Exit Price', row.get('Exit Price'))) else None,
|
||||
size=float(row.get('Size', 0)),
|
||||
direction=str(row.get('Direction', 'Long')),
|
||||
pnl=float(row.get('PnL', 0)) if pd.notna(row.get('PnL')) else None,
|
||||
return_pct=float(row.get('Return', 0)) * 100
|
||||
if pd.notna(row.get('Return')) else None,
|
||||
status="closed" if pd.notna(exit_time) else "open",
|
||||
))
|
||||
|
||||
return records
|
||||
|
||||
def _build_metrics(self, result, portfolio) -> BacktestMetrics:
|
||||
"""Build metrics from backtest result."""
|
||||
stats = portfolio.stats()
|
||||
|
||||
# Extract values, handling potential multi-column results
|
||||
def get_stat(key: str, default: float = 0.0) -> float:
|
||||
val = stats.get(key, default)
|
||||
if hasattr(val, 'mean'):
|
||||
return float(val.mean())
|
||||
return float(val) if pd.notna(val) else default
|
||||
|
||||
total_return = portfolio.total_return()
|
||||
if hasattr(total_return, 'mean'):
|
||||
total_return = total_return.mean()
|
||||
|
||||
# Calculate benchmark return from first trade to end (not full period)
|
||||
# This gives accurate comparison when strategy has training period
|
||||
close = portfolio.close
|
||||
benchmark_return = 0.0
|
||||
|
||||
if hasattr(close, 'iloc'):
|
||||
# Find first trade entry time
|
||||
trades_df = portfolio.trades.records_readable
|
||||
if not trades_df.empty:
|
||||
# Get the first trade entry timestamp
|
||||
first_entry_col = 'Entry Timestamp' if 'Entry Timestamp' in trades_df.columns else 'Entry Time'
|
||||
if first_entry_col in trades_df.columns:
|
||||
first_trade_time = pd.to_datetime(trades_df[first_entry_col].iloc[0])
|
||||
|
||||
# Find the price at first trade
|
||||
if hasattr(close.index, 'get_indexer'):
|
||||
# Find closest index to first trade time
|
||||
idx = close.index.get_indexer([first_trade_time], method='nearest')[0]
|
||||
start_price = close.iloc[idx]
|
||||
else:
|
||||
start_price = close.iloc[0]
|
||||
|
||||
end_price = close.iloc[-1]
|
||||
|
||||
if hasattr(start_price, 'mean'):
|
||||
start_price = start_price.mean()
|
||||
if hasattr(end_price, 'mean'):
|
||||
end_price = end_price.mean()
|
||||
|
||||
benchmark_return = ((end_price - start_price) / start_price)
|
||||
else:
|
||||
# No trades - use full period
|
||||
start_price = close.iloc[0]
|
||||
end_price = close.iloc[-1]
|
||||
if hasattr(start_price, 'mean'):
|
||||
start_price = start_price.mean()
|
||||
if hasattr(end_price, 'mean'):
|
||||
end_price = end_price.mean()
|
||||
benchmark_return = ((end_price - start_price) / start_price)
|
||||
|
||||
# Alpha = strategy return - benchmark return
|
||||
alpha = float(total_return) - float(benchmark_return)
|
||||
|
||||
sharpe = portfolio.sharpe_ratio()
|
||||
if hasattr(sharpe, 'mean'):
|
||||
sharpe = sharpe.mean()
|
||||
|
||||
max_dd = portfolio.max_drawdown()
|
||||
if hasattr(max_dd, 'mean'):
|
||||
max_dd = max_dd.mean()
|
||||
|
||||
win_rate = portfolio.trades.win_rate()
|
||||
if hasattr(win_rate, 'mean'):
|
||||
win_rate = win_rate.mean()
|
||||
|
||||
trade_count = portfolio.trades.count()
|
||||
if hasattr(trade_count, 'mean'):
|
||||
trade_count = int(trade_count.mean())
|
||||
else:
|
||||
trade_count = int(trade_count)
|
||||
|
||||
return BacktestMetrics(
|
||||
total_return=float(total_return) * 100,
|
||||
benchmark_return=float(benchmark_return) * 100,
|
||||
alpha=float(alpha) * 100,
|
||||
sharpe_ratio=float(sharpe) if pd.notna(sharpe) else 0.0,
|
||||
max_drawdown=float(max_dd) * 100,
|
||||
win_rate=float(win_rate) * 100 if pd.notna(win_rate) else 0.0,
|
||||
total_trades=trade_count,
|
||||
profit_factor=get_stat('Profit Factor'),
|
||||
avg_trade_return=get_stat('Avg Winning Trade [%]'),
|
||||
total_fees=get_stat('Total Fees Paid'),
|
||||
total_funding=result.total_funding_paid,
|
||||
liquidation_count=result.liquidation_count,
|
||||
liquidation_loss=result.total_liquidation_loss,
|
||||
adjusted_return=result.adjusted_return,
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_runner: BacktestRunner | None = None
|
||||
|
||||
|
||||
def get_runner() -> BacktestRunner:
|
||||
"""Get or create the backtest runner instance."""
|
||||
global _runner
|
||||
if _runner is None:
|
||||
_runner = BacktestRunner()
|
||||
return _runner
|
||||
225
api/services/storage.py
Normal file
225
api/services/storage.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Storage service for persisting and retrieving backtest runs.
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.models.database import BacktestRun
|
||||
from api.models.schemas import (
|
||||
BacktestResult,
|
||||
BacktestSummary,
|
||||
EquityPoint,
|
||||
BacktestMetrics,
|
||||
TradeRecord,
|
||||
)
|
||||
|
||||
|
||||
class StorageService:
|
||||
"""
|
||||
Service for saving and loading backtest runs from SQLite.
|
||||
"""
|
||||
|
||||
def save_run(self, db: Session, result: BacktestResult) -> BacktestRun:
|
||||
"""
|
||||
Persist a backtest result to the database.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
result: BacktestResult to save
|
||||
|
||||
Returns:
|
||||
Created BacktestRun record
|
||||
"""
|
||||
run = BacktestRun(
|
||||
run_id=result.run_id,
|
||||
strategy=result.strategy,
|
||||
symbol=result.symbol,
|
||||
market_type=result.market_type,
|
||||
timeframe=result.timeframe,
|
||||
leverage=result.leverage,
|
||||
params=result.params,
|
||||
start_date=result.start_date,
|
||||
end_date=result.end_date,
|
||||
total_return=result.metrics.total_return,
|
||||
benchmark_return=result.metrics.benchmark_return,
|
||||
alpha=result.metrics.alpha,
|
||||
sharpe_ratio=result.metrics.sharpe_ratio,
|
||||
max_drawdown=result.metrics.max_drawdown,
|
||||
win_rate=result.metrics.win_rate,
|
||||
total_trades=result.metrics.total_trades,
|
||||
profit_factor=result.metrics.profit_factor,
|
||||
total_fees=result.metrics.total_fees,
|
||||
total_funding=result.metrics.total_funding,
|
||||
liquidation_count=result.metrics.liquidation_count,
|
||||
liquidation_loss=result.metrics.liquidation_loss,
|
||||
adjusted_return=result.metrics.adjusted_return,
|
||||
)
|
||||
|
||||
# Serialize complex data
|
||||
run.set_equity_curve([p.model_dump() for p in result.equity_curve])
|
||||
run.set_trades([t.model_dump() for t in result.trades])
|
||||
|
||||
db.add(run)
|
||||
db.commit()
|
||||
db.refresh(run)
|
||||
|
||||
return run
|
||||
|
||||
def get_run(self, db: Session, run_id: str) -> BacktestResult | None:
|
||||
"""
|
||||
Retrieve a backtest run by ID.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
run_id: UUID of the run
|
||||
|
||||
Returns:
|
||||
BacktestResult or None if not found
|
||||
"""
|
||||
run = db.query(BacktestRun).filter(BacktestRun.run_id == run_id).first()
|
||||
|
||||
if not run:
|
||||
return None
|
||||
|
||||
return self._to_result(run)
|
||||
|
||||
def list_runs(
|
||||
self,
|
||||
db: Session,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
strategy: str | None = None,
|
||||
symbol: str | None = None,
|
||||
) -> tuple[list[BacktestSummary], int]:
|
||||
"""
|
||||
List backtest runs with optional filtering.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
limit: Maximum number of runs to return
|
||||
offset: Offset for pagination
|
||||
strategy: Filter by strategy name
|
||||
symbol: Filter by symbol
|
||||
|
||||
Returns:
|
||||
Tuple of (list of summaries, total count)
|
||||
"""
|
||||
query = db.query(BacktestRun)
|
||||
|
||||
if strategy:
|
||||
query = query.filter(BacktestRun.strategy == strategy)
|
||||
if symbol:
|
||||
query = query.filter(BacktestRun.symbol == symbol)
|
||||
|
||||
total = query.count()
|
||||
|
||||
runs = query.order_by(BacktestRun.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
summaries = [self._to_summary(run) for run in runs]
|
||||
|
||||
return summaries, total
|
||||
|
||||
def get_runs_by_ids(self, db: Session, run_ids: list[str]) -> list[BacktestResult]:
|
||||
"""
|
||||
Retrieve multiple runs by their IDs.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
run_ids: List of run UUIDs
|
||||
|
||||
Returns:
|
||||
List of BacktestResults (preserves order)
|
||||
"""
|
||||
runs = db.query(BacktestRun).filter(BacktestRun.run_id.in_(run_ids)).all()
|
||||
|
||||
# Create lookup and preserve order
|
||||
run_map = {run.run_id: run for run in runs}
|
||||
results = []
|
||||
|
||||
for run_id in run_ids:
|
||||
if run_id in run_map:
|
||||
results.append(self._to_result(run_map[run_id]))
|
||||
|
||||
return results
|
||||
|
||||
def delete_run(self, db: Session, run_id: str) -> bool:
|
||||
"""
|
||||
Delete a backtest run.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
run_id: UUID of the run to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
run = db.query(BacktestRun).filter(BacktestRun.run_id == run_id).first()
|
||||
|
||||
if not run:
|
||||
return False
|
||||
|
||||
db.delete(run)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def _to_result(self, run: BacktestRun) -> BacktestResult:
|
||||
"""Convert database record to BacktestResult schema."""
|
||||
equity_data = run.get_equity_curve()
|
||||
trades_data = run.get_trades()
|
||||
|
||||
return BacktestResult(
|
||||
run_id=run.run_id,
|
||||
strategy=run.strategy,
|
||||
symbol=run.symbol,
|
||||
market_type=run.market_type,
|
||||
timeframe=run.timeframe,
|
||||
start_date=run.start_date or "",
|
||||
end_date=run.end_date or "",
|
||||
leverage=run.leverage,
|
||||
params=run.params or {},
|
||||
metrics=BacktestMetrics(
|
||||
total_return=run.total_return,
|
||||
benchmark_return=run.benchmark_return or 0.0,
|
||||
alpha=run.alpha or 0.0,
|
||||
sharpe_ratio=run.sharpe_ratio,
|
||||
max_drawdown=run.max_drawdown,
|
||||
win_rate=run.win_rate,
|
||||
total_trades=run.total_trades,
|
||||
profit_factor=run.profit_factor,
|
||||
total_fees=run.total_fees,
|
||||
total_funding=run.total_funding,
|
||||
liquidation_count=run.liquidation_count,
|
||||
liquidation_loss=run.liquidation_loss,
|
||||
adjusted_return=run.adjusted_return,
|
||||
),
|
||||
equity_curve=[EquityPoint(**p) for p in equity_data],
|
||||
trades=[TradeRecord(**t) for t in trades_data],
|
||||
created_at=run.created_at.isoformat() if run.created_at else "",
|
||||
)
|
||||
|
||||
def _to_summary(self, run: BacktestRun) -> BacktestSummary:
|
||||
"""Convert database record to BacktestSummary schema."""
|
||||
return BacktestSummary(
|
||||
run_id=run.run_id,
|
||||
strategy=run.strategy,
|
||||
symbol=run.symbol,
|
||||
market_type=run.market_type,
|
||||
timeframe=run.timeframe,
|
||||
total_return=run.total_return,
|
||||
sharpe_ratio=run.sharpe_ratio,
|
||||
max_drawdown=run.max_drawdown,
|
||||
total_trades=run.total_trades,
|
||||
created_at=run.created_at.isoformat() if run.created_at else "",
|
||||
params=run.params or {},
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_storage: StorageService | None = None
|
||||
|
||||
|
||||
def get_storage() -> StorageService:
|
||||
"""Get or create the storage service instance."""
|
||||
global _storage
|
||||
if _storage is None:
|
||||
_storage = StorageService()
|
||||
return _storage
|
||||
Reference in New Issue
Block a user