Compare commits
9 Commits
main
...
regime-imb
| Author | SHA1 | Date | |
|---|---|---|---|
| 1af0aab5fa | |||
| df37366603 | |||
| 7e4a6874a2 | |||
| c4ecb29d4c | |||
| 0c82c4f366 | |||
| 1e4cb87da3 | |||
| 10bb371054 | |||
| e6d69ed04d | |||
| 44fac1ed25 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -170,3 +170,8 @@ cython_debug/
|
||||
|
||||
./logs/
|
||||
*.csv
|
||||
research/regime_results.html
|
||||
data/backtest_runs.db
|
||||
.gitignore
|
||||
live_trading/regime_model.pkl
|
||||
live_trading/positions.json
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../OHLCVPredictor
|
||||
3
api/__init__.py
Normal file
3
api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
FastAPI backend for Lowkey Backtest UI.
|
||||
"""
|
||||
47
api/main.py
Normal file
47
api/main.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
FastAPI application entry point for Lowkey Backtest UI.
|
||||
|
||||
Run with: uvicorn api.main:app --reload
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from api.models.database import init_db
|
||||
from api.routers import backtest, data, strategies
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize database on startup."""
|
||||
init_db()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Lowkey Backtest API",
|
||||
description="API for running and analyzing trading strategy backtests",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS configuration for local development
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Register routers
|
||||
app.include_router(strategies.router, prefix="/api", tags=["strategies"])
|
||||
app.include_router(data.router, prefix="/api", tags=["data"])
|
||||
app.include_router(backtest.router, prefix="/api", tags=["backtest"])
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok", "service": "lowkey-backtest-api"}
|
||||
3
api/models/__init__.py
Normal file
3
api/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Pydantic schemas and database models.
|
||||
"""
|
||||
99
api/models/database.py
Normal file
99
api/models/database.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
SQLAlchemy database models and session management for backtest run persistence.
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime, Float, Integer, String, Text, create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
# Database file location
|
||||
DB_PATH = Path(__file__).parent.parent.parent / "data" / "backtest_runs.db"
|
||||
DATABASE_URL = f"sqlite:///{DB_PATH}"
|
||||
|
||||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for SQLAlchemy models."""
|
||||
pass
|
||||
|
||||
|
||||
class BacktestRun(Base):
|
||||
"""
|
||||
Persisted backtest run record.
|
||||
|
||||
Stores all information needed to display and compare runs.
|
||||
"""
|
||||
__tablename__ = "backtest_runs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
run_id = Column(String(36), unique=True, nullable=False, index=True)
|
||||
|
||||
# Configuration
|
||||
strategy = Column(String(50), nullable=False, index=True)
|
||||
symbol = Column(String(20), nullable=False, index=True)
|
||||
exchange = Column(String(20), nullable=False, default="okx")
|
||||
market_type = Column(String(20), nullable=False)
|
||||
timeframe = Column(String(10), nullable=False)
|
||||
leverage = Column(Integer, nullable=False, default=1)
|
||||
params = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Date range
|
||||
start_date = Column(String(20), nullable=True)
|
||||
end_date = Column(String(20), nullable=True)
|
||||
|
||||
# Metrics (denormalized for quick listing)
|
||||
total_return = Column(Float, nullable=False)
|
||||
benchmark_return = Column(Float, nullable=False, default=0.0)
|
||||
alpha = Column(Float, nullable=False, default=0.0)
|
||||
sharpe_ratio = Column(Float, nullable=False)
|
||||
max_drawdown = Column(Float, nullable=False)
|
||||
win_rate = Column(Float, nullable=False)
|
||||
total_trades = Column(Integer, nullable=False)
|
||||
profit_factor = Column(Float, nullable=True)
|
||||
total_fees = Column(Float, nullable=False, default=0.0)
|
||||
total_funding = Column(Float, nullable=False, default=0.0)
|
||||
liquidation_count = Column(Integer, nullable=False, default=0)
|
||||
liquidation_loss = Column(Float, nullable=False, default=0.0)
|
||||
adjusted_return = Column(Float, nullable=True)
|
||||
|
||||
# Full data (JSON serialized)
|
||||
equity_curve = Column(Text, nullable=False) # JSON array
|
||||
trades = Column(Text, nullable=False) # JSON array
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def set_equity_curve(self, data: list[dict]):
|
||||
"""Serialize equity curve to JSON string."""
|
||||
self.equity_curve = json.dumps(data)
|
||||
|
||||
def get_equity_curve(self) -> list[dict]:
|
||||
"""Deserialize equity curve from JSON string."""
|
||||
return json.loads(self.equity_curve) if self.equity_curve else []
|
||||
|
||||
def set_trades(self, data: list[dict]):
|
||||
"""Serialize trades to JSON string."""
|
||||
self.trades = json.dumps(data)
|
||||
|
||||
def get_trades(self) -> list[dict]:
|
||||
"""Deserialize trades from JSON string."""
|
||||
return json.loads(self.trades) if self.trades else []
|
||||
|
||||
|
||||
def init_db():
|
||||
"""Create database tables if they don't exist."""
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_db() -> Session:
|
||||
"""Get database session (dependency injection)."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
162
api/models/schemas.py
Normal file
162
api/models/schemas.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Pydantic schemas for API request/response models.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# --- Strategy Schemas ---
|
||||
|
||||
class StrategyParam(BaseModel):
|
||||
"""Single strategy parameter definition."""
|
||||
name: str
|
||||
value: Any
|
||||
param_type: str = Field(description="Type: int, float, bool, list")
|
||||
min_value: float | None = None
|
||||
max_value: float | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class StrategyInfo(BaseModel):
|
||||
"""Strategy information with parameters."""
|
||||
name: str
|
||||
display_name: str
|
||||
market_type: str
|
||||
default_leverage: int
|
||||
default_params: dict[str, Any]
|
||||
grid_params: dict[str, Any]
|
||||
|
||||
|
||||
class StrategiesResponse(BaseModel):
|
||||
"""Response for GET /api/strategies."""
|
||||
strategies: list[StrategyInfo]
|
||||
|
||||
|
||||
# --- Symbol/Data Schemas ---
|
||||
|
||||
class SymbolInfo(BaseModel):
|
||||
"""Available symbol information."""
|
||||
symbol: str
|
||||
exchange: str
|
||||
market_type: str
|
||||
timeframes: list[str]
|
||||
start_date: str | None = None
|
||||
end_date: str | None = None
|
||||
row_count: int = 0
|
||||
|
||||
|
||||
class DataStatusResponse(BaseModel):
|
||||
"""Response for GET /api/data/status."""
|
||||
symbols: list[SymbolInfo]
|
||||
|
||||
|
||||
# --- Backtest Schemas ---
|
||||
|
||||
class BacktestRequest(BaseModel):
|
||||
"""Request body for POST /api/backtest."""
|
||||
strategy: str
|
||||
symbol: str
|
||||
exchange: str = "okx"
|
||||
timeframe: str = "1h"
|
||||
market_type: str = "perpetual"
|
||||
start_date: str | None = None
|
||||
end_date: str | None = None
|
||||
init_cash: float = 10000.0
|
||||
leverage: int | None = None
|
||||
fees: float | None = None
|
||||
slippage: float = 0.001
|
||||
sl_stop: float | None = None
|
||||
tp_stop: float | None = None
|
||||
sl_trail: bool = False
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TradeRecord(BaseModel):
|
||||
"""Single trade record."""
|
||||
entry_time: str
|
||||
exit_time: str | None = None
|
||||
entry_price: float
|
||||
exit_price: float | None = None
|
||||
size: float
|
||||
direction: str
|
||||
pnl: float | None = None
|
||||
return_pct: float | None = None
|
||||
status: str = "closed"
|
||||
|
||||
|
||||
class EquityPoint(BaseModel):
|
||||
"""Single point on equity curve."""
|
||||
timestamp: str
|
||||
value: float
|
||||
drawdown: float = 0.0
|
||||
|
||||
|
||||
class BacktestMetrics(BaseModel):
|
||||
"""Backtest performance metrics."""
|
||||
total_return: float
|
||||
benchmark_return: float = 0.0
|
||||
alpha: float = 0.0
|
||||
sharpe_ratio: float
|
||||
max_drawdown: float
|
||||
win_rate: float
|
||||
total_trades: int
|
||||
profit_factor: float | None = None
|
||||
avg_trade_return: float | None = None
|
||||
total_fees: float = 0.0
|
||||
total_funding: float = 0.0
|
||||
liquidation_count: int = 0
|
||||
liquidation_loss: float = 0.0
|
||||
adjusted_return: float | None = None
|
||||
|
||||
|
||||
class BacktestResult(BaseModel):
|
||||
"""Complete backtest result."""
|
||||
run_id: str
|
||||
strategy: str
|
||||
symbol: str
|
||||
market_type: str
|
||||
timeframe: str
|
||||
start_date: str
|
||||
end_date: str
|
||||
leverage: int
|
||||
params: dict[str, Any]
|
||||
metrics: BacktestMetrics
|
||||
equity_curve: list[EquityPoint]
|
||||
trades: list[TradeRecord]
|
||||
created_at: str
|
||||
|
||||
|
||||
class BacktestSummary(BaseModel):
|
||||
"""Summary for backtest list view."""
|
||||
run_id: str
|
||||
strategy: str
|
||||
symbol: str
|
||||
market_type: str
|
||||
timeframe: str
|
||||
total_return: float
|
||||
sharpe_ratio: float
|
||||
max_drawdown: float
|
||||
total_trades: int
|
||||
created_at: str
|
||||
params: dict[str, Any]
|
||||
|
||||
|
||||
class BacktestListResponse(BaseModel):
|
||||
"""Response for GET /api/backtests."""
|
||||
runs: list[BacktestSummary]
|
||||
total: int
|
||||
|
||||
|
||||
# --- Comparison Schemas ---
|
||||
|
||||
class CompareRequest(BaseModel):
|
||||
"""Request body for POST /api/compare."""
|
||||
run_ids: list[str] = Field(min_length=2, max_length=5)
|
||||
|
||||
|
||||
class CompareResult(BaseModel):
|
||||
"""Comparison of multiple backtest runs."""
|
||||
runs: list[BacktestResult]
|
||||
param_diff: dict[str, list[Any]]
|
||||
3
api/routers/__init__.py
Normal file
3
api/routers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
API routers for backtest, strategies, and data endpoints.
|
||||
"""
|
||||
193
api/routers/backtest.py
Normal file
193
api/routers/backtest.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Backtest execution and history endpoints.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.models.database import get_db
|
||||
from api.models.schemas import (
|
||||
BacktestListResponse,
|
||||
BacktestRequest,
|
||||
BacktestResult,
|
||||
CompareRequest,
|
||||
CompareResult,
|
||||
)
|
||||
from api.services.runner import get_runner
|
||||
from api.services.storage import get_storage
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post("/backtest", response_model=BacktestResult)
|
||||
async def run_backtest(
|
||||
request: BacktestRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Execute a backtest with the specified configuration.
|
||||
|
||||
Runs the strategy on historical data and returns metrics,
|
||||
equity curve, and trade records. Results are automatically saved.
|
||||
"""
|
||||
runner = get_runner()
|
||||
storage = get_storage()
|
||||
|
||||
try:
|
||||
# Execute backtest
|
||||
result = runner.run(request)
|
||||
|
||||
# Save to database
|
||||
storage.save_run(db, result)
|
||||
|
||||
logger.info(
|
||||
"Backtest completed and saved: %s (return=%.2f%%, sharpe=%.2f)",
|
||||
result.run_id,
|
||||
result.metrics.total_return,
|
||||
result.metrics.sharpe_ratio,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid strategy: {e}")
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=f"Data not found: {e}")
|
||||
except Exception as e:
|
||||
logger.error("Backtest failed: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/backtests", response_model=BacktestListResponse)
|
||||
async def list_backtests(
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
strategy: str | None = None,
|
||||
symbol: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List saved backtest runs with optional filtering.
|
||||
|
||||
Returns summaries for quick display in the history sidebar.
|
||||
"""
|
||||
storage = get_storage()
|
||||
|
||||
runs, total = storage.list_runs(
|
||||
db,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
strategy=strategy,
|
||||
symbol=symbol,
|
||||
)
|
||||
|
||||
return BacktestListResponse(runs=runs, total=total)
|
||||
|
||||
|
||||
@router.get("/backtest/{run_id}", response_model=BacktestResult)
|
||||
async def get_backtest(
|
||||
run_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve a specific backtest run by ID.
|
||||
|
||||
Returns full result including equity curve and trades.
|
||||
"""
|
||||
storage = get_storage()
|
||||
|
||||
result = storage.get_run(db, run_id)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/backtest/{run_id}")
|
||||
async def delete_backtest(
|
||||
run_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete a backtest run.
|
||||
"""
|
||||
storage = get_storage()
|
||||
|
||||
deleted = storage.delete_run(db, run_id)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
|
||||
|
||||
return {"status": "deleted", "run_id": run_id}
|
||||
|
||||
|
||||
@router.post("/compare", response_model=CompareResult)
|
||||
async def compare_runs(
|
||||
request: CompareRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Compare multiple backtest runs (2-5 runs).
|
||||
|
||||
Returns full results for each run plus parameter differences.
|
||||
"""
|
||||
storage = get_storage()
|
||||
|
||||
runs = storage.get_runs_by_ids(db, request.run_ids)
|
||||
|
||||
if len(runs) != len(request.run_ids):
|
||||
found_ids = {r.run_id for r in runs}
|
||||
missing = [rid for rid in request.run_ids if rid not in found_ids]
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Runs not found: {missing}"
|
||||
)
|
||||
|
||||
# Calculate parameter differences
|
||||
param_diff = _calculate_param_diff(runs)
|
||||
|
||||
return CompareResult(runs=runs, param_diff=param_diff)
|
||||
|
||||
|
||||
def _calculate_param_diff(runs: list[BacktestResult]) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Find parameters that differ between runs.
|
||||
|
||||
Returns dict mapping param name to list of values (one per run).
|
||||
"""
|
||||
if not runs:
|
||||
return {}
|
||||
|
||||
# Collect all param keys
|
||||
all_keys: set[str] = set()
|
||||
for run in runs:
|
||||
all_keys.update(run.params.keys())
|
||||
|
||||
# Also include strategy and key config
|
||||
all_keys.update(['strategy', 'symbol', 'leverage', 'timeframe'])
|
||||
|
||||
diff: dict[str, list[Any]] = {}
|
||||
|
||||
for key in sorted(all_keys):
|
||||
values = []
|
||||
for run in runs:
|
||||
if key == 'strategy':
|
||||
values.append(run.strategy)
|
||||
elif key == 'symbol':
|
||||
values.append(run.symbol)
|
||||
elif key == 'leverage':
|
||||
values.append(run.leverage)
|
||||
elif key == 'timeframe':
|
||||
values.append(run.timeframe)
|
||||
else:
|
||||
values.append(run.params.get(key))
|
||||
|
||||
# Only include if values differ
|
||||
if len(set(str(v) for v in values)) > 1:
|
||||
diff[key] = values
|
||||
|
||||
return diff
|
||||
97
api/routers/data.py
Normal file
97
api/routers/data.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Data status and symbol information endpoints.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from fastapi import APIRouter
|
||||
|
||||
from api.models.schemas import DataStatusResponse, SymbolInfo
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Base path for CCXT data
|
||||
DATA_BASE = Path(__file__).parent.parent.parent / "data" / "ccxt"
|
||||
|
||||
|
||||
def _scan_available_data() -> list[SymbolInfo]:
|
||||
"""
|
||||
Scan the data directory for available symbols and timeframes.
|
||||
|
||||
Returns list of SymbolInfo with date ranges and row counts.
|
||||
"""
|
||||
symbols = []
|
||||
|
||||
if not DATA_BASE.exists():
|
||||
return symbols
|
||||
|
||||
# Structure: data/ccxt/{exchange}/{market_type}/{symbol}/{timeframe}.csv
|
||||
for exchange_dir in DATA_BASE.iterdir():
|
||||
if not exchange_dir.is_dir():
|
||||
continue
|
||||
exchange = exchange_dir.name
|
||||
|
||||
for market_dir in exchange_dir.iterdir():
|
||||
if not market_dir.is_dir():
|
||||
continue
|
||||
market_type = market_dir.name
|
||||
|
||||
for symbol_dir in market_dir.iterdir():
|
||||
if not symbol_dir.is_dir():
|
||||
continue
|
||||
symbol = symbol_dir.name
|
||||
|
||||
# Find all timeframes
|
||||
timeframes = []
|
||||
start_date = None
|
||||
end_date = None
|
||||
row_count = 0
|
||||
|
||||
for csv_file in symbol_dir.glob("*.csv"):
|
||||
tf = csv_file.stem
|
||||
timeframes.append(tf)
|
||||
|
||||
# Read first and last rows for date range
|
||||
try:
|
||||
df = pd.read_csv(csv_file, parse_dates=['timestamp'])
|
||||
if not df.empty:
|
||||
row_count = len(df)
|
||||
start_date = df['timestamp'].min().strftime("%Y-%m-%d")
|
||||
end_date = df['timestamp'].max().strftime("%Y-%m-%d")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if timeframes:
|
||||
symbols.append(SymbolInfo(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
market_type=market_type,
|
||||
timeframes=sorted(timeframes),
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
row_count=row_count,
|
||||
))
|
||||
|
||||
return symbols
|
||||
|
||||
|
||||
@router.get("/symbols", response_model=DataStatusResponse)
|
||||
async def get_symbols():
|
||||
"""
|
||||
Get list of available symbols with their data ranges.
|
||||
|
||||
Scans the local data directory for downloaded OHLCV data.
|
||||
"""
|
||||
symbols = _scan_available_data()
|
||||
return DataStatusResponse(symbols=symbols)
|
||||
|
||||
|
||||
@router.get("/data/status", response_model=DataStatusResponse)
|
||||
async def get_data_status():
|
||||
"""
|
||||
Get detailed data inventory status.
|
||||
|
||||
Alias for /symbols with additional metadata.
|
||||
"""
|
||||
symbols = _scan_available_data()
|
||||
return DataStatusResponse(symbols=symbols)
|
||||
67
api/routers/strategies.py
Normal file
67
api/routers/strategies.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Strategy information endpoints.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter
|
||||
|
||||
from api.models.schemas import StrategiesResponse, StrategyInfo
|
||||
from strategies.factory import get_registry
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _serialize_param_value(value: Any) -> Any:
|
||||
"""Convert numpy arrays and other types to JSON-serializable format."""
|
||||
if isinstance(value, np.ndarray):
|
||||
return value.tolist()
|
||||
if isinstance(value, (np.integer, np.floating)):
|
||||
return value.item()
|
||||
return value
|
||||
|
||||
|
||||
def _get_display_name(name: str) -> str:
|
||||
"""Convert strategy key to display name."""
|
||||
display_names = {
|
||||
"rsi": "RSI Strategy",
|
||||
"macross": "MA Crossover",
|
||||
"meta_st": "Meta Supertrend",
|
||||
"regime": "Regime Reversion (ML)",
|
||||
}
|
||||
return display_names.get(name, name.replace("_", " ").title())
|
||||
|
||||
|
||||
@router.get("/strategies", response_model=StrategiesResponse)
|
||||
async def get_strategies():
|
||||
"""
|
||||
Get list of available strategies with their parameters.
|
||||
|
||||
Returns strategy names, default parameters, and grid search ranges.
|
||||
"""
|
||||
registry = get_registry()
|
||||
strategies = []
|
||||
|
||||
for name, config in registry.items():
|
||||
strategy_instance = config.strategy_class()
|
||||
|
||||
# Serialize parameters (convert numpy arrays to lists)
|
||||
default_params = {
|
||||
k: _serialize_param_value(v)
|
||||
for k, v in config.default_params.items()
|
||||
}
|
||||
grid_params = {
|
||||
k: _serialize_param_value(v)
|
||||
for k, v in config.grid_params.items()
|
||||
}
|
||||
|
||||
strategies.append(StrategyInfo(
|
||||
name=name,
|
||||
display_name=_get_display_name(name),
|
||||
market_type=strategy_instance.default_market_type.value,
|
||||
default_leverage=strategy_instance.default_leverage,
|
||||
default_params=default_params,
|
||||
grid_params=grid_params,
|
||||
))
|
||||
|
||||
return StrategiesResponse(strategies=strategies)
|
||||
3
api/services/__init__.py
Normal file
3
api/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Business logic services for backtest execution and storage.
|
||||
"""
|
||||
300
api/services/runner.py
Normal file
300
api/services/runner.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Backtest runner service that wraps the existing Backtester engine.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from api.models.schemas import (
|
||||
BacktestMetrics,
|
||||
BacktestRequest,
|
||||
BacktestResult,
|
||||
EquityPoint,
|
||||
TradeRecord,
|
||||
)
|
||||
from engine.backtester import Backtester
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketType
|
||||
from strategies.factory import get_strategy
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BacktestRunner:
|
||||
"""
|
||||
Service for executing backtests via the API.
|
||||
|
||||
Wraps the existing Backtester engine and converts results
|
||||
to API response format.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.dm = DataManager()
|
||||
self.bt = Backtester(self.dm)
|
||||
|
||||
def run(self, request: BacktestRequest) -> BacktestResult:
|
||||
"""
|
||||
Execute a backtest and return structured results.
|
||||
|
||||
Args:
|
||||
request: BacktestRequest with strategy, symbol, and parameters
|
||||
|
||||
Returns:
|
||||
BacktestResult with metrics, equity curve, and trades
|
||||
"""
|
||||
# Get strategy instance
|
||||
strategy, default_params = get_strategy(request.strategy, is_grid=False)
|
||||
|
||||
# Merge default params with request params
|
||||
params = {**default_params, **request.params}
|
||||
|
||||
# Convert market type string to enum
|
||||
market_type = MarketType(request.market_type)
|
||||
|
||||
# Override strategy market type if specified
|
||||
strategy.default_market_type = market_type
|
||||
|
||||
logger.info(
|
||||
"Running backtest: %s on %s (%s), params=%s",
|
||||
request.strategy, request.symbol, request.timeframe, params
|
||||
)
|
||||
|
||||
# Execute backtest
|
||||
result = self.bt.run_strategy(
|
||||
strategy=strategy,
|
||||
exchange_id=request.exchange,
|
||||
symbol=request.symbol,
|
||||
timeframe=request.timeframe,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
init_cash=request.init_cash,
|
||||
fees=request.fees,
|
||||
slippage=request.slippage,
|
||||
sl_stop=request.sl_stop,
|
||||
tp_stop=request.tp_stop,
|
||||
sl_trail=request.sl_trail,
|
||||
leverage=request.leverage,
|
||||
**params
|
||||
)
|
||||
|
||||
# Extract data from portfolio
|
||||
portfolio = result.portfolio
|
||||
|
||||
# Build trade records
|
||||
trades = self._build_trade_records(portfolio)
|
||||
|
||||
# Build equity curve (trimmed to trading period)
|
||||
equity_curve = self._build_equity_curve(portfolio)
|
||||
|
||||
# Build metrics
|
||||
metrics = self._build_metrics(result, portfolio)
|
||||
|
||||
# Get date range from actual trading period (first trade to end)
|
||||
idx = portfolio.wrapper.index
|
||||
end_date = idx[-1].strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
# Use first trade time as start if trades exist
|
||||
trades_df = portfolio.trades.records_readable
|
||||
if not trades_df.empty:
|
||||
first_entry_col = 'Entry Timestamp' if 'Entry Timestamp' in trades_df.columns else 'Entry Time'
|
||||
if first_entry_col in trades_df.columns:
|
||||
first_trade_time = pd.to_datetime(trades_df[first_entry_col].iloc[0])
|
||||
start_date = first_trade_time.strftime("%Y-%m-%d %H:%M")
|
||||
else:
|
||||
start_date = idx[0].strftime("%Y-%m-%d %H:%M")
|
||||
else:
|
||||
start_date = idx[0].strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
return BacktestResult(
|
||||
run_id=str(uuid.uuid4()),
|
||||
strategy=request.strategy,
|
||||
symbol=request.symbol,
|
||||
market_type=result.market_type.value,
|
||||
timeframe=request.timeframe,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
leverage=result.leverage,
|
||||
params=params,
|
||||
metrics=metrics,
|
||||
equity_curve=equity_curve,
|
||||
trades=trades,
|
||||
created_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
def _build_equity_curve(self, portfolio) -> list[EquityPoint]:
|
||||
"""Extract equity curve with drawdown from portfolio, starting from first trade."""
|
||||
value_series = portfolio.value()
|
||||
drawdown_series = portfolio.drawdown()
|
||||
|
||||
# Handle multi-column case (from grid search)
|
||||
if hasattr(value_series, 'columns') and len(value_series.columns) > 1:
|
||||
value_series = value_series.iloc[:, 0]
|
||||
drawdown_series = drawdown_series.iloc[:, 0]
|
||||
elif hasattr(value_series, 'columns'):
|
||||
value_series = value_series.iloc[:, 0]
|
||||
drawdown_series = drawdown_series.iloc[:, 0]
|
||||
|
||||
# Find first trade time to trim equity curve
|
||||
first_trade_idx = 0
|
||||
trades_df = portfolio.trades.records_readable
|
||||
if not trades_df.empty:
|
||||
first_entry_col = 'Entry Timestamp' if 'Entry Timestamp' in trades_df.columns else 'Entry Time'
|
||||
if first_entry_col in trades_df.columns:
|
||||
first_trade_time = pd.to_datetime(trades_df[first_entry_col].iloc[0])
|
||||
# Find index in value_series closest to first trade
|
||||
if hasattr(value_series.index, 'get_indexer'):
|
||||
first_trade_idx = value_series.index.get_indexer([first_trade_time], method='nearest')[0]
|
||||
# Start a few bars before first trade for context
|
||||
first_trade_idx = max(0, first_trade_idx - 5)
|
||||
|
||||
# Slice from first trade onwards
|
||||
value_series = value_series.iloc[first_trade_idx:]
|
||||
drawdown_series = drawdown_series.iloc[first_trade_idx:]
|
||||
|
||||
points = []
|
||||
for i, (ts, val) in enumerate(value_series.items()):
|
||||
dd = drawdown_series.iloc[i] if i < len(drawdown_series) else 0.0
|
||||
points.append(EquityPoint(
|
||||
timestamp=ts.isoformat(),
|
||||
value=float(val),
|
||||
drawdown=float(dd) * 100, # Convert to percentage
|
||||
))
|
||||
|
||||
return points
|
||||
|
||||
def _build_trade_records(self, portfolio) -> list[TradeRecord]:
|
||||
"""Extract trade records from portfolio."""
|
||||
trades_df = portfolio.trades.records_readable
|
||||
|
||||
if trades_df.empty:
|
||||
return []
|
||||
|
||||
records = []
|
||||
for _, row in trades_df.iterrows():
|
||||
# Handle different column names in vectorbt
|
||||
entry_time = row.get('Entry Timestamp', row.get('Entry Time', ''))
|
||||
exit_time = row.get('Exit Timestamp', row.get('Exit Time', ''))
|
||||
|
||||
records.append(TradeRecord(
|
||||
entry_time=str(entry_time) if pd.notna(entry_time) else "",
|
||||
exit_time=str(exit_time) if pd.notna(exit_time) else None,
|
||||
entry_price=float(row.get('Avg Entry Price', row.get('Entry Price', 0))),
|
||||
exit_price=float(row.get('Avg Exit Price', row.get('Exit Price', 0)))
|
||||
if pd.notna(row.get('Avg Exit Price', row.get('Exit Price'))) else None,
|
||||
size=float(row.get('Size', 0)),
|
||||
direction=str(row.get('Direction', 'Long')),
|
||||
pnl=float(row.get('PnL', 0)) if pd.notna(row.get('PnL')) else None,
|
||||
return_pct=float(row.get('Return', 0)) * 100
|
||||
if pd.notna(row.get('Return')) else None,
|
||||
status="closed" if pd.notna(exit_time) else "open",
|
||||
))
|
||||
|
||||
return records
|
||||
|
||||
def _build_metrics(self, result, portfolio) -> BacktestMetrics:
|
||||
"""Build metrics from backtest result."""
|
||||
stats = portfolio.stats()
|
||||
|
||||
# Extract values, handling potential multi-column results
|
||||
def get_stat(key: str, default: float = 0.0) -> float:
|
||||
val = stats.get(key, default)
|
||||
if hasattr(val, 'mean'):
|
||||
return float(val.mean())
|
||||
return float(val) if pd.notna(val) else default
|
||||
|
||||
total_return = portfolio.total_return()
|
||||
if hasattr(total_return, 'mean'):
|
||||
total_return = total_return.mean()
|
||||
|
||||
# Calculate benchmark return from first trade to end (not full period)
|
||||
# This gives accurate comparison when strategy has training period
|
||||
close = portfolio.close
|
||||
benchmark_return = 0.0
|
||||
|
||||
if hasattr(close, 'iloc'):
|
||||
# Find first trade entry time
|
||||
trades_df = portfolio.trades.records_readable
|
||||
if not trades_df.empty:
|
||||
# Get the first trade entry timestamp
|
||||
first_entry_col = 'Entry Timestamp' if 'Entry Timestamp' in trades_df.columns else 'Entry Time'
|
||||
if first_entry_col in trades_df.columns:
|
||||
first_trade_time = pd.to_datetime(trades_df[first_entry_col].iloc[0])
|
||||
|
||||
# Find the price at first trade
|
||||
if hasattr(close.index, 'get_indexer'):
|
||||
# Find closest index to first trade time
|
||||
idx = close.index.get_indexer([first_trade_time], method='nearest')[0]
|
||||
start_price = close.iloc[idx]
|
||||
else:
|
||||
start_price = close.iloc[0]
|
||||
|
||||
end_price = close.iloc[-1]
|
||||
|
||||
if hasattr(start_price, 'mean'):
|
||||
start_price = start_price.mean()
|
||||
if hasattr(end_price, 'mean'):
|
||||
end_price = end_price.mean()
|
||||
|
||||
benchmark_return = ((end_price - start_price) / start_price)
|
||||
else:
|
||||
# No trades - use full period
|
||||
start_price = close.iloc[0]
|
||||
end_price = close.iloc[-1]
|
||||
if hasattr(start_price, 'mean'):
|
||||
start_price = start_price.mean()
|
||||
if hasattr(end_price, 'mean'):
|
||||
end_price = end_price.mean()
|
||||
benchmark_return = ((end_price - start_price) / start_price)
|
||||
|
||||
# Alpha = strategy return - benchmark return
|
||||
alpha = float(total_return) - float(benchmark_return)
|
||||
|
||||
sharpe = portfolio.sharpe_ratio()
|
||||
if hasattr(sharpe, 'mean'):
|
||||
sharpe = sharpe.mean()
|
||||
|
||||
max_dd = portfolio.max_drawdown()
|
||||
if hasattr(max_dd, 'mean'):
|
||||
max_dd = max_dd.mean()
|
||||
|
||||
win_rate = portfolio.trades.win_rate()
|
||||
if hasattr(win_rate, 'mean'):
|
||||
win_rate = win_rate.mean()
|
||||
|
||||
trade_count = portfolio.trades.count()
|
||||
if hasattr(trade_count, 'mean'):
|
||||
trade_count = int(trade_count.mean())
|
||||
else:
|
||||
trade_count = int(trade_count)
|
||||
|
||||
return BacktestMetrics(
|
||||
total_return=float(total_return) * 100,
|
||||
benchmark_return=float(benchmark_return) * 100,
|
||||
alpha=float(alpha) * 100,
|
||||
sharpe_ratio=float(sharpe) if pd.notna(sharpe) else 0.0,
|
||||
max_drawdown=float(max_dd) * 100,
|
||||
win_rate=float(win_rate) * 100 if pd.notna(win_rate) else 0.0,
|
||||
total_trades=trade_count,
|
||||
profit_factor=get_stat('Profit Factor'),
|
||||
avg_trade_return=get_stat('Avg Winning Trade [%]'),
|
||||
total_fees=get_stat('Total Fees Paid'),
|
||||
total_funding=result.total_funding_paid,
|
||||
liquidation_count=result.liquidation_count,
|
||||
liquidation_loss=result.total_liquidation_loss,
|
||||
adjusted_return=result.adjusted_return,
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_runner: BacktestRunner | None = None
|
||||
|
||||
|
||||
def get_runner() -> BacktestRunner:
|
||||
"""Get or create the backtest runner instance."""
|
||||
global _runner
|
||||
if _runner is None:
|
||||
_runner = BacktestRunner()
|
||||
return _runner
|
||||
225
api/services/storage.py
Normal file
225
api/services/storage.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Storage service for persisting and retrieving backtest runs.
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.models.database import BacktestRun
|
||||
from api.models.schemas import (
|
||||
BacktestResult,
|
||||
BacktestSummary,
|
||||
EquityPoint,
|
||||
BacktestMetrics,
|
||||
TradeRecord,
|
||||
)
|
||||
|
||||
|
||||
class StorageService:
|
||||
"""
|
||||
Service for saving and loading backtest runs from SQLite.
|
||||
"""
|
||||
|
||||
def save_run(self, db: Session, result: BacktestResult) -> BacktestRun:
|
||||
"""
|
||||
Persist a backtest result to the database.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
result: BacktestResult to save
|
||||
|
||||
Returns:
|
||||
Created BacktestRun record
|
||||
"""
|
||||
run = BacktestRun(
|
||||
run_id=result.run_id,
|
||||
strategy=result.strategy,
|
||||
symbol=result.symbol,
|
||||
market_type=result.market_type,
|
||||
timeframe=result.timeframe,
|
||||
leverage=result.leverage,
|
||||
params=result.params,
|
||||
start_date=result.start_date,
|
||||
end_date=result.end_date,
|
||||
total_return=result.metrics.total_return,
|
||||
benchmark_return=result.metrics.benchmark_return,
|
||||
alpha=result.metrics.alpha,
|
||||
sharpe_ratio=result.metrics.sharpe_ratio,
|
||||
max_drawdown=result.metrics.max_drawdown,
|
||||
win_rate=result.metrics.win_rate,
|
||||
total_trades=result.metrics.total_trades,
|
||||
profit_factor=result.metrics.profit_factor,
|
||||
total_fees=result.metrics.total_fees,
|
||||
total_funding=result.metrics.total_funding,
|
||||
liquidation_count=result.metrics.liquidation_count,
|
||||
liquidation_loss=result.metrics.liquidation_loss,
|
||||
adjusted_return=result.metrics.adjusted_return,
|
||||
)
|
||||
|
||||
# Serialize complex data
|
||||
run.set_equity_curve([p.model_dump() for p in result.equity_curve])
|
||||
run.set_trades([t.model_dump() for t in result.trades])
|
||||
|
||||
db.add(run)
|
||||
db.commit()
|
||||
db.refresh(run)
|
||||
|
||||
return run
|
||||
|
||||
def get_run(self, db: Session, run_id: str) -> BacktestResult | None:
|
||||
"""
|
||||
Retrieve a backtest run by ID.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
run_id: UUID of the run
|
||||
|
||||
Returns:
|
||||
BacktestResult or None if not found
|
||||
"""
|
||||
run = db.query(BacktestRun).filter(BacktestRun.run_id == run_id).first()
|
||||
|
||||
if not run:
|
||||
return None
|
||||
|
||||
return self._to_result(run)
|
||||
|
||||
def list_runs(
|
||||
self,
|
||||
db: Session,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
strategy: str | None = None,
|
||||
symbol: str | None = None,
|
||||
) -> tuple[list[BacktestSummary], int]:
|
||||
"""
|
||||
List backtest runs with optional filtering.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
limit: Maximum number of runs to return
|
||||
offset: Offset for pagination
|
||||
strategy: Filter by strategy name
|
||||
symbol: Filter by symbol
|
||||
|
||||
Returns:
|
||||
Tuple of (list of summaries, total count)
|
||||
"""
|
||||
query = db.query(BacktestRun)
|
||||
|
||||
if strategy:
|
||||
query = query.filter(BacktestRun.strategy == strategy)
|
||||
if symbol:
|
||||
query = query.filter(BacktestRun.symbol == symbol)
|
||||
|
||||
total = query.count()
|
||||
|
||||
runs = query.order_by(BacktestRun.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
summaries = [self._to_summary(run) for run in runs]
|
||||
|
||||
return summaries, total
|
||||
|
||||
def get_runs_by_ids(self, db: Session, run_ids: list[str]) -> list[BacktestResult]:
|
||||
"""
|
||||
Retrieve multiple runs by their IDs.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
run_ids: List of run UUIDs
|
||||
|
||||
Returns:
|
||||
List of BacktestResults (preserves order)
|
||||
"""
|
||||
runs = db.query(BacktestRun).filter(BacktestRun.run_id.in_(run_ids)).all()
|
||||
|
||||
# Create lookup and preserve order
|
||||
run_map = {run.run_id: run for run in runs}
|
||||
results = []
|
||||
|
||||
for run_id in run_ids:
|
||||
if run_id in run_map:
|
||||
results.append(self._to_result(run_map[run_id]))
|
||||
|
||||
return results
|
||||
|
||||
def delete_run(self, db: Session, run_id: str) -> bool:
|
||||
"""
|
||||
Delete a backtest run.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
run_id: UUID of the run to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
run = db.query(BacktestRun).filter(BacktestRun.run_id == run_id).first()
|
||||
|
||||
if not run:
|
||||
return False
|
||||
|
||||
db.delete(run)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def _to_result(self, run: BacktestRun) -> BacktestResult:
|
||||
"""Convert database record to BacktestResult schema."""
|
||||
equity_data = run.get_equity_curve()
|
||||
trades_data = run.get_trades()
|
||||
|
||||
return BacktestResult(
|
||||
run_id=run.run_id,
|
||||
strategy=run.strategy,
|
||||
symbol=run.symbol,
|
||||
market_type=run.market_type,
|
||||
timeframe=run.timeframe,
|
||||
start_date=run.start_date or "",
|
||||
end_date=run.end_date or "",
|
||||
leverage=run.leverage,
|
||||
params=run.params or {},
|
||||
metrics=BacktestMetrics(
|
||||
total_return=run.total_return,
|
||||
benchmark_return=run.benchmark_return or 0.0,
|
||||
alpha=run.alpha or 0.0,
|
||||
sharpe_ratio=run.sharpe_ratio,
|
||||
max_drawdown=run.max_drawdown,
|
||||
win_rate=run.win_rate,
|
||||
total_trades=run.total_trades,
|
||||
profit_factor=run.profit_factor,
|
||||
total_fees=run.total_fees,
|
||||
total_funding=run.total_funding,
|
||||
liquidation_count=run.liquidation_count,
|
||||
liquidation_loss=run.liquidation_loss,
|
||||
adjusted_return=run.adjusted_return,
|
||||
),
|
||||
equity_curve=[EquityPoint(**p) for p in equity_data],
|
||||
trades=[TradeRecord(**t) for t in trades_data],
|
||||
created_at=run.created_at.isoformat() if run.created_at else "",
|
||||
)
|
||||
|
||||
def _to_summary(self, run: BacktestRun) -> BacktestSummary:
|
||||
"""Convert database record to BacktestSummary schema."""
|
||||
return BacktestSummary(
|
||||
run_id=run.run_id,
|
||||
strategy=run.strategy,
|
||||
symbol=run.symbol,
|
||||
market_type=run.market_type,
|
||||
timeframe=run.timeframe,
|
||||
total_return=run.total_return,
|
||||
sharpe_ratio=run.sharpe_ratio,
|
||||
max_drawdown=run.max_drawdown,
|
||||
total_trades=run.total_trades,
|
||||
created_at=run.created_at.isoformat() if run.created_at else "",
|
||||
params=run.params or {},
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_storage: StorageService | None = None
|
||||
|
||||
|
||||
def get_storage() -> StorageService:
|
||||
"""Get or create the storage service instance."""
|
||||
global _storage
|
||||
if _storage is None:
|
||||
_storage = StorageService()
|
||||
return _storage
|
||||
70
backtest.py
70
backtest.py
@@ -1,70 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from trade import TradeState, enter_long, exit_long, maybe_trailing_stop
|
||||
from indicators import add_supertrends, compute_meta_trend
|
||||
from metrics import compute_metrics
|
||||
from logging_utils import write_trade_log
|
||||
|
||||
DEFAULT_ST_SETTINGS = [(12, 3.0), (10, 1.0), (11, 2.0)]
|
||||
|
||||
def backtest(
|
||||
df: pd.DataFrame,
|
||||
df_1min: pd.DataFrame,
|
||||
timeframe_minutes: int,
|
||||
stop_loss: float,
|
||||
exit_on_bearish_flip: bool,
|
||||
fee_bps: float,
|
||||
slippage_bps: float,
|
||||
log_path: Path | None = None,
|
||||
):
|
||||
df = add_supertrends(df, DEFAULT_ST_SETTINGS)
|
||||
df["meta_bull"] = compute_meta_trend(df, DEFAULT_ST_SETTINGS)
|
||||
|
||||
state = TradeState(stop_loss_frac=stop_loss, fee_bps=fee_bps, slippage_bps=slippage_bps)
|
||||
equity, trades = [], []
|
||||
|
||||
for i, row in df.iterrows():
|
||||
price = float(row["Close"])
|
||||
ts = pd.Timestamp(row["Timestamp"])
|
||||
|
||||
if state.qty <= 0 and row["meta_bull"] == 1:
|
||||
evt = enter_long(state, price)
|
||||
if evt:
|
||||
evt.update({"t": ts.isoformat(), "reason": "bull_flip"})
|
||||
trades.append(evt)
|
||||
|
||||
start = ts
|
||||
end = df["Timestamp"].iat[i + 1] if i + 1 < len(df) else ts + pd.Timedelta(minutes=timeframe_minutes)
|
||||
|
||||
if state.qty > 0:
|
||||
win = df_1min[(df_1min["Timestamp"] >= start) & (df_1min["Timestamp"] < end)]
|
||||
for _, m in win.iterrows():
|
||||
hi = float(m["High"])
|
||||
lo = float(m["Low"])
|
||||
state.max_px = max(state.max_px or hi, hi)
|
||||
trail = state.max_px * (1.0 - state.stop_loss_frac)
|
||||
if lo <= trail:
|
||||
evt = exit_long(state, trail)
|
||||
if evt:
|
||||
prev = trades[-1]
|
||||
pnl = (evt["price"] - (prev.get("price") or evt["price"])) * (prev.get("qty") or 0.0)
|
||||
evt.update({"t": pd.Timestamp(m["Timestamp"]).isoformat(), "reason": "stop", "pnl": pnl})
|
||||
trades.append(evt)
|
||||
break
|
||||
|
||||
if state.qty > 0 and exit_on_bearish_flip and row["meta_bull"] == 0:
|
||||
evt = exit_long(state, price)
|
||||
if evt:
|
||||
prev = trades[-1]
|
||||
pnl = (evt["price"] - (prev.get("price") or evt["price"])) * (prev.get("qty") or 0.0)
|
||||
evt.update({"t": ts.isoformat(), "reason": "bearish_flip", "pnl": pnl})
|
||||
trades.append(evt)
|
||||
|
||||
equity.append(state.cash + state.qty * price)
|
||||
|
||||
equity_curve = pd.Series(equity, index=df["Timestamp"])
|
||||
if log_path:
|
||||
write_trade_log(trades, log_path)
|
||||
perf = compute_metrics(equity_curve, trades)
|
||||
return perf, equity_curve, trades
|
||||
98
check_demo_account.py
Normal file
98
check_demo_account.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Check OKX demo account positions and recent orders.
|
||||
|
||||
Usage:
|
||||
uv run python check_demo_account.py
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from live_trading.config import OKXConfig
|
||||
import ccxt
|
||||
|
||||
|
||||
def main():
|
||||
"""Check demo account status."""
|
||||
config = OKXConfig()
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" OKX Demo Account Check")
|
||||
print(f"{'='*60}")
|
||||
print(f" Demo Mode: {config.demo_mode}")
|
||||
print(f" API Key: {config.api_key[:8]}..." if config.api_key else " API Key: NOT SET")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
exchange = ccxt.okx({
|
||||
'apiKey': config.api_key,
|
||||
'secret': config.secret,
|
||||
'password': config.password,
|
||||
'sandbox': config.demo_mode,
|
||||
'options': {'defaultType': 'swap'},
|
||||
'enableRateLimit': True,
|
||||
})
|
||||
|
||||
# Check balance
|
||||
print("--- BALANCE ---")
|
||||
balance = exchange.fetch_balance()
|
||||
usdt = balance.get('USDT', {})
|
||||
print(f"USDT Total: {usdt.get('total', 0):.2f}")
|
||||
print(f"USDT Free: {usdt.get('free', 0):.2f}")
|
||||
print(f"USDT Used: {usdt.get('used', 0):.2f}")
|
||||
|
||||
# Check all balances
|
||||
print("\n--- ALL NON-ZERO BALANCES ---")
|
||||
for currency, data in balance.items():
|
||||
if isinstance(data, dict) and data.get('total', 0) > 0:
|
||||
print(f"{currency}: total={data.get('total', 0):.6f}, free={data.get('free', 0):.6f}")
|
||||
|
||||
# Check open positions
|
||||
print("\n--- OPEN POSITIONS ---")
|
||||
positions = exchange.fetch_positions()
|
||||
open_positions = [p for p in positions if abs(float(p.get('contracts', 0))) > 0]
|
||||
|
||||
if open_positions:
|
||||
for pos in open_positions:
|
||||
print(f" {pos['symbol']}: {pos['side']} {pos['contracts']} contracts @ {pos.get('entryPrice', 'N/A')}")
|
||||
print(f" Unrealized PnL: {pos.get('unrealizedPnl', 'N/A')}")
|
||||
else:
|
||||
print(" No open positions")
|
||||
|
||||
# Check recent orders (last 50)
|
||||
print("\n--- RECENT ORDERS (last 24h) ---")
|
||||
try:
|
||||
# Fetch closed orders for AVAX
|
||||
orders = exchange.fetch_orders('AVAX/USDT:USDT', limit=20)
|
||||
if orders:
|
||||
for order in orders[-10:]: # Last 10
|
||||
ts = datetime.fromtimestamp(order['timestamp']/1000, tz=timezone.utc)
|
||||
print(f" [{ts.strftime('%H:%M:%S')}] {order['side'].upper()} {order['amount']} AVAX @ {order.get('average', order.get('price', 'market'))}")
|
||||
print(f" Status: {order['status']}, Filled: {order.get('filled', 0)}, ID: {order['id']}")
|
||||
else:
|
||||
print(" No recent AVAX orders")
|
||||
except Exception as e:
|
||||
print(f" Could not fetch orders: {e}")
|
||||
|
||||
# Check order history more broadly
|
||||
print("\n--- ORDER HISTORY (AVAX) ---")
|
||||
try:
|
||||
# Try fetching my trades
|
||||
trades = exchange.fetch_my_trades('AVAX/USDT:USDT', limit=10)
|
||||
if trades:
|
||||
for trade in trades[-5:]:
|
||||
ts = datetime.fromtimestamp(trade['timestamp']/1000, tz=timezone.utc)
|
||||
print(f" [{ts.strftime('%Y-%m-%d %H:%M:%S')}] {trade['side'].upper()} {trade['amount']} @ {trade['price']}")
|
||||
print(f" Fee: {trade.get('fee', {}).get('cost', 'N/A')} {trade.get('fee', {}).get('currency', '')}")
|
||||
else:
|
||||
print(" No recent AVAX trades")
|
||||
except Exception as e:
|
||||
print(f" Could not fetch trades: {e}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
28
check_symbols.py
Normal file
28
check_symbols.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import ccxt
|
||||
import sys
|
||||
|
||||
def main():
|
||||
try:
|
||||
exchange = ccxt.okx()
|
||||
print("Loading markets...")
|
||||
markets = exchange.load_markets()
|
||||
|
||||
# Filter for ETH perpetuals
|
||||
eth_perps = [
|
||||
symbol for symbol, market in markets.items()
|
||||
if 'ETH' in symbol and 'USDT' in symbol and market.get('swap') and market.get('linear')
|
||||
]
|
||||
|
||||
print(f"\nFound {len(eth_perps)} ETH Linear Perps:")
|
||||
for symbol in eth_perps:
|
||||
market = markets[symbol]
|
||||
print(f" CCXT Symbol: {symbol}")
|
||||
print(f" Exchange ID: {market['id']}")
|
||||
print(f" Type: {market['type']}")
|
||||
print("-" * 30)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
80
cli.py
80
cli.py
@@ -1,80 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
from config import CLIConfig
|
||||
from data import load_data
|
||||
from backtest import backtest
|
||||
|
||||
def parse_args() -> CLIConfig:
|
||||
p = argparse.ArgumentParser(prog="bt", description="Simple supertrend backtester")
|
||||
p.add_argument("start")
|
||||
p.add_argument("end")
|
||||
p.add_argument("--timeframe-minutes", type=int, default=15) # single TF
|
||||
p.add_argument("--timeframes-minutes", nargs="+", type=int) # multi TF: e.g. 5 15 60 240
|
||||
p.add_argument("--stop-loss", dest="stop_losses", type=float, nargs="+", default=[0.02, 0.05])
|
||||
p.add_argument("--exit-on-bearish-flip", action="store_true")
|
||||
p.add_argument("--csv", dest="csv_path", type=Path, required=True)
|
||||
p.add_argument("--out-csv", type=Path, default=Path("summary.csv"))
|
||||
p.add_argument("--log-dir", type=Path, default=Path("./logs"))
|
||||
p.add_argument("--fee-bps", type=float, default=10.0)
|
||||
p.add_argument("--slippage-bps", type=float, default=2.0)
|
||||
a = p.parse_args()
|
||||
|
||||
return CLIConfig(
|
||||
start=a.start,
|
||||
end=a.end,
|
||||
timeframe_minutes=a.timeframe_minutes,
|
||||
timeframes_minutes=a.timeframes_minutes,
|
||||
stop_losses=a.stop_losses,
|
||||
exit_on_bearish_flip=a.exit_on_bearish_flip,
|
||||
csv_path=a.csv_path,
|
||||
out_csv=a.out_csv,
|
||||
log_dir=a.log_dir,
|
||||
fee_bps=a.fee_bps,
|
||||
slippage_bps=a.slippage_bps,
|
||||
)
|
||||
|
||||
def main():
|
||||
cfg = parse_args()
|
||||
frames = cfg.timeframes_minutes or [cfg.timeframe_minutes]
|
||||
|
||||
rows: list[dict] = []
|
||||
for tfm in frames:
|
||||
df_1min, df = load_data(cfg.start, cfg.end, tfm, cfg.csv_path)
|
||||
for sl in cfg.stop_losses:
|
||||
log_path = cfg.log_dir / f"{tfm}m_sl{sl:.2%}.csv"
|
||||
perf, equity, _ = backtest(
|
||||
df=df,
|
||||
df_1min=df_1min,
|
||||
timeframe_minutes=tfm,
|
||||
stop_loss=sl,
|
||||
exit_on_bearish_flip=cfg.exit_on_bearish_flip,
|
||||
fee_bps=cfg.fee_bps,
|
||||
slippage_bps=cfg.slippage_bps,
|
||||
log_path=log_path,
|
||||
)
|
||||
rows.append({
|
||||
"timeframe": f"{tfm}min",
|
||||
"stop_loss": sl,
|
||||
"exit_on_bearish_flip": cfg.exit_on_bearish_flip,
|
||||
"total_return": f"{perf.total_return:.2%}",
|
||||
"max_drawdown": f"{perf.max_drawdown:.2%}",
|
||||
"sharpe_ratio": f"{perf.sharpe_ratio:.2f}",
|
||||
"win_rate": f"{perf.win_rate:.2%}",
|
||||
"num_trades": perf.num_trades,
|
||||
"final_equity": f"${perf.final_equity:.2f}",
|
||||
"initial_equity": f"${perf.initial_equity:.2f}",
|
||||
"num_stop_losses": perf.num_stop_losses,
|
||||
"total_fees": perf.total_fees,
|
||||
"total_slippage_usd": perf.total_slippage_usd,
|
||||
"avg_slippage_bps": perf.avg_slippage_bps,
|
||||
})
|
||||
|
||||
out = pd.DataFrame(rows)
|
||||
out.to_csv(cfg.out_csv, index=False)
|
||||
print(out.to_string(index=False))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
18
config.py
18
config.py
@@ -1,18 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
@dataclass
|
||||
class CLIConfig:
|
||||
start: str
|
||||
end: str
|
||||
timeframe_minutes: int
|
||||
timeframes_minutes: list[int] | None
|
||||
stop_losses: Sequence[float]
|
||||
exit_on_bearish_flip: bool
|
||||
csv_path: Path | None
|
||||
out_csv: Path
|
||||
log_dir: Path
|
||||
fee_bps: float
|
||||
slippage_bps: float
|
||||
24
data.py
24
data.py
@@ -1,24 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
def load_data(start: str, end: str, timeframe_minutes: int, csv_path: Path) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
df_1min = pd.read_csv(csv_path)
|
||||
df_1min["Timestamp"] = pd.to_datetime(df_1min["Timestamp"], unit="s", utc=True)
|
||||
df_1min = df_1min[(df_1min["Timestamp"] >= pd.Timestamp(start, tz="UTC")) &
|
||||
(df_1min["Timestamp"] <= pd.Timestamp(end, tz="UTC"))] \
|
||||
.sort_values("Timestamp").reset_index(drop=True)
|
||||
|
||||
if timeframe_minutes != 1:
|
||||
g = df_1min.set_index("Timestamp").resample(f"{timeframe_minutes}min")
|
||||
df = pd.DataFrame({
|
||||
"Open": g["Open"].first(),
|
||||
"High": g["High"].max(),
|
||||
"Low": g["Low"].min(),
|
||||
"Close": g["Close"].last(),
|
||||
"Volume": g["Volume"].sum(),
|
||||
}).dropna().reset_index()
|
||||
else:
|
||||
df = df_1min.copy()
|
||||
|
||||
return df_1min, df
|
||||
BIN
data/multi_pair_model.pkl
Normal file
BIN
data/multi_pair_model.pkl
Normal file
Binary file not shown.
377
engine/backtester.py
Normal file
377
engine/backtester.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Core backtesting engine for running strategy simulations.
|
||||
|
||||
Supports multiple market types with realistic trading conditions.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketType, get_market_config
|
||||
from engine.optimizer import WalkForwardOptimizer
|
||||
from engine.portfolio import run_long_only_portfolio, run_long_short_portfolio
|
||||
from engine.risk import (
|
||||
LiquidationEvent,
|
||||
calculate_funding,
|
||||
calculate_liquidation_adjustment,
|
||||
inject_liquidation_exits,
|
||||
)
|
||||
from strategies.base import BaseStrategy
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BacktestResult:
|
||||
"""
|
||||
Container for backtest results with market-specific metrics.
|
||||
|
||||
Attributes:
|
||||
portfolio: VectorBT Portfolio object
|
||||
market_type: Market type used for the backtest
|
||||
leverage: Effective leverage used
|
||||
total_funding_paid: Total funding fees paid (perpetuals only)
|
||||
liquidation_count: Number of positions that were liquidated
|
||||
liquidation_events: Detailed list of liquidation events
|
||||
total_liquidation_loss: Total margin lost from liquidations
|
||||
adjusted_return: Return adjusted for liquidation losses (percentage)
|
||||
"""
|
||||
portfolio: vbt.Portfolio
|
||||
market_type: MarketType
|
||||
leverage: int
|
||||
total_funding_paid: float = 0.0
|
||||
liquidation_count: int = 0
|
||||
liquidation_events: list[LiquidationEvent] | None = None
|
||||
total_liquidation_loss: float = 0.0
|
||||
adjusted_return: float | None = None
|
||||
|
||||
|
||||
class Backtester:
|
||||
"""
|
||||
Backtester supporting multiple market types with realistic simulation.
|
||||
|
||||
Features:
|
||||
- Spot and Perpetual market support
|
||||
- Long and short position handling
|
||||
- Leverage simulation
|
||||
- Funding rate calculation (perpetuals)
|
||||
- Liquidation warnings
|
||||
"""
|
||||
|
||||
def __init__(self, data_manager: DataManager):
|
||||
self.dm = data_manager
|
||||
|
||||
def run_strategy(
|
||||
self,
|
||||
strategy: BaseStrategy,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str = '1m',
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
init_cash: float = 10000,
|
||||
fees: float | None = None,
|
||||
slippage: float = 0.001,
|
||||
sl_stop: float | None = None,
|
||||
tp_stop: float | None = None,
|
||||
sl_trail: bool = False,
|
||||
leverage: int | None = None,
|
||||
**strategy_params
|
||||
) -> BacktestResult:
|
||||
"""
|
||||
Run a backtest with market-type-aware simulation.
|
||||
|
||||
Args:
|
||||
strategy: Strategy instance to backtest
|
||||
exchange_id: Exchange identifier (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Data timeframe (e.g., '1m', '1h', '1d')
|
||||
start_date: Start date filter (YYYY-MM-DD)
|
||||
end_date: End date filter (YYYY-MM-DD)
|
||||
init_cash: Initial capital (margin for leveraged)
|
||||
fees: Transaction fee override (uses market default if None)
|
||||
slippage: Slippage percentage
|
||||
sl_stop: Stop loss percentage
|
||||
tp_stop: Take profit percentage
|
||||
sl_trail: Enable trailing stop loss
|
||||
leverage: Leverage override (uses strategy default if None)
|
||||
**strategy_params: Additional strategy parameters
|
||||
|
||||
Returns:
|
||||
BacktestResult with portfolio and market-specific metrics
|
||||
"""
|
||||
# Get market configuration from strategy
|
||||
market_type = strategy.default_market_type
|
||||
market_config = get_market_config(market_type)
|
||||
|
||||
# Resolve leverage and fees
|
||||
effective_leverage = self._resolve_leverage(leverage, strategy, market_type)
|
||||
effective_fees = fees if fees is not None else market_config.taker_fee
|
||||
|
||||
# Load and filter data
|
||||
df = self._load_data(
|
||||
exchange_id, symbol, timeframe, market_type, start_date, end_date
|
||||
)
|
||||
|
||||
close_price = df['close']
|
||||
high_price = df['high']
|
||||
low_price = df['low']
|
||||
open_price = df['open']
|
||||
volume = df['volume']
|
||||
|
||||
# Run strategy logic
|
||||
signals = strategy.run(
|
||||
close_price,
|
||||
high=high_price,
|
||||
low=low_price,
|
||||
open=open_price,
|
||||
volume=volume,
|
||||
**strategy_params
|
||||
)
|
||||
|
||||
# Normalize signals to 5-tuple format
|
||||
signals = self._normalize_signals(signals, close_price, market_config)
|
||||
long_entries, long_exits, short_entries, short_exits, size = signals
|
||||
|
||||
# Default size if None
|
||||
if size is None:
|
||||
size = 1.0
|
||||
|
||||
# Convert leverage multiplier to raw value (USD) for vbt
|
||||
# This works around "SizeType.Percent reversal" error
|
||||
# Effectively "Fixed Fractional" sizing based on Initial Capital
|
||||
# (Does not compound, but safe for backtesting)
|
||||
if isinstance(size, pd.Series):
|
||||
size = size * init_cash
|
||||
else:
|
||||
size = size * init_cash
|
||||
|
||||
# Process liquidations - inject forced exits at liquidation points
|
||||
liquidation_events: list[LiquidationEvent] = []
|
||||
if effective_leverage > 1:
|
||||
long_exits, short_exits, liquidation_events = inject_liquidation_exits(
|
||||
close_price, high_price, low_price,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
effective_leverage,
|
||||
market_config.maintenance_margin_rate
|
||||
)
|
||||
|
||||
# Calculate perpetual-specific metrics (after liquidation processing)
|
||||
total_funding = 0.0
|
||||
if market_type == MarketType.PERPETUAL:
|
||||
total_funding = calculate_funding(
|
||||
close_price,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
market_config,
|
||||
effective_leverage
|
||||
)
|
||||
|
||||
# Run portfolio simulation with liquidation-aware exits
|
||||
portfolio = self._run_portfolio(
|
||||
close_price, market_config,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
init_cash, effective_fees, slippage, timeframe,
|
||||
sl_stop, tp_stop, sl_trail, effective_leverage,
|
||||
size=size
|
||||
)
|
||||
|
||||
# Calculate adjusted returns accounting for liquidation losses
|
||||
total_liq_loss, liq_adjustment = calculate_liquidation_adjustment(
|
||||
liquidation_events, init_cash, effective_leverage
|
||||
)
|
||||
|
||||
raw_return = portfolio.total_return().mean() * 100
|
||||
adjusted_return = raw_return - liq_adjustment
|
||||
|
||||
if liquidation_events:
|
||||
logger.info(
|
||||
"Liquidation impact: %d events, $%.2f margin lost, %.2f%% adjustment",
|
||||
len(liquidation_events), total_liq_loss, liq_adjustment
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Backtest completed: %s market, %dx leverage, fees=%.4f%%",
|
||||
market_type.value, effective_leverage, effective_fees * 100
|
||||
)
|
||||
|
||||
return BacktestResult(
|
||||
portfolio=portfolio,
|
||||
market_type=market_type,
|
||||
leverage=effective_leverage,
|
||||
total_funding_paid=total_funding,
|
||||
liquidation_count=len(liquidation_events),
|
||||
liquidation_events=liquidation_events,
|
||||
total_liquidation_loss=total_liq_loss,
|
||||
adjusted_return=adjusted_return
|
||||
)
|
||||
|
||||
def _resolve_leverage(
|
||||
self,
|
||||
leverage: int | None,
|
||||
strategy: BaseStrategy,
|
||||
market_type: MarketType
|
||||
) -> int:
|
||||
"""Resolve effective leverage from CLI, strategy default, or market type."""
|
||||
effective = leverage or strategy.default_leverage
|
||||
if market_type == MarketType.SPOT:
|
||||
return 1 # Spot cannot have leverage
|
||||
return effective
|
||||
|
||||
def _load_data(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
market_type: MarketType,
|
||||
start_date: str | None,
|
||||
end_date: str | None
|
||||
) -> pd.DataFrame:
|
||||
"""Load and filter OHLCV data."""
|
||||
try:
|
||||
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
|
||||
except FileNotFoundError:
|
||||
logger.warning("Data not found locally. Attempting download...")
|
||||
df = self.dm.download_data(
|
||||
exchange_id, symbol, timeframe,
|
||||
start_date, end_date, market_type
|
||||
)
|
||||
|
||||
if start_date:
|
||||
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
|
||||
if end_date:
|
||||
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
|
||||
|
||||
return df
|
||||
|
||||
def _normalize_signals(
|
||||
self,
|
||||
signals: tuple,
|
||||
close: pd.Series,
|
||||
market_config
|
||||
) -> tuple:
|
||||
"""
|
||||
Normalize strategy signals to 5-tuple format.
|
||||
|
||||
Returns:
|
||||
(long_entries, long_exits, short_entries, short_exits, size)
|
||||
"""
|
||||
# Default size is None (will be treated as 1.0 or default later)
|
||||
size = None
|
||||
|
||||
if len(signals) == 2:
|
||||
long_entries, long_exits = signals
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
return long_entries, long_exits, short_entries, short_exits, size
|
||||
|
||||
if len(signals) == 4:
|
||||
long_entries, long_exits, short_entries, short_exits = signals
|
||||
elif len(signals) == 5:
|
||||
long_entries, long_exits, short_entries, short_exits, size = signals
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Strategy must return 2, 4, or 5 signal arrays, got {len(signals)}"
|
||||
)
|
||||
|
||||
# Warn and clear short signals on spot markets
|
||||
if not market_config.supports_short:
|
||||
has_shorts = (
|
||||
short_entries.any().any()
|
||||
if hasattr(short_entries, 'any')
|
||||
else short_entries.any()
|
||||
)
|
||||
if has_shorts:
|
||||
logger.warning(
|
||||
"Short signals detected but market type is SPOT. "
|
||||
"Short signals will be ignored."
|
||||
)
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits, size
|
||||
|
||||
def _run_portfolio(
|
||||
self,
|
||||
close: pd.Series,
|
||||
market_config,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
init_cash: float,
|
||||
fees: float,
|
||||
slippage: float,
|
||||
freq: str,
|
||||
sl_stop: float | None,
|
||||
tp_stop: float | None,
|
||||
sl_trail: bool,
|
||||
leverage: int,
|
||||
size: pd.Series | float = 1.0
|
||||
) -> vbt.Portfolio:
|
||||
"""Select and run appropriate portfolio simulation."""
|
||||
has_shorts = (
|
||||
short_entries.any().any()
|
||||
if hasattr(short_entries, 'any')
|
||||
else short_entries.any()
|
||||
)
|
||||
|
||||
if market_config.supports_short and has_shorts:
|
||||
return run_long_short_portfolio(
|
||||
close,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
init_cash, fees, slippage, freq,
|
||||
sl_stop, tp_stop, sl_trail, leverage,
|
||||
size=size
|
||||
)
|
||||
|
||||
return run_long_only_portfolio(
|
||||
close,
|
||||
long_entries, long_exits,
|
||||
init_cash, fees, slippage, freq,
|
||||
sl_stop, tp_stop, sl_trail, leverage,
|
||||
# Long-only doesn't support variable size in current implementation
|
||||
# without modification, but we can add it if needed.
|
||||
# For now, only regime strategy uses it, which is Long/Short.
|
||||
)
|
||||
|
||||
def run_wfa(
|
||||
self,
|
||||
strategy: BaseStrategy,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
param_grid: dict,
|
||||
n_windows: int = 10,
|
||||
timeframe: str = '1m'
|
||||
):
|
||||
"""
|
||||
Execute Walk-Forward Analysis.
|
||||
|
||||
Args:
|
||||
strategy: Strategy instance to optimize
|
||||
exchange_id: Exchange identifier
|
||||
symbol: Trading pair symbol
|
||||
param_grid: Parameter grid for optimization
|
||||
n_windows: Number of walk-forward windows
|
||||
timeframe: Data timeframe to load
|
||||
|
||||
Returns:
|
||||
Tuple of (results DataFrame, stitched equity curve)
|
||||
"""
|
||||
market_type = strategy.default_market_type
|
||||
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
|
||||
|
||||
wfa = WalkForwardOptimizer(self, strategy, param_grid)
|
||||
|
||||
results, stitched_curve = wfa.run(
|
||||
df['close'],
|
||||
high=df['high'],
|
||||
low=df['low'],
|
||||
n_windows=n_windows
|
||||
)
|
||||
|
||||
return results, stitched_curve
|
||||
243
engine/cli.py
Normal file
243
engine/cli.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
CLI handler for Lowkey Backtest.
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from engine.backtester import Backtester
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import get_logger, setup_logging
|
||||
from engine.market import MarketType
|
||||
from engine.reporting import Reporter
|
||||
from strategies.factory import get_strategy, get_strategy_names
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_parser() -> argparse.ArgumentParser:
|
||||
"""Create and configure the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Lowkey Backtest CLI (VectorBT Edition)"
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
_add_download_parser(subparsers)
|
||||
_add_backtest_parser(subparsers)
|
||||
_add_wfa_parser(subparsers)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _add_download_parser(subparsers) -> None:
|
||||
"""Add download command parser."""
|
||||
dl_parser = subparsers.add_parser("download", help="Download historical data")
|
||||
dl_parser.add_argument("--exchange", "-e", type=str, default="okx")
|
||||
dl_parser.add_argument("--pair", "-p", type=str, required=True)
|
||||
dl_parser.add_argument("--timeframe", "-t", type=str, default="1m")
|
||||
dl_parser.add_argument("--start", type=str, help="Start Date (YYYY-MM-DD)")
|
||||
dl_parser.add_argument(
|
||||
"--market", "-m",
|
||||
type=str,
|
||||
choices=["spot", "perpetual"],
|
||||
default="spot"
|
||||
)
|
||||
|
||||
|
||||
def _add_backtest_parser(subparsers) -> None:
|
||||
"""Add backtest command parser."""
|
||||
strategy_choices = get_strategy_names()
|
||||
|
||||
bt_parser = subparsers.add_parser("backtest", help="Run a backtest")
|
||||
bt_parser.add_argument(
|
||||
"--strategy", "-s",
|
||||
type=str,
|
||||
choices=strategy_choices,
|
||||
required=True
|
||||
)
|
||||
bt_parser.add_argument("--exchange", "-e", type=str, default="okx")
|
||||
bt_parser.add_argument("--pair", "-p", type=str, required=True)
|
||||
bt_parser.add_argument("--timeframe", "-t", type=str, default="1m")
|
||||
bt_parser.add_argument("--start", type=str)
|
||||
bt_parser.add_argument("--end", type=str)
|
||||
bt_parser.add_argument("--grid", "-g", action="store_true")
|
||||
bt_parser.add_argument("--plot", action="store_true")
|
||||
|
||||
# Risk parameters
|
||||
bt_parser.add_argument("--sl", type=float, help="Stop Loss %%")
|
||||
bt_parser.add_argument("--tp", type=float, help="Take Profit %%")
|
||||
bt_parser.add_argument("--trail", action="store_true")
|
||||
bt_parser.add_argument("--no-bear-exit", action="store_true")
|
||||
|
||||
# Cost parameters
|
||||
bt_parser.add_argument("--fees", type=float, default=None)
|
||||
bt_parser.add_argument("--slippage", type=float, default=0.001)
|
||||
bt_parser.add_argument("--leverage", "-l", type=int, default=None)
|
||||
|
||||
|
||||
def _add_wfa_parser(subparsers) -> None:
|
||||
"""Add walk-forward analysis command parser."""
|
||||
strategy_choices = get_strategy_names()
|
||||
|
||||
wfa_parser = subparsers.add_parser("wfa", help="Run Walk-Forward Analysis")
|
||||
wfa_parser.add_argument(
|
||||
"--strategy", "-s",
|
||||
type=str,
|
||||
choices=strategy_choices,
|
||||
required=True
|
||||
)
|
||||
wfa_parser.add_argument("--pair", "-p", type=str, required=True)
|
||||
wfa_parser.add_argument("--timeframe", "-t", type=str, default="1d")
|
||||
wfa_parser.add_argument("--windows", "-w", type=int, default=10)
|
||||
wfa_parser.add_argument("--plot", action="store_true")
|
||||
|
||||
|
||||
def run_download(args) -> None:
|
||||
"""Execute download command."""
|
||||
dm = DataManager()
|
||||
market_type = MarketType(args.market)
|
||||
dm.download_data(
|
||||
args.exchange,
|
||||
args.pair,
|
||||
args.timeframe,
|
||||
start_date=args.start,
|
||||
market_type=market_type
|
||||
)
|
||||
|
||||
|
||||
def run_backtest(args) -> None:
|
||||
"""Execute backtest command."""
|
||||
dm = DataManager()
|
||||
bt = Backtester(dm)
|
||||
reporter = Reporter()
|
||||
|
||||
strategy, params = get_strategy(args.strategy, args.grid)
|
||||
|
||||
# Apply CLI overrides for meta_st strategy
|
||||
params = _apply_strategy_overrides(args, strategy, params)
|
||||
|
||||
if args.grid and args.strategy == "meta_st":
|
||||
logger.info("Running Grid Search for Meta Supertrend...")
|
||||
|
||||
try:
|
||||
result = bt.run_strategy(
|
||||
strategy,
|
||||
args.exchange,
|
||||
args.pair,
|
||||
timeframe=args.timeframe,
|
||||
start_date=args.start,
|
||||
end_date=args.end,
|
||||
fees=args.fees,
|
||||
slippage=args.slippage,
|
||||
sl_stop=args.sl,
|
||||
tp_stop=args.tp,
|
||||
sl_trail=args.trail,
|
||||
leverage=args.leverage,
|
||||
**params
|
||||
)
|
||||
|
||||
reporter.print_summary(result)
|
||||
reporter.save_reports(result, f"{args.strategy}_{args.pair.replace('/','-')}")
|
||||
|
||||
if args.plot and not args.grid:
|
||||
reporter.plot(result.portfolio)
|
||||
elif args.plot and args.grid:
|
||||
logger.info("Plotting skipped for Grid Search. Check CSV results.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Backtest failed: %s", e, exc_info=True)
|
||||
|
||||
|
||||
def run_wfa(args) -> None:
|
||||
"""Execute walk-forward analysis command."""
|
||||
dm = DataManager()
|
||||
bt = Backtester(dm)
|
||||
reporter = Reporter()
|
||||
|
||||
strategy, params = get_strategy(args.strategy, is_grid=True)
|
||||
|
||||
logger.info(
|
||||
"Running WFA on %s for %s (%s) with %d windows...",
|
||||
args.strategy, args.pair, args.timeframe, args.windows
|
||||
)
|
||||
|
||||
try:
|
||||
results, stitched_curve = bt.run_wfa(
|
||||
strategy,
|
||||
"okx",
|
||||
args.pair,
|
||||
params,
|
||||
n_windows=args.windows,
|
||||
timeframe=args.timeframe
|
||||
)
|
||||
|
||||
_log_wfa_results(results)
|
||||
_save_wfa_results(args, results, stitched_curve, reporter)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("WFA failed: %s", e, exc_info=True)
|
||||
|
||||
|
||||
def _apply_strategy_overrides(args, strategy, params: dict) -> dict:
|
||||
"""Apply CLI argument overrides to strategy parameters."""
|
||||
if args.strategy != "meta_st":
|
||||
return params
|
||||
|
||||
if args.no_bear_exit:
|
||||
params['exit_on_bearish_flip'] = False
|
||||
|
||||
if args.sl is None:
|
||||
args.sl = strategy.default_sl_stop
|
||||
|
||||
if not args.trail:
|
||||
args.trail = strategy.default_sl_trail
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _log_wfa_results(results) -> None:
|
||||
"""Log WFA results summary."""
|
||||
logger.info("Walk-Forward Analysis Results:")
|
||||
|
||||
if results.empty or 'window' not in results.columns:
|
||||
logger.warning("No valid WFA results. All windows may have failed.")
|
||||
return
|
||||
|
||||
columns = ['window', 'train_score', 'test_score', 'test_return']
|
||||
logger.info("\n%s", results[columns].to_string(index=False))
|
||||
|
||||
avg_test_sharpe = results['test_score'].mean()
|
||||
avg_test_return = results['test_return'].mean()
|
||||
logger.info("Average Test Sharpe: %.2f", avg_test_sharpe)
|
||||
logger.info("Average Test Return: %.2f%%", avg_test_return * 100)
|
||||
|
||||
|
||||
def _save_wfa_results(args, results, stitched_curve, reporter) -> None:
|
||||
"""Save WFA results to file and optionally plot."""
|
||||
if results.empty:
|
||||
return
|
||||
|
||||
output_path = f"backtest_logs/wfa_{args.strategy}_{args.pair.replace('/','-')}.csv"
|
||||
results.to_csv(output_path)
|
||||
logger.info("Saved full results to %s", output_path)
|
||||
|
||||
if args.plot:
|
||||
reporter.plot_wfa(results, stitched_curve)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
setup_logging()
|
||||
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
commands = {
|
||||
"download": run_download,
|
||||
"backtest": run_backtest,
|
||||
"wfa": run_wfa,
|
||||
}
|
||||
|
||||
if args.command in commands:
|
||||
commands[args.command](args)
|
||||
else:
|
||||
parser.print_help()
|
||||
201
engine/cryptoquant.py
Normal file
201
engine/cryptoquant.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import requests
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load env vars from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Fix path for direct execution
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class CryptoQuantClient:
|
||||
"""
|
||||
Client for fetching data from CryptoQuant API.
|
||||
"""
|
||||
BASE_URL = "https://api.cryptoquant.com/v1"
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.getenv("CRYPTOQUANT_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("CryptoQuant API Key not found. Set CRYPTOQUANT_API_KEY env var.")
|
||||
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
|
||||
def fetch_metric(
|
||||
self,
|
||||
metric_path: str,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
exchange: str | None = "all_exchange",
|
||||
window: str = "day"
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch a specific metric from CryptoQuant.
|
||||
"""
|
||||
url = f"{self.BASE_URL}/{metric_path}"
|
||||
|
||||
params = {
|
||||
"window": window,
|
||||
"from": start_date,
|
||||
"to": end_date,
|
||||
"limit": 100000
|
||||
}
|
||||
|
||||
if exchange:
|
||||
params["exchange"] = exchange
|
||||
|
||||
logger.info(f"Fetching {metric_path} for {symbol} ({start_date}-{end_date})...")
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=self.headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if 'result' in data and 'data' in data['result']:
|
||||
df = pd.DataFrame(data['result']['data'])
|
||||
if not df.empty:
|
||||
if 'date' in df.columns:
|
||||
df['timestamp'] = pd.to_datetime(df['date'])
|
||||
df.set_index('timestamp', inplace=True)
|
||||
df.sort_index(inplace=True)
|
||||
return df
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CQ data {metric_path}: {e}")
|
||||
if 'response' in locals() and hasattr(response, 'text'):
|
||||
logger.error(f"Response: {response.text}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def fetch_multi_metrics(self, symbols: list[str], metrics: dict, start_date: str, end_date: str):
|
||||
"""
|
||||
Fetch multiple metrics for multiple symbols and combine them.
|
||||
"""
|
||||
combined_df = pd.DataFrame()
|
||||
|
||||
for symbol in symbols:
|
||||
asset = symbol.lower()
|
||||
|
||||
for metric_name, api_path in metrics.items():
|
||||
full_path = f"{asset}/{api_path}"
|
||||
|
||||
# Some metrics (like funding rates) might need specific exchange vs all_exchange
|
||||
# Defaulting to all_exchange is usually safe for flows, but check specific logic if needed
|
||||
exchange_param = "all_exchange"
|
||||
if "funding-rates" in api_path:
|
||||
# For funding rates, 'all_exchange' might not be valid or might be aggregated
|
||||
# Let's try 'binance' as a proxy for market sentiment if all fails,
|
||||
# or keep 'all_exchange' if supported.
|
||||
# Based on testing, 'all_exchange' is standard for flows.
|
||||
pass
|
||||
|
||||
df = self.fetch_metric(full_path, asset, start_date, end_date, exchange=exchange_param)
|
||||
|
||||
if not df.empty:
|
||||
target_col = None
|
||||
# Heuristic to find the value column
|
||||
candidates = ['funding_rate', 'reserve', 'inflow_total', 'outflow_total', 'open_interest', 'ratio', 'value']
|
||||
|
||||
for col in df.columns:
|
||||
if col in candidates:
|
||||
target_col = col
|
||||
break
|
||||
|
||||
if not target_col:
|
||||
# Fallback: take first numeric col that isn't date
|
||||
for col in df.columns:
|
||||
if col not in ['date', 'datetime', 'timestamp_str', 'block_height']:
|
||||
target_col = col
|
||||
break
|
||||
|
||||
if target_col:
|
||||
col_name = f"{asset}_{metric_name}"
|
||||
subset = df[[target_col]].rename(columns={target_col: col_name})
|
||||
|
||||
if combined_df.empty:
|
||||
combined_df = subset
|
||||
else:
|
||||
combined_df = combined_df.join(subset, how='outer')
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
return combined_df
|
||||
|
||||
def fetch_history_chunked(
|
||||
self,
|
||||
symbols: list[str],
|
||||
metrics: dict,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
chunk_months: int = 3
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch historical data in chunks to avoid API limits.
|
||||
"""
|
||||
start_dt = datetime.strptime(start_date, "%Y%m%d")
|
||||
end_dt = datetime.strptime(end_date, "%Y%m%d")
|
||||
|
||||
all_data = []
|
||||
|
||||
current = start_dt
|
||||
while current < end_dt:
|
||||
next_chunk = current + timedelta(days=chunk_months * 30)
|
||||
if next_chunk > end_dt:
|
||||
next_chunk = end_dt
|
||||
|
||||
s_str = current.strftime("%Y%m%d")
|
||||
e_str = next_chunk.strftime("%Y%m%d")
|
||||
|
||||
logger.info(f"Processing chunk: {s_str} to {e_str}")
|
||||
chunk_df = self.fetch_multi_metrics(symbols, metrics, s_str, e_str)
|
||||
|
||||
if not chunk_df.empty:
|
||||
all_data.append(chunk_df)
|
||||
|
||||
current = next_chunk + timedelta(days=1)
|
||||
time.sleep(1) # Be nice to API
|
||||
|
||||
if not all_data:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Combine all chunks
|
||||
full_df = pd.concat(all_data)
|
||||
# Remove duplicates if any overlap
|
||||
full_df = full_df[~full_df.index.duplicated(keep='first')]
|
||||
full_df.sort_index(inplace=True)
|
||||
|
||||
return full_df
|
||||
|
||||
if __name__ == "__main__":
|
||||
cq = CryptoQuantClient()
|
||||
|
||||
# 12 Months Data (Jan 1 2025 - Jan 14 2026)
|
||||
start = "20250101"
|
||||
end = "20260114"
|
||||
|
||||
metrics = {
|
||||
"reserves": "exchange-flows/exchange-reserve",
|
||||
"inflow": "exchange-flows/inflow",
|
||||
"funding": "market-data/funding-rates"
|
||||
}
|
||||
|
||||
print(f"Fetching training data from {start} to {end}...")
|
||||
df = cq.fetch_history_chunked(["btc", "eth"], metrics, start, end)
|
||||
|
||||
output_file = "data/cq_training_data.csv"
|
||||
os.makedirs("data", exist_ok=True)
|
||||
df.to_csv(output_file)
|
||||
print(f"\nSaved {len(df)} rows to {output_file}")
|
||||
print(df.head())
|
||||
209
engine/data_manager.py
Normal file
209
engine/data_manager.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Data management for OHLCV data download and storage.
|
||||
|
||||
Handles data retrieval from exchanges and local file management.
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import ccxt
|
||||
import pandas as pd
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketType, get_ccxt_symbol
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataManager:
|
||||
"""
|
||||
Manages OHLCV data download and storage for different market types.
|
||||
|
||||
Data is stored in: data/ccxt/{exchange}/{market_type}/{symbol}/{timeframe}.csv
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str = "data/ccxt"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.exchanges: dict[str, ccxt.Exchange] = {}
|
||||
|
||||
def get_exchange(self, exchange_id: str) -> ccxt.Exchange:
|
||||
"""Get or create a CCXT exchange instance."""
|
||||
if exchange_id not in self.exchanges:
|
||||
exchange_class = getattr(ccxt, exchange_id)
|
||||
self.exchanges[exchange_id] = exchange_class({
|
||||
'enableRateLimit': True,
|
||||
})
|
||||
return self.exchanges[exchange_id]
|
||||
|
||||
def _get_data_path(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
market_type: MarketType
|
||||
) -> Path:
|
||||
"""
|
||||
Get the file path for storing/loading data.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange name (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Candle timeframe (e.g., '1m')
|
||||
market_type: Market type (spot or perpetual)
|
||||
|
||||
Returns:
|
||||
Path to the CSV file
|
||||
"""
|
||||
safe_symbol = symbol.replace('/', '-')
|
||||
return (
|
||||
self.data_dir
|
||||
/ exchange_id
|
||||
/ market_type.value
|
||||
/ safe_symbol
|
||||
/ f"{timeframe}.csv"
|
||||
)
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str = '1m',
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
market_type: MarketType = MarketType.SPOT
|
||||
) -> pd.DataFrame | None:
|
||||
"""
|
||||
Download OHLCV data from exchange and save to CSV.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange name (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Candle timeframe (e.g., '1m')
|
||||
start_date: Start date string (YYYY-MM-DD)
|
||||
end_date: End date string (YYYY-MM-DD)
|
||||
market_type: Market type (spot or perpetual)
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data, or None if download failed
|
||||
"""
|
||||
exchange = self.get_exchange(exchange_id)
|
||||
|
||||
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ccxt_symbol = get_ccxt_symbol(symbol, market_type)
|
||||
|
||||
since, until = self._parse_date_range(exchange, start_date, end_date)
|
||||
|
||||
logger.info(
|
||||
"Downloading %s (%s) from %s...",
|
||||
symbol, market_type.value, exchange_id
|
||||
)
|
||||
|
||||
all_ohlcv = self._fetch_all_candles(exchange, ccxt_symbol, timeframe, since, until)
|
||||
|
||||
if not all_ohlcv:
|
||||
logger.warning("No data downloaded.")
|
||||
return None
|
||||
|
||||
df = self._convert_to_dataframe(all_ohlcv)
|
||||
df.to_csv(file_path)
|
||||
logger.info("Saved %d candles to %s", len(df), file_path)
|
||||
return df
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str = '1m',
|
||||
market_type: MarketType = MarketType.SPOT
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Load saved OHLCV data for vectorbt.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange name (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Candle timeframe (e.g., '1m')
|
||||
market_type: Market type (spot or perpetual)
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data indexed by timestamp
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If data file does not exist
|
||||
"""
|
||||
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Data not found at {file_path}. "
|
||||
f"Run: uv run python main.py download --pair {symbol} "
|
||||
f"--market {market_type.value}"
|
||||
)
|
||||
|
||||
return pd.read_csv(file_path, index_col='timestamp', parse_dates=True)
|
||||
|
||||
def _parse_date_range(
|
||||
self,
|
||||
exchange: ccxt.Exchange,
|
||||
start_date: str | None,
|
||||
end_date: str | None
|
||||
) -> tuple[int, int]:
|
||||
"""Parse date strings into millisecond timestamps."""
|
||||
if start_date:
|
||||
since = exchange.parse8601(f"{start_date}T00:00:00Z")
|
||||
else:
|
||||
since = exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
|
||||
|
||||
if end_date:
|
||||
until = exchange.parse8601(f"{end_date}T23:59:59Z")
|
||||
else:
|
||||
until = exchange.milliseconds()
|
||||
|
||||
return since, until
|
||||
|
||||
def _fetch_all_candles(
|
||||
self,
|
||||
exchange: ccxt.Exchange,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
since: int,
|
||||
until: int
|
||||
) -> list:
|
||||
"""Fetch all candles in the date range."""
|
||||
all_ohlcv = []
|
||||
|
||||
while since < until:
|
||||
try:
|
||||
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since, limit=100)
|
||||
if not ohlcv:
|
||||
break
|
||||
|
||||
all_ohlcv.extend(ohlcv)
|
||||
since = ohlcv[-1][0] + 1
|
||||
|
||||
current_date = datetime.fromtimestamp(
|
||||
since/1000, tz=timezone.utc
|
||||
).strftime('%Y-%m-%d')
|
||||
logger.debug("Fetched up to %s", current_date)
|
||||
|
||||
time.sleep(exchange.rateLimit / 1000)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error fetching data: %s", e)
|
||||
break
|
||||
|
||||
return all_ohlcv
|
||||
|
||||
def _convert_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
|
||||
"""Convert OHLCV list to DataFrame."""
|
||||
df = pd.DataFrame(
|
||||
ohlcv,
|
||||
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
return df
|
||||
124
engine/logging_config.py
Normal file
124
engine/logging_config.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Centralized logging configuration for the backtest engine.
|
||||
|
||||
Provides colored console output and rotating file logs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
class Colors:
|
||||
"""ANSI escape codes for colored terminal output."""
|
||||
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
|
||||
# Log level colors
|
||||
DEBUG = "\033[36m" # Cyan
|
||||
INFO = "\033[32m" # Green
|
||||
WARNING = "\033[33m" # Yellow
|
||||
ERROR = "\033[31m" # Red
|
||||
CRITICAL = "\033[35m" # Magenta
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""
|
||||
Custom formatter that adds colors to log level names in terminal output.
|
||||
"""
|
||||
|
||||
LEVEL_COLORS = {
|
||||
logging.DEBUG: Colors.DEBUG,
|
||||
logging.INFO: Colors.INFO,
|
||||
logging.WARNING: Colors.WARNING,
|
||||
logging.ERROR: Colors.ERROR,
|
||||
logging.CRITICAL: Colors.CRITICAL,
|
||||
}
|
||||
|
||||
def __init__(self, fmt: str = None, datefmt: str = None):
|
||||
super().__init__(fmt, datefmt)
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Save original levelname
|
||||
original_levelname = record.levelname
|
||||
|
||||
# Add color to levelname
|
||||
color = self.LEVEL_COLORS.get(record.levelno, Colors.RESET)
|
||||
record.levelname = f"{color}{record.levelname}{Colors.RESET}"
|
||||
|
||||
# Format the message
|
||||
result = super().format(record)
|
||||
|
||||
# Restore original levelname
|
||||
record.levelname = original_levelname
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def setup_logging(
|
||||
log_dir: str = "logs",
|
||||
log_level: int = logging.INFO,
|
||||
console_level: int = logging.INFO,
|
||||
max_bytes: int = 5 * 1024 * 1024, # 5MB
|
||||
backup_count: int = 3
|
||||
) -> None:
|
||||
"""
|
||||
Configure logging for the application.
|
||||
|
||||
Args:
|
||||
log_dir: Directory for log files
|
||||
log_level: File logging level
|
||||
console_level: Console logging level
|
||||
max_bytes: Max size per log file before rotation
|
||||
backup_count: Number of backup files to keep
|
||||
"""
|
||||
log_path = Path(log_dir)
|
||||
log_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.DEBUG) # Capture all, handlers filter
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Console handler with colors
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(console_level)
|
||||
console_fmt = ColoredFormatter(
|
||||
fmt="[%(asctime)s] %(levelname)s - %(message)s",
|
||||
datefmt="%H:%M:%S"
|
||||
)
|
||||
console_handler.setFormatter(console_fmt)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler with rotation
|
||||
file_handler = RotatingFileHandler(
|
||||
log_path / "backtest.log",
|
||||
maxBytes=max_bytes,
|
||||
backupCount=backup_count,
|
||||
encoding="utf-8"
|
||||
)
|
||||
file_handler.setLevel(log_level)
|
||||
file_fmt = logging.Formatter(
|
||||
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
file_handler.setFormatter(file_fmt)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger instance for the given module name.
|
||||
|
||||
Args:
|
||||
name: Module name (typically __name__)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
168
engine/market.py
Normal file
168
engine/market.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Market type definitions and configuration for backtesting.
|
||||
|
||||
Supports different market types with their specific trading conditions:
|
||||
- SPOT: No leverage, no funding, long-only
|
||||
- PERPETUAL: Leverage, funding rates, long/short
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MarketType(Enum):
|
||||
"""Supported market types for backtesting."""
|
||||
SPOT = "spot"
|
||||
PERPETUAL = "perpetual"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MarketConfig:
|
||||
"""
|
||||
Configuration for a specific market type.
|
||||
|
||||
Attributes:
|
||||
market_type: The market type enum value
|
||||
maker_fee: Maker fee as decimal (e.g., 0.0008 = 0.08%)
|
||||
taker_fee: Taker fee as decimal (e.g., 0.001 = 0.1%)
|
||||
max_leverage: Maximum allowed leverage
|
||||
funding_rate: Funding rate per 8 hours as decimal (perpetuals only)
|
||||
funding_interval_hours: Hours between funding payments
|
||||
maintenance_margin_rate: Rate for liquidation calculation
|
||||
supports_short: Whether short-selling is supported
|
||||
"""
|
||||
market_type: MarketType
|
||||
maker_fee: float
|
||||
taker_fee: float
|
||||
max_leverage: int
|
||||
funding_rate: float
|
||||
funding_interval_hours: int
|
||||
maintenance_margin_rate: float
|
||||
supports_short: bool
|
||||
|
||||
|
||||
# OKX-based default configurations
|
||||
SPOT_CONFIG = MarketConfig(
|
||||
market_type=MarketType.SPOT,
|
||||
maker_fee=0.0008, # 0.08%
|
||||
taker_fee=0.0010, # 0.10%
|
||||
max_leverage=1,
|
||||
funding_rate=0.0,
|
||||
funding_interval_hours=0,
|
||||
maintenance_margin_rate=0.0,
|
||||
supports_short=False,
|
||||
)
|
||||
|
||||
PERPETUAL_CONFIG = MarketConfig(
|
||||
market_type=MarketType.PERPETUAL,
|
||||
maker_fee=0.0002, # 0.02%
|
||||
taker_fee=0.0005, # 0.05%
|
||||
max_leverage=125,
|
||||
funding_rate=0.0001, # 0.01% per 8 hours (simplified average)
|
||||
funding_interval_hours=8,
|
||||
maintenance_margin_rate=0.004, # 0.4% for BTC on OKX
|
||||
supports_short=True,
|
||||
)
|
||||
|
||||
|
||||
def get_market_config(market_type: MarketType) -> MarketConfig:
|
||||
"""
|
||||
Get the configuration for a specific market type.
|
||||
|
||||
Args:
|
||||
market_type: The market type to get configuration for
|
||||
|
||||
Returns:
|
||||
MarketConfig with default values for that market type
|
||||
"""
|
||||
configs = {
|
||||
MarketType.SPOT: SPOT_CONFIG,
|
||||
MarketType.PERPETUAL: PERPETUAL_CONFIG,
|
||||
}
|
||||
return configs[market_type]
|
||||
|
||||
|
||||
def get_ccxt_symbol(symbol: str, market_type: MarketType) -> str:
|
||||
"""
|
||||
Convert a standard symbol to CCXT format for the given market type.
|
||||
|
||||
Args:
|
||||
symbol: Standard symbol (e.g., 'BTC/USDT')
|
||||
market_type: The market type
|
||||
|
||||
Returns:
|
||||
CCXT-formatted symbol (e.g., 'BTC/USDT:USDT' for perpetuals)
|
||||
"""
|
||||
if market_type == MarketType.PERPETUAL:
|
||||
# OKX perpetual format: BTC/USDT:USDT
|
||||
if '/' in symbol:
|
||||
base, quote = symbol.split('/')
|
||||
return f"{symbol}:{quote}"
|
||||
elif '-' in symbol:
|
||||
base, quote = symbol.split('-')
|
||||
return f"{base}/{quote}:{quote}"
|
||||
else:
|
||||
# Assume base is symbol, quote is USDT default
|
||||
return f"{symbol}/USDT:USDT"
|
||||
|
||||
# For spot, normalize dash to slash for CCXT
|
||||
if '-' in symbol:
|
||||
return symbol.replace('-', '/')
|
||||
|
||||
return symbol
|
||||
|
||||
|
||||
def calculate_leverage_stop_loss(
|
||||
leverage: int,
|
||||
maintenance_margin_rate: float = 0.004
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the implicit stop-loss percentage from leverage.
|
||||
|
||||
At a given leverage, liquidation occurs when the position loses
|
||||
approximately (1/leverage - maintenance_margin_rate) of its value.
|
||||
|
||||
Args:
|
||||
leverage: Position leverage multiplier
|
||||
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
|
||||
|
||||
Returns:
|
||||
Stop-loss percentage as decimal (e.g., 0.196 for 19.6%)
|
||||
"""
|
||||
if leverage <= 1:
|
||||
return 1.0 # No forced stop for spot
|
||||
|
||||
return (1 / leverage) - maintenance_margin_rate
|
||||
|
||||
|
||||
def calculate_liquidation_price(
|
||||
entry_price: float,
|
||||
leverage: float,
|
||||
is_long: bool,
|
||||
maintenance_margin_rate: float = 0.004
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the liquidation price for a leveraged position.
|
||||
|
||||
Args:
|
||||
entry_price: Position entry price
|
||||
leverage: Position leverage
|
||||
is_long: True for long positions, False for short
|
||||
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
|
||||
|
||||
Returns:
|
||||
Liquidation price
|
||||
"""
|
||||
if leverage <= 1:
|
||||
return 0.0 if is_long else float('inf')
|
||||
|
||||
# Simplified liquidation formula
|
||||
# Long: liq_price = entry * (1 - 1/leverage + maintenance_margin_rate)
|
||||
# Short: liq_price = entry * (1 + 1/leverage - maintenance_margin_rate)
|
||||
margin_ratio = 1 / leverage
|
||||
|
||||
if is_long:
|
||||
liq_price = entry_price * (1 - margin_ratio + maintenance_margin_rate)
|
||||
else:
|
||||
liq_price = entry_price * (1 + margin_ratio - maintenance_margin_rate)
|
||||
|
||||
return liq_price
|
||||
245
engine/optimizer.py
Normal file
245
engine/optimizer.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Walk-Forward Analysis optimizer for strategy parameter optimization.
|
||||
|
||||
Implements expanding window walk-forward analysis with train/test splits.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_rolling_windows(
|
||||
index: pd.Index,
|
||||
n_windows: int,
|
||||
train_split: float = 0.7
|
||||
):
|
||||
"""
|
||||
Create rolling train/test split indices using expanding window approach.
|
||||
|
||||
Args:
|
||||
index: DataFrame index to split
|
||||
n_windows: Number of walk-forward windows
|
||||
train_split: Unused, kept for API compatibility
|
||||
|
||||
Yields:
|
||||
Tuples of (train_idx, test_idx) numpy arrays
|
||||
"""
|
||||
chunks = np.array_split(index, n_windows + 1)
|
||||
|
||||
for i in range(n_windows):
|
||||
train_idx = np.concatenate([c for c in chunks[:i+1]])
|
||||
test_idx = chunks[i+1]
|
||||
yield train_idx, test_idx
|
||||
|
||||
|
||||
class WalkForwardOptimizer:
|
||||
"""
|
||||
Walk-Forward Analysis optimizer for strategy backtesting.
|
||||
|
||||
Optimizes strategy parameters on training windows and validates
|
||||
on out-of-sample test windows.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backtester,
|
||||
strategy,
|
||||
param_grid: dict,
|
||||
metric: str = 'Sharpe Ratio',
|
||||
fees: float = 0.001,
|
||||
freq: str = '1m'
|
||||
):
|
||||
"""
|
||||
Initialize the optimizer.
|
||||
|
||||
Args:
|
||||
backtester: Backtester instance
|
||||
strategy: Strategy instance to optimize
|
||||
param_grid: Parameter grid for optimization
|
||||
metric: Performance metric to optimize
|
||||
fees: Transaction fees for simulation
|
||||
freq: Data frequency for portfolio simulation
|
||||
"""
|
||||
self.bt = backtester
|
||||
self.strategy = strategy
|
||||
self.param_grid = param_grid
|
||||
self.metric = metric
|
||||
self.fees = fees
|
||||
self.freq = freq
|
||||
|
||||
# Separate grid params (lists) from fixed params (scalars)
|
||||
self.grid_keys = []
|
||||
self.fixed_params = {}
|
||||
for k, v in param_grid.items():
|
||||
if isinstance(v, (list, np.ndarray)):
|
||||
self.grid_keys.append(k)
|
||||
else:
|
||||
self.fixed_params[k] = v
|
||||
|
||||
def run(
|
||||
self,
|
||||
close_price: pd.Series,
|
||||
high: pd.Series | None = None,
|
||||
low: pd.Series | None = None,
|
||||
n_windows: int = 10
|
||||
) -> tuple[pd.DataFrame, pd.Series | None]:
|
||||
"""
|
||||
Execute walk-forward analysis.
|
||||
|
||||
Args:
|
||||
close_price: Close price series
|
||||
high: High price series (optional)
|
||||
low: Low price series (optional)
|
||||
n_windows: Number of walk-forward windows
|
||||
|
||||
Returns:
|
||||
Tuple of (results DataFrame, stitched equity curve)
|
||||
"""
|
||||
results = []
|
||||
equity_curves = []
|
||||
|
||||
logger.info(
|
||||
"Starting Walk-Forward Analysis with %d windows (Expanding Train)...",
|
||||
n_windows
|
||||
)
|
||||
|
||||
splitter = create_rolling_windows(close_price.index, n_windows)
|
||||
|
||||
for i, (train_idx, test_idx) in enumerate(splitter):
|
||||
logger.info("Processing Window %d/%d...", i + 1, n_windows)
|
||||
|
||||
window_result = self._process_window(
|
||||
i, train_idx, test_idx, close_price, high, low
|
||||
)
|
||||
|
||||
if window_result is not None:
|
||||
result_dict, eq_curve = window_result
|
||||
results.append(result_dict)
|
||||
equity_curves.append(eq_curve)
|
||||
|
||||
stitched_series = self._stitch_equity_curves(equity_curves)
|
||||
return pd.DataFrame(results), stitched_series
|
||||
|
||||
def _process_window(
|
||||
self,
|
||||
window_idx: int,
|
||||
train_idx: np.ndarray,
|
||||
test_idx: np.ndarray,
|
||||
close_price: pd.Series,
|
||||
high: pd.Series | None,
|
||||
low: pd.Series | None
|
||||
) -> tuple[dict, pd.Series] | None:
|
||||
"""Process a single WFA window."""
|
||||
try:
|
||||
# Slice data for train/test
|
||||
train_close = close_price.loc[train_idx]
|
||||
train_high = high.loc[train_idx] if high is not None else None
|
||||
train_low = low.loc[train_idx] if low is not None else None
|
||||
|
||||
# Train phase: find best parameters
|
||||
best_params, best_score = self._optimize_train(
|
||||
train_close, train_high, train_low
|
||||
)
|
||||
|
||||
# Test phase: validate with best params
|
||||
test_close = close_price.loc[test_idx]
|
||||
test_high = high.loc[test_idx] if high is not None else None
|
||||
test_low = low.loc[test_idx] if low is not None else None
|
||||
|
||||
test_params = {**self.fixed_params, **best_params}
|
||||
test_score, test_return, eq_curve = self._run_test(
|
||||
test_close, test_high, test_low, test_params
|
||||
)
|
||||
|
||||
return {
|
||||
'window': window_idx + 1,
|
||||
'train_start': train_idx[0],
|
||||
'train_end': train_idx[-1],
|
||||
'test_start': test_idx[0],
|
||||
'test_end': test_idx[-1],
|
||||
'best_params': best_params,
|
||||
'train_score': best_score,
|
||||
'test_score': test_score,
|
||||
'test_return': test_return
|
||||
}, eq_curve
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in window %d: %s", window_idx + 1, e, exc_info=True)
|
||||
return None
|
||||
|
||||
def _optimize_train(
|
||||
self,
|
||||
close: pd.Series,
|
||||
high: pd.Series | None,
|
||||
low: pd.Series | None
|
||||
) -> tuple[dict, float]:
|
||||
"""Run grid search on training data to find best parameters."""
|
||||
entries, exits = self.strategy.run(
|
||||
close, high=high, low=low, **self.param_grid
|
||||
)
|
||||
|
||||
pf_train = vbt.Portfolio.from_signals(
|
||||
close, entries, exits,
|
||||
fees=self.fees,
|
||||
freq=self.freq
|
||||
)
|
||||
|
||||
perf_stats = pf_train.sharpe_ratio()
|
||||
perf_stats = perf_stats.fillna(-999)
|
||||
|
||||
best_idx = perf_stats.idxmax()
|
||||
best_score = perf_stats.max()
|
||||
|
||||
# Extract best params from grid search
|
||||
if len(self.grid_keys) == 1:
|
||||
best_params = {self.grid_keys[0]: best_idx}
|
||||
elif len(self.grid_keys) > 1:
|
||||
best_params = dict(zip(self.grid_keys, best_idx))
|
||||
else:
|
||||
best_params = {}
|
||||
|
||||
return best_params, best_score
|
||||
|
||||
def _run_test(
|
||||
self,
|
||||
close: pd.Series,
|
||||
high: pd.Series | None,
|
||||
low: pd.Series | None,
|
||||
params: dict
|
||||
) -> tuple[float, float, pd.Series]:
|
||||
"""Run test phase with given parameters."""
|
||||
entries, exits = self.strategy.run(
|
||||
close, high=high, low=low, **params
|
||||
)
|
||||
|
||||
pf_test = vbt.Portfolio.from_signals(
|
||||
close, entries, exits,
|
||||
fees=self.fees,
|
||||
freq=self.freq
|
||||
)
|
||||
|
||||
return pf_test.sharpe_ratio(), pf_test.total_return(), pf_test.value()
|
||||
|
||||
def _stitch_equity_curves(
|
||||
self,
|
||||
equity_curves: list[pd.Series]
|
||||
) -> pd.Series | None:
|
||||
"""Stitch multiple equity curves into a continuous series."""
|
||||
if not equity_curves:
|
||||
return None
|
||||
|
||||
stitched = [equity_curves[0]]
|
||||
for j in range(1, len(equity_curves)):
|
||||
prev_end_val = stitched[-1].iloc[-1]
|
||||
curr_curve = equity_curves[j]
|
||||
init_cash = curr_curve.iloc[0]
|
||||
|
||||
# Scale curve to continue from previous end value
|
||||
scaled_curve = (curr_curve / init_cash) * prev_end_val
|
||||
stitched.append(scaled_curve)
|
||||
|
||||
return pd.concat(stitched)
|
||||
108
engine/portfolio.py
Normal file
108
engine/portfolio.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Portfolio simulation utilities for backtesting.
|
||||
|
||||
Handles long-only and long/short portfolio creation using VectorBT.
|
||||
"""
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_long_only_portfolio(
|
||||
close: pd.Series,
|
||||
entries: pd.DataFrame,
|
||||
exits: pd.DataFrame,
|
||||
init_cash: float,
|
||||
fees: float,
|
||||
slippage: float,
|
||||
freq: str,
|
||||
sl_stop: float | None,
|
||||
tp_stop: float | None,
|
||||
sl_trail: bool,
|
||||
leverage: int
|
||||
) -> vbt.Portfolio:
|
||||
"""
|
||||
Run a long-only portfolio simulation.
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
entries: Entry signals
|
||||
exits: Exit signals
|
||||
init_cash: Initial capital
|
||||
fees: Transaction fee percentage
|
||||
slippage: Slippage percentage
|
||||
freq: Data frequency string
|
||||
sl_stop: Stop loss percentage
|
||||
tp_stop: Take profit percentage
|
||||
sl_trail: Enable trailing stop loss
|
||||
leverage: Leverage multiplier
|
||||
|
||||
Returns:
|
||||
VectorBT Portfolio object
|
||||
"""
|
||||
effective_cash = init_cash * leverage
|
||||
|
||||
return vbt.Portfolio.from_signals(
|
||||
close=close,
|
||||
entries=entries,
|
||||
exits=exits,
|
||||
init_cash=effective_cash,
|
||||
fees=fees,
|
||||
slippage=slippage,
|
||||
freq=freq,
|
||||
sl_stop=sl_stop,
|
||||
tp_stop=tp_stop,
|
||||
sl_trail=sl_trail,
|
||||
size=1.0,
|
||||
size_type='percent',
|
||||
)
|
||||
|
||||
|
||||
def run_long_short_portfolio(
|
||||
close: pd.Series,
|
||||
long_entries: pd.DataFrame,
|
||||
long_exits: pd.DataFrame,
|
||||
short_entries: pd.DataFrame,
|
||||
short_exits: pd.DataFrame,
|
||||
init_cash: float,
|
||||
fees: float,
|
||||
slippage: float,
|
||||
freq: str,
|
||||
sl_stop: float | None,
|
||||
tp_stop: float | None,
|
||||
sl_trail: bool,
|
||||
leverage: int,
|
||||
size: pd.Series | float = 1.0,
|
||||
size_type: str = 'value' # Changed to 'value' to support reversals/sizing
|
||||
) -> vbt.Portfolio:
|
||||
"""
|
||||
Run a portfolio supporting both long and short positions.
|
||||
|
||||
Uses VectorBT's native support for short_entries/short_exits
|
||||
to simulate a single unified portfolio.
|
||||
"""
|
||||
effective_cash = init_cash * leverage
|
||||
|
||||
# If size is passed as value (USD), we don't scale it by leverage here
|
||||
# The backtester has already scaled it by init_cash.
|
||||
# If using 'value', vbt treats it as "Amount of CASH to use for the trade"
|
||||
|
||||
return vbt.Portfolio.from_signals(
|
||||
close=close,
|
||||
entries=long_entries,
|
||||
exits=long_exits,
|
||||
short_entries=short_entries,
|
||||
short_exits=short_exits,
|
||||
init_cash=effective_cash,
|
||||
fees=fees,
|
||||
slippage=slippage,
|
||||
freq=freq,
|
||||
sl_stop=sl_stop,
|
||||
tp_stop=tp_stop,
|
||||
sl_trail=sl_trail,
|
||||
size=size,
|
||||
size_type=size_type,
|
||||
)
|
||||
228
engine/reporting.py
Normal file
228
engine/reporting.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Reporting module for backtest results.
|
||||
|
||||
Handles summary printing, CSV exports, and plotting.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import vectorbt as vbt
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Reporter:
|
||||
"""Reporter for backtest results with market-specific metrics."""
|
||||
|
||||
def __init__(self, output_dir: str = "backtest_logs"):
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def print_summary(self, result) -> None:
|
||||
"""
|
||||
Print backtest summary to console via logger.
|
||||
|
||||
Args:
|
||||
result: BacktestResult or vbt.Portfolio object
|
||||
"""
|
||||
(portfolio, market_type, leverage, funding_paid,
|
||||
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
|
||||
|
||||
# Extract period info
|
||||
idx = portfolio.wrapper.index
|
||||
start_date = idx[0].strftime("%Y-%m-%d")
|
||||
end_date = idx[-1].strftime("%Y-%m-%d")
|
||||
|
||||
# Extract price info
|
||||
close = portfolio.close
|
||||
start_price = close.iloc[0].mean() if hasattr(close.iloc[0], 'mean') else close.iloc[0]
|
||||
end_price = close.iloc[-1].mean() if hasattr(close.iloc[-1], 'mean') else close.iloc[-1]
|
||||
price_change = ((end_price - start_price) / start_price) * 100
|
||||
|
||||
# Extract fees
|
||||
stats = portfolio.stats()
|
||||
total_fees = stats.get('Total Fees Paid', 0)
|
||||
|
||||
raw_return = portfolio.total_return().mean() * 100
|
||||
|
||||
# Build summary
|
||||
summary_lines = [
|
||||
"",
|
||||
"=" * 50,
|
||||
"BACKTEST RESULTS",
|
||||
"=" * 50,
|
||||
f"Market Type: [{market_type.upper()}]",
|
||||
f"Leverage: [{leverage}x]",
|
||||
f"Period: [{start_date} to {end_date}]",
|
||||
f"Price: [{start_price:,.2f} -> {end_price:,.2f} ({price_change:+.2f}%)]",
|
||||
]
|
||||
|
||||
# Show adjusted return if liquidations occurred
|
||||
if liq_count > 0 and adjusted_return is not None:
|
||||
summary_lines.append(f"Raw Return: [%{raw_return:.2f}] (before liq adjustment)")
|
||||
summary_lines.append(f"Adj Return: [%{adjusted_return:.2f}] (after liq losses)")
|
||||
else:
|
||||
summary_lines.append(f"Total Return: [%{raw_return:.2f}]")
|
||||
|
||||
summary_lines.extend([
|
||||
f"Sharpe Ratio: [{portfolio.sharpe_ratio().mean():.2f}]",
|
||||
f"Max Drawdown: [%{portfolio.max_drawdown().mean() * 100:.2f}]",
|
||||
f"Total Trades: [{portfolio.trades.count().mean():.0f}]",
|
||||
f"Win Rate: [%{portfolio.trades.win_rate().mean() * 100:.2f}]",
|
||||
f"Total Fees: [{total_fees:,.2f}]",
|
||||
])
|
||||
|
||||
if funding_paid != 0:
|
||||
summary_lines.append(f"Funding Paid: [{funding_paid:,.2f}]")
|
||||
if liq_count > 0:
|
||||
summary_lines.append(f"Liquidations: [{liq_count}] (${liq_loss:,.2f} margin lost)")
|
||||
|
||||
summary_lines.append("=" * 50)
|
||||
logger.info("\n".join(summary_lines))
|
||||
|
||||
def save_reports(self, result, filename_prefix: str) -> None:
|
||||
"""
|
||||
Save trade log, stats, and liquidation events to CSV files.
|
||||
|
||||
Args:
|
||||
result: BacktestResult or vbt.Portfolio object
|
||||
filename_prefix: Prefix for output filenames
|
||||
"""
|
||||
(portfolio, market_type, leverage, funding_paid,
|
||||
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
|
||||
|
||||
# Save trades
|
||||
self._save_csv(
|
||||
data=portfolio.trades.records_readable,
|
||||
path=self.output_dir / f"{filename_prefix}_trades.csv",
|
||||
description="trade log"
|
||||
)
|
||||
|
||||
# Save stats with market-specific additions
|
||||
stats = portfolio.stats()
|
||||
stats['Market Type'] = market_type
|
||||
stats['Leverage'] = leverage
|
||||
stats['Total Funding Paid'] = funding_paid
|
||||
stats['Liquidations'] = liq_count
|
||||
stats['Liquidation Loss'] = liq_loss
|
||||
if adjusted_return is not None:
|
||||
stats['Adjusted Return'] = adjusted_return
|
||||
|
||||
self._save_csv(
|
||||
data=stats,
|
||||
path=self.output_dir / f"{filename_prefix}_stats.csv",
|
||||
description="stats"
|
||||
)
|
||||
|
||||
# Save liquidation events if any
|
||||
if hasattr(result, 'liquidation_events') and result.liquidation_events:
|
||||
liq_df = pd.DataFrame([
|
||||
{
|
||||
'entry_time': e.entry_time,
|
||||
'entry_price': e.entry_price,
|
||||
'liquidation_time': e.liquidation_time,
|
||||
'liquidation_price': e.liquidation_price,
|
||||
'actual_price': e.actual_price,
|
||||
'direction': e.direction,
|
||||
'margin_lost_pct': e.margin_lost_pct
|
||||
}
|
||||
for e in result.liquidation_events
|
||||
])
|
||||
self._save_csv(
|
||||
data=liq_df,
|
||||
path=self.output_dir / f"{filename_prefix}_liquidations.csv",
|
||||
description="liquidation events"
|
||||
)
|
||||
|
||||
def plot(self, portfolio: vbt.Portfolio, show: bool = True) -> None:
|
||||
"""Display portfolio plot."""
|
||||
if show:
|
||||
portfolio.plot().show()
|
||||
|
||||
def plot_wfa(
|
||||
self,
|
||||
wfa_results: pd.DataFrame,
|
||||
stitched_curve: pd.Series | None = None,
|
||||
show: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Plot Walk-Forward Analysis results.
|
||||
|
||||
Args:
|
||||
wfa_results: DataFrame with WFA window results
|
||||
stitched_curve: Stitched out-of-sample equity curve
|
||||
show: Whether to display the plot
|
||||
"""
|
||||
fig = make_subplots(
|
||||
rows=2, cols=1,
|
||||
shared_xaxes=False,
|
||||
vertical_spacing=0.1,
|
||||
subplot_titles=(
|
||||
"Walk-Forward Test Scores (Sharpe)",
|
||||
"Stitched Out-of-Sample Equity"
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=wfa_results['window'],
|
||||
y=wfa_results['test_score'],
|
||||
name="Test Sharpe"
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
if stitched_curve is not None:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=stitched_curve.index,
|
||||
y=stitched_curve.values,
|
||||
name="Equity",
|
||||
mode='lines'
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
fig.update_layout(height=800, title_text="Walk-Forward Analysis Report")
|
||||
|
||||
if show:
|
||||
fig.show()
|
||||
|
||||
def _extract_result_data(self, result) -> tuple:
|
||||
"""
|
||||
Extract data from BacktestResult or raw Portfolio.
|
||||
|
||||
Returns:
|
||||
Tuple of (portfolio, market_type, leverage, funding_paid, liq_count,
|
||||
liq_loss, adjusted_return)
|
||||
"""
|
||||
if hasattr(result, 'portfolio'):
|
||||
return (
|
||||
result.portfolio,
|
||||
result.market_type.value,
|
||||
result.leverage,
|
||||
result.total_funding_paid,
|
||||
result.liquidation_count,
|
||||
getattr(result, 'total_liquidation_loss', 0.0),
|
||||
getattr(result, 'adjusted_return', None)
|
||||
)
|
||||
return (result, "unknown", 1, 0.0, 0, 0.0, None)
|
||||
|
||||
def _save_csv(self, data, path: Path, description: str) -> None:
|
||||
"""
|
||||
Save data to CSV with consistent error handling.
|
||||
|
||||
Args:
|
||||
data: DataFrame or Series to save
|
||||
path: Output file path
|
||||
description: Human-readable description for logging
|
||||
"""
|
||||
try:
|
||||
data.to_csv(path)
|
||||
logger.info("Saved %s to %s", description, path)
|
||||
except Exception as e:
|
||||
logger.error("Could not save %s: %s", description, e)
|
||||
395
engine/risk.py
Normal file
395
engine/risk.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Risk management utilities for backtesting.
|
||||
|
||||
Handles funding rate calculations and liquidation detection.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketConfig, calculate_liquidation_price
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiquidationEvent:
|
||||
"""
|
||||
Record of a liquidation event during backtesting.
|
||||
|
||||
Attributes:
|
||||
entry_time: Timestamp when position was opened
|
||||
entry_price: Price at position entry
|
||||
liquidation_time: Timestamp when liquidation occurred
|
||||
liquidation_price: Calculated liquidation price
|
||||
actual_price: Actual price that triggered liquidation (high/low)
|
||||
direction: 'long' or 'short'
|
||||
margin_lost_pct: Percentage of margin lost (typically 100%)
|
||||
"""
|
||||
entry_time: pd.Timestamp
|
||||
entry_price: float
|
||||
liquidation_time: pd.Timestamp
|
||||
liquidation_price: float
|
||||
actual_price: float
|
||||
direction: str
|
||||
margin_lost_pct: float = 1.0
|
||||
|
||||
|
||||
def calculate_funding(
|
||||
close: pd.Series,
|
||||
long_entries: pd.DataFrame,
|
||||
long_exits: pd.DataFrame,
|
||||
short_entries: pd.DataFrame,
|
||||
short_exits: pd.DataFrame,
|
||||
market_config: MarketConfig,
|
||||
leverage: int
|
||||
) -> float:
|
||||
"""
|
||||
Calculate total funding paid/received for perpetual positions.
|
||||
|
||||
Simplified model: applies funding rate every 8 hours to open positions.
|
||||
Positive rate means longs pay shorts.
|
||||
|
||||
Args:
|
||||
close: Price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
market_config: Market configuration with funding parameters
|
||||
leverage: Position leverage
|
||||
|
||||
Returns:
|
||||
Total funding paid (positive) or received (negative)
|
||||
"""
|
||||
if market_config.funding_interval_hours == 0:
|
||||
return 0.0
|
||||
|
||||
funding_rate = market_config.funding_rate
|
||||
interval_hours = market_config.funding_interval_hours
|
||||
|
||||
# Determine position state at each bar
|
||||
long_position = long_entries.cumsum() - long_exits.cumsum()
|
||||
short_position = short_entries.cumsum() - short_exits.cumsum()
|
||||
|
||||
# Clamp to 0/1 (either in position or not)
|
||||
long_position = (long_position > 0).astype(int)
|
||||
short_position = (short_position > 0).astype(int)
|
||||
|
||||
# Find funding timestamps (every 8 hours: 00:00, 08:00, 16:00 UTC)
|
||||
funding_times = close.index[close.index.hour % interval_hours == 0]
|
||||
|
||||
total_funding = 0.0
|
||||
for ts in funding_times:
|
||||
if ts not in close.index:
|
||||
continue
|
||||
price = close.loc[ts]
|
||||
|
||||
# Long pays funding, short receives (when rate > 0)
|
||||
if isinstance(long_position, pd.DataFrame):
|
||||
long_open = long_position.loc[ts].any()
|
||||
short_open = short_position.loc[ts].any()
|
||||
else:
|
||||
long_open = long_position.loc[ts] > 0
|
||||
short_open = short_position.loc[ts] > 0
|
||||
|
||||
position_value = price * leverage
|
||||
if long_open:
|
||||
total_funding += position_value * funding_rate
|
||||
if short_open:
|
||||
total_funding -= position_value * funding_rate
|
||||
|
||||
return total_funding
|
||||
|
||||
|
||||
def inject_liquidation_exits(
|
||||
close: pd.Series,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
long_entries: pd.DataFrame | pd.Series,
|
||||
long_exits: pd.DataFrame | pd.Series,
|
||||
short_entries: pd.DataFrame | pd.Series,
|
||||
short_exits: pd.DataFrame | pd.Series,
|
||||
leverage: int,
|
||||
maintenance_margin_rate: float
|
||||
) -> tuple[pd.DataFrame | pd.Series, pd.DataFrame | pd.Series, list[LiquidationEvent]]:
|
||||
"""
|
||||
Modify exit signals to force position closure at liquidation points.
|
||||
|
||||
This function simulates realistic liquidation behavior by:
|
||||
1. Finding positions that would be liquidated before their normal exit
|
||||
2. Injecting forced exit signals at the liquidation bar
|
||||
3. Recording all liquidation events
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
high: High price series
|
||||
low: Low price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
leverage: Position leverage
|
||||
maintenance_margin_rate: Maintenance margin rate for liquidation
|
||||
|
||||
Returns:
|
||||
Tuple of (modified_long_exits, modified_short_exits, liquidation_events)
|
||||
"""
|
||||
if leverage <= 1:
|
||||
return long_exits, short_exits, []
|
||||
|
||||
liquidation_events: list[LiquidationEvent] = []
|
||||
|
||||
# Convert to DataFrame if Series for consistent handling
|
||||
is_series = isinstance(long_entries, pd.Series)
|
||||
if is_series:
|
||||
long_entries_df = long_entries.to_frame()
|
||||
long_exits_df = long_exits.to_frame()
|
||||
short_entries_df = short_entries.to_frame()
|
||||
short_exits_df = short_exits.to_frame()
|
||||
else:
|
||||
long_entries_df = long_entries
|
||||
long_exits_df = long_exits.copy()
|
||||
short_entries_df = short_entries
|
||||
short_exits_df = short_exits.copy()
|
||||
|
||||
modified_long_exits = long_exits_df.copy()
|
||||
modified_short_exits = short_exits_df.copy()
|
||||
|
||||
# Process long positions
|
||||
long_mask = long_entries_df.any(axis=1)
|
||||
for entry_idx in close.index[long_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=True,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
# Find the normal exit for this entry
|
||||
subsequent_exits = long_exits_df.loc[entry_idx:].any(axis=1)
|
||||
exit_indices = subsequent_exits[subsequent_exits].index
|
||||
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
|
||||
|
||||
# Check if liquidation occurs before normal exit
|
||||
price_range = low.loc[entry_idx:normal_exit_idx]
|
||||
if (price_range < liq_price).any():
|
||||
liq_bar = price_range[price_range < liq_price].index[0]
|
||||
|
||||
# Inject forced exit at liquidation bar
|
||||
for col in modified_long_exits.columns:
|
||||
modified_long_exits.loc[liq_bar, col] = True
|
||||
|
||||
# Record the liquidation event
|
||||
liquidation_events.append(LiquidationEvent(
|
||||
entry_time=entry_idx,
|
||||
entry_price=entry_price,
|
||||
liquidation_time=liq_bar,
|
||||
liquidation_price=liq_price,
|
||||
actual_price=low.loc[liq_bar],
|
||||
direction='long',
|
||||
margin_lost_pct=1.0
|
||||
))
|
||||
|
||||
logger.warning(
|
||||
"LIQUIDATION (Long): Entry %s ($%.2f) -> Liquidated %s "
|
||||
"(liq=$%.2f, low=$%.2f)",
|
||||
entry_idx.strftime('%Y-%m-%d'), entry_price,
|
||||
liq_bar.strftime('%Y-%m-%d'), liq_price, low.loc[liq_bar]
|
||||
)
|
||||
|
||||
# Process short positions
|
||||
short_mask = short_entries_df.any(axis=1)
|
||||
for entry_idx in close.index[short_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=False,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
# Find the normal exit for this entry
|
||||
subsequent_exits = short_exits_df.loc[entry_idx:].any(axis=1)
|
||||
exit_indices = subsequent_exits[subsequent_exits].index
|
||||
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
|
||||
|
||||
# Check if liquidation occurs before normal exit
|
||||
price_range = high.loc[entry_idx:normal_exit_idx]
|
||||
if (price_range > liq_price).any():
|
||||
liq_bar = price_range[price_range > liq_price].index[0]
|
||||
|
||||
# Inject forced exit at liquidation bar
|
||||
for col in modified_short_exits.columns:
|
||||
modified_short_exits.loc[liq_bar, col] = True
|
||||
|
||||
# Record the liquidation event
|
||||
liquidation_events.append(LiquidationEvent(
|
||||
entry_time=entry_idx,
|
||||
entry_price=entry_price,
|
||||
liquidation_time=liq_bar,
|
||||
liquidation_price=liq_price,
|
||||
actual_price=high.loc[liq_bar],
|
||||
direction='short',
|
||||
margin_lost_pct=1.0
|
||||
))
|
||||
|
||||
logger.warning(
|
||||
"LIQUIDATION (Short): Entry %s ($%.2f) -> Liquidated %s "
|
||||
"(liq=$%.2f, high=$%.2f)",
|
||||
entry_idx.strftime('%Y-%m-%d'), entry_price,
|
||||
liq_bar.strftime('%Y-%m-%d'), liq_price, high.loc[liq_bar]
|
||||
)
|
||||
|
||||
# Convert back to Series if input was Series
|
||||
if is_series:
|
||||
modified_long_exits = modified_long_exits.iloc[:, 0]
|
||||
modified_short_exits = modified_short_exits.iloc[:, 0]
|
||||
|
||||
return modified_long_exits, modified_short_exits, liquidation_events
|
||||
|
||||
|
||||
def calculate_liquidation_adjustment(
|
||||
liquidation_events: list[LiquidationEvent],
|
||||
init_cash: float,
|
||||
leverage: int
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate the return adjustment for liquidated positions.
|
||||
|
||||
VectorBT calculates trade P&L using close price at exit bar.
|
||||
For liquidations, the actual loss is 100% of the position margin.
|
||||
This function calculates the difference between what VectorBT
|
||||
recorded and what actually would have happened.
|
||||
|
||||
In our portfolio setup:
|
||||
- Long/short each get half the capital (init_cash * leverage / 2)
|
||||
- Each trade uses 100% of that allocation (size=1.0, percent)
|
||||
- On liquidation, the margin for that trade is lost entirely
|
||||
|
||||
The adjustment is the DIFFERENCE between:
|
||||
- VectorBT's calculated P&L (exit at close price)
|
||||
- Actual liquidation P&L (100% margin loss)
|
||||
|
||||
Args:
|
||||
liquidation_events: List of liquidation events
|
||||
init_cash: Initial portfolio cash (before leverage)
|
||||
leverage: Position leverage used
|
||||
|
||||
Returns:
|
||||
Tuple of (total_margin_lost, adjustment_pct)
|
||||
- total_margin_lost: Estimated total margin lost from liquidations
|
||||
- adjustment_pct: Percentage adjustment to apply to returns
|
||||
"""
|
||||
if not liquidation_events:
|
||||
return 0.0, 0.0
|
||||
|
||||
# In our setup, each side (long/short) gets half the capital
|
||||
# Margin per side = init_cash / 2
|
||||
margin_per_side = init_cash / 2
|
||||
|
||||
# For each liquidation, VectorBT recorded some P&L based on close price
|
||||
# The actual P&L should be -100% of the margin used for that trade
|
||||
#
|
||||
# We estimate the adjustment as:
|
||||
# - Each liquidation should have resulted in ~-20% loss (at 5x leverage)
|
||||
# - VectorBT may have recorded a different value
|
||||
# - The margin loss is (1/leverage) per trade that gets liquidated
|
||||
|
||||
# Calculate liquidation loss rate based on leverage
|
||||
# At 5x leverage, liquidation = ~19.6% adverse move = 100% margin loss
|
||||
liq_loss_rate = 1.0 / leverage # Approximate loss per trade as % of position
|
||||
|
||||
# Count liquidations
|
||||
n_liquidations = len(liquidation_events)
|
||||
|
||||
# Estimate total margin lost:
|
||||
# Each liquidation on average loses the margin for that trade
|
||||
# Since VectorBT uses half capital per side, and we trade 100% size,
|
||||
# each liquidation loses approximately margin_per_side
|
||||
# But we cap at available capital
|
||||
total_margin_lost = min(n_liquidations * margin_per_side * liq_loss_rate, init_cash)
|
||||
|
||||
# Calculate as percentage of initial capital
|
||||
adjustment_pct = (total_margin_lost / init_cash) * 100
|
||||
|
||||
return total_margin_lost, adjustment_pct
|
||||
|
||||
|
||||
def check_liquidations(
|
||||
close: pd.Series,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
long_entries: pd.DataFrame,
|
||||
long_exits: pd.DataFrame,
|
||||
short_entries: pd.DataFrame,
|
||||
short_exits: pd.DataFrame,
|
||||
leverage: int,
|
||||
maintenance_margin_rate: float
|
||||
) -> int:
|
||||
"""
|
||||
Check for liquidation events and log warnings.
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
high: High price series
|
||||
low: Low price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
leverage: Position leverage
|
||||
maintenance_margin_rate: Maintenance margin rate for liquidation
|
||||
|
||||
Returns:
|
||||
Count of liquidation warnings
|
||||
"""
|
||||
warnings = 0
|
||||
|
||||
# For long positions
|
||||
long_mask = (
|
||||
long_entries.any(axis=1)
|
||||
if isinstance(long_entries, pd.DataFrame)
|
||||
else long_entries
|
||||
)
|
||||
|
||||
for entry_idx in close.index[long_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=True,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
subsequent = low.loc[entry_idx:]
|
||||
if (subsequent < liq_price).any():
|
||||
liq_bar = subsequent[subsequent < liq_price].index[0]
|
||||
logger.warning(
|
||||
"LIQUIDATION WARNING (Long): Entry at %s ($%.2f), "
|
||||
"would liquidate at %s (liq_price=$%.2f, low=$%.2f)",
|
||||
entry_idx, entry_price, liq_bar, liq_price, low.loc[liq_bar]
|
||||
)
|
||||
warnings += 1
|
||||
|
||||
# For short positions
|
||||
short_mask = (
|
||||
short_entries.any(axis=1)
|
||||
if isinstance(short_entries, pd.DataFrame)
|
||||
else short_entries
|
||||
)
|
||||
|
||||
for entry_idx in close.index[short_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=False,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
subsequent = high.loc[entry_idx:]
|
||||
if (subsequent > liq_price).any():
|
||||
liq_bar = subsequent[subsequent > liq_price].index[0]
|
||||
logger.warning(
|
||||
"LIQUIDATION WARNING (Short): Entry at %s ($%.2f), "
|
||||
"would liquidate at %s (liq_price=$%.2f, high=$%.2f)",
|
||||
entry_idx, entry_price, liq_bar, liq_price, high.loc[liq_bar]
|
||||
)
|
||||
warnings += 1
|
||||
|
||||
return warnings
|
||||
24
frontend/.gitignore
vendored
Normal file
24
frontend/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
3
frontend/.vscode/extensions.json
vendored
Normal file
3
frontend/.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"recommendations": ["Vue.volar"]
|
||||
}
|
||||
5
frontend/README.md
Normal file
5
frontend/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# Vue 3 + TypeScript + Vite
|
||||
|
||||
This template should help get you started developing with Vue 3 and TypeScript in Vite. The template uses Vue 3 `<script setup>` SFCs, check out the [script setup docs](https://v3.vuejs.org/api/sfc-script-setup.html#sfc-script-setup) to learn more.
|
||||
|
||||
Learn more about the recommended Project Setup and IDE Support in the [Vue Docs TypeScript Guide](https://vuejs.org/guide/typescript/overview.html#project-setup).
|
||||
24
frontend/index.html
Normal file
24
frontend/index.html
Normal file
@@ -0,0 +1,24 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Lowkey Backtest</title>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600;700&family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
|
||||
}
|
||||
code, pre, .font-mono, input, select {
|
||||
font-family: 'JetBrains Mono', 'Fira Code', monospace;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
<script type="module" src="/src/main.ts"></script>
|
||||
</body>
|
||||
</html>
|
||||
2427
frontend/package-lock.json
generated
Normal file
2427
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
27
frontend/package.json
Normal file
27
frontend/package.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"name": "frontend",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vue-tsc -b && vite build",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.13.2",
|
||||
"plotly.js-dist-min": "^3.3.1",
|
||||
"vue": "^3.5.24",
|
||||
"vue-router": "^4.6.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tailwindcss/vite": "^4.1.18",
|
||||
"@types/node": "^24.10.1",
|
||||
"@vitejs/plugin-vue": "^6.0.1",
|
||||
"@vue/tsconfig": "^0.8.1",
|
||||
"tailwindcss": "^4.1.18",
|
||||
"typescript": "~5.9.3",
|
||||
"vite": "^7.2.4",
|
||||
"vue-tsc": "^3.1.4"
|
||||
}
|
||||
}
|
||||
1
frontend/public/vite.svg
Normal file
1
frontend/public/vite.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="31.88" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 257"><defs><linearGradient id="IconifyId1813088fe1fbc01fb466" x1="-.828%" x2="57.636%" y1="7.652%" y2="78.411%"><stop offset="0%" stop-color="#41D1FF"></stop><stop offset="100%" stop-color="#BD34FE"></stop></linearGradient><linearGradient id="IconifyId1813088fe1fbc01fb467" x1="43.376%" x2="50.316%" y1="2.242%" y2="89.03%"><stop offset="0%" stop-color="#FFEA83"></stop><stop offset="8.333%" stop-color="#FFDD35"></stop><stop offset="100%" stop-color="#FFA800"></stop></linearGradient></defs><path fill="url(#IconifyId1813088fe1fbc01fb466)" d="M255.153 37.938L134.897 252.976c-2.483 4.44-8.862 4.466-11.382.048L.875 37.958c-2.746-4.814 1.371-10.646 6.827-9.67l120.385 21.517a6.537 6.537 0 0 0 2.322-.004l117.867-21.483c5.438-.991 9.574 4.796 6.877 9.62Z"></path><path fill="url(#IconifyId1813088fe1fbc01fb467)" d="M185.432.063L96.44 17.501a3.268 3.268 0 0 0-2.634 3.014l-5.474 92.456a3.268 3.268 0 0 0 3.997 3.378l24.777-5.718c2.318-.535 4.413 1.507 3.936 3.838l-7.361 36.047c-.495 2.426 1.782 4.5 4.151 3.78l15.304-4.649c2.372-.72 4.652 1.36 4.15 3.788l-11.698 56.621c-.732 3.542 3.979 5.473 5.943 2.437l1.313-2.028l72.516-144.72c1.215-2.423-.88-5.186-3.54-4.672l-25.505 4.922c-2.396.462-4.435-1.77-3.759-4.114l16.646-57.705c.677-2.35-1.37-4.583-3.769-4.113Z"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
72
frontend/src/App.vue
Normal file
72
frontend/src/App.vue
Normal file
@@ -0,0 +1,72 @@
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import { RouterLink, RouterView } from 'vue-router'
|
||||
import RunHistory from '@/components/RunHistory.vue'
|
||||
|
||||
const historyOpen = ref(true)
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="flex h-screen overflow-hidden">
|
||||
<!-- Sidebar Navigation -->
|
||||
<aside class="w-16 bg-bg-secondary border-r border-border flex flex-col items-center py-4 gap-4">
|
||||
<!-- Logo -->
|
||||
<div class="w-10 h-10 rounded-lg bg-accent-blue flex items-center justify-center text-black font-bold text-lg">
|
||||
LB
|
||||
</div>
|
||||
|
||||
<!-- Nav Links -->
|
||||
<nav class="flex flex-col gap-2 mt-4">
|
||||
<RouterLink
|
||||
to="/"
|
||||
class="w-10 h-10 rounded-lg flex items-center justify-center hover:bg-bg-hover transition-colors"
|
||||
:class="{ 'bg-bg-tertiary': $route.path === '/' }"
|
||||
title="Dashboard"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
|
||||
</svg>
|
||||
</RouterLink>
|
||||
|
||||
<RouterLink
|
||||
to="/compare"
|
||||
class="w-10 h-10 rounded-lg flex items-center justify-center hover:bg-bg-hover transition-colors"
|
||||
:class="{ 'bg-bg-tertiary': $route.path === '/compare' }"
|
||||
title="Compare Runs"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 17V7m0 10a2 2 0 01-2 2H5a2 2 0 01-2-2V7a2 2 0 012-2h2a2 2 0 012 2m0 10a2 2 0 002 2h2a2 2 0 002-2M9 7a2 2 0 012-2h2a2 2 0 012 2m0 10V7m0 10a2 2 0 002 2h2a2 2 0 002-2V7a2 2 0 00-2-2h-2a2 2 0 00-2 2" />
|
||||
</svg>
|
||||
</RouterLink>
|
||||
</nav>
|
||||
|
||||
<!-- Spacer -->
|
||||
<div class="flex-1"></div>
|
||||
|
||||
<!-- Toggle History -->
|
||||
<button
|
||||
@click="historyOpen = !historyOpen"
|
||||
class="w-10 h-10 rounded-lg flex items-center justify-center hover:bg-bg-hover transition-colors"
|
||||
:class="{ 'bg-bg-tertiary': historyOpen }"
|
||||
title="Toggle Run History"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
</button>
|
||||
</aside>
|
||||
|
||||
<!-- Main Content -->
|
||||
<main class="flex-1 overflow-auto">
|
||||
<RouterView />
|
||||
</main>
|
||||
|
||||
<!-- Run History Sidebar -->
|
||||
<aside
|
||||
v-if="historyOpen"
|
||||
class="w-72 bg-bg-secondary border-l border-border overflow-hidden flex flex-col"
|
||||
>
|
||||
<RunHistory />
|
||||
</aside>
|
||||
</div>
|
||||
</template>
|
||||
81
frontend/src/api/client.ts
Normal file
81
frontend/src/api/client.ts
Normal file
@@ -0,0 +1,81 @@
|
||||
/**
|
||||
* API client for Lowkey Backtest backend.
|
||||
*/
|
||||
import axios from 'axios'
|
||||
import type {
|
||||
StrategiesResponse,
|
||||
DataStatusResponse,
|
||||
BacktestRequest,
|
||||
BacktestResult,
|
||||
BacktestListResponse,
|
||||
CompareResult,
|
||||
} from './types'
|
||||
|
||||
const api = axios.create({
|
||||
baseURL: '/api',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
/**
|
||||
* Get list of available strategies with parameters.
|
||||
*/
|
||||
export async function getStrategies(): Promise<StrategiesResponse> {
|
||||
const response = await api.get<StrategiesResponse>('/strategies')
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of available symbols with data status.
|
||||
*/
|
||||
export async function getSymbols(): Promise<DataStatusResponse> {
|
||||
const response = await api.get<DataStatusResponse>('/symbols')
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a backtest with the given configuration.
|
||||
*/
|
||||
export async function runBacktest(request: BacktestRequest): Promise<BacktestResult> {
|
||||
const response = await api.post<BacktestResult>('/backtest', request)
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of saved backtest runs.
|
||||
*/
|
||||
export async function getBacktests(params?: {
|
||||
limit?: number
|
||||
offset?: number
|
||||
strategy?: string
|
||||
symbol?: string
|
||||
}): Promise<BacktestListResponse> {
|
||||
const response = await api.get<BacktestListResponse>('/backtests', { params })
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a specific backtest run by ID.
|
||||
*/
|
||||
export async function getBacktest(runId: string): Promise<BacktestResult> {
|
||||
const response = await api.get<BacktestResult>(`/backtest/${runId}`)
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a backtest run.
|
||||
*/
|
||||
export async function deleteBacktest(runId: string): Promise<void> {
|
||||
await api.delete(`/backtest/${runId}`)
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare multiple backtest runs.
|
||||
*/
|
||||
export async function compareRuns(runIds: string[]): Promise<CompareResult> {
|
||||
const response = await api.post<CompareResult>('/compare', { run_ids: runIds })
|
||||
return response.data
|
||||
}
|
||||
|
||||
export default api
|
||||
131
frontend/src/api/types.ts
Normal file
131
frontend/src/api/types.ts
Normal file
@@ -0,0 +1,131 @@
|
||||
/**
|
||||
* TypeScript types matching the FastAPI Pydantic schemas.
|
||||
*/
|
||||
|
||||
// Strategy types
|
||||
export interface StrategyInfo {
|
||||
name: string
|
||||
display_name: string
|
||||
market_type: string
|
||||
default_leverage: number
|
||||
default_params: Record<string, unknown>
|
||||
grid_params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface StrategiesResponse {
|
||||
strategies: StrategyInfo[]
|
||||
}
|
||||
|
||||
// Symbol/Data types
|
||||
export interface SymbolInfo {
|
||||
symbol: string
|
||||
exchange: string
|
||||
market_type: string
|
||||
timeframes: string[]
|
||||
start_date: string | null
|
||||
end_date: string | null
|
||||
row_count: number
|
||||
}
|
||||
|
||||
export interface DataStatusResponse {
|
||||
symbols: SymbolInfo[]
|
||||
}
|
||||
|
||||
// Backtest types
|
||||
export interface BacktestRequest {
|
||||
strategy: string
|
||||
symbol: string
|
||||
exchange?: string
|
||||
timeframe?: string
|
||||
market_type?: string
|
||||
start_date?: string | null
|
||||
end_date?: string | null
|
||||
init_cash?: number
|
||||
leverage?: number | null
|
||||
fees?: number | null
|
||||
slippage?: number
|
||||
sl_stop?: number | null
|
||||
tp_stop?: number | null
|
||||
sl_trail?: boolean
|
||||
params?: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface TradeRecord {
|
||||
entry_time: string
|
||||
exit_time: string | null
|
||||
entry_price: number
|
||||
exit_price: number | null
|
||||
size: number
|
||||
direction: string
|
||||
pnl: number | null
|
||||
return_pct: number | null
|
||||
status: string
|
||||
}
|
||||
|
||||
export interface EquityPoint {
|
||||
timestamp: string
|
||||
value: number
|
||||
drawdown: number
|
||||
}
|
||||
|
||||
export interface BacktestMetrics {
|
||||
total_return: number
|
||||
benchmark_return: number
|
||||
alpha: number
|
||||
sharpe_ratio: number
|
||||
max_drawdown: number
|
||||
win_rate: number
|
||||
total_trades: number
|
||||
profit_factor: number | null
|
||||
avg_trade_return: number | null
|
||||
total_fees: number
|
||||
total_funding: number
|
||||
liquidation_count: number
|
||||
liquidation_loss: number
|
||||
adjusted_return: number | null
|
||||
}
|
||||
|
||||
export interface BacktestResult {
|
||||
run_id: string
|
||||
strategy: string
|
||||
symbol: string
|
||||
market_type: string
|
||||
timeframe: string
|
||||
start_date: string
|
||||
end_date: string
|
||||
leverage: number
|
||||
params: Record<string, unknown>
|
||||
metrics: BacktestMetrics
|
||||
equity_curve: EquityPoint[]
|
||||
trades: TradeRecord[]
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface BacktestSummary {
|
||||
run_id: string
|
||||
strategy: string
|
||||
symbol: string
|
||||
market_type: string
|
||||
timeframe: string
|
||||
total_return: number
|
||||
sharpe_ratio: number
|
||||
max_drawdown: number
|
||||
total_trades: number
|
||||
created_at: string
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface BacktestListResponse {
|
||||
runs: BacktestSummary[]
|
||||
total: number
|
||||
}
|
||||
|
||||
// Comparison types
|
||||
export interface CompareRequest {
|
||||
run_ids: string[]
|
||||
}
|
||||
|
||||
export interface CompareResult {
|
||||
runs: BacktestResult[]
|
||||
param_diff: Record<string, unknown[]>
|
||||
}
|
||||
1
frontend/src/assets/vue.svg
Normal file
1
frontend/src/assets/vue.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="37.07" height="36" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 198"><path fill="#41B883" d="M204.8 0H256L128 220.8L0 0h97.92L128 51.2L157.44 0h47.36Z"></path><path fill="#41B883" d="m0 0l128 220.8L256 0h-51.2L128 132.48L50.56 0H0Z"></path><path fill="#35495E" d="M50.56 0L128 133.12L204.8 0h-47.36L128 51.2L97.92 0H50.56Z"></path></svg>
|
||||
|
After Width: | Height: | Size: 496 B |
186
frontend/src/components/BacktestConfig.vue
Normal file
186
frontend/src/components/BacktestConfig.vue
Normal file
@@ -0,0 +1,186 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, watch, onMounted } from 'vue'
|
||||
import { useBacktest } from '@/composables/useBacktest'
|
||||
import type { BacktestRequest } from '@/api/types'
|
||||
|
||||
const { strategies, symbols, loading, init, executeBacktest } = useBacktest()
|
||||
|
||||
// Form state
|
||||
const selectedStrategy = ref('')
|
||||
const selectedSymbol = ref('')
|
||||
const selectedMarket = ref('perpetual')
|
||||
const timeframe = ref('1h')
|
||||
const initCash = ref(10000)
|
||||
const leverage = ref<number | null>(null)
|
||||
const slStop = ref<number | null>(null)
|
||||
const tpStop = ref<number | null>(null)
|
||||
const params = ref<Record<string, number | boolean>>({})
|
||||
|
||||
// Initialize
|
||||
onMounted(async () => {
|
||||
await init()
|
||||
if (strategies.value.length > 0 && strategies.value[0]) {
|
||||
selectedStrategy.value = strategies.value[0].name
|
||||
}
|
||||
})
|
||||
|
||||
// Get current strategy config
|
||||
const currentStrategy = computed(() =>
|
||||
strategies.value.find(s => s.name === selectedStrategy.value)
|
||||
)
|
||||
|
||||
// Filter symbols by market type
|
||||
const filteredSymbols = computed(() =>
|
||||
symbols.value.filter(s => s.market_type === selectedMarket.value)
|
||||
)
|
||||
|
||||
// Update params when strategy changes
|
||||
watch(selectedStrategy, (name) => {
|
||||
const strategy = strategies.value.find(s => s.name === name)
|
||||
if (strategy) {
|
||||
params.value = { ...strategy.default_params } as Record<string, number | boolean>
|
||||
selectedMarket.value = strategy.market_type
|
||||
leverage.value = strategy.default_leverage > 1 ? strategy.default_leverage : null
|
||||
}
|
||||
})
|
||||
|
||||
// Update symbol when market changes
|
||||
watch([filteredSymbols, selectedMarket], () => {
|
||||
const firstSymbol = filteredSymbols.value[0]
|
||||
if (filteredSymbols.value.length > 0 && firstSymbol && !filteredSymbols.value.find(s => s.symbol === selectedSymbol.value)) {
|
||||
selectedSymbol.value = firstSymbol.symbol
|
||||
}
|
||||
})
|
||||
|
||||
async function handleSubmit() {
|
||||
if (!selectedStrategy.value || !selectedSymbol.value) return
|
||||
|
||||
const request: BacktestRequest = {
|
||||
strategy: selectedStrategy.value,
|
||||
symbol: selectedSymbol.value,
|
||||
market_type: selectedMarket.value,
|
||||
timeframe: timeframe.value,
|
||||
init_cash: initCash.value,
|
||||
leverage: leverage.value,
|
||||
sl_stop: slStop.value,
|
||||
tp_stop: tpStop.value,
|
||||
params: params.value,
|
||||
}
|
||||
|
||||
await executeBacktest(request)
|
||||
}
|
||||
|
||||
function getParamType(value: unknown): 'number' | 'boolean' | 'unknown' {
|
||||
if (typeof value === 'boolean') return 'boolean'
|
||||
if (typeof value === 'number') return 'number'
|
||||
return 'unknown'
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="card">
|
||||
<h2 class="text-lg font-semibold mb-4">Backtest Configuration</h2>
|
||||
|
||||
<form @submit.prevent="handleSubmit" class="space-y-4">
|
||||
<!-- Strategy -->
|
||||
<div>
|
||||
<label class="block text-xs text-text-secondary uppercase mb-1">Strategy</label>
|
||||
<select v-model="selectedStrategy" class="w-full">
|
||||
<option v-for="s in strategies" :key="s.name" :value="s.name">
|
||||
{{ s.display_name }}
|
||||
</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- Market Type & Symbol -->
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<div>
|
||||
<label class="block text-xs text-text-secondary uppercase mb-1">Market</label>
|
||||
<select v-model="selectedMarket" class="w-full">
|
||||
<option value="spot">Spot</option>
|
||||
<option value="perpetual">Perpetual</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-xs text-text-secondary uppercase mb-1">Symbol</label>
|
||||
<select v-model="selectedSymbol" class="w-full">
|
||||
<option v-for="s in filteredSymbols" :key="s.symbol" :value="s.symbol">
|
||||
{{ s.symbol }}
|
||||
</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Timeframe & Cash -->
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<div>
|
||||
<label class="block text-xs text-text-secondary uppercase mb-1">Timeframe</label>
|
||||
<select v-model="timeframe" class="w-full">
|
||||
<option value="1h">1 Hour</option>
|
||||
<option value="4h">4 Hours</option>
|
||||
<option value="1d">1 Day</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-xs text-text-secondary uppercase mb-1">Initial Cash</label>
|
||||
<input type="number" v-model.number="initCash" class="w-full" min="100" step="100" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Leverage (perpetual only) -->
|
||||
<div v-if="selectedMarket === 'perpetual'" class="grid grid-cols-3 gap-3">
|
||||
<div>
|
||||
<label class="block text-xs text-text-secondary uppercase mb-1">Leverage</label>
|
||||
<input type="number" v-model.number="leverage" class="w-full" min="1" max="100" placeholder="1" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-xs text-text-secondary uppercase mb-1">Stop Loss %</label>
|
||||
<input type="number" v-model.number="slStop" class="w-full" min="0" max="100" step="0.1" placeholder="None" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-xs text-text-secondary uppercase mb-1">Take Profit %</label>
|
||||
<input type="number" v-model.number="tpStop" class="w-full" min="0" max="100" step="0.1" placeholder="None" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Strategy Parameters -->
|
||||
<div v-if="currentStrategy && Object.keys(params).length > 0">
|
||||
<h3 class="text-sm font-medium text-text-secondary mb-2">Strategy Parameters</h3>
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<div v-for="(value, key) in params" :key="key">
|
||||
<label class="block text-xs text-text-secondary uppercase mb-1">
|
||||
{{ String(key).replace(/_/g, ' ') }}
|
||||
</label>
|
||||
<template v-if="getParamType(value) === 'boolean'">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="Boolean(value)"
|
||||
@change="params[key] = ($event.target as HTMLInputElement).checked"
|
||||
class="w-5 h-5"
|
||||
/>
|
||||
</template>
|
||||
<template v-else>
|
||||
<input
|
||||
type="number"
|
||||
:value="value"
|
||||
@input="params[key] = parseFloat(($event.target as HTMLInputElement).value)"
|
||||
class="w-full"
|
||||
step="any"
|
||||
/>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Submit -->
|
||||
<button
|
||||
type="submit"
|
||||
class="btn btn-primary w-full"
|
||||
:disabled="loading || !selectedStrategy || !selectedSymbol"
|
||||
>
|
||||
<span v-if="loading" class="spinner"></span>
|
||||
<span v-else>Run Backtest</span>
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
</template>
|
||||
88
frontend/src/components/EquityCurve.vue
Normal file
88
frontend/src/components/EquityCurve.vue
Normal file
@@ -0,0 +1,88 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, onMounted, onUnmounted } from 'vue'
|
||||
import Plotly from 'plotly.js-dist-min'
|
||||
import type { EquityPoint } from '@/api/types'
|
||||
|
||||
const props = defineProps<{
|
||||
data: EquityPoint[]
|
||||
title?: string
|
||||
}>()
|
||||
|
||||
const chartRef = ref<HTMLDivElement | null>(null)
|
||||
|
||||
const CHART_COLORS = {
|
||||
equity: '#58a6ff',
|
||||
grid: '#30363d',
|
||||
text: '#8b949e',
|
||||
}
|
||||
|
||||
function renderChart() {
|
||||
if (!chartRef.value || props.data.length === 0) return
|
||||
|
||||
const timestamps = props.data.map(p => p.timestamp)
|
||||
const values = props.data.map(p => p.value)
|
||||
|
||||
const traces: Plotly.Data[] = [
|
||||
{
|
||||
x: timestamps,
|
||||
y: values,
|
||||
type: 'scatter',
|
||||
mode: 'lines',
|
||||
name: 'Portfolio Value',
|
||||
line: { color: CHART_COLORS.equity, width: 2 },
|
||||
hovertemplate: '%{x}<br>Value: $%{y:,.2f}<extra></extra>',
|
||||
},
|
||||
]
|
||||
|
||||
const layout: Partial<Plotly.Layout> = {
|
||||
title: props.title ? {
|
||||
text: props.title,
|
||||
font: { color: CHART_COLORS.text, size: 14 },
|
||||
} : undefined,
|
||||
paper_bgcolor: 'transparent',
|
||||
plot_bgcolor: 'transparent',
|
||||
margin: { l: 60, r: 20, t: props.title ? 40 : 20, b: 40 },
|
||||
xaxis: {
|
||||
showgrid: true,
|
||||
gridcolor: CHART_COLORS.grid,
|
||||
tickfont: { color: CHART_COLORS.text, size: 10 },
|
||||
linecolor: CHART_COLORS.grid,
|
||||
},
|
||||
yaxis: {
|
||||
showgrid: true,
|
||||
gridcolor: CHART_COLORS.grid,
|
||||
tickfont: { color: CHART_COLORS.text, size: 10 },
|
||||
linecolor: CHART_COLORS.grid,
|
||||
tickprefix: '$',
|
||||
hoverformat: ',.2f',
|
||||
},
|
||||
showlegend: false,
|
||||
hovermode: 'x unified',
|
||||
}
|
||||
|
||||
const config: Partial<Plotly.Config> = {
|
||||
responsive: true,
|
||||
displayModeBar: false,
|
||||
}
|
||||
|
||||
Plotly.react(chartRef.value, traces, layout, config)
|
||||
}
|
||||
|
||||
watch(() => props.data, renderChart, { deep: true })
|
||||
|
||||
onMounted(() => {
|
||||
renderChart()
|
||||
window.addEventListener('resize', renderChart)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
window.removeEventListener('resize', renderChart)
|
||||
if (chartRef.value) {
|
||||
Plotly.purge(chartRef.value)
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div ref="chartRef" class="w-full h-full min-h-[300px]"></div>
|
||||
</template>
|
||||
41
frontend/src/components/HelloWorld.vue
Normal file
41
frontend/src/components/HelloWorld.vue
Normal file
@@ -0,0 +1,41 @@
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
|
||||
defineProps<{ msg: string }>()
|
||||
|
||||
const count = ref(0)
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<h1>{{ msg }}</h1>
|
||||
|
||||
<div class="card">
|
||||
<button type="button" @click="count++">count is {{ count }}</button>
|
||||
<p>
|
||||
Edit
|
||||
<code>components/HelloWorld.vue</code> to test HMR
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<p>
|
||||
Check out
|
||||
<a href="https://vuejs.org/guide/quick-start.html#local" target="_blank"
|
||||
>create-vue</a
|
||||
>, the official Vue + Vite starter
|
||||
</p>
|
||||
<p>
|
||||
Learn more about IDE Support for Vue in the
|
||||
<a
|
||||
href="https://vuejs.org/guide/scaling-up/tooling.html#ide-support"
|
||||
target="_blank"
|
||||
>Vue Docs Scaling up Guide</a
|
||||
>.
|
||||
</p>
|
||||
<p class="read-the-docs">Click on the Vite and Vue logos to learn more</p>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.read-the-docs {
|
||||
color: #888;
|
||||
}
|
||||
</style>
|
||||
144
frontend/src/components/MetricsPanel.vue
Normal file
144
frontend/src/components/MetricsPanel.vue
Normal file
@@ -0,0 +1,144 @@
|
||||
<script setup lang="ts">
|
||||
import type { BacktestMetrics } from '@/api/types'
|
||||
|
||||
const props = defineProps<{
|
||||
metrics: BacktestMetrics
|
||||
leverage?: number
|
||||
marketType?: string
|
||||
}>()
|
||||
|
||||
function formatPercent(val: number): string {
|
||||
return (val >= 0 ? '+' : '') + val.toFixed(2) + '%'
|
||||
}
|
||||
|
||||
function formatNumber(val: number | null | undefined, decimals = 2): string {
|
||||
if (val === null || val === undefined) return '-'
|
||||
return val.toFixed(decimals)
|
||||
}
|
||||
|
||||
function formatCurrency(val: number): string {
|
||||
return '$' + val.toLocaleString('en-US', { minimumFractionDigits: 2, maximumFractionDigits: 2 })
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="grid grid-cols-2 md:grid-cols-4 gap-4">
|
||||
<!-- Total Return -->
|
||||
<div class="card">
|
||||
<div class="metric-label">Strategy Return</div>
|
||||
<div
|
||||
class="metric-value"
|
||||
:class="metrics.total_return >= 0 ? 'profit' : 'loss'"
|
||||
>
|
||||
{{ formatPercent(metrics.total_return) }}
|
||||
</div>
|
||||
<div v-if="metrics.adjusted_return !== null && metrics.adjusted_return !== metrics.total_return" class="text-xs text-text-muted mt-1">
|
||||
Adj: {{ formatPercent(metrics.adjusted_return) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Benchmark Return -->
|
||||
<div class="card">
|
||||
<div class="metric-label">Benchmark (B&H)</div>
|
||||
<div
|
||||
class="metric-value"
|
||||
:class="metrics.benchmark_return >= 0 ? 'profit' : 'loss'"
|
||||
>
|
||||
{{ formatPercent(metrics.benchmark_return) }}
|
||||
</div>
|
||||
<div class="text-xs text-text-muted mt-1">
|
||||
Market change
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Alpha -->
|
||||
<div class="card">
|
||||
<div class="metric-label">Alpha</div>
|
||||
<div
|
||||
class="metric-value"
|
||||
:class="metrics.alpha >= 0 ? 'profit' : 'loss'"
|
||||
>
|
||||
{{ formatPercent(metrics.alpha) }}
|
||||
</div>
|
||||
<div class="text-xs text-text-muted mt-1">
|
||||
vs Buy & Hold
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Sharpe Ratio -->
|
||||
<div class="card">
|
||||
<div class="metric-label">Sharpe Ratio</div>
|
||||
<div
|
||||
class="metric-value"
|
||||
:class="metrics.sharpe_ratio >= 1 ? 'profit' : metrics.sharpe_ratio < 0 ? 'loss' : ''"
|
||||
>
|
||||
{{ formatNumber(metrics.sharpe_ratio) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Max Drawdown -->
|
||||
<div class="card">
|
||||
<div class="metric-label">Max Drawdown</div>
|
||||
<div class="metric-value loss">
|
||||
{{ formatPercent(metrics.max_drawdown) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Win Rate -->
|
||||
<div class="card">
|
||||
<div class="metric-label">Win Rate</div>
|
||||
<div
|
||||
class="metric-value"
|
||||
:class="metrics.win_rate >= 50 ? 'profit' : 'loss'"
|
||||
>
|
||||
{{ formatNumber(metrics.win_rate, 1) }}%
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Total Trades -->
|
||||
<div class="card">
|
||||
<div class="metric-label">Total Trades</div>
|
||||
<div class="metric-value">
|
||||
{{ metrics.total_trades }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Profit Factor -->
|
||||
<div class="card">
|
||||
<div class="metric-label">Profit Factor</div>
|
||||
<div
|
||||
class="metric-value"
|
||||
:class="(metrics.profit_factor || 0) >= 1 ? 'profit' : 'loss'"
|
||||
>
|
||||
{{ formatNumber(metrics.profit_factor) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Total Fees -->
|
||||
<div class="card">
|
||||
<div class="metric-label">Total Fees</div>
|
||||
<div class="metric-value text-warning">
|
||||
{{ formatCurrency(metrics.total_fees) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Funding (perpetual only) -->
|
||||
<div v-if="marketType === 'perpetual'" class="card">
|
||||
<div class="metric-label">Funding Paid</div>
|
||||
<div class="metric-value text-warning">
|
||||
{{ formatCurrency(metrics.total_funding) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Liquidations (if any) -->
|
||||
<div v-if="metrics.liquidation_count > 0" class="card">
|
||||
<div class="metric-label">Liquidations</div>
|
||||
<div class="metric-value loss">
|
||||
{{ metrics.liquidation_count }}
|
||||
</div>
|
||||
<div class="text-xs text-text-muted mt-1">
|
||||
Lost: {{ formatCurrency(metrics.liquidation_loss) }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
141
frontend/src/components/RunHistory.vue
Normal file
141
frontend/src/components/RunHistory.vue
Normal file
@@ -0,0 +1,141 @@
|
||||
<script setup lang="ts">
|
||||
import { onMounted } from 'vue'
|
||||
import { useBacktest } from '@/composables/useBacktest'
|
||||
import { useRouter } from 'vue-router'
|
||||
|
||||
const router = useRouter()
|
||||
const {
|
||||
runs,
|
||||
currentResult,
|
||||
selectedRuns,
|
||||
refreshRuns,
|
||||
loadRun,
|
||||
removeRun,
|
||||
toggleRunSelection
|
||||
} = useBacktest()
|
||||
|
||||
onMounted(() => {
|
||||
refreshRuns()
|
||||
})
|
||||
|
||||
function formatDate(iso: string): string {
|
||||
const d = new Date(iso)
|
||||
return d.toLocaleDateString('en-US', {
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
})
|
||||
}
|
||||
|
||||
function formatReturn(val: number): string {
|
||||
return (val >= 0 ? '+' : '') + val.toFixed(2) + '%'
|
||||
}
|
||||
|
||||
async function handleClick(runId: string) {
|
||||
await loadRun(runId)
|
||||
router.push('/')
|
||||
}
|
||||
|
||||
function handleCheckbox(e: Event, runId: string) {
|
||||
e.stopPropagation()
|
||||
toggleRunSelection(runId)
|
||||
}
|
||||
|
||||
function handleDelete(e: Event, runId: string) {
|
||||
e.stopPropagation()
|
||||
if (confirm('Delete this run?')) {
|
||||
removeRun(runId)
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Header -->
|
||||
<div class="p-4 border-b border-border">
|
||||
<h2 class="text-sm font-semibold text-text-secondary uppercase tracking-wide">
|
||||
Run History
|
||||
</h2>
|
||||
<p class="text-xs text-text-muted mt-1">
|
||||
{{ runs.length }} runs | {{ selectedRuns.length }} selected
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Run List -->
|
||||
<div class="flex-1 overflow-y-auto">
|
||||
<div
|
||||
v-for="run in runs"
|
||||
:key="run.run_id"
|
||||
@click="handleClick(run.run_id)"
|
||||
class="p-3 border-b border-border-muted cursor-pointer hover:bg-bg-hover transition-colors"
|
||||
:class="{ 'bg-bg-tertiary': currentResult?.run_id === run.run_id }"
|
||||
>
|
||||
<div class="flex items-start gap-2">
|
||||
<!-- Checkbox for comparison -->
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="selectedRuns.includes(run.run_id)"
|
||||
@click="handleCheckbox($event, run.run_id)"
|
||||
class="mt-1 w-4 h-4 rounded border-border bg-bg-tertiary"
|
||||
/>
|
||||
|
||||
<div class="flex-1 min-w-0">
|
||||
<!-- Strategy & Symbol -->
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="font-medium text-sm truncate">{{ run.strategy }}</span>
|
||||
<span class="text-xs px-1.5 py-0.5 rounded bg-bg-tertiary text-text-secondary">
|
||||
{{ run.symbol }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Metrics -->
|
||||
<div class="flex items-center gap-3 mt-1">
|
||||
<span
|
||||
class="text-sm font-mono"
|
||||
:class="run.total_return >= 0 ? 'profit' : 'loss'"
|
||||
>
|
||||
{{ formatReturn(run.total_return) }}
|
||||
</span>
|
||||
<span class="text-xs text-text-muted">
|
||||
SR {{ run.sharpe_ratio.toFixed(2) }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Date -->
|
||||
<div class="text-xs text-text-muted mt-1">
|
||||
{{ formatDate(run.created_at) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
@click="handleDelete($event, run.run_id)"
|
||||
class="p-1 rounded hover:bg-loss/20 text-text-muted hover:text-loss transition-colors"
|
||||
title="Delete run"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Empty state -->
|
||||
<div v-if="runs.length === 0" class="p-8 text-center text-text-muted">
|
||||
<p>No runs yet.</p>
|
||||
<p class="text-xs mt-1">Run a backtest to see results here.</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Compare Button -->
|
||||
<div v-if="selectedRuns.length >= 2" class="p-4 border-t border-border">
|
||||
<router-link
|
||||
to="/compare"
|
||||
class="btn btn-primary w-full"
|
||||
>
|
||||
Compare {{ selectedRuns.length }} Runs
|
||||
</router-link>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
157
frontend/src/components/TradeLog.vue
Normal file
157
frontend/src/components/TradeLog.vue
Normal file
@@ -0,0 +1,157 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import type { TradeRecord } from '@/api/types'
|
||||
|
||||
const props = defineProps<{
|
||||
trades: TradeRecord[]
|
||||
}>()
|
||||
|
||||
type SortKey = 'entry_time' | 'pnl' | 'return_pct' | 'size'
|
||||
const sortKey = ref<SortKey>('entry_time')
|
||||
const sortDesc = ref(true)
|
||||
|
||||
const sortedTrades = computed(() => {
|
||||
return [...props.trades].sort((a, b) => {
|
||||
let aVal: number | string = 0
|
||||
let bVal: number | string = 0
|
||||
|
||||
switch (sortKey.value) {
|
||||
case 'entry_time':
|
||||
aVal = a.entry_time
|
||||
bVal = b.entry_time
|
||||
break
|
||||
case 'pnl':
|
||||
aVal = a.pnl ?? 0
|
||||
bVal = b.pnl ?? 0
|
||||
break
|
||||
case 'return_pct':
|
||||
aVal = a.return_pct ?? 0
|
||||
bVal = b.return_pct ?? 0
|
||||
break
|
||||
case 'size':
|
||||
aVal = a.size
|
||||
bVal = b.size
|
||||
break
|
||||
}
|
||||
|
||||
if (aVal < bVal) return sortDesc.value ? 1 : -1
|
||||
if (aVal > bVal) return sortDesc.value ? -1 : 1
|
||||
return 0
|
||||
})
|
||||
})
|
||||
|
||||
function toggleSort(key: SortKey) {
|
||||
if (sortKey.value === key) {
|
||||
sortDesc.value = !sortDesc.value
|
||||
} else {
|
||||
sortKey.value = key
|
||||
sortDesc.value = true
|
||||
}
|
||||
}
|
||||
|
||||
function formatDate(iso: string): string {
|
||||
if (!iso) return '-'
|
||||
const d = new Date(iso)
|
||||
return d.toLocaleDateString('en-US', {
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
})
|
||||
}
|
||||
|
||||
function formatPrice(val: number | null): string {
|
||||
if (val === null) return '-'
|
||||
return val.toLocaleString('en-US', { minimumFractionDigits: 2, maximumFractionDigits: 2 })
|
||||
}
|
||||
|
||||
function formatPnL(val: number | null): string {
|
||||
if (val === null) return '-'
|
||||
const sign = val >= 0 ? '+' : ''
|
||||
return sign + val.toLocaleString('en-US', { minimumFractionDigits: 2, maximumFractionDigits: 2 })
|
||||
}
|
||||
|
||||
function formatReturn(val: number | null): string {
|
||||
if (val === null) return '-'
|
||||
return (val >= 0 ? '+' : '') + val.toFixed(2) + '%'
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="card overflow-hidden">
|
||||
<div class="flex items-center justify-between mb-4">
|
||||
<h3 class="text-sm font-semibold text-text-secondary uppercase tracking-wide">
|
||||
Trade Log
|
||||
</h3>
|
||||
<span class="text-xs text-text-muted">{{ trades.length }} trades</span>
|
||||
</div>
|
||||
|
||||
<div class="overflow-x-auto max-h-[400px] overflow-y-auto">
|
||||
<table class="min-w-full">
|
||||
<thead class="sticky top-0 bg-bg-card">
|
||||
<tr>
|
||||
<th
|
||||
@click="toggleSort('entry_time')"
|
||||
class="cursor-pointer hover:text-text-primary"
|
||||
>
|
||||
Entry Time
|
||||
<span v-if="sortKey === 'entry_time'">{{ sortDesc ? ' v' : ' ^' }}</span>
|
||||
</th>
|
||||
<th>Exit Time</th>
|
||||
<th>Direction</th>
|
||||
<th>Entry</th>
|
||||
<th>Exit</th>
|
||||
<th
|
||||
@click="toggleSort('size')"
|
||||
class="cursor-pointer hover:text-text-primary"
|
||||
>
|
||||
Size
|
||||
<span v-if="sortKey === 'size'">{{ sortDesc ? ' v' : ' ^' }}</span>
|
||||
</th>
|
||||
<th
|
||||
@click="toggleSort('pnl')"
|
||||
class="cursor-pointer hover:text-text-primary"
|
||||
>
|
||||
PnL
|
||||
<span v-if="sortKey === 'pnl'">{{ sortDesc ? ' v' : ' ^' }}</span>
|
||||
</th>
|
||||
<th
|
||||
@click="toggleSort('return_pct')"
|
||||
class="cursor-pointer hover:text-text-primary"
|
||||
>
|
||||
Return
|
||||
<span v-if="sortKey === 'return_pct'">{{ sortDesc ? ' v' : ' ^' }}</span>
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="(trade, idx) in sortedTrades" :key="idx">
|
||||
<td class="text-text-secondary">{{ formatDate(trade.entry_time) }}</td>
|
||||
<td class="text-text-secondary">{{ formatDate(trade.exit_time || '') }}</td>
|
||||
<td>
|
||||
<span
|
||||
class="px-2 py-0.5 rounded text-xs font-medium"
|
||||
:class="trade.direction === 'Long' ? 'bg-profit/20 text-profit' : 'bg-loss/20 text-loss'"
|
||||
>
|
||||
{{ trade.direction }}
|
||||
</span>
|
||||
</td>
|
||||
<td>${{ formatPrice(trade.entry_price) }}</td>
|
||||
<td>${{ formatPrice(trade.exit_price) }}</td>
|
||||
<td>{{ trade.size.toFixed(4) }}</td>
|
||||
<td :class="(trade.pnl ?? 0) >= 0 ? 'profit' : 'loss'">
|
||||
${{ formatPnL(trade.pnl) }}
|
||||
</td>
|
||||
<td :class="(trade.return_pct ?? 0) >= 0 ? 'profit' : 'loss'">
|
||||
{{ formatReturn(trade.return_pct) }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<div v-if="trades.length === 0" class="p-8 text-center text-text-muted">
|
||||
No trades executed.
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
150
frontend/src/composables/useBacktest.ts
Normal file
150
frontend/src/composables/useBacktest.ts
Normal file
@@ -0,0 +1,150 @@
|
||||
/**
|
||||
* Composable for managing backtest state across components.
|
||||
*/
|
||||
import { ref, computed } from 'vue'
|
||||
import type { BacktestResult, BacktestSummary, StrategyInfo, SymbolInfo } from '@/api/types'
|
||||
import { getStrategies, getSymbols, getBacktests, getBacktest, runBacktest, deleteBacktest } from '@/api/client'
|
||||
import type { BacktestRequest } from '@/api/types'
|
||||
|
||||
// Shared state
|
||||
const strategies = ref<StrategyInfo[]>([])
|
||||
const symbols = ref<SymbolInfo[]>([])
|
||||
const runs = ref<BacktestSummary[]>([])
|
||||
const currentResult = ref<BacktestResult | null>(null)
|
||||
const selectedRuns = ref<string[]>([])
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
// Computed
|
||||
const symbolsByMarket = computed(() => {
|
||||
const grouped: Record<string, SymbolInfo[]> = {}
|
||||
for (const s of symbols.value) {
|
||||
const key = `${s.market_type}`
|
||||
if (!grouped[key]) grouped[key] = []
|
||||
grouped[key].push(s)
|
||||
}
|
||||
return grouped
|
||||
})
|
||||
|
||||
export function useBacktest() {
|
||||
/**
|
||||
* Load strategies and symbols on app init.
|
||||
*/
|
||||
async function init() {
|
||||
try {
|
||||
const [stratRes, symRes] = await Promise.all([
|
||||
getStrategies(),
|
||||
getSymbols(),
|
||||
])
|
||||
strategies.value = stratRes.strategies
|
||||
symbols.value = symRes.symbols
|
||||
} catch (e) {
|
||||
error.value = `Failed to load initial data: ${e}`
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh the run history list.
|
||||
*/
|
||||
async function refreshRuns() {
|
||||
try {
|
||||
const res = await getBacktests({ limit: 100 })
|
||||
runs.value = res.runs
|
||||
} catch (e) {
|
||||
error.value = `Failed to load runs: ${e}`
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a new backtest.
|
||||
*/
|
||||
async function executeBacktest(request: BacktestRequest) {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const result = await runBacktest(request)
|
||||
currentResult.value = result
|
||||
await refreshRuns()
|
||||
return result
|
||||
} catch (e: unknown) {
|
||||
const msg = e instanceof Error ? e.message : String(e)
|
||||
error.value = `Backtest failed: ${msg}`
|
||||
throw e
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a specific run by ID.
|
||||
*/
|
||||
async function loadRun(runId: string) {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const result = await getBacktest(runId)
|
||||
currentResult.value = result
|
||||
return result
|
||||
} catch (e) {
|
||||
error.value = `Failed to load run: ${e}`
|
||||
throw e
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a run.
|
||||
*/
|
||||
async function removeRun(runId: string) {
|
||||
try {
|
||||
await deleteBacktest(runId)
|
||||
await refreshRuns()
|
||||
if (currentResult.value?.run_id === runId) {
|
||||
currentResult.value = null
|
||||
}
|
||||
selectedRuns.value = selectedRuns.value.filter(id => id !== runId)
|
||||
} catch (e) {
|
||||
error.value = `Failed to delete run: ${e}`
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Toggle run selection for comparison.
|
||||
*/
|
||||
function toggleRunSelection(runId: string) {
|
||||
const idx = selectedRuns.value.indexOf(runId)
|
||||
if (idx >= 0) {
|
||||
selectedRuns.value.splice(idx, 1)
|
||||
} else if (selectedRuns.value.length < 5) {
|
||||
selectedRuns.value.push(runId)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all selections.
|
||||
*/
|
||||
function clearSelections() {
|
||||
selectedRuns.value = []
|
||||
}
|
||||
|
||||
return {
|
||||
// State
|
||||
strategies,
|
||||
symbols,
|
||||
symbolsByMarket,
|
||||
runs,
|
||||
currentResult,
|
||||
selectedRuns,
|
||||
loading,
|
||||
error,
|
||||
// Actions
|
||||
init,
|
||||
refreshRuns,
|
||||
executeBacktest,
|
||||
loadRun,
|
||||
removeRun,
|
||||
toggleRunSelection,
|
||||
clearSelections,
|
||||
}
|
||||
}
|
||||
8
frontend/src/main.ts
Normal file
8
frontend/src/main.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
import { createApp } from 'vue'
|
||||
import App from './App.vue'
|
||||
import router from './router'
|
||||
import './style.css'
|
||||
|
||||
const app = createApp(App)
|
||||
app.use(router)
|
||||
app.mount('#app')
|
||||
21
frontend/src/router/index.ts
Normal file
21
frontend/src/router/index.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
import { createRouter, createWebHistory } from 'vue-router'
|
||||
import DashboardView from '@/views/DashboardView.vue'
|
||||
import CompareView from '@/views/CompareView.vue'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHistory(import.meta.env.BASE_URL),
|
||||
routes: [
|
||||
{
|
||||
path: '/',
|
||||
name: 'dashboard',
|
||||
component: DashboardView,
|
||||
},
|
||||
{
|
||||
path: '/compare',
|
||||
name: 'compare',
|
||||
component: CompareView,
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
export default router
|
||||
198
frontend/src/style.css
Normal file
198
frontend/src/style.css
Normal file
@@ -0,0 +1,198 @@
|
||||
@import "tailwindcss";
|
||||
|
||||
/* QuantConnect-inspired dark theme */
|
||||
@theme {
|
||||
/* Background colors */
|
||||
--color-bg-primary: #0d1117;
|
||||
--color-bg-secondary: #161b22;
|
||||
--color-bg-tertiary: #21262d;
|
||||
--color-bg-card: #1c2128;
|
||||
--color-bg-hover: #30363d;
|
||||
|
||||
/* Text colors */
|
||||
--color-text-primary: #e6edf3;
|
||||
--color-text-secondary: #8b949e;
|
||||
--color-text-muted: #6e7681;
|
||||
|
||||
/* Accent colors */
|
||||
--color-accent-blue: #58a6ff;
|
||||
--color-accent-purple: #a371f7;
|
||||
--color-accent-cyan: #39d4e8;
|
||||
|
||||
/* Status colors */
|
||||
--color-profit: #3fb950;
|
||||
--color-loss: #f85149;
|
||||
--color-warning: #d29922;
|
||||
|
||||
/* Border colors */
|
||||
--color-border: #30363d;
|
||||
--color-border-muted: #21262d;
|
||||
|
||||
/* Chart colors for comparison */
|
||||
--color-chart-1: #58a6ff;
|
||||
--color-chart-2: #a371f7;
|
||||
--color-chart-3: #39d4e8;
|
||||
--color-chart-4: #f0883e;
|
||||
--color-chart-5: #db61a2;
|
||||
}
|
||||
|
||||
/* Base styles */
|
||||
body {
|
||||
background-color: var(--color-bg-primary);
|
||||
color: var(--color-text-primary);
|
||||
font-family: 'JetBrains Mono', 'Fira Code', 'SF Mono', Consolas, monospace;
|
||||
}
|
||||
|
||||
/* Scrollbar styling */
|
||||
::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-track {
|
||||
background: var(--color-bg-secondary);
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: var(--color-bg-hover);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: var(--color-border);
|
||||
}
|
||||
|
||||
/* Input styling */
|
||||
input[type="number"],
|
||||
input[type="text"],
|
||||
select {
|
||||
background-color: var(--color-bg-tertiary);
|
||||
border: 1px solid var(--color-border);
|
||||
color: var(--color-text-primary);
|
||||
border-radius: 6px;
|
||||
padding: 0.5rem 0.75rem;
|
||||
font-size: 0.875rem;
|
||||
transition: border-color 0.2s, box-shadow 0.2s;
|
||||
}
|
||||
|
||||
input[type="number"]:focus,
|
||||
input[type="text"]:focus,
|
||||
select:focus {
|
||||
outline: none;
|
||||
border-color: var(--color-accent-blue);
|
||||
box-shadow: 0 0 0 3px rgba(88, 166, 255, 0.2);
|
||||
}
|
||||
|
||||
/* Button base */
|
||||
.btn {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.5rem 1rem;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
border-radius: 6px;
|
||||
transition: all 0.2s;
|
||||
cursor: pointer;
|
||||
border: 1px solid transparent;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background-color: var(--color-accent-blue);
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background-color: #79b8ff;
|
||||
}
|
||||
|
||||
.btn-primary:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-secondary {
|
||||
background-color: var(--color-bg-tertiary);
|
||||
border-color: var(--color-border);
|
||||
color: var(--color-text-primary);
|
||||
}
|
||||
|
||||
.btn-secondary:hover {
|
||||
background-color: var(--color-bg-hover);
|
||||
}
|
||||
|
||||
/* Card styling */
|
||||
.card {
|
||||
background-color: var(--color-bg-card);
|
||||
border: 1px solid var(--color-border-muted);
|
||||
border-radius: 8px;
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
/* Profit/Loss coloring */
|
||||
.profit {
|
||||
color: var(--color-profit);
|
||||
}
|
||||
|
||||
.loss {
|
||||
color: var(--color-loss);
|
||||
}
|
||||
|
||||
/* Metric value styling */
|
||||
.metric-value {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
.metric-label {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
|
||||
/* Table styling */
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
th {
|
||||
text-align: left;
|
||||
padding: 0.75rem;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 500;
|
||||
color: var(--color-text-secondary);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
border-bottom: 1px solid var(--color-border);
|
||||
}
|
||||
|
||||
td {
|
||||
padding: 0.75rem;
|
||||
font-size: 0.875rem;
|
||||
border-bottom: 1px solid var(--color-border-muted);
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
tr:hover td {
|
||||
background-color: var(--color-bg-hover);
|
||||
}
|
||||
|
||||
/* Loading spinner */
|
||||
.spinner {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border: 2px solid var(--color-border);
|
||||
border-top-color: var(--color-accent-blue);
|
||||
border-radius: 50%;
|
||||
animation: spin 0.8s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
4
frontend/src/types/plotly.d.ts
vendored
Normal file
4
frontend/src/types/plotly.d.ts
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
declare module 'plotly.js-dist-min' {
|
||||
import Plotly from 'plotly.js'
|
||||
export default Plotly
|
||||
}
|
||||
362
frontend/src/views/CompareView.vue
Normal file
362
frontend/src/views/CompareView.vue
Normal file
@@ -0,0 +1,362 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, onMounted, onUnmounted } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import Plotly from 'plotly.js-dist-min'
|
||||
import { useBacktest } from '@/composables/useBacktest'
|
||||
import { compareRuns } from '@/api/client'
|
||||
import type { BacktestResult, CompareResult } from '@/api/types'
|
||||
|
||||
const router = useRouter()
|
||||
const { selectedRuns, clearSelections } = useBacktest()
|
||||
|
||||
const chartRef = ref<HTMLDivElement | null>(null)
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
const compareResult = ref<CompareResult | null>(null)
|
||||
|
||||
const CHART_COLORS = [
|
||||
'#58a6ff', // blue
|
||||
'#a371f7', // purple
|
||||
'#39d4e8', // cyan
|
||||
'#f0883e', // orange
|
||||
'#db61a2', // pink
|
||||
]
|
||||
|
||||
async function loadComparison() {
|
||||
if (selectedRuns.value.length < 2) {
|
||||
router.push('/')
|
||||
return
|
||||
}
|
||||
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
compareResult.value = await compareRuns(selectedRuns.value)
|
||||
renderChart()
|
||||
} catch (e) {
|
||||
error.value = `Failed to load comparison: ${e}`
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function renderChart() {
|
||||
if (!chartRef.value || !compareResult.value) return
|
||||
|
||||
const traces: Plotly.Data[] = compareResult.value.runs.map((run, idx) => {
|
||||
// Normalize equity curves to start at 100 for comparison
|
||||
const startValue = run.equity_curve[0]?.value || 1
|
||||
const normalizedValues = run.equity_curve.map(p => (p.value / startValue) * 100)
|
||||
|
||||
return {
|
||||
x: run.equity_curve.map(p => p.timestamp),
|
||||
y: normalizedValues,
|
||||
type: 'scatter',
|
||||
mode: 'lines',
|
||||
name: `${run.strategy} (${run.params.period || ''})`,
|
||||
line: { color: CHART_COLORS[idx % CHART_COLORS.length], width: 2 },
|
||||
hovertemplate: `%{x}<br>${run.strategy}: %{y:.2f}<extra></extra>`,
|
||||
}
|
||||
})
|
||||
|
||||
const layout: Partial<Plotly.Layout> = {
|
||||
title: {
|
||||
text: 'Normalized Equity Comparison (Base 100)',
|
||||
font: { color: '#8b949e', size: 14 },
|
||||
},
|
||||
paper_bgcolor: 'transparent',
|
||||
plot_bgcolor: 'transparent',
|
||||
margin: { l: 60, r: 20, t: 50, b: 40 },
|
||||
xaxis: {
|
||||
showgrid: true,
|
||||
gridcolor: '#30363d',
|
||||
tickfont: { color: '#8b949e', size: 10 },
|
||||
linecolor: '#30363d',
|
||||
},
|
||||
yaxis: {
|
||||
showgrid: true,
|
||||
gridcolor: '#30363d',
|
||||
tickfont: { color: '#8b949e', size: 10 },
|
||||
linecolor: '#30363d',
|
||||
title: { text: 'Normalized Value', font: { color: '#8b949e' } },
|
||||
},
|
||||
legend: {
|
||||
orientation: 'h',
|
||||
yanchor: 'bottom',
|
||||
y: 1.02,
|
||||
xanchor: 'left',
|
||||
x: 0,
|
||||
font: { color: '#8b949e' },
|
||||
},
|
||||
hovermode: 'x unified',
|
||||
}
|
||||
|
||||
Plotly.react(chartRef.value, traces, layout, { responsive: true, displayModeBar: false })
|
||||
}
|
||||
|
||||
function formatPercent(val: number): string {
|
||||
return (val >= 0 ? '+' : '') + val.toFixed(2) + '%'
|
||||
}
|
||||
|
||||
function formatNumber(val: number | null): string {
|
||||
if (val === null) return '-'
|
||||
return val.toFixed(2)
|
||||
}
|
||||
|
||||
function getBestIndex(runs: BacktestResult[], metric: keyof BacktestResult['metrics'], higher = true): number {
|
||||
let bestIdx = 0
|
||||
let bestVal = runs[0]?.metrics[metric] ?? 0
|
||||
|
||||
for (let i = 1; i < runs.length; i++) {
|
||||
const val = runs[i]?.metrics[metric] ?? 0
|
||||
const isBetter = higher ? (val as number) > (bestVal as number) : (val as number) < (bestVal as number)
|
||||
if (isBetter) {
|
||||
bestIdx = i
|
||||
bestVal = val
|
||||
}
|
||||
}
|
||||
|
||||
return bestIdx
|
||||
}
|
||||
|
||||
function handleClearAndBack() {
|
||||
clearSelections()
|
||||
router.push('/')
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadComparison()
|
||||
window.addEventListener('resize', renderChart)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
window.removeEventListener('resize', renderChart)
|
||||
if (chartRef.value) {
|
||||
Plotly.purge(chartRef.value)
|
||||
}
|
||||
})
|
||||
|
||||
watch(selectedRuns, () => {
|
||||
if (selectedRuns.value.length >= 2) {
|
||||
loadComparison()
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="p-6 space-y-6">
|
||||
<!-- Header -->
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<h1 class="text-2xl font-bold">Compare Runs</h1>
|
||||
<p class="text-text-secondary text-sm mt-1">
|
||||
Comparing {{ selectedRuns.length }} backtest runs
|
||||
</p>
|
||||
</div>
|
||||
<button @click="handleClearAndBack" class="btn btn-secondary">
|
||||
Clear & Back
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Error -->
|
||||
<div v-if="error" class="p-4 rounded-lg bg-loss/10 border border-loss/30 text-loss">
|
||||
{{ error }}
|
||||
</div>
|
||||
|
||||
<!-- Loading -->
|
||||
<div v-if="loading" class="card flex items-center justify-center h-[400px]">
|
||||
<div class="spinner" style="width: 40px; height: 40px;"></div>
|
||||
</div>
|
||||
|
||||
<!-- Comparison Results -->
|
||||
<template v-else-if="compareResult">
|
||||
<!-- Equity Curve Comparison -->
|
||||
<div class="card">
|
||||
<div ref="chartRef" class="h-[400px]"></div>
|
||||
</div>
|
||||
|
||||
<!-- Metrics Comparison Table -->
|
||||
<div class="card overflow-x-auto">
|
||||
<h3 class="text-sm font-semibold text-text-secondary uppercase tracking-wide mb-4">
|
||||
Metrics Comparison
|
||||
</h3>
|
||||
<table class="min-w-full">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Metric</th>
|
||||
<th
|
||||
v-for="(run, idx) in compareResult.runs"
|
||||
:key="run.run_id"
|
||||
class="text-center"
|
||||
>
|
||||
<span
|
||||
class="inline-block w-3 h-3 rounded-full mr-2"
|
||||
:style="{ backgroundColor: CHART_COLORS[idx] }"
|
||||
></span>
|
||||
{{ run.strategy }}
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<!-- Total Return -->
|
||||
<tr>
|
||||
<td class="font-medium">Strategy Return</td>
|
||||
<td
|
||||
v-for="(run, idx) in compareResult.runs"
|
||||
:key="run.run_id"
|
||||
class="text-center"
|
||||
:class="[
|
||||
run.metrics.total_return >= 0 ? 'profit' : 'loss',
|
||||
idx === getBestIndex(compareResult.runs, 'total_return') ? 'font-bold' : ''
|
||||
]"
|
||||
>
|
||||
{{ formatPercent(run.metrics.total_return) }}
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Benchmark Return -->
|
||||
<tr>
|
||||
<td class="font-medium">Benchmark (B&H)</td>
|
||||
<td
|
||||
v-for="run in compareResult.runs"
|
||||
:key="run.run_id"
|
||||
class="text-center"
|
||||
:class="run.metrics.benchmark_return >= 0 ? 'profit' : 'loss'"
|
||||
>
|
||||
{{ formatPercent(run.metrics.benchmark_return) }}
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Alpha -->
|
||||
<tr>
|
||||
<td class="font-medium">Alpha</td>
|
||||
<td
|
||||
v-for="(run, idx) in compareResult.runs"
|
||||
:key="run.run_id"
|
||||
class="text-center"
|
||||
:class="[
|
||||
run.metrics.alpha >= 0 ? 'profit' : 'loss',
|
||||
idx === getBestIndex(compareResult.runs, 'alpha') ? 'font-bold' : ''
|
||||
]"
|
||||
>
|
||||
{{ formatPercent(run.metrics.alpha) }}
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Sharpe Ratio -->
|
||||
<tr>
|
||||
<td class="font-medium">Sharpe Ratio</td>
|
||||
<td
|
||||
v-for="(run, idx) in compareResult.runs"
|
||||
:key="run.run_id"
|
||||
class="text-center"
|
||||
:class="idx === getBestIndex(compareResult.runs, 'sharpe_ratio') ? 'font-bold text-accent-blue' : ''"
|
||||
>
|
||||
{{ formatNumber(run.metrics.sharpe_ratio) }}
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Max Drawdown -->
|
||||
<tr>
|
||||
<td class="font-medium">Max Drawdown</td>
|
||||
<td
|
||||
v-for="(run, idx) in compareResult.runs"
|
||||
:key="run.run_id"
|
||||
class="text-center loss"
|
||||
:class="idx === getBestIndex(compareResult.runs, 'max_drawdown', false) ? 'font-bold' : ''"
|
||||
>
|
||||
{{ formatPercent(run.metrics.max_drawdown) }}
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Win Rate -->
|
||||
<tr>
|
||||
<td class="font-medium">Win Rate</td>
|
||||
<td
|
||||
v-for="(run, idx) in compareResult.runs"
|
||||
:key="run.run_id"
|
||||
class="text-center"
|
||||
:class="idx === getBestIndex(compareResult.runs, 'win_rate') ? 'font-bold text-profit' : ''"
|
||||
>
|
||||
{{ formatNumber(run.metrics.win_rate) }}%
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Total Trades -->
|
||||
<tr>
|
||||
<td class="font-medium">Total Trades</td>
|
||||
<td
|
||||
v-for="run in compareResult.runs"
|
||||
:key="run.run_id"
|
||||
class="text-center"
|
||||
>
|
||||
{{ run.metrics.total_trades }}
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Profit Factor -->
|
||||
<tr>
|
||||
<td class="font-medium">Profit Factor</td>
|
||||
<td
|
||||
v-for="(run, idx) in compareResult.runs"
|
||||
:key="run.run_id"
|
||||
class="text-center"
|
||||
:class="idx === getBestIndex(compareResult.runs, 'profit_factor') ? 'font-bold text-profit' : ''"
|
||||
>
|
||||
{{ formatNumber(run.metrics.profit_factor) }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- Parameter Differences -->
|
||||
<div v-if="Object.keys(compareResult.param_diff).length > 0" class="card">
|
||||
<h3 class="text-sm font-semibold text-text-secondary uppercase tracking-wide mb-4">
|
||||
Parameter Differences
|
||||
</h3>
|
||||
<table class="min-w-full">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Parameter</th>
|
||||
<th
|
||||
v-for="(run, idx) in compareResult.runs"
|
||||
:key="run.run_id"
|
||||
class="text-center"
|
||||
>
|
||||
<span
|
||||
class="inline-block w-3 h-3 rounded-full mr-2"
|
||||
:style="{ backgroundColor: CHART_COLORS[idx] }"
|
||||
></span>
|
||||
Run {{ idx + 1 }}
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="(values, key) in compareResult.param_diff" :key="key">
|
||||
<td class="font-medium font-mono">{{ key }}</td>
|
||||
<td
|
||||
v-for="(val, idx) in values"
|
||||
:key="idx"
|
||||
class="text-center font-mono"
|
||||
>
|
||||
{{ val ?? '-' }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- No Selection -->
|
||||
<div v-else class="card flex items-center justify-center h-[400px]">
|
||||
<div class="text-center text-text-muted">
|
||||
<p>Select at least 2 runs from the history to compare.</p>
|
||||
<button @click="router.push('/')" class="btn btn-secondary mt-4">
|
||||
Go to Dashboard
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
136
frontend/src/views/DashboardView.vue
Normal file
136
frontend/src/views/DashboardView.vue
Normal file
@@ -0,0 +1,136 @@
|
||||
<script setup lang="ts">
|
||||
import { useBacktest } from '@/composables/useBacktest'
|
||||
import BacktestConfig from '@/components/BacktestConfig.vue'
|
||||
import EquityCurve from '@/components/EquityCurve.vue'
|
||||
import MetricsPanel from '@/components/MetricsPanel.vue'
|
||||
import TradeLog from '@/components/TradeLog.vue'
|
||||
|
||||
const { currentResult, loading, error } = useBacktest()
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="p-6 space-y-6">
|
||||
<!-- Header -->
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<h1 class="text-2xl font-bold">Lowkey Backtest</h1>
|
||||
<p class="text-text-secondary text-sm mt-1">
|
||||
Run and analyze trading strategy backtests
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Error Banner -->
|
||||
<div
|
||||
v-if="error"
|
||||
class="p-4 rounded-lg bg-loss/10 border border-loss/30 text-loss"
|
||||
>
|
||||
{{ error }}
|
||||
</div>
|
||||
|
||||
<!-- Main Grid -->
|
||||
<div class="grid grid-cols-1 lg:grid-cols-3 gap-6">
|
||||
<!-- Config Panel (Left) -->
|
||||
<div class="lg:col-span-1">
|
||||
<BacktestConfig />
|
||||
</div>
|
||||
|
||||
<!-- Results (Right) -->
|
||||
<div class="lg:col-span-2 space-y-6">
|
||||
<!-- Loading State -->
|
||||
<div
|
||||
v-if="loading"
|
||||
class="card flex items-center justify-center h-[400px]"
|
||||
>
|
||||
<div class="text-center">
|
||||
<div class="spinner mx-auto mb-4" style="width: 40px; height: 40px;"></div>
|
||||
<p class="text-text-secondary">Running backtest...</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Results Display -->
|
||||
<template v-else-if="currentResult">
|
||||
<!-- Result Header -->
|
||||
<div class="card">
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<h2 class="text-lg font-semibold">
|
||||
{{ currentResult.strategy }} on {{ currentResult.symbol }}
|
||||
</h2>
|
||||
<p class="text-sm text-text-secondary mt-1">
|
||||
{{ currentResult.start_date }} - {{ currentResult.end_date }}
|
||||
<span class="mx-2">|</span>
|
||||
{{ currentResult.market_type.toUpperCase() }}
|
||||
<span v-if="currentResult.leverage > 1" class="ml-2">
|
||||
{{ currentResult.leverage }}x
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
<div class="text-right">
|
||||
<div
|
||||
class="text-2xl font-bold"
|
||||
:class="currentResult.metrics.total_return >= 0 ? 'profit' : 'loss'"
|
||||
>
|
||||
{{ currentResult.metrics.total_return >= 0 ? '+' : '' }}{{ currentResult.metrics.total_return.toFixed(2) }}%
|
||||
</div>
|
||||
<div class="text-sm text-text-secondary">
|
||||
Sharpe: {{ currentResult.metrics.sharpe_ratio.toFixed(2) }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Equity Curve -->
|
||||
<div class="card">
|
||||
<h3 class="text-sm font-semibold text-text-secondary uppercase tracking-wide mb-4">
|
||||
Equity Curve
|
||||
</h3>
|
||||
<div class="h-[350px]">
|
||||
<EquityCurve :data="currentResult.equity_curve" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Metrics -->
|
||||
<MetricsPanel
|
||||
:metrics="currentResult.metrics"
|
||||
:leverage="currentResult.leverage"
|
||||
:market-type="currentResult.market_type"
|
||||
/>
|
||||
|
||||
<!-- Trade Log -->
|
||||
<TradeLog :trades="currentResult.trades" />
|
||||
|
||||
<!-- Parameters Used -->
|
||||
<div class="card">
|
||||
<h3 class="text-sm font-semibold text-text-secondary uppercase tracking-wide mb-3">
|
||||
Parameters
|
||||
</h3>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<span
|
||||
v-for="(value, key) in currentResult.params"
|
||||
:key="key"
|
||||
class="px-2 py-1 rounded bg-bg-tertiary text-sm font-mono"
|
||||
>
|
||||
{{ key }}: {{ value }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- Empty State -->
|
||||
<div
|
||||
v-else
|
||||
class="card flex items-center justify-center h-[400px]"
|
||||
>
|
||||
<div class="text-center text-text-muted">
|
||||
<svg class="w-16 h-16 mx-auto mb-4 opacity-50" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="1" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
|
||||
</svg>
|
||||
<p>Configure and run a backtest to see results.</p>
|
||||
<p class="text-xs mt-2">Or select a run from history.</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
20
frontend/tsconfig.app.json
Normal file
20
frontend/tsconfig.app.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"extends": "@vue/tsconfig/tsconfig.dom.json",
|
||||
"compilerOptions": {
|
||||
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
|
||||
"types": ["vite/client"],
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"@/*": ["./src/*"]
|
||||
},
|
||||
|
||||
/* Linting */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"erasableSyntaxOnly": true,
|
||||
"noFallthroughCasesInSwitch": true,
|
||||
"noUncheckedSideEffectImports": true
|
||||
},
|
||||
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"]
|
||||
}
|
||||
7
frontend/tsconfig.json
Normal file
7
frontend/tsconfig.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"files": [],
|
||||
"references": [
|
||||
{ "path": "./tsconfig.app.json" },
|
||||
{ "path": "./tsconfig.node.json" }
|
||||
]
|
||||
}
|
||||
26
frontend/tsconfig.node.json
Normal file
26
frontend/tsconfig.node.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
|
||||
"target": "ES2023",
|
||||
"lib": ["ES2023"],
|
||||
"module": "ESNext",
|
||||
"types": ["node"],
|
||||
"skipLibCheck": true,
|
||||
|
||||
/* Bundler mode */
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"verbatimModuleSyntax": true,
|
||||
"moduleDetection": "force",
|
||||
"noEmit": true,
|
||||
|
||||
/* Linting */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"erasableSyntaxOnly": true,
|
||||
"noFallthroughCasesInSwitch": true,
|
||||
"noUncheckedSideEffectImports": true
|
||||
},
|
||||
"include": ["vite.config.ts"]
|
||||
}
|
||||
22
frontend/vite.config.ts
Normal file
22
frontend/vite.config.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import vue from '@vitejs/plugin-vue'
|
||||
import tailwindcss from '@tailwindcss/vite'
|
||||
import { fileURLToPath, URL } from 'node:url'
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [vue(), tailwindcss()],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': fileURLToPath(new URL('./src', import.meta.url))
|
||||
}
|
||||
},
|
||||
server: {
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: 'http://127.0.0.1:8000',
|
||||
changeOrigin: true
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1,3 +0,0 @@
|
||||
from .supertrend import add_supertrends, compute_meta_trend
|
||||
|
||||
__all__ = ["add_supertrends", "compute_meta_trend"]
|
||||
@@ -1,58 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _atr(high: pd.Series, low: pd.Series, close: pd.Series, period: int) -> pd.Series:
|
||||
hl = (high - low).abs()
|
||||
hc = (high - close.shift()).abs()
|
||||
lc = (low - close.shift()).abs()
|
||||
tr = pd.concat([hl, hc, lc], axis=1).max(axis=1)
|
||||
return tr.rolling(period, min_periods=period).mean()
|
||||
|
||||
|
||||
def supertrend_series(df: pd.DataFrame, length: int, multiplier: float) -> pd.Series:
|
||||
atr = _atr(df["High"], df["Low"], df["Close"], length)
|
||||
hl2 = (df["High"] + df["Low"]) / 2
|
||||
upper = hl2 + multiplier * atr
|
||||
lower = hl2 - multiplier * atr
|
||||
|
||||
trend = pd.Series(index=df.index, dtype=float)
|
||||
dir_up = True
|
||||
prev_upper = np.nan
|
||||
prev_lower = np.nan
|
||||
|
||||
for i in range(len(df)):
|
||||
if i == 0 or pd.isna(atr.iat[i]):
|
||||
trend.iat[i] = np.nan
|
||||
prev_upper = upper.iat[i]
|
||||
prev_lower = lower.iat[i]
|
||||
continue
|
||||
|
||||
cu = min(upper.iat[i], prev_upper) if dir_up else upper.iat[i]
|
||||
cl = max(lower.iat[i], prev_lower) if not dir_up else lower.iat[i]
|
||||
|
||||
if df["Close"].iat[i] > cu:
|
||||
dir_up = True
|
||||
elif df["Close"].iat[i] < cl:
|
||||
dir_up = False
|
||||
|
||||
prev_upper = cu if dir_up else upper.iat[i]
|
||||
prev_lower = lower.iat[i] if dir_up else cl
|
||||
trend.iat[i] = cl if dir_up else cu
|
||||
|
||||
return trend
|
||||
|
||||
|
||||
def add_supertrends(df: pd.DataFrame, settings: list[tuple[int, float]]) -> pd.DataFrame:
|
||||
out = df.copy()
|
||||
for length, mult in settings:
|
||||
col = f"supertrend_{length}_{mult}"
|
||||
out[col] = supertrend_series(out, length, mult)
|
||||
out[f"bull_{length}_{mult}"] = (out["Close"] >= out[col]).astype(int)
|
||||
return out
|
||||
|
||||
|
||||
def compute_meta_trend(df: pd.DataFrame, settings: list[tuple[int, float]]) -> pd.Series:
|
||||
bull_cols = [f"bull_{l}_{m}" for l, m in settings]
|
||||
return (df[bull_cols].sum(axis=1) == len(bull_cols)).astype(int)
|
||||
10
intrabar.py
10
intrabar.py
@@ -1,10 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def precompute_slices(df: pd.DataFrame) -> pd.DataFrame:
|
||||
return df # hook for future use
|
||||
|
||||
|
||||
def entry_slippage_row(price: float, qty: float, slippage_bps: float) -> float:
|
||||
return price + price * (slippage_bps / 1e4)
|
||||
135
live_trading/README.md
Normal file
135
live_trading/README.md
Normal file
@@ -0,0 +1,135 @@
|
||||
# Live Trading - Regime Reversion Strategy
|
||||
|
||||
This module implements live trading for the ML-based regime detection and mean reversion strategy on OKX perpetual futures.
|
||||
|
||||
## Overview
|
||||
|
||||
The strategy trades ETH perpetual futures based on:
|
||||
1. **BTC/ETH Spread Z-Score**: Identifies when ETH is cheap or expensive relative to BTC
|
||||
2. **Random Forest ML Model**: Predicts probability of successful mean reversion
|
||||
3. **Funding Rate Filter**: Avoids trades in overheated/oversold market conditions
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. API Keys
|
||||
|
||||
The bot loads OKX API credentials from `../BTC_spot_MVRV/.env`.
|
||||
|
||||
**IMPORTANT: OKX uses SEPARATE API keys for live vs demo trading!**
|
||||
|
||||
#### Option A: Demo Trading (Recommended for Testing)
|
||||
1. Go to [OKX Demo Trading](https://www.okx.com/demo-trading)
|
||||
2. Create a demo account if you haven't
|
||||
3. Generate API keys from the demo environment
|
||||
4. Set in `.env`:
|
||||
```env
|
||||
OKX_API_KEY=your_demo_api_key
|
||||
OKX_SECRET=your_demo_secret
|
||||
OKX_PASSWORD=your_demo_passphrase
|
||||
OKX_DEMO_MODE=true
|
||||
```
|
||||
|
||||
#### Option B: Live Trading (Real Funds)
|
||||
Use your existing live API keys with:
|
||||
```env
|
||||
OKX_API_KEY=your_live_api_key
|
||||
OKX_SECRET=your_live_secret
|
||||
OKX_PASSWORD=your_live_passphrase
|
||||
OKX_DEMO_MODE=false
|
||||
```
|
||||
|
||||
**Note:** You cannot use live API keys with `OKX_DEMO_MODE=true` or vice versa.
|
||||
OKX will return error `50101: APIKey does not match current environment`.
|
||||
|
||||
### 2. Dependencies
|
||||
|
||||
All dependencies are already in the project's `pyproject.toml`. No additional installation needed.
|
||||
|
||||
## Usage
|
||||
|
||||
### Run with Demo Account (Recommended First)
|
||||
|
||||
```bash
|
||||
cd /path/to/lowkey_backtest
|
||||
uv run python -m live_trading.main
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
```bash
|
||||
# Custom position size
|
||||
uv run python -m live_trading.main --max-position 500
|
||||
|
||||
# Custom leverage
|
||||
uv run python -m live_trading.main --leverage 2
|
||||
|
||||
# Custom cycle interval (in seconds)
|
||||
uv run python -m live_trading.main --interval 1800
|
||||
|
||||
# Combine options
|
||||
uv run python -m live_trading.main --max-position 1000 --leverage 3 --interval 3600
|
||||
```
|
||||
|
||||
### Live Trading (Use with Caution)
|
||||
|
||||
```bash
|
||||
# Requires OKX_DEMO_MODE=false in .env
|
||||
uv run python -m live_trading.main --live
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
live_trading/
|
||||
__init__.py # Module initialization
|
||||
config.py # Configuration loading
|
||||
okx_client.py # OKX API wrapper
|
||||
data_feed.py # Real-time OHLCV data
|
||||
position_manager.py # Position tracking
|
||||
live_regime_strategy.py # Strategy logic
|
||||
main.py # Entry point
|
||||
.env.example # Environment template
|
||||
README.md # This file
|
||||
```
|
||||
|
||||
## Strategy Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `z_entry_threshold` | 1.0 | Enter when \|Z-Score\| > threshold |
|
||||
| `z_window` | 24 | Rolling window for Z-Score (hours) |
|
||||
| `model_prob_threshold` | 0.5 | ML probability threshold for entry |
|
||||
| `funding_threshold` | 0.005 | Funding rate filter threshold |
|
||||
| `stop_loss_pct` | 6% | Stop-loss percentage |
|
||||
| `take_profit_pct` | 5% | Take-profit percentage |
|
||||
|
||||
## Files Generated
|
||||
|
||||
- `live_trading/positions.json` - Open positions persistence
|
||||
- `live_trading/trade_log.csv` - Trade history
|
||||
- `live_trading/regime_model.pkl` - Trained ML model
|
||||
- `logs/live_trading.log` - Trading logs
|
||||
|
||||
## Risk Warning
|
||||
|
||||
This is experimental trading software. Use at your own risk:
|
||||
- Always start with demo trading
|
||||
- Never risk more than you can afford to lose
|
||||
- Monitor the bot regularly
|
||||
- Have a kill switch ready (Ctrl+C)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### API Key Issues
|
||||
- Ensure API keys have trading permissions
|
||||
- For demo trading, use demo-specific API keys
|
||||
- Check that passphrase matches exactly
|
||||
|
||||
### No Signals Generated
|
||||
- The strategy requires the ML model to be trained
|
||||
- Need at least 200 candles of data
|
||||
- Model trains automatically on first run
|
||||
|
||||
### Position Sync Issues
|
||||
- The bot syncs with exchange positions on each cycle
|
||||
- If positions are closed manually, the bot will detect this
|
||||
6
live_trading/__init__.py
Normal file
6
live_trading/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Live Trading Module for Regime Reversion Strategy on OKX.
|
||||
|
||||
This module implements live trading using the ML-based regime detection
|
||||
and mean reversion strategy on OKX perpetual futures.
|
||||
"""
|
||||
114
live_trading/config.py
Normal file
114
live_trading/config.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Configuration for Live Trading.
|
||||
|
||||
Loads OKX API credentials from environment variables.
|
||||
Uses demo/sandbox mode by default for paper trading.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OKXConfig:
|
||||
"""OKX API configuration."""
|
||||
api_key: str = field(default_factory=lambda: "")
|
||||
secret: str = field(default_factory=lambda: "")
|
||||
password: str = field(default_factory=lambda: "")
|
||||
demo_mode: bool = field(default_factory=lambda: True)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Load credentials based on demo mode setting."""
|
||||
# Check demo mode first
|
||||
self.demo_mode = os.getenv("OKX_DEMO_MODE", "true").lower() in ("true", "1", "yes")
|
||||
|
||||
if self.demo_mode:
|
||||
# Load demo-specific credentials if available
|
||||
self.api_key = os.getenv("OKX_DEMO_API_KEY", os.getenv("OKX_API_KEY", ""))
|
||||
self.secret = os.getenv("OKX_DEMO_SECRET", os.getenv("OKX_SECRET", ""))
|
||||
self.password = os.getenv("OKX_DEMO_PASSWORD", os.getenv("OKX_PASSWORD", ""))
|
||||
else:
|
||||
# Load live credentials
|
||||
self.api_key = os.getenv("OKX_API_KEY", "")
|
||||
self.secret = os.getenv("OKX_SECRET", "")
|
||||
self.password = os.getenv("OKX_PASSWORD", "")
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate that required credentials are present."""
|
||||
mode = "demo" if self.demo_mode else "live"
|
||||
if not self.api_key:
|
||||
raise ValueError(f"OKX API key not set for {mode} mode")
|
||||
if not self.secret:
|
||||
raise ValueError(f"OKX secret not set for {mode} mode")
|
||||
if not self.password:
|
||||
raise ValueError(f"OKX password not set for {mode} mode")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradingConfig:
|
||||
"""Trading parameters configuration."""
|
||||
# Trading pairs
|
||||
eth_symbol: str = "ETH/USDT:USDT" # ETH perpetual (primary trading asset)
|
||||
btc_symbol: str = "BTC/USDT:USDT" # BTC perpetual (context asset)
|
||||
|
||||
# Timeframe
|
||||
timeframe: str = "1h"
|
||||
candles_to_fetch: int = 500 # Enough for feature calculation
|
||||
|
||||
# Position sizing
|
||||
max_position_usdt: float = -1.0 # Max position size in USDT. If <= 0, use all available funds
|
||||
min_position_usdt: float = 10.0 # Min position size in USDT
|
||||
leverage: int = 1 # Leverage (1x = no leverage)
|
||||
margin_mode: str = "cross" # "cross" or "isolated"
|
||||
|
||||
# Risk management
|
||||
stop_loss_pct: float = 0.06 # 6% stop loss
|
||||
take_profit_pct: float = 0.05 # 5% take profit
|
||||
max_concurrent_positions: int = 1 # Max open positions
|
||||
|
||||
# Strategy parameters (from regime_strategy.py)
|
||||
z_entry_threshold: float = 1.0 # Enter when |Z| > 1.0
|
||||
z_window: int = 24 # 24h rolling Z-score window
|
||||
model_prob_threshold: float = 0.5 # ML model probability threshold
|
||||
funding_threshold: float = 0.005 # Funding rate filter threshold
|
||||
|
||||
# Execution
|
||||
sleep_seconds: int = 3600 # Run every hour (1h candles)
|
||||
slippage_pct: float = 0.001 # 0.1% slippage buffer
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathConfig:
|
||||
"""File paths configuration."""
|
||||
base_dir: Path = field(
|
||||
default_factory=lambda: Path(__file__).parent.parent
|
||||
)
|
||||
data_dir: Path = field(default=None)
|
||||
logs_dir: Path = field(default=None)
|
||||
model_path: Path = field(default=None)
|
||||
positions_file: Path = field(default=None)
|
||||
trade_log_file: Path = field(default=None)
|
||||
cq_data_path: Path = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
self.data_dir = self.base_dir / "data"
|
||||
self.logs_dir = self.base_dir / "logs"
|
||||
self.model_path = self.base_dir / "live_trading" / "regime_model.pkl"
|
||||
self.positions_file = self.base_dir / "live_trading" / "positions.json"
|
||||
self.trade_log_file = self.base_dir / "live_trading" / "trade_log.csv"
|
||||
self.cq_data_path = self.data_dir / "cq_training_data.csv"
|
||||
|
||||
# Ensure directories exist
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Get all configuration objects."""
|
||||
okx = OKXConfig()
|
||||
trading = TradingConfig()
|
||||
paths = PathConfig()
|
||||
return okx, trading, paths
|
||||
216
live_trading/data_feed.py
Normal file
216
live_trading/data_feed.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Data Feed for Live Trading.
|
||||
|
||||
Fetches real-time OHLCV data from OKX and prepares features
|
||||
for the regime strategy.
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import ta
|
||||
|
||||
from .okx_client import OKXClient
|
||||
from .config import TradingConfig, PathConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataFeed:
|
||||
"""
|
||||
Real-time data feed for the regime strategy.
|
||||
|
||||
Fetches BTC and ETH OHLCV data from OKX and calculates
|
||||
the spread-based features required by the ML model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
okx_client: OKXClient,
|
||||
trading_config: TradingConfig,
|
||||
path_config: PathConfig
|
||||
):
|
||||
self.client = okx_client
|
||||
self.config = trading_config
|
||||
self.paths = path_config
|
||||
self.cq_data: Optional[pd.DataFrame] = None
|
||||
self._load_cq_data()
|
||||
|
||||
def _load_cq_data(self) -> None:
|
||||
"""Load CryptoQuant on-chain data if available."""
|
||||
try:
|
||||
if self.paths.cq_data_path.exists():
|
||||
self.cq_data = pd.read_csv(
|
||||
self.paths.cq_data_path,
|
||||
index_col='timestamp',
|
||||
parse_dates=True
|
||||
)
|
||||
if self.cq_data.index.tz is None:
|
||||
self.cq_data.index = self.cq_data.index.tz_localize('UTC')
|
||||
logger.info(f"Loaded CryptoQuant data: {len(self.cq_data)} rows")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load CryptoQuant data: {e}")
|
||||
self.cq_data = None
|
||||
|
||||
def fetch_ohlcv_data(self) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Fetch OHLCV data for BTC and ETH.
|
||||
|
||||
Returns:
|
||||
Tuple of (btc_df, eth_df) DataFrames
|
||||
"""
|
||||
# Fetch BTC data
|
||||
btc_ohlcv = self.client.fetch_ohlcv(
|
||||
self.config.btc_symbol,
|
||||
self.config.timeframe,
|
||||
self.config.candles_to_fetch
|
||||
)
|
||||
btc_df = self._ohlcv_to_dataframe(btc_ohlcv)
|
||||
|
||||
# Fetch ETH data
|
||||
eth_ohlcv = self.client.fetch_ohlcv(
|
||||
self.config.eth_symbol,
|
||||
self.config.timeframe,
|
||||
self.config.candles_to_fetch
|
||||
)
|
||||
eth_df = self._ohlcv_to_dataframe(eth_ohlcv)
|
||||
|
||||
logger.info(
|
||||
f"Fetched {len(btc_df)} BTC candles and {len(eth_df)} ETH candles"
|
||||
)
|
||||
|
||||
return btc_df, eth_df
|
||||
|
||||
def _ohlcv_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
|
||||
"""Convert OHLCV list to DataFrame."""
|
||||
df = pd.DataFrame(
|
||||
ohlcv,
|
||||
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
return df
|
||||
|
||||
def calculate_features(
|
||||
self,
|
||||
btc_df: pd.DataFrame,
|
||||
eth_df: pd.DataFrame
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate spread-based features for the regime strategy.
|
||||
|
||||
Args:
|
||||
btc_df: BTC OHLCV DataFrame
|
||||
eth_df: ETH OHLCV DataFrame
|
||||
|
||||
Returns:
|
||||
DataFrame with calculated features
|
||||
"""
|
||||
# Align indices
|
||||
common_idx = btc_df.index.intersection(eth_df.index)
|
||||
df_btc = btc_df.loc[common_idx].copy()
|
||||
df_eth = eth_df.loc[common_idx].copy()
|
||||
|
||||
# Calculate spread (ETH/BTC ratio)
|
||||
spread = df_eth['close'] / df_btc['close']
|
||||
|
||||
# Z-Score of spread
|
||||
z_window = self.config.z_window
|
||||
rolling_mean = spread.rolling(window=z_window).mean()
|
||||
rolling_std = spread.rolling(window=z_window).std()
|
||||
z_score = (spread - rolling_mean) / rolling_std
|
||||
|
||||
# Spread technicals
|
||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||
spread_roc = spread.pct_change(periods=5) * 100
|
||||
spread_change_1h = spread.pct_change(periods=1)
|
||||
|
||||
# Volume ratio
|
||||
vol_ratio = df_eth['volume'] / df_btc['volume']
|
||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||
|
||||
# Volatility
|
||||
ret_btc = df_btc['close'].pct_change()
|
||||
ret_eth = df_eth['close'].pct_change()
|
||||
vol_btc = ret_btc.rolling(window=z_window).std()
|
||||
vol_eth = ret_eth.rolling(window=z_window).std()
|
||||
vol_spread_ratio = vol_eth / vol_btc
|
||||
|
||||
# Build features DataFrame
|
||||
features = pd.DataFrame(index=spread.index)
|
||||
features['spread'] = spread
|
||||
features['z_score'] = z_score
|
||||
features['spread_rsi'] = spread_rsi
|
||||
features['spread_roc'] = spread_roc
|
||||
features['spread_change_1h'] = spread_change_1h
|
||||
features['vol_ratio'] = vol_ratio
|
||||
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
|
||||
features['vol_diff_ratio'] = vol_spread_ratio
|
||||
|
||||
# Add price data for reference
|
||||
features['btc_close'] = df_btc['close']
|
||||
features['eth_close'] = df_eth['close']
|
||||
features['eth_volume'] = df_eth['volume']
|
||||
|
||||
# Merge CryptoQuant data if available
|
||||
if self.cq_data is not None:
|
||||
cq_aligned = self.cq_data.reindex(features.index, method='ffill')
|
||||
|
||||
# Calculate derived features
|
||||
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
|
||||
cq_aligned['funding_diff'] = (
|
||||
cq_aligned['eth_funding'] - cq_aligned['btc_funding']
|
||||
)
|
||||
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
|
||||
cq_aligned['inflow_ratio'] = (
|
||||
cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
|
||||
)
|
||||
|
||||
features = features.join(cq_aligned)
|
||||
|
||||
return features.dropna()
|
||||
|
||||
def get_latest_data(self) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Fetch and process latest market data.
|
||||
|
||||
Returns:
|
||||
DataFrame with features or None on error
|
||||
"""
|
||||
try:
|
||||
btc_df, eth_df = self.fetch_ohlcv_data()
|
||||
features = self.calculate_features(btc_df, eth_df)
|
||||
|
||||
if features.empty:
|
||||
logger.warning("No valid features calculated")
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Latest data: ETH={features['eth_close'].iloc[-1]:.2f}, "
|
||||
f"BTC={features['btc_close'].iloc[-1]:.2f}, "
|
||||
f"Z-Score={features['z_score'].iloc[-1]:.3f}"
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching market data: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_current_funding_rates(self) -> dict:
|
||||
"""
|
||||
Get current funding rates for BTC and ETH.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'btc_funding' and 'eth_funding' rates
|
||||
"""
|
||||
btc_funding = self.client.get_funding_rate(self.config.btc_symbol)
|
||||
eth_funding = self.client.get_funding_rate(self.config.eth_symbol)
|
||||
|
||||
return {
|
||||
'btc_funding': btc_funding,
|
||||
'eth_funding': eth_funding,
|
||||
'funding_diff': eth_funding - btc_funding,
|
||||
}
|
||||
15
live_trading/env.template
Normal file
15
live_trading/env.template
Normal file
@@ -0,0 +1,15 @@
|
||||
# OKX API Credentials Template
|
||||
# Copy this file to .env and fill in your credentials
|
||||
# For demo trading, use your OKX demo account API keys
|
||||
# Generate keys at: https://www.okx.com/account/my-api (Demo Trading section)
|
||||
|
||||
OKX_API_KEY=your_api_key_here
|
||||
OKX_SECRET=your_secret_key_here
|
||||
OKX_PASSWORD=your_passphrase_here
|
||||
|
||||
# Demo Mode: Set to "true" for paper trading (sandbox)
|
||||
# Set to "false" for live trading with real funds
|
||||
OKX_DEMO_MODE=true
|
||||
|
||||
# CryptoQuant API (optional, for on-chain features)
|
||||
CRYPTOQUANT_API_KEY=your_cryptoquant_api_key_here
|
||||
287
live_trading/live_regime_strategy.py
Normal file
287
live_trading/live_regime_strategy.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Live Regime Reversion Strategy.
|
||||
|
||||
Adapts the backtest regime strategy for live trading.
|
||||
Uses a pre-trained ML model or trains on historical data.
|
||||
"""
|
||||
import logging
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
from .config import TradingConfig, PathConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LiveRegimeStrategy:
|
||||
"""
|
||||
Live trading implementation of the ML-based regime detection
|
||||
and mean reversion strategy.
|
||||
|
||||
Logic:
|
||||
1. Calculates BTC/ETH spread Z-Score
|
||||
2. Uses Random Forest to predict reversion probability
|
||||
3. Applies funding rate filter
|
||||
4. Generates long/short signals on ETH perpetual
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trading_config: TradingConfig,
|
||||
path_config: PathConfig
|
||||
):
|
||||
self.config = trading_config
|
||||
self.paths = path_config
|
||||
self.model: Optional[RandomForestClassifier] = None
|
||||
self.feature_cols: Optional[list] = None
|
||||
self._load_or_train_model()
|
||||
|
||||
def _load_or_train_model(self) -> None:
|
||||
"""Load pre-trained model or train a new one."""
|
||||
if self.paths.model_path.exists():
|
||||
try:
|
||||
with open(self.paths.model_path, 'rb') as f:
|
||||
saved = pickle.load(f)
|
||||
self.model = saved['model']
|
||||
self.feature_cols = saved['feature_cols']
|
||||
logger.info(f"Loaded model from {self.paths.model_path}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load model: {e}")
|
||||
|
||||
logger.info("No pre-trained model found. Will train on first data batch.")
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save trained model to file."""
|
||||
if self.model is None:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.paths.model_path, 'wb') as f:
|
||||
pickle.dump({
|
||||
'model': self.model,
|
||||
'feature_cols': self.feature_cols,
|
||||
}, f)
|
||||
logger.info(f"Saved model to {self.paths.model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Could not save model: {e}")
|
||||
|
||||
def train_model(self, features: pd.DataFrame) -> None:
|
||||
"""
|
||||
Train the Random Forest model on historical data.
|
||||
|
||||
Args:
|
||||
features: DataFrame with calculated features
|
||||
"""
|
||||
logger.info(f"Training model on {len(features)} samples...")
|
||||
|
||||
z_thresh = self.config.z_entry_threshold
|
||||
horizon = 102 # Optimal horizon from research
|
||||
profit_target = 0.005 # 0.5% profit threshold
|
||||
|
||||
# Define targets
|
||||
future_min = features['spread'].rolling(window=horizon).min().shift(-horizon)
|
||||
future_max = features['spread'].rolling(window=horizon).max().shift(-horizon)
|
||||
|
||||
target_short = features['spread'] * (1 - profit_target)
|
||||
target_long = features['spread'] * (1 + profit_target)
|
||||
|
||||
success_short = (features['z_score'] > z_thresh) & (future_min < target_short)
|
||||
success_long = (features['z_score'] < -z_thresh) & (future_max > target_long)
|
||||
|
||||
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||
|
||||
# Exclude non-feature columns
|
||||
exclude = ['spread', 'btc_close', 'eth_close', 'eth_volume']
|
||||
self.feature_cols = [c for c in features.columns if c not in exclude]
|
||||
|
||||
# Clean features
|
||||
X = features[self.feature_cols].fillna(0)
|
||||
X = X.replace([np.inf, -np.inf], 0)
|
||||
|
||||
# Remove rows with invalid targets
|
||||
valid_mask = ~np.isnan(targets) & future_min.notna().values & future_max.notna().values
|
||||
X_clean = X[valid_mask]
|
||||
y_clean = targets[valid_mask]
|
||||
|
||||
if len(X_clean) < 100:
|
||||
logger.warning("Not enough data to train model")
|
||||
return
|
||||
|
||||
# Train model
|
||||
self.model = RandomForestClassifier(
|
||||
n_estimators=300,
|
||||
max_depth=5,
|
||||
min_samples_leaf=30,
|
||||
class_weight={0: 1, 1: 3},
|
||||
random_state=42
|
||||
)
|
||||
self.model.fit(X_clean, y_clean)
|
||||
|
||||
logger.info(f"Model trained on {len(X_clean)} samples")
|
||||
self.save_model()
|
||||
|
||||
def generate_signal(
|
||||
self,
|
||||
features: pd.DataFrame,
|
||||
current_funding: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Generate trading signal from latest features.
|
||||
|
||||
Args:
|
||||
features: DataFrame with calculated features
|
||||
current_funding: Dictionary with funding rate data
|
||||
|
||||
Returns:
|
||||
Signal dictionary with action, side, confidence, etc.
|
||||
"""
|
||||
if self.model is None:
|
||||
# Train model if not available
|
||||
if len(features) >= 200:
|
||||
self.train_model(features)
|
||||
else:
|
||||
return {'action': 'hold', 'reason': 'model_not_trained'}
|
||||
|
||||
if self.model is None:
|
||||
return {'action': 'hold', 'reason': 'insufficient_data_for_training'}
|
||||
|
||||
# Get latest row
|
||||
latest = features.iloc[-1]
|
||||
z_score = latest['z_score']
|
||||
eth_price = latest['eth_close']
|
||||
btc_price = latest['btc_close']
|
||||
|
||||
# Prepare features for prediction
|
||||
X = features[self.feature_cols].iloc[[-1]].fillna(0)
|
||||
X = X.replace([np.inf, -np.inf], 0)
|
||||
|
||||
# Get prediction probability
|
||||
prob = self.model.predict_proba(X)[0, 1]
|
||||
|
||||
# Apply thresholds
|
||||
z_thresh = self.config.z_entry_threshold
|
||||
prob_thresh = self.config.model_prob_threshold
|
||||
|
||||
# Determine signal direction
|
||||
signal = {
|
||||
'action': 'hold',
|
||||
'side': None,
|
||||
'probability': prob,
|
||||
'z_score': z_score,
|
||||
'eth_price': eth_price,
|
||||
'btc_price': btc_price,
|
||||
'reason': '',
|
||||
}
|
||||
|
||||
# Check for entry conditions
|
||||
if prob > prob_thresh:
|
||||
if z_score > z_thresh:
|
||||
# Spread high (ETH expensive relative to BTC) -> Short ETH
|
||||
signal['action'] = 'entry'
|
||||
signal['side'] = 'short'
|
||||
signal['reason'] = f'z_score={z_score:.2f}>threshold, prob={prob:.2f}'
|
||||
elif z_score < -z_thresh:
|
||||
# Spread low (ETH cheap relative to BTC) -> Long ETH
|
||||
signal['action'] = 'entry'
|
||||
signal['side'] = 'long'
|
||||
signal['reason'] = f'z_score={z_score:.2f}<-threshold, prob={prob:.2f}'
|
||||
else:
|
||||
signal['reason'] = f'z_score={z_score:.2f} within threshold'
|
||||
else:
|
||||
signal['reason'] = f'prob={prob:.2f}<threshold'
|
||||
|
||||
# Apply funding rate filter
|
||||
if signal['action'] == 'entry':
|
||||
btc_funding = current_funding.get('btc_funding', 0)
|
||||
funding_thresh = self.config.funding_threshold
|
||||
|
||||
if signal['side'] == 'long' and btc_funding > funding_thresh:
|
||||
# High positive funding = overheated, don't go long
|
||||
signal['action'] = 'hold'
|
||||
signal['reason'] = f'funding_filter_blocked_long (funding={btc_funding:.4f})'
|
||||
elif signal['side'] == 'short' and btc_funding < -funding_thresh:
|
||||
# High negative funding = oversold, don't go short
|
||||
signal['action'] = 'hold'
|
||||
signal['reason'] = f'funding_filter_blocked_short (funding={btc_funding:.4f})'
|
||||
|
||||
# Check for exit conditions (mean reversion complete)
|
||||
if signal['action'] == 'hold':
|
||||
# Z-score crossed back through 0
|
||||
if abs(z_score) < 0.3:
|
||||
signal['action'] = 'check_exit'
|
||||
signal['reason'] = f'z_score_reverted_to_mean ({z_score:.2f})'
|
||||
|
||||
logger.info(
|
||||
f"Signal: {signal['action']} {signal['side'] or ''} "
|
||||
f"(prob={prob:.2f}, z={z_score:.2f}, reason={signal['reason']})"
|
||||
)
|
||||
|
||||
return signal
|
||||
|
||||
def calculate_position_size(
|
||||
self,
|
||||
signal: dict,
|
||||
available_usdt: float
|
||||
) -> float:
|
||||
"""
|
||||
Calculate position size based on signal confidence.
|
||||
|
||||
Args:
|
||||
signal: Signal dictionary with probability
|
||||
available_usdt: Available USDT balance
|
||||
|
||||
Returns:
|
||||
Position size in USDT
|
||||
"""
|
||||
prob = signal.get('probability', 0.5)
|
||||
|
||||
# Base size: if max_position_usdt <= 0, use all available funds
|
||||
if self.config.max_position_usdt <= 0:
|
||||
base_size = available_usdt
|
||||
else:
|
||||
base_size = min(available_usdt, self.config.max_position_usdt)
|
||||
|
||||
# Scale by probability (1.0x at 0.5 prob, up to 1.6x at 0.8 prob)
|
||||
scale = 1.0 + (prob - 0.5) * 2.0
|
||||
scale = max(1.0, min(scale, 2.0)) # Clamp between 1x and 2x
|
||||
|
||||
size = base_size * scale
|
||||
|
||||
# Ensure minimum position size
|
||||
if size < self.config.min_position_usdt:
|
||||
return 0.0
|
||||
|
||||
return min(size, available_usdt * 0.95) # Leave 5% buffer
|
||||
|
||||
def calculate_sl_tp(
|
||||
self,
|
||||
entry_price: float,
|
||||
side: str
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate stop-loss and take-profit prices.
|
||||
|
||||
Args:
|
||||
entry_price: Entry price
|
||||
side: "long" or "short"
|
||||
|
||||
Returns:
|
||||
Tuple of (stop_loss_price, take_profit_price)
|
||||
"""
|
||||
sl_pct = self.config.stop_loss_pct
|
||||
tp_pct = self.config.take_profit_pct
|
||||
|
||||
if side == "long":
|
||||
stop_loss = entry_price * (1 - sl_pct)
|
||||
take_profit = entry_price * (1 + tp_pct)
|
||||
else: # short
|
||||
stop_loss = entry_price * (1 + sl_pct)
|
||||
take_profit = entry_price * (1 - tp_pct)
|
||||
|
||||
return stop_loss, take_profit
|
||||
390
live_trading/main.py
Normal file
390
live_trading/main.py
Normal file
@@ -0,0 +1,390 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Live Trading Bot for Regime Reversion Strategy on OKX.
|
||||
|
||||
This script runs the regime-based mean reversion strategy
|
||||
on ETH perpetual futures using OKX exchange.
|
||||
|
||||
Usage:
|
||||
# Run with demo account (default)
|
||||
uv run python -m live_trading.main
|
||||
|
||||
# Run with specific settings
|
||||
uv run python -m live_trading.main --max-position 500 --leverage 2
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from live_trading.config import get_config, OKXConfig, TradingConfig, PathConfig
|
||||
from live_trading.okx_client import OKXClient
|
||||
from live_trading.data_feed import DataFeed
|
||||
from live_trading.position_manager import PositionManager
|
||||
from live_trading.live_regime_strategy import LiveRegimeStrategy
|
||||
|
||||
|
||||
def setup_logging(log_dir: Path) -> logging.Logger:
|
||||
"""Configure logging for the trading bot."""
|
||||
log_file = log_dir / "live_trading.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_file),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
force=True
|
||||
)
|
||||
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LiveTradingBot:
|
||||
"""
|
||||
Main trading bot orchestrator.
|
||||
|
||||
Coordinates data fetching, signal generation, and order execution
|
||||
in a continuous loop.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
okx_config: OKXConfig,
|
||||
trading_config: TradingConfig,
|
||||
path_config: PathConfig
|
||||
):
|
||||
self.okx_config = okx_config
|
||||
self.trading_config = trading_config
|
||||
self.path_config = path_config
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.running = True
|
||||
|
||||
# Initialize components
|
||||
self.logger.info("Initializing trading bot components...")
|
||||
|
||||
self.okx_client = OKXClient(okx_config, trading_config)
|
||||
self.data_feed = DataFeed(self.okx_client, trading_config, path_config)
|
||||
self.position_manager = PositionManager(
|
||||
self.okx_client, trading_config, path_config
|
||||
)
|
||||
self.strategy = LiveRegimeStrategy(trading_config, path_config)
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGINT, self._handle_shutdown)
|
||||
signal.signal(signal.SIGTERM, self._handle_shutdown)
|
||||
|
||||
self._print_startup_banner()
|
||||
|
||||
def _print_startup_banner(self) -> None:
|
||||
"""Print startup information."""
|
||||
mode = "DEMO/SANDBOX" if self.okx_config.demo_mode else "LIVE"
|
||||
|
||||
print("=" * 60)
|
||||
print(f" Regime Reversion Strategy - Live Trading Bot")
|
||||
print("=" * 60)
|
||||
print(f" Mode: {mode}")
|
||||
print(f" Trading Pair: {self.trading_config.eth_symbol}")
|
||||
print(f" Context Pair: {self.trading_config.btc_symbol}")
|
||||
print(f" Timeframe: {self.trading_config.timeframe}")
|
||||
print(f" Max Position: ${self.trading_config.max_position_usdt if self.trading_config.max_position_usdt > 0 else 'All available'}")
|
||||
print(f" Leverage: {self.trading_config.leverage}x")
|
||||
print(f" Stop Loss: {self.trading_config.stop_loss_pct * 100:.1f}%")
|
||||
print(f" Take Profit: {self.trading_config.take_profit_pct * 100:.1f}%")
|
||||
print(f" Cycle Interval: {self.trading_config.sleep_seconds // 60} minutes")
|
||||
print("=" * 60)
|
||||
|
||||
if not self.okx_config.demo_mode:
|
||||
print("\n *** WARNING: LIVE TRADING MODE - REAL FUNDS AT RISK ***\n")
|
||||
|
||||
def _handle_shutdown(self, signum, frame) -> None:
|
||||
"""Handle shutdown signals gracefully."""
|
||||
self.logger.info("Shutdown signal received, stopping...")
|
||||
self.running = False
|
||||
|
||||
def run_trading_cycle(self) -> None:
|
||||
"""
|
||||
Execute one trading cycle.
|
||||
|
||||
1. Fetch latest market data
|
||||
2. Update open positions
|
||||
3. Generate trading signal
|
||||
4. Execute trades if signal triggers
|
||||
"""
|
||||
cycle_start = datetime.now(timezone.utc)
|
||||
self.logger.info(f"--- Trading Cycle Start: {cycle_start.isoformat()} ---")
|
||||
|
||||
try:
|
||||
# 1. Fetch market data
|
||||
features = self.data_feed.get_latest_data()
|
||||
if features is None or features.empty:
|
||||
self.logger.warning("No market data available, skipping cycle")
|
||||
return
|
||||
|
||||
# Get current prices
|
||||
eth_price = features['eth_close'].iloc[-1]
|
||||
btc_price = features['btc_close'].iloc[-1]
|
||||
|
||||
current_prices = {
|
||||
self.trading_config.eth_symbol: eth_price,
|
||||
self.trading_config.btc_symbol: btc_price,
|
||||
}
|
||||
|
||||
# 2. Update existing positions (check SL/TP)
|
||||
closed_trades = self.position_manager.update_positions(current_prices)
|
||||
if closed_trades:
|
||||
for trade in closed_trades:
|
||||
self.logger.info(
|
||||
f"Trade closed: {trade['trade_id']} "
|
||||
f"PnL=${trade['pnl_usd']:.2f} ({trade['reason']})"
|
||||
)
|
||||
|
||||
# 3. Sync with exchange positions
|
||||
self.position_manager.sync_with_exchange()
|
||||
|
||||
# 4. Get current funding rates
|
||||
funding = self.data_feed.get_current_funding_rates()
|
||||
|
||||
# 5. Generate trading signal
|
||||
signal = self.strategy.generate_signal(features, funding)
|
||||
|
||||
# 6. Execute trades based on signal
|
||||
if signal['action'] == 'entry':
|
||||
self._execute_entry(signal, eth_price)
|
||||
elif signal['action'] == 'check_exit':
|
||||
self._execute_exit(signal)
|
||||
|
||||
# 7. Log portfolio summary
|
||||
summary = self.position_manager.get_portfolio_summary()
|
||||
self.logger.info(
|
||||
f"Portfolio: {summary['open_positions']} positions, "
|
||||
f"exposure=${summary['total_exposure_usdt']:.2f}, "
|
||||
f"unrealized_pnl=${summary['total_unrealized_pnl']:.2f}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Trading cycle error: {e}", exc_info=True)
|
||||
# Save positions on error
|
||||
self.position_manager.save_positions()
|
||||
|
||||
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
|
||||
self.logger.info(f"--- Cycle completed in {cycle_duration:.1f}s ---")
|
||||
|
||||
def _execute_entry(self, signal: dict, current_price: float) -> None:
|
||||
"""Execute entry trade."""
|
||||
symbol = self.trading_config.eth_symbol
|
||||
side = signal['side']
|
||||
|
||||
# Check if we can open a position
|
||||
if not self.position_manager.can_open_position():
|
||||
self.logger.info("Cannot open position: max positions reached")
|
||||
return
|
||||
|
||||
# Get account balance
|
||||
balance = self.okx_client.get_balance()
|
||||
available_usdt = balance['free']
|
||||
|
||||
# Calculate position size
|
||||
size_usdt = self.strategy.calculate_position_size(signal, available_usdt)
|
||||
if size_usdt <= 0:
|
||||
self.logger.info("Position size too small, skipping entry")
|
||||
return
|
||||
|
||||
size_eth = size_usdt / current_price
|
||||
|
||||
# Calculate SL/TP
|
||||
stop_loss, take_profit = self.strategy.calculate_sl_tp(current_price, side)
|
||||
|
||||
self.logger.info(
|
||||
f"Executing {side.upper()} entry: {size_eth:.4f} ETH @ {current_price:.2f} "
|
||||
f"(${size_usdt:.2f}), SL={stop_loss:.2f}, TP={take_profit:.2f}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Place market order
|
||||
order_side = "buy" if side == "long" else "sell"
|
||||
order = self.okx_client.place_market_order(symbol, order_side, size_eth)
|
||||
|
||||
# Get filled price (handle None values from OKX response)
|
||||
filled_price = order.get('average') or order.get('price') or current_price
|
||||
filled_amount = order.get('filled') or order.get('amount') or size_eth
|
||||
|
||||
# Ensure we have valid numeric values
|
||||
if filled_price is None or filled_price == 0:
|
||||
self.logger.warning(f"No fill price in order response, using current price: {current_price}")
|
||||
filled_price = current_price
|
||||
if filled_amount is None or filled_amount == 0:
|
||||
self.logger.warning(f"No fill amount in order response, using requested: {size_eth}")
|
||||
filled_amount = size_eth
|
||||
|
||||
# Recalculate SL/TP with filled price
|
||||
stop_loss, take_profit = self.strategy.calculate_sl_tp(filled_price, side)
|
||||
|
||||
# Get order ID from response
|
||||
order_id = order.get('id', '')
|
||||
|
||||
# Record position locally
|
||||
position = self.position_manager.open_position(
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
entry_price=filled_price,
|
||||
size=filled_amount,
|
||||
stop_loss_price=stop_loss,
|
||||
take_profit_price=take_profit,
|
||||
order_id=order_id,
|
||||
)
|
||||
|
||||
if position:
|
||||
self.logger.info(
|
||||
f"Position opened: {position.trade_id}, "
|
||||
f"{filled_amount:.4f} ETH @ {filled_price:.2f}"
|
||||
)
|
||||
|
||||
# Try to set SL/TP on exchange
|
||||
try:
|
||||
self.okx_client.set_stop_loss_take_profit(
|
||||
symbol, side, filled_amount, stop_loss, take_profit
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not set SL/TP on exchange: {e}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Order execution failed: {e}", exc_info=True)
|
||||
|
||||
def _execute_exit(self, signal: dict) -> None:
|
||||
"""Execute exit based on mean reversion signal."""
|
||||
symbol = self.trading_config.eth_symbol
|
||||
|
||||
# Get position for ETH
|
||||
position = self.position_manager.get_position_for_symbol(symbol)
|
||||
if not position:
|
||||
return
|
||||
|
||||
current_price = signal.get('eth_price', position.current_price)
|
||||
|
||||
self.logger.info(
|
||||
f"Mean reversion exit signal: closing {position.trade_id} "
|
||||
f"@ {current_price:.2f}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Close position on exchange
|
||||
exit_order = self.okx_client.close_position(symbol)
|
||||
exit_order_id = exit_order.get('id', '') if exit_order else ''
|
||||
|
||||
# Record closure locally
|
||||
self.position_manager.close_position(
|
||||
position.trade_id,
|
||||
current_price,
|
||||
reason="mean_reversion_complete",
|
||||
exit_order_id=exit_order_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exit execution failed: {e}", exc_info=True)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Main trading loop."""
|
||||
self.logger.info("Starting trading loop...")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self.run_trading_cycle()
|
||||
|
||||
if self.running:
|
||||
sleep_seconds = self.trading_config.sleep_seconds
|
||||
minutes = sleep_seconds // 60
|
||||
self.logger.info(f"Sleeping for {minutes} minutes...")
|
||||
|
||||
# Sleep in smaller chunks to allow faster shutdown
|
||||
for _ in range(sleep_seconds):
|
||||
if not self.running:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
self.logger.info("Keyboard interrupt received")
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error in main loop: {e}", exc_info=True)
|
||||
time.sleep(60) # Wait before retry
|
||||
|
||||
# Cleanup
|
||||
self.logger.info("Shutting down...")
|
||||
self.position_manager.save_positions()
|
||||
self.logger.info("Shutdown complete")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Live Trading Bot for Regime Reversion Strategy"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-position",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Maximum position size in USDT"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leverage",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Trading leverage (1-125)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interval",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Trading cycle interval in seconds"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--live",
|
||||
action="store_true",
|
||||
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
args = parse_args()
|
||||
|
||||
# Load configuration
|
||||
okx_config, trading_config, path_config = get_config()
|
||||
|
||||
# Apply command line overrides
|
||||
if args.max_position is not None:
|
||||
trading_config.max_position_usdt = args.max_position
|
||||
if args.leverage is not None:
|
||||
trading_config.leverage = args.leverage
|
||||
if args.interval is not None:
|
||||
trading_config.sleep_seconds = args.interval
|
||||
if args.live:
|
||||
okx_config.demo_mode = False
|
||||
|
||||
# Setup logging
|
||||
logger = setup_logging(path_config.logs_dir)
|
||||
|
||||
try:
|
||||
# Create and run bot
|
||||
bot = LiveTradingBot(okx_config, trading_config, path_config)
|
||||
bot.run()
|
||||
except ValueError as e:
|
||||
logger.error(f"Configuration error: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
145
live_trading/multi_pair/README.md
Normal file
145
live_trading/multi_pair/README.md
Normal file
@@ -0,0 +1,145 @@
|
||||
# Multi-Pair Divergence Live Trading
|
||||
|
||||
This module implements live trading for the Multi-Pair Divergence Selection Strategy on OKX perpetual futures.
|
||||
|
||||
## Overview
|
||||
|
||||
The strategy scans 10 cryptocurrency pairs for spread divergence opportunities:
|
||||
|
||||
1. **Pair Universe**: Top 10 assets by market cap (BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT)
|
||||
2. **Spread Z-Score**: Identifies when pairs are divergent from their historical mean
|
||||
3. **Universal ML Model**: Predicts probability of successful mean reversion
|
||||
4. **Dynamic Selection**: Trades the pair with highest divergence score
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before running live trading, you must train the model via backtesting:
|
||||
|
||||
```bash
|
||||
uv run python scripts/run_multi_pair_backtest.py
|
||||
```
|
||||
|
||||
This creates `data/multi_pair_model.pkl` which the live trading bot requires.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. API Keys
|
||||
|
||||
Same as single-pair trading. Set in `.env`:
|
||||
|
||||
```env
|
||||
OKX_API_KEY=your_api_key
|
||||
OKX_SECRET=your_secret
|
||||
OKX_PASSWORD=your_passphrase
|
||||
OKX_DEMO_MODE=true # Use demo for testing
|
||||
```
|
||||
|
||||
### 2. Dependencies
|
||||
|
||||
All dependencies are in `pyproject.toml`. No additional installation needed.
|
||||
|
||||
## Usage
|
||||
|
||||
### Run with Demo Account (Recommended First)
|
||||
|
||||
```bash
|
||||
uv run python -m live_trading.multi_pair.main
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
```bash
|
||||
# Custom position size
|
||||
uv run python -m live_trading.multi_pair.main --max-position 500
|
||||
|
||||
# Custom leverage
|
||||
uv run python -m live_trading.multi_pair.main --leverage 2
|
||||
|
||||
# Custom cycle interval (in seconds)
|
||||
uv run python -m live_trading.multi_pair.main --interval 1800
|
||||
|
||||
# Combine options
|
||||
uv run python -m live_trading.multi_pair.main --max-position 1000 --leverage 3 --interval 3600
|
||||
```
|
||||
|
||||
### Live Trading (Use with Caution)
|
||||
|
||||
```bash
|
||||
uv run python -m live_trading.multi_pair.main --live
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Each Trading Cycle
|
||||
|
||||
1. **Fetch Data**: Gets OHLCV for all 10 assets from OKX
|
||||
2. **Calculate Features**: Computes Z-Score, RSI, volatility for all 45 pair combinations
|
||||
3. **Score Pairs**: Uses ML model to rank pairs by divergence score (|Z| x probability)
|
||||
4. **Check Exits**: If holding, check mean reversion or SL/TP
|
||||
5. **Enter Best**: If no position, enter the highest-scoring divergent pair
|
||||
|
||||
### Entry Conditions
|
||||
|
||||
- |Z-Score| > 1.0 (spread diverged from mean)
|
||||
- ML probability > 0.5 (model predicts successful reversion)
|
||||
- Funding rate filter passes (avoid crowded trades)
|
||||
|
||||
### Exit Conditions
|
||||
|
||||
- Mean reversion: |Z-Score| returns to ~0
|
||||
- Stop-loss: ATR-based (default ~6%)
|
||||
- Take-profit: ATR-based (default ~5%)
|
||||
|
||||
## Strategy Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `z_entry_threshold` | 1.0 | Enter when \|Z-Score\| > threshold |
|
||||
| `z_exit_threshold` | 0.0 | Exit when Z reverts to mean |
|
||||
| `z_window` | 24 | Rolling window for Z-Score (hours) |
|
||||
| `prob_threshold` | 0.5 | ML probability threshold for entry |
|
||||
| `funding_threshold` | 0.0005 | Funding rate filter (0.05%) |
|
||||
| `sl_atr_multiplier` | 10.0 | Stop-loss as ATR multiple |
|
||||
| `tp_atr_multiplier` | 8.0 | Take-profit as ATR multiple |
|
||||
|
||||
## Files
|
||||
|
||||
### Input
|
||||
|
||||
- `data/multi_pair_model.pkl` - Pre-trained ML model (required)
|
||||
|
||||
### Output
|
||||
|
||||
- `logs/multi_pair_live.log` - Trading logs
|
||||
- `live_trading/multi_pair_positions.json` - Position persistence
|
||||
- `live_trading/multi_pair_trade_log.csv` - Trade history
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
live_trading/multi_pair/
|
||||
__init__.py # Module exports
|
||||
config.py # Configuration classes
|
||||
data_feed.py # Multi-asset OHLCV fetcher
|
||||
strategy.py # ML scoring and signal generation
|
||||
main.py # Bot orchestrator
|
||||
README.md # This file
|
||||
```
|
||||
|
||||
## Differences from Single-Pair
|
||||
|
||||
| Aspect | Single-Pair | Multi-Pair |
|
||||
|--------|-------------|------------|
|
||||
| Assets | ETH only (BTC context) | 10 assets, 45 pairs |
|
||||
| Model | ETH-specific | Universal across pairs |
|
||||
| Selection | Fixed pair | Dynamic best pair |
|
||||
| Stops | Fixed 6%/5% | ATR-based dynamic |
|
||||
|
||||
## Risk Warning
|
||||
|
||||
This is experimental trading software. Use at your own risk:
|
||||
|
||||
- Always start with demo trading
|
||||
- Never risk more than you can afford to lose
|
||||
- Monitor the bot regularly
|
||||
- The model was trained on historical data and may not predict future performance
|
||||
11
live_trading/multi_pair/__init__.py
Normal file
11
live_trading/multi_pair/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Multi-Pair Divergence Live Trading Module."""
|
||||
from .config import MultiPairLiveConfig, get_multi_pair_config
|
||||
from .data_feed import MultiPairDataFeed
|
||||
from .strategy import LiveMultiPairStrategy
|
||||
|
||||
__all__ = [
|
||||
"MultiPairLiveConfig",
|
||||
"get_multi_pair_config",
|
||||
"MultiPairDataFeed",
|
||||
"LiveMultiPairStrategy",
|
||||
]
|
||||
145
live_trading/multi_pair/config.py
Normal file
145
live_trading/multi_pair/config.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Configuration for Multi-Pair Live Trading.
|
||||
|
||||
Extends the base live trading config with multi-pair specific settings.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OKXConfig:
|
||||
"""OKX API configuration."""
|
||||
api_key: str = field(default_factory=lambda: "")
|
||||
secret: str = field(default_factory=lambda: "")
|
||||
password: str = field(default_factory=lambda: "")
|
||||
demo_mode: bool = field(default_factory=lambda: True)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Load credentials based on demo mode setting."""
|
||||
self.demo_mode = os.getenv("OKX_DEMO_MODE", "true").lower() in ("true", "1", "yes")
|
||||
|
||||
if self.demo_mode:
|
||||
self.api_key = os.getenv("OKX_DEMO_API_KEY", os.getenv("OKX_API_KEY", ""))
|
||||
self.secret = os.getenv("OKX_DEMO_SECRET", os.getenv("OKX_SECRET", ""))
|
||||
self.password = os.getenv("OKX_DEMO_PASSWORD", os.getenv("OKX_PASSWORD", ""))
|
||||
else:
|
||||
self.api_key = os.getenv("OKX_API_KEY", "")
|
||||
self.secret = os.getenv("OKX_SECRET", "")
|
||||
self.password = os.getenv("OKX_PASSWORD", "")
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate that required credentials are present."""
|
||||
mode = "demo" if self.demo_mode else "live"
|
||||
if not self.api_key:
|
||||
raise ValueError(f"OKX API key not set for {mode} mode")
|
||||
if not self.secret:
|
||||
raise ValueError(f"OKX secret not set for {mode} mode")
|
||||
if not self.password:
|
||||
raise ValueError(f"OKX password not set for {mode} mode")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiPairLiveConfig:
|
||||
"""
|
||||
Configuration for multi-pair live trading.
|
||||
|
||||
Combines trading parameters, strategy settings, and risk management.
|
||||
"""
|
||||
# Asset Universe (top 10 by market cap perpetuals)
|
||||
assets: list[str] = field(default_factory=lambda: [
|
||||
"BTC/USDT:USDT", "ETH/USDT:USDT", "SOL/USDT:USDT", "XRP/USDT:USDT",
|
||||
"BNB/USDT:USDT", "DOGE/USDT:USDT", "ADA/USDT:USDT", "AVAX/USDT:USDT",
|
||||
"LINK/USDT:USDT", "DOT/USDT:USDT"
|
||||
])
|
||||
|
||||
# Timeframe
|
||||
timeframe: str = "1h"
|
||||
candles_to_fetch: int = 500 # Enough for feature calculation
|
||||
|
||||
# Z-Score Thresholds
|
||||
z_window: int = 24
|
||||
z_entry_threshold: float = 1.0
|
||||
z_exit_threshold: float = 0.0 # Exit at mean reversion
|
||||
|
||||
# ML Thresholds
|
||||
prob_threshold: float = 0.5
|
||||
|
||||
# Position sizing
|
||||
max_position_usdt: float = -1.0 # If <= 0, use all available funds
|
||||
min_position_usdt: float = 10.0
|
||||
leverage: int = 1
|
||||
margin_mode: str = "cross"
|
||||
max_concurrent_positions: int = 1 # Trade one pair at a time
|
||||
|
||||
# Risk Management - ATR-Based Stops
|
||||
atr_period: int = 14
|
||||
sl_atr_multiplier: float = 10.0
|
||||
tp_atr_multiplier: float = 8.0
|
||||
|
||||
# Fallback fixed percentages
|
||||
base_sl_pct: float = 0.06
|
||||
base_tp_pct: float = 0.05
|
||||
|
||||
# ATR bounds
|
||||
min_sl_pct: float = 0.02
|
||||
max_sl_pct: float = 0.10
|
||||
min_tp_pct: float = 0.02
|
||||
max_tp_pct: float = 0.15
|
||||
|
||||
# Funding Rate Filter
|
||||
funding_threshold: float = 0.0005 # 0.05%
|
||||
|
||||
# Trade Management
|
||||
min_hold_bars: int = 0
|
||||
cooldown_bars: int = 0
|
||||
|
||||
# Execution
|
||||
sleep_seconds: int = 3600 # Run every hour
|
||||
slippage_pct: float = 0.001
|
||||
|
||||
def get_asset_short_name(self, symbol: str) -> str:
|
||||
"""Convert symbol to short name (e.g., BTC/USDT:USDT -> btc)."""
|
||||
return symbol.split("/")[0].lower()
|
||||
|
||||
def get_pair_count(self) -> int:
|
||||
"""Calculate number of unique pairs from asset list."""
|
||||
n = len(self.assets)
|
||||
return n * (n - 1) // 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathConfig:
|
||||
"""File paths configuration."""
|
||||
base_dir: Path = field(
|
||||
default_factory=lambda: Path(__file__).parent.parent.parent
|
||||
)
|
||||
data_dir: Path = field(default=None)
|
||||
logs_dir: Path = field(default=None)
|
||||
model_path: Path = field(default=None)
|
||||
positions_file: Path = field(default=None)
|
||||
trade_log_file: Path = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
self.data_dir = self.base_dir / "data"
|
||||
self.logs_dir = self.base_dir / "logs"
|
||||
# Use the same model as backtesting
|
||||
self.model_path = self.base_dir / "data" / "multi_pair_model.pkl"
|
||||
self.positions_file = self.base_dir / "live_trading" / "multi_pair_positions.json"
|
||||
self.trade_log_file = self.base_dir / "live_trading" / "multi_pair_trade_log.csv"
|
||||
|
||||
# Ensure directories exist
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_multi_pair_config() -> tuple[OKXConfig, MultiPairLiveConfig, PathConfig]:
|
||||
"""Get all configuration objects for multi-pair trading."""
|
||||
okx = OKXConfig()
|
||||
trading = MultiPairLiveConfig()
|
||||
paths = PathConfig()
|
||||
return okx, trading, paths
|
||||
336
live_trading/multi_pair/data_feed.py
Normal file
336
live_trading/multi_pair/data_feed.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
Multi-Pair Data Feed for Live Trading.
|
||||
|
||||
Fetches real-time OHLCV and funding data for all assets in the universe.
|
||||
"""
|
||||
import logging
|
||||
from itertools import combinations
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import ta
|
||||
|
||||
from live_trading.okx_client import OKXClient
|
||||
from .config import MultiPairLiveConfig, PathConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TradingPair:
|
||||
"""
|
||||
Represents a tradeable pair for spread analysis.
|
||||
|
||||
Attributes:
|
||||
base_asset: First asset symbol (e.g., ETH/USDT:USDT)
|
||||
quote_asset: Second asset symbol (e.g., BTC/USDT:USDT)
|
||||
pair_id: Unique identifier
|
||||
"""
|
||||
def __init__(self, base_asset: str, quote_asset: str):
|
||||
self.base_asset = base_asset
|
||||
self.quote_asset = quote_asset
|
||||
self.pair_id = f"{base_asset}__{quote_asset}"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable pair name."""
|
||||
base = self.base_asset.split("/")[0]
|
||||
quote = self.quote_asset.split("/")[0]
|
||||
return f"{base}/{quote}"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.pair_id)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, TradingPair):
|
||||
return False
|
||||
return self.pair_id == other.pair_id
|
||||
|
||||
|
||||
class MultiPairDataFeed:
|
||||
"""
|
||||
Real-time data feed for multi-pair strategy.
|
||||
|
||||
Fetches OHLCV data for all assets and calculates spread features
|
||||
for all pair combinations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
okx_client: OKXClient,
|
||||
config: MultiPairLiveConfig,
|
||||
path_config: PathConfig
|
||||
):
|
||||
self.client = okx_client
|
||||
self.config = config
|
||||
self.paths = path_config
|
||||
|
||||
# Cache for asset data
|
||||
self._asset_data: dict[str, pd.DataFrame] = {}
|
||||
self._funding_rates: dict[str, float] = {}
|
||||
self._pairs: list[TradingPair] = []
|
||||
|
||||
# Generate pairs
|
||||
self._generate_pairs()
|
||||
|
||||
def _generate_pairs(self) -> None:
|
||||
"""Generate all unique pairs from asset universe."""
|
||||
self._pairs = []
|
||||
for base, quote in combinations(self.config.assets, 2):
|
||||
pair = TradingPair(base_asset=base, quote_asset=quote)
|
||||
self._pairs.append(pair)
|
||||
|
||||
logger.info("Generated %d pairs from %d assets",
|
||||
len(self._pairs), len(self.config.assets))
|
||||
|
||||
@property
|
||||
def pairs(self) -> list[TradingPair]:
|
||||
"""Get list of trading pairs."""
|
||||
return self._pairs
|
||||
|
||||
def fetch_all_ohlcv(self) -> dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Fetch OHLCV data for all assets.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to OHLCV DataFrame
|
||||
"""
|
||||
self._asset_data = {}
|
||||
|
||||
for symbol in self.config.assets:
|
||||
try:
|
||||
ohlcv = self.client.fetch_ohlcv(
|
||||
symbol,
|
||||
self.config.timeframe,
|
||||
self.config.candles_to_fetch
|
||||
)
|
||||
df = self._ohlcv_to_dataframe(ohlcv)
|
||||
|
||||
if len(df) >= 200:
|
||||
self._asset_data[symbol] = df
|
||||
logger.debug("Fetched %s: %d candles", symbol, len(df))
|
||||
else:
|
||||
logger.warning("Skipping %s: insufficient data (%d)",
|
||||
symbol, len(df))
|
||||
except Exception as e:
|
||||
logger.error("Error fetching %s: %s", symbol, e)
|
||||
|
||||
logger.info("Fetched data for %d/%d assets",
|
||||
len(self._asset_data), len(self.config.assets))
|
||||
return self._asset_data
|
||||
|
||||
def _ohlcv_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
|
||||
"""Convert OHLCV list to DataFrame."""
|
||||
df = pd.DataFrame(
|
||||
ohlcv,
|
||||
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
return df
|
||||
|
||||
def fetch_all_funding_rates(self) -> dict[str, float]:
|
||||
"""
|
||||
Fetch current funding rates for all assets.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to funding rate
|
||||
"""
|
||||
self._funding_rates = {}
|
||||
|
||||
for symbol in self.config.assets:
|
||||
try:
|
||||
rate = self.client.get_funding_rate(symbol)
|
||||
self._funding_rates[symbol] = rate
|
||||
except Exception as e:
|
||||
logger.warning("Could not get funding for %s: %s", symbol, e)
|
||||
self._funding_rates[symbol] = 0.0
|
||||
|
||||
return self._funding_rates
|
||||
|
||||
def calculate_pair_features(
|
||||
self,
|
||||
pair: TradingPair
|
||||
) -> pd.DataFrame | None:
|
||||
"""
|
||||
Calculate features for a single pair.
|
||||
|
||||
Args:
|
||||
pair: Trading pair
|
||||
|
||||
Returns:
|
||||
DataFrame with features, or None if insufficient data
|
||||
"""
|
||||
base = pair.base_asset
|
||||
quote = pair.quote_asset
|
||||
|
||||
if base not in self._asset_data or quote not in self._asset_data:
|
||||
return None
|
||||
|
||||
df_base = self._asset_data[base]
|
||||
df_quote = self._asset_data[quote]
|
||||
|
||||
# Align indices
|
||||
common_idx = df_base.index.intersection(df_quote.index)
|
||||
if len(common_idx) < 200:
|
||||
return None
|
||||
|
||||
df_a = df_base.loc[common_idx]
|
||||
df_b = df_quote.loc[common_idx]
|
||||
|
||||
# Calculate spread (base / quote)
|
||||
spread = df_a['close'] / df_b['close']
|
||||
|
||||
# Z-Score
|
||||
z_window = self.config.z_window
|
||||
rolling_mean = spread.rolling(window=z_window).mean()
|
||||
rolling_std = spread.rolling(window=z_window).std()
|
||||
z_score = (spread - rolling_mean) / rolling_std
|
||||
|
||||
# Spread Technicals
|
||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||
spread_roc = spread.pct_change(periods=5) * 100
|
||||
spread_change_1h = spread.pct_change(periods=1)
|
||||
|
||||
# Volume Analysis
|
||||
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
|
||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
|
||||
|
||||
# Volatility
|
||||
ret_a = df_a['close'].pct_change()
|
||||
ret_b = df_b['close'].pct_change()
|
||||
vol_a = ret_a.rolling(window=z_window).std()
|
||||
vol_b = ret_b.rolling(window=z_window).std()
|
||||
vol_spread_ratio = vol_a / (vol_b + 1e-10)
|
||||
|
||||
# Realized Volatility
|
||||
realized_vol_a = ret_a.rolling(window=24).std()
|
||||
realized_vol_b = ret_b.rolling(window=24).std()
|
||||
|
||||
# ATR (Average True Range)
|
||||
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
|
||||
|
||||
tr_a = pd.concat([
|
||||
high_a - low_a,
|
||||
(high_a - close_a.shift(1)).abs(),
|
||||
(low_a - close_a.shift(1)).abs()
|
||||
], axis=1).max(axis=1)
|
||||
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
|
||||
atr_pct_a = atr_a / close_a
|
||||
|
||||
# Build feature DataFrame
|
||||
features = pd.DataFrame(index=common_idx)
|
||||
features['pair_id'] = pair.pair_id
|
||||
features['base_asset'] = base
|
||||
features['quote_asset'] = quote
|
||||
|
||||
# Price data
|
||||
features['spread'] = spread
|
||||
features['base_close'] = df_a['close']
|
||||
features['quote_close'] = df_b['close']
|
||||
features['base_volume'] = df_a['volume']
|
||||
|
||||
# Core Features
|
||||
features['z_score'] = z_score
|
||||
features['spread_rsi'] = spread_rsi
|
||||
features['spread_roc'] = spread_roc
|
||||
features['spread_change_1h'] = spread_change_1h
|
||||
features['vol_ratio'] = vol_ratio
|
||||
features['vol_ratio_rel'] = vol_ratio_rel
|
||||
features['vol_diff_ratio'] = vol_spread_ratio
|
||||
|
||||
# Volatility
|
||||
features['realized_vol_base'] = realized_vol_a
|
||||
features['realized_vol_quote'] = realized_vol_b
|
||||
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
|
||||
|
||||
# ATR
|
||||
features['atr_base'] = atr_a
|
||||
features['atr_pct_base'] = atr_pct_a
|
||||
|
||||
# Pair encoding
|
||||
assets = self.config.assets
|
||||
features['base_idx'] = assets.index(base) if base in assets else -1
|
||||
features['quote_idx'] = assets.index(quote) if quote in assets else -1
|
||||
|
||||
# Funding rates
|
||||
base_funding = self._funding_rates.get(base, 0.0)
|
||||
quote_funding = self._funding_rates.get(quote, 0.0)
|
||||
features['base_funding'] = base_funding
|
||||
features['quote_funding'] = quote_funding
|
||||
features['funding_diff'] = base_funding - quote_funding
|
||||
features['funding_avg'] = (base_funding + quote_funding) / 2
|
||||
|
||||
# Drop NaN rows in core features
|
||||
core_cols = [
|
||||
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
|
||||
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
|
||||
'realized_vol_base', 'atr_base', 'atr_pct_base'
|
||||
]
|
||||
features = features.dropna(subset=core_cols)
|
||||
|
||||
return features
|
||||
|
||||
def calculate_all_pair_features(self) -> dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Calculate features for all pairs.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping pair_id to feature DataFrame
|
||||
"""
|
||||
all_features = {}
|
||||
|
||||
for pair in self._pairs:
|
||||
features = self.calculate_pair_features(pair)
|
||||
if features is not None and len(features) > 0:
|
||||
all_features[pair.pair_id] = features
|
||||
|
||||
logger.info("Calculated features for %d/%d pairs",
|
||||
len(all_features), len(self._pairs))
|
||||
|
||||
return all_features
|
||||
|
||||
def get_latest_data(self) -> dict[str, pd.DataFrame] | None:
|
||||
"""
|
||||
Fetch and process latest market data for all pairs.
|
||||
|
||||
Returns:
|
||||
Dictionary of pair features or None on error
|
||||
"""
|
||||
try:
|
||||
# Fetch OHLCV for all assets
|
||||
self.fetch_all_ohlcv()
|
||||
|
||||
if len(self._asset_data) < 2:
|
||||
logger.warning("Insufficient assets fetched")
|
||||
return None
|
||||
|
||||
# Fetch funding rates
|
||||
self.fetch_all_funding_rates()
|
||||
|
||||
# Calculate features for all pairs
|
||||
pair_features = self.calculate_all_pair_features()
|
||||
|
||||
if not pair_features:
|
||||
logger.warning("No pair features calculated")
|
||||
return None
|
||||
|
||||
logger.info("Processed %d pairs with valid features", len(pair_features))
|
||||
return pair_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error fetching market data: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
def get_pair_by_id(self, pair_id: str) -> TradingPair | None:
|
||||
"""Get pair object by ID."""
|
||||
for pair in self._pairs:
|
||||
if pair.pair_id == pair_id:
|
||||
return pair
|
||||
return None
|
||||
|
||||
def get_current_price(self, symbol: str) -> float | None:
|
||||
"""Get current price for a symbol."""
|
||||
if symbol in self._asset_data:
|
||||
return self._asset_data[symbol]['close'].iloc[-1]
|
||||
return None
|
||||
609
live_trading/multi_pair/main.py
Normal file
609
live_trading/multi_pair/main.py
Normal file
@@ -0,0 +1,609 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Multi-Pair Divergence Live Trading Bot.
|
||||
|
||||
Trades the top 10 cryptocurrency pairs based on spread divergence
|
||||
using a universal ML model for signal generation.
|
||||
|
||||
Usage:
|
||||
# Run with demo account (default)
|
||||
uv run python -m live_trading.multi_pair.main
|
||||
|
||||
# Run with specific settings
|
||||
uv run python -m live_trading.multi_pair.main --max-position 500 --leverage 2
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from live_trading.okx_client import OKXClient
|
||||
from live_trading.position_manager import PositionManager
|
||||
from live_trading.multi_pair.config import (
|
||||
OKXConfig, MultiPairLiveConfig, PathConfig, get_multi_pair_config
|
||||
)
|
||||
from live_trading.multi_pair.data_feed import MultiPairDataFeed, TradingPair
|
||||
from live_trading.multi_pair.strategy import LiveMultiPairStrategy
|
||||
|
||||
|
||||
def setup_logging(log_dir: Path) -> logging.Logger:
|
||||
"""Configure logging for the trading bot."""
|
||||
log_file = log_dir / "multi_pair_live.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_file),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
force=True
|
||||
)
|
||||
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PositionState:
|
||||
"""Track current position state for multi-pair."""
|
||||
pair: TradingPair | None = None
|
||||
pair_id: str | None = None
|
||||
direction: str | None = None
|
||||
entry_price: float = 0.0
|
||||
size: float = 0.0
|
||||
stop_loss: float = 0.0
|
||||
take_profit: float = 0.0
|
||||
entry_time: datetime | None = None
|
||||
|
||||
|
||||
class MultiPairLiveTradingBot:
|
||||
"""
|
||||
Main trading bot for multi-pair divergence strategy.
|
||||
|
||||
Coordinates data fetching, pair scoring, and order execution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
okx_config: OKXConfig,
|
||||
trading_config: MultiPairLiveConfig,
|
||||
path_config: PathConfig
|
||||
):
|
||||
self.okx_config = okx_config
|
||||
self.trading_config = trading_config
|
||||
self.path_config = path_config
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.running = True
|
||||
|
||||
# Initialize components
|
||||
self.logger.info("Initializing multi-pair trading bot...")
|
||||
|
||||
# Create OKX client with adapted config
|
||||
self._adapted_trading_config = self._adapt_config_for_okx_client()
|
||||
self.okx_client = OKXClient(okx_config, self._adapted_trading_config)
|
||||
|
||||
# Initialize data feed
|
||||
self.data_feed = MultiPairDataFeed(
|
||||
self.okx_client, trading_config, path_config
|
||||
)
|
||||
|
||||
# Initialize position manager (reuse from single-pair)
|
||||
self.position_manager = PositionManager(
|
||||
self.okx_client, self._adapted_trading_config, path_config
|
||||
)
|
||||
|
||||
# Initialize strategy
|
||||
self.strategy = LiveMultiPairStrategy(trading_config, path_config)
|
||||
|
||||
# Current position state
|
||||
self.position = PositionState()
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, self._handle_shutdown)
|
||||
signal.signal(signal.SIGTERM, self._handle_shutdown)
|
||||
|
||||
self._print_startup_banner()
|
||||
|
||||
# Sync with exchange positions on startup
|
||||
self._sync_position_from_exchange()
|
||||
|
||||
def _adapt_config_for_okx_client(self):
|
||||
"""Create config compatible with OKXClient."""
|
||||
# OKXClient expects specific attributes
|
||||
@dataclass
|
||||
class AdaptedConfig:
|
||||
eth_symbol: str = "ETH/USDT:USDT"
|
||||
btc_symbol: str = "BTC/USDT:USDT"
|
||||
timeframe: str = "1h"
|
||||
candles_to_fetch: int = 500
|
||||
max_position_usdt: float = -1.0
|
||||
min_position_usdt: float = 10.0
|
||||
leverage: int = 1
|
||||
margin_mode: str = "cross"
|
||||
stop_loss_pct: float = 0.06
|
||||
take_profit_pct: float = 0.05
|
||||
max_concurrent_positions: int = 1
|
||||
z_entry_threshold: float = 1.0
|
||||
z_window: int = 24
|
||||
model_prob_threshold: float = 0.5
|
||||
funding_threshold: float = 0.0005
|
||||
sleep_seconds: int = 3600
|
||||
slippage_pct: float = 0.001
|
||||
|
||||
adapted = AdaptedConfig()
|
||||
adapted.timeframe = self.trading_config.timeframe
|
||||
adapted.candles_to_fetch = self.trading_config.candles_to_fetch
|
||||
adapted.max_position_usdt = self.trading_config.max_position_usdt
|
||||
adapted.min_position_usdt = self.trading_config.min_position_usdt
|
||||
adapted.leverage = self.trading_config.leverage
|
||||
adapted.margin_mode = self.trading_config.margin_mode
|
||||
adapted.max_concurrent_positions = self.trading_config.max_concurrent_positions
|
||||
adapted.sleep_seconds = self.trading_config.sleep_seconds
|
||||
adapted.slippage_pct = self.trading_config.slippage_pct
|
||||
|
||||
return adapted
|
||||
|
||||
def _print_startup_banner(self) -> None:
|
||||
"""Print startup information."""
|
||||
mode = "DEMO/SANDBOX" if self.okx_config.demo_mode else "LIVE"
|
||||
|
||||
print("=" * 60)
|
||||
print(" Multi-Pair Divergence Strategy - Live Trading Bot")
|
||||
print("=" * 60)
|
||||
print(f" Mode: {mode}")
|
||||
print(f" Assets: {len(self.trading_config.assets)} assets")
|
||||
print(f" Pairs: {self.trading_config.get_pair_count()} pairs")
|
||||
print(f" Timeframe: {self.trading_config.timeframe}")
|
||||
print(f" Max Position: ${self.trading_config.max_position_usdt if self.trading_config.max_position_usdt > 0 else 'All available'}")
|
||||
print(f" Leverage: {self.trading_config.leverage}x")
|
||||
print(f" Z-Entry: > {self.trading_config.z_entry_threshold}")
|
||||
print(f" Prob Threshold: > {self.trading_config.prob_threshold}")
|
||||
print(f" Cycle Interval: {self.trading_config.sleep_seconds // 60} minutes")
|
||||
print("=" * 60)
|
||||
print(f" Assets: {', '.join([a.split('/')[0] for a in self.trading_config.assets])}")
|
||||
print("=" * 60)
|
||||
|
||||
if not self.okx_config.demo_mode:
|
||||
print("\n *** WARNING: LIVE TRADING MODE - REAL FUNDS AT RISK ***\n")
|
||||
|
||||
def _handle_shutdown(self, signum, frame) -> None:
|
||||
"""Handle shutdown signals gracefully."""
|
||||
self.logger.info("Shutdown signal received, stopping...")
|
||||
self.running = False
|
||||
|
||||
def _sync_position_from_exchange(self) -> bool:
|
||||
"""
|
||||
Sync internal position state with exchange positions.
|
||||
|
||||
Checks for existing open positions on the exchange and updates
|
||||
internal state to match. This prevents stacking positions when
|
||||
the bot is restarted.
|
||||
|
||||
Returns:
|
||||
True if a position was synced, False otherwise
|
||||
"""
|
||||
try:
|
||||
positions = self.okx_client.get_positions()
|
||||
|
||||
if not positions:
|
||||
if self.position.pair is not None:
|
||||
# Position was closed externally (e.g., SL/TP hit)
|
||||
self.logger.info(
|
||||
"Position %s was closed externally, resetting state",
|
||||
self.position.pair.name if self.position.pair else "unknown"
|
||||
)
|
||||
self.position = PositionState()
|
||||
return False
|
||||
|
||||
# Check each position against our tradeable assets
|
||||
our_assets = set(self.trading_config.assets)
|
||||
|
||||
for pos in positions:
|
||||
pos_symbol = pos.get('symbol', '')
|
||||
contracts = abs(float(pos.get('contracts', 0)))
|
||||
|
||||
if contracts == 0:
|
||||
continue
|
||||
|
||||
# Check if this position is for one of our assets
|
||||
if pos_symbol not in our_assets:
|
||||
continue
|
||||
|
||||
# Found a position for one of our assets
|
||||
side = pos.get('side', 'long')
|
||||
entry_price = float(pos.get('entryPrice', 0))
|
||||
unrealized_pnl = float(pos.get('unrealizedPnl', 0))
|
||||
|
||||
# If we already track this position, just update
|
||||
if (self.position.pair is not None and
|
||||
self.position.pair.base_asset == pos_symbol):
|
||||
self.logger.debug(
|
||||
"Position already tracked: %s %s %.2f contracts",
|
||||
side, pos_symbol, contracts
|
||||
)
|
||||
return True
|
||||
|
||||
# New position found - sync it
|
||||
# Find or create a TradingPair for this position
|
||||
matched_pair = None
|
||||
for pair in self.data_feed.pairs:
|
||||
if pair.base_asset == pos_symbol:
|
||||
matched_pair = pair
|
||||
break
|
||||
|
||||
if matched_pair is None:
|
||||
# Create a placeholder pair (we don't know the quote asset)
|
||||
matched_pair = TradingPair(
|
||||
base_asset=pos_symbol,
|
||||
quote_asset="UNKNOWN"
|
||||
)
|
||||
|
||||
# Calculate approximate SL/TP based on config defaults
|
||||
sl_pct = self.trading_config.base_sl_pct
|
||||
tp_pct = self.trading_config.base_tp_pct
|
||||
|
||||
if side == 'long':
|
||||
stop_loss = entry_price * (1 - sl_pct)
|
||||
take_profit = entry_price * (1 + tp_pct)
|
||||
else:
|
||||
stop_loss = entry_price * (1 + sl_pct)
|
||||
take_profit = entry_price * (1 - tp_pct)
|
||||
|
||||
self.position = PositionState(
|
||||
pair=matched_pair,
|
||||
pair_id=matched_pair.pair_id,
|
||||
direction=side,
|
||||
entry_price=entry_price,
|
||||
size=contracts,
|
||||
stop_loss=stop_loss,
|
||||
take_profit=take_profit,
|
||||
entry_time=None # Unknown for synced positions
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"Synced existing position from exchange: %s %s %.4f @ %.4f (PnL: %.2f)",
|
||||
side.upper(),
|
||||
pos_symbol,
|
||||
contracts,
|
||||
entry_price,
|
||||
unrealized_pnl
|
||||
)
|
||||
return True
|
||||
|
||||
# No matching positions found
|
||||
if self.position.pair is not None:
|
||||
self.logger.info(
|
||||
"Position %s no longer exists on exchange, resetting state",
|
||||
self.position.pair.name
|
||||
)
|
||||
self.position = PositionState()
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to sync position from exchange: %s", e)
|
||||
return False
|
||||
|
||||
def run_trading_cycle(self) -> None:
|
||||
"""
|
||||
Execute one trading cycle.
|
||||
|
||||
1. Sync position state with exchange
|
||||
2. Fetch latest market data for all assets
|
||||
3. Calculate features for all pairs
|
||||
4. Score pairs and find best opportunity
|
||||
5. Check exit conditions for current position
|
||||
6. Execute trades if needed
|
||||
"""
|
||||
cycle_start = datetime.now(timezone.utc)
|
||||
self.logger.info("--- Trading Cycle Start: %s ---", cycle_start.isoformat())
|
||||
|
||||
try:
|
||||
# 1. Sync position state with exchange (detect SL/TP closures)
|
||||
self._sync_position_from_exchange()
|
||||
|
||||
# 2. Fetch all market data
|
||||
pair_features = self.data_feed.get_latest_data()
|
||||
if pair_features is None:
|
||||
self.logger.warning("No market data available, skipping cycle")
|
||||
return
|
||||
|
||||
# 2. Check exit conditions for current position
|
||||
if self.position.pair is not None:
|
||||
exit_signal = self.strategy.check_exit_signal(
|
||||
pair_features,
|
||||
self.position.pair_id
|
||||
)
|
||||
|
||||
if exit_signal['action'] == 'exit':
|
||||
self._execute_exit(exit_signal)
|
||||
else:
|
||||
# Check SL/TP
|
||||
current_price = self.data_feed.get_current_price(
|
||||
self.position.pair.base_asset
|
||||
)
|
||||
if current_price:
|
||||
sl_tp_exit = self._check_sl_tp(current_price)
|
||||
if sl_tp_exit:
|
||||
self._execute_exit({'reason': sl_tp_exit})
|
||||
|
||||
# 3. Generate entry signal if no position
|
||||
if self.position.pair is None:
|
||||
entry_signal = self.strategy.generate_signal(
|
||||
pair_features,
|
||||
self.data_feed.pairs
|
||||
)
|
||||
|
||||
if entry_signal['action'] == 'entry':
|
||||
self._execute_entry(entry_signal)
|
||||
|
||||
# 4. Log status
|
||||
if self.position.pair:
|
||||
self.logger.info(
|
||||
"Position: %s %s, entry=%.4f, current PnL check pending",
|
||||
self.position.direction,
|
||||
self.position.pair.name,
|
||||
self.position.entry_price
|
||||
)
|
||||
else:
|
||||
self.logger.info("No open position")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Trading cycle error: %s", e, exc_info=True)
|
||||
|
||||
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
|
||||
self.logger.info("--- Cycle completed in %.1fs ---", cycle_duration)
|
||||
|
||||
def _check_sl_tp(self, current_price: float) -> str | None:
|
||||
"""Check stop-loss and take-profit levels."""
|
||||
if self.position.direction == 'long':
|
||||
if current_price <= self.position.stop_loss:
|
||||
return f"stop_loss ({current_price:.4f} <= {self.position.stop_loss:.4f})"
|
||||
if current_price >= self.position.take_profit:
|
||||
return f"take_profit ({current_price:.4f} >= {self.position.take_profit:.4f})"
|
||||
else: # short
|
||||
if current_price >= self.position.stop_loss:
|
||||
return f"stop_loss ({current_price:.4f} >= {self.position.stop_loss:.4f})"
|
||||
if current_price <= self.position.take_profit:
|
||||
return f"take_profit ({current_price:.4f} <= {self.position.take_profit:.4f})"
|
||||
return None
|
||||
|
||||
def _execute_entry(self, signal: dict) -> None:
|
||||
"""Execute entry trade."""
|
||||
pair = signal['pair']
|
||||
symbol = pair.base_asset # Trade the base asset
|
||||
direction = signal['direction']
|
||||
|
||||
self.logger.info(
|
||||
"Entry signal: %s %s (z=%.2f, p=%.2f, score=%.3f)",
|
||||
direction.upper(),
|
||||
pair.name,
|
||||
signal['z_score'],
|
||||
signal['probability'],
|
||||
signal['divergence_score']
|
||||
)
|
||||
|
||||
# Get account balance
|
||||
try:
|
||||
balance = self.okx_client.get_balance()
|
||||
available_usdt = balance['free']
|
||||
except Exception as e:
|
||||
self.logger.error("Could not get balance: %s", e)
|
||||
return
|
||||
|
||||
# Calculate position size
|
||||
size_usdt = self.strategy.calculate_position_size(
|
||||
signal['divergence_score'],
|
||||
available_usdt
|
||||
)
|
||||
|
||||
if size_usdt <= 0:
|
||||
self.logger.info("Position size too small, skipping entry")
|
||||
return
|
||||
|
||||
current_price = signal['base_price']
|
||||
size_asset = size_usdt / current_price
|
||||
|
||||
# Calculate SL/TP
|
||||
stop_loss, take_profit = self.strategy.calculate_sl_tp(
|
||||
current_price,
|
||||
direction,
|
||||
signal['atr'],
|
||||
signal['atr_pct']
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"Executing %s entry: %.6f %s @ %.4f ($%.2f), SL=%.4f, TP=%.4f",
|
||||
direction.upper(),
|
||||
size_asset,
|
||||
symbol.split('/')[0],
|
||||
current_price,
|
||||
size_usdt,
|
||||
stop_loss,
|
||||
take_profit
|
||||
)
|
||||
|
||||
try:
|
||||
# Place market order
|
||||
order_side = "buy" if direction == "long" else "sell"
|
||||
order = self.okx_client.place_market_order(symbol, order_side, size_asset)
|
||||
|
||||
filled_price = order.get('average') or order.get('price') or current_price
|
||||
filled_amount = order.get('filled') or order.get('amount') or size_asset
|
||||
|
||||
if filled_price is None or filled_price == 0:
|
||||
filled_price = current_price
|
||||
if filled_amount is None or filled_amount == 0:
|
||||
filled_amount = size_asset
|
||||
|
||||
# Recalculate SL/TP with filled price
|
||||
stop_loss, take_profit = self.strategy.calculate_sl_tp(
|
||||
filled_price, direction, signal['atr'], signal['atr_pct']
|
||||
)
|
||||
|
||||
# Update position state
|
||||
self.position = PositionState(
|
||||
pair=pair,
|
||||
pair_id=pair.pair_id,
|
||||
direction=direction,
|
||||
entry_price=filled_price,
|
||||
size=filled_amount,
|
||||
stop_loss=stop_loss,
|
||||
take_profit=take_profit,
|
||||
entry_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"Position opened: %s %s %.6f @ %.4f",
|
||||
direction.upper(),
|
||||
pair.name,
|
||||
filled_amount,
|
||||
filled_price
|
||||
)
|
||||
|
||||
# Try to set SL/TP on exchange
|
||||
try:
|
||||
self.okx_client.set_stop_loss_take_profit(
|
||||
symbol, direction, filled_amount, stop_loss, take_profit
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning("Could not set SL/TP on exchange: %s", e)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Order execution failed: %s", e, exc_info=True)
|
||||
|
||||
def _execute_exit(self, signal: dict) -> None:
|
||||
"""Execute exit trade."""
|
||||
if self.position.pair is None:
|
||||
return
|
||||
|
||||
symbol = self.position.pair.base_asset
|
||||
reason = signal.get('reason', 'unknown')
|
||||
|
||||
self.logger.info(
|
||||
"Exit signal: %s %s, reason: %s",
|
||||
self.position.direction,
|
||||
self.position.pair.name,
|
||||
reason
|
||||
)
|
||||
|
||||
try:
|
||||
# Close position on exchange
|
||||
self.okx_client.close_position(symbol)
|
||||
|
||||
self.logger.info(
|
||||
"Position closed: %s %s",
|
||||
self.position.direction,
|
||||
self.position.pair.name
|
||||
)
|
||||
|
||||
# Reset position state
|
||||
self.position = PositionState()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Exit execution failed: %s", e, exc_info=True)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Main trading loop."""
|
||||
self.logger.info("Starting multi-pair trading loop...")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self.run_trading_cycle()
|
||||
|
||||
if self.running:
|
||||
sleep_seconds = self.trading_config.sleep_seconds
|
||||
minutes = sleep_seconds // 60
|
||||
self.logger.info("Sleeping for %d minutes...", minutes)
|
||||
|
||||
for _ in range(sleep_seconds):
|
||||
if not self.running:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
self.logger.info("Keyboard interrupt received")
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error("Unexpected error in main loop: %s", e, exc_info=True)
|
||||
time.sleep(60)
|
||||
|
||||
self.logger.info("Shutting down...")
|
||||
self.logger.info("Shutdown complete")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Multi-Pair Divergence Live Trading Bot"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-position",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Maximum position size in USDT"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leverage",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Trading leverage (1-125)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interval",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Trading cycle interval in seconds"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--live",
|
||||
action="store_true",
|
||||
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
args = parse_args()
|
||||
|
||||
# Load configuration
|
||||
okx_config, trading_config, path_config = get_multi_pair_config()
|
||||
|
||||
# Apply command line overrides
|
||||
if args.max_position is not None:
|
||||
trading_config.max_position_usdt = args.max_position
|
||||
if args.leverage is not None:
|
||||
trading_config.leverage = args.leverage
|
||||
if args.interval is not None:
|
||||
trading_config.sleep_seconds = args.interval
|
||||
if args.live:
|
||||
okx_config.demo_mode = False
|
||||
|
||||
# Setup logging
|
||||
logger = setup_logging(path_config.logs_dir)
|
||||
|
||||
try:
|
||||
# Validate config
|
||||
okx_config.validate()
|
||||
|
||||
# Create and run bot
|
||||
bot = MultiPairLiveTradingBot(okx_config, trading_config, path_config)
|
||||
bot.run()
|
||||
except ValueError as e:
|
||||
logger.error("Configuration error: %s", e)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error("Fatal error: %s", e, exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
396
live_trading/multi_pair/strategy.py
Normal file
396
live_trading/multi_pair/strategy.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
Live Multi-Pair Divergence Strategy.
|
||||
|
||||
Scores all pairs and selects the best divergence opportunity for trading.
|
||||
Uses the pre-trained universal ML model from backtesting.
|
||||
"""
|
||||
import logging
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
# Opt-in to future pandas behavior to silence FutureWarning on fillna
|
||||
pd.set_option('future.no_silent_downcasting', True)
|
||||
|
||||
from .config import MultiPairLiveConfig, PathConfig
|
||||
from .data_feed import TradingPair
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DivergenceSignal:
|
||||
"""
|
||||
Signal for a divergent pair.
|
||||
|
||||
Attributes:
|
||||
pair: Trading pair
|
||||
z_score: Current Z-Score of the spread
|
||||
probability: ML model probability of profitable reversion
|
||||
divergence_score: Combined score (|z_score| * probability)
|
||||
direction: 'long' or 'short' (relative to base asset)
|
||||
base_price: Current price of base asset
|
||||
quote_price: Current price of quote asset
|
||||
atr: Average True Range in price units
|
||||
atr_pct: ATR as percentage of price
|
||||
"""
|
||||
pair: TradingPair
|
||||
z_score: float
|
||||
probability: float
|
||||
divergence_score: float
|
||||
direction: str
|
||||
base_price: float
|
||||
quote_price: float
|
||||
atr: float
|
||||
atr_pct: float
|
||||
base_funding: float = 0.0
|
||||
|
||||
|
||||
class LiveMultiPairStrategy:
|
||||
"""
|
||||
Live trading implementation of multi-pair divergence strategy.
|
||||
|
||||
Scores all pairs using the universal ML model and selects
|
||||
the best opportunity for mean-reversion trading.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MultiPairLiveConfig,
|
||||
path_config: PathConfig
|
||||
):
|
||||
self.config = config
|
||||
self.paths = path_config
|
||||
self.model: RandomForestClassifier | None = None
|
||||
self.feature_cols: list[str] | None = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Load pre-trained model from backtesting."""
|
||||
if self.paths.model_path.exists():
|
||||
try:
|
||||
with open(self.paths.model_path, 'rb') as f:
|
||||
saved = pickle.load(f)
|
||||
self.model = saved['model']
|
||||
self.feature_cols = saved['feature_cols']
|
||||
logger.info("Loaded model from %s", self.paths.model_path)
|
||||
except Exception as e:
|
||||
logger.error("Could not load model: %s", e)
|
||||
raise ValueError(
|
||||
f"Multi-pair model not found at {self.paths.model_path}. "
|
||||
"Run the backtest first to train the model."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Multi-pair model not found at {self.paths.model_path}. "
|
||||
"Run the backtest first to train the model."
|
||||
)
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pair_features: dict[str, pd.DataFrame],
|
||||
pairs: list[TradingPair]
|
||||
) -> list[DivergenceSignal]:
|
||||
"""
|
||||
Score all pairs and return ranked signals.
|
||||
|
||||
Args:
|
||||
pair_features: Feature DataFrames by pair_id
|
||||
pairs: List of TradingPair objects
|
||||
|
||||
Returns:
|
||||
List of DivergenceSignal sorted by score (descending)
|
||||
"""
|
||||
if self.model is None:
|
||||
logger.warning("Model not loaded")
|
||||
return []
|
||||
|
||||
signals = []
|
||||
pair_map = {p.pair_id: p for p in pairs}
|
||||
|
||||
for pair_id, features in pair_features.items():
|
||||
if pair_id not in pair_map:
|
||||
continue
|
||||
|
||||
pair = pair_map[pair_id]
|
||||
|
||||
# Get latest features
|
||||
if len(features) == 0:
|
||||
continue
|
||||
|
||||
latest = features.iloc[-1]
|
||||
z_score = latest['z_score']
|
||||
|
||||
# Skip if Z-score below threshold
|
||||
if abs(z_score) < self.config.z_entry_threshold:
|
||||
continue
|
||||
|
||||
# Prepare features for prediction
|
||||
# Handle missing feature columns gracefully
|
||||
available_cols = [c for c in self.feature_cols if c in latest.index]
|
||||
missing_cols = [c for c in self.feature_cols if c not in latest.index]
|
||||
|
||||
if missing_cols:
|
||||
logger.debug("Missing feature columns: %s", missing_cols)
|
||||
|
||||
feature_row = latest[available_cols].fillna(0)
|
||||
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
||||
|
||||
# Create full feature vector with zeros for missing
|
||||
X_dict = {c: 0 for c in self.feature_cols}
|
||||
for col in available_cols:
|
||||
X_dict[col] = feature_row[col]
|
||||
|
||||
X = pd.DataFrame([X_dict])
|
||||
|
||||
# Predict probability
|
||||
prob = self.model.predict_proba(X)[0, 1]
|
||||
|
||||
# Skip if probability below threshold
|
||||
if prob < self.config.prob_threshold:
|
||||
continue
|
||||
|
||||
# Apply funding rate filter
|
||||
base_funding = latest.get('base_funding', 0) or 0
|
||||
funding_thresh = self.config.funding_threshold
|
||||
|
||||
if z_score > 0: # Short signal
|
||||
if base_funding < -funding_thresh:
|
||||
logger.debug(
|
||||
"Skipping %s short: funding too negative (%.4f)",
|
||||
pair.name, base_funding
|
||||
)
|
||||
continue
|
||||
else: # Long signal
|
||||
if base_funding > funding_thresh:
|
||||
logger.debug(
|
||||
"Skipping %s long: funding too positive (%.4f)",
|
||||
pair.name, base_funding
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate divergence score
|
||||
divergence_score = abs(z_score) * prob
|
||||
|
||||
# Determine direction
|
||||
direction = 'short' if z_score > 0 else 'long'
|
||||
|
||||
signal = DivergenceSignal(
|
||||
pair=pair,
|
||||
z_score=z_score,
|
||||
probability=prob,
|
||||
divergence_score=divergence_score,
|
||||
direction=direction,
|
||||
base_price=latest['base_close'],
|
||||
quote_price=latest['quote_close'],
|
||||
atr=latest.get('atr_base', 0),
|
||||
atr_pct=latest.get('atr_pct_base', 0.02),
|
||||
base_funding=base_funding
|
||||
)
|
||||
signals.append(signal)
|
||||
|
||||
# Sort by divergence score (highest first)
|
||||
signals.sort(key=lambda s: s.divergence_score, reverse=True)
|
||||
|
||||
if signals:
|
||||
logger.info(
|
||||
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f, dir=%s)",
|
||||
len(signals),
|
||||
signals[0].pair.name,
|
||||
signals[0].divergence_score,
|
||||
signals[0].z_score,
|
||||
signals[0].probability,
|
||||
signals[0].direction
|
||||
)
|
||||
else:
|
||||
logger.info("No pairs meet entry criteria")
|
||||
|
||||
return signals
|
||||
|
||||
def select_best_pair(
|
||||
self,
|
||||
signals: list[DivergenceSignal]
|
||||
) -> DivergenceSignal | None:
|
||||
"""
|
||||
Select the best pair from scored signals.
|
||||
|
||||
Args:
|
||||
signals: List of DivergenceSignal (pre-sorted by score)
|
||||
|
||||
Returns:
|
||||
Best signal or None if no valid candidates
|
||||
"""
|
||||
if not signals:
|
||||
return None
|
||||
return signals[0]
|
||||
|
||||
def generate_signal(
|
||||
self,
|
||||
pair_features: dict[str, pd.DataFrame],
|
||||
pairs: list[TradingPair]
|
||||
) -> dict:
|
||||
"""
|
||||
Generate trading signal from latest features.
|
||||
|
||||
Args:
|
||||
pair_features: Feature DataFrames by pair_id
|
||||
pairs: List of TradingPair objects
|
||||
|
||||
Returns:
|
||||
Signal dictionary with action, pair, direction, etc.
|
||||
"""
|
||||
# Score all pairs
|
||||
signals = self.score_pairs(pair_features, pairs)
|
||||
|
||||
# Select best
|
||||
best = self.select_best_pair(signals)
|
||||
|
||||
if best is None:
|
||||
return {
|
||||
'action': 'hold',
|
||||
'reason': 'no_valid_signals'
|
||||
}
|
||||
|
||||
return {
|
||||
'action': 'entry',
|
||||
'pair': best.pair,
|
||||
'pair_id': best.pair.pair_id,
|
||||
'direction': best.direction,
|
||||
'z_score': best.z_score,
|
||||
'probability': best.probability,
|
||||
'divergence_score': best.divergence_score,
|
||||
'base_price': best.base_price,
|
||||
'quote_price': best.quote_price,
|
||||
'atr': best.atr,
|
||||
'atr_pct': best.atr_pct,
|
||||
'base_funding': best.base_funding,
|
||||
'reason': f'{best.pair.name} z={best.z_score:.2f} p={best.probability:.2f}'
|
||||
}
|
||||
|
||||
def check_exit_signal(
|
||||
self,
|
||||
pair_features: dict[str, pd.DataFrame],
|
||||
current_pair_id: str
|
||||
) -> dict:
|
||||
"""
|
||||
Check if current position should be exited.
|
||||
|
||||
Exit conditions:
|
||||
1. Z-Score reverted to mean (|Z| < threshold)
|
||||
|
||||
Args:
|
||||
pair_features: Feature DataFrames by pair_id
|
||||
current_pair_id: Current position's pair ID
|
||||
|
||||
Returns:
|
||||
Signal dictionary with action and reason
|
||||
"""
|
||||
if current_pair_id not in pair_features:
|
||||
return {
|
||||
'action': 'exit',
|
||||
'reason': 'pair_data_missing'
|
||||
}
|
||||
|
||||
features = pair_features[current_pair_id]
|
||||
if len(features) == 0:
|
||||
return {
|
||||
'action': 'exit',
|
||||
'reason': 'no_data'
|
||||
}
|
||||
|
||||
latest = features.iloc[-1]
|
||||
z_score = latest['z_score']
|
||||
|
||||
# Check mean reversion
|
||||
if abs(z_score) < self.config.z_exit_threshold:
|
||||
return {
|
||||
'action': 'exit',
|
||||
'reason': f'mean_reversion (z={z_score:.2f})'
|
||||
}
|
||||
|
||||
return {
|
||||
'action': 'hold',
|
||||
'z_score': z_score,
|
||||
'reason': f'holding (z={z_score:.2f})'
|
||||
}
|
||||
|
||||
def calculate_sl_tp(
|
||||
self,
|
||||
entry_price: float,
|
||||
direction: str,
|
||||
atr: float,
|
||||
atr_pct: float
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate ATR-based dynamic stop-loss and take-profit prices.
|
||||
|
||||
Args:
|
||||
entry_price: Entry price
|
||||
direction: 'long' or 'short'
|
||||
atr: ATR in price units
|
||||
atr_pct: ATR as percentage of price
|
||||
|
||||
Returns:
|
||||
Tuple of (stop_loss_price, take_profit_price)
|
||||
"""
|
||||
if atr > 0 and atr_pct > 0:
|
||||
sl_distance = atr * self.config.sl_atr_multiplier
|
||||
tp_distance = atr * self.config.tp_atr_multiplier
|
||||
|
||||
sl_pct = sl_distance / entry_price
|
||||
tp_pct = tp_distance / entry_price
|
||||
else:
|
||||
sl_pct = self.config.base_sl_pct
|
||||
tp_pct = self.config.base_tp_pct
|
||||
|
||||
# Apply bounds
|
||||
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
|
||||
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
|
||||
|
||||
if direction == 'long':
|
||||
stop_loss = entry_price * (1 - sl_pct)
|
||||
take_profit = entry_price * (1 + tp_pct)
|
||||
else:
|
||||
stop_loss = entry_price * (1 + sl_pct)
|
||||
take_profit = entry_price * (1 - tp_pct)
|
||||
|
||||
return stop_loss, take_profit
|
||||
|
||||
def calculate_position_size(
|
||||
self,
|
||||
divergence_score: float,
|
||||
available_usdt: float
|
||||
) -> float:
|
||||
"""
|
||||
Calculate position size based on divergence score.
|
||||
|
||||
Args:
|
||||
divergence_score: Combined score (|z| * prob)
|
||||
available_usdt: Available USDT balance
|
||||
|
||||
Returns:
|
||||
Position size in USDT
|
||||
"""
|
||||
if self.config.max_position_usdt <= 0:
|
||||
base_size = available_usdt
|
||||
else:
|
||||
base_size = min(available_usdt, self.config.max_position_usdt)
|
||||
|
||||
# Scale by divergence (1.0 at 0.5 score, up to 2.0 at 1.0+ score)
|
||||
base_threshold = 0.5
|
||||
if divergence_score <= base_threshold:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
|
||||
scale = min(scale, 2.0)
|
||||
|
||||
size = base_size * scale
|
||||
|
||||
if size < self.config.min_position_usdt:
|
||||
return 0.0
|
||||
|
||||
return min(size, available_usdt * 0.95)
|
||||
338
live_trading/okx_client.py
Normal file
338
live_trading/okx_client.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
OKX Exchange Client for Live Trading.
|
||||
|
||||
Handles connection to OKX API, order execution, and account management.
|
||||
Supports demo/sandbox mode for paper trading.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import ccxt
|
||||
|
||||
from .config import OKXConfig, TradingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OKXClient:
|
||||
"""
|
||||
OKX Exchange client wrapper using CCXT.
|
||||
|
||||
Supports both live and demo (sandbox) trading modes.
|
||||
Demo mode uses OKX's official sandbox environment.
|
||||
"""
|
||||
|
||||
def __init__(self, okx_config: OKXConfig, trading_config: TradingConfig):
|
||||
self.okx_config = okx_config
|
||||
self.trading_config = trading_config
|
||||
self.exchange: Optional[ccxt.okx] = None
|
||||
self._setup_exchange()
|
||||
|
||||
def _setup_exchange(self) -> None:
|
||||
"""Initialize CCXT OKX exchange instance."""
|
||||
self.okx_config.validate()
|
||||
|
||||
config = {
|
||||
'apiKey': self.okx_config.api_key,
|
||||
'secret': self.okx_config.secret,
|
||||
'password': self.okx_config.password,
|
||||
'sandbox': self.okx_config.demo_mode,
|
||||
'options': {
|
||||
'defaultType': 'swap', # Perpetual futures
|
||||
},
|
||||
'timeout': 30000,
|
||||
'enableRateLimit': True,
|
||||
}
|
||||
|
||||
self.exchange = ccxt.okx(config)
|
||||
|
||||
mode_str = "DEMO/SANDBOX" if self.okx_config.demo_mode else "LIVE"
|
||||
logger.info(f"OKX Exchange initialized in {mode_str} mode")
|
||||
|
||||
# Configure trading settings
|
||||
self._configure_trading_settings()
|
||||
|
||||
def _configure_trading_settings(self) -> None:
|
||||
"""Configure leverage and margin mode."""
|
||||
symbol = self.trading_config.eth_symbol
|
||||
leverage = self.trading_config.leverage
|
||||
margin_mode = self.trading_config.margin_mode
|
||||
|
||||
try:
|
||||
# Set position mode to one-way (net) first
|
||||
self.exchange.set_position_mode(False) # False = one-way mode
|
||||
logger.info("Position mode set to One-Way (Net)")
|
||||
except Exception as e:
|
||||
# Position mode might already be set
|
||||
logger.debug(f"Position mode setting: {e}")
|
||||
|
||||
try:
|
||||
# Set margin mode with leverage parameter (required by OKX)
|
||||
self.exchange.set_margin_mode(
|
||||
margin_mode,
|
||||
symbol,
|
||||
params={'lever': leverage}
|
||||
)
|
||||
logger.info(
|
||||
f"Margin mode set to {margin_mode} with {leverage}x leverage "
|
||||
f"for {symbol}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not set margin mode: {e}")
|
||||
# Try setting leverage separately
|
||||
try:
|
||||
self.exchange.set_leverage(leverage, symbol)
|
||||
logger.info(f"Leverage set to {leverage}x for {symbol}")
|
||||
except Exception as e2:
|
||||
logger.warning(f"Could not set leverage: {e2}")
|
||||
|
||||
def fetch_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = "1h",
|
||||
limit: int = 500
|
||||
) -> list:
|
||||
"""
|
||||
Fetch OHLCV candle data.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., "ETH/USDT:USDT")
|
||||
timeframe: Candle timeframe (e.g., "1h")
|
||||
limit: Number of candles to fetch
|
||||
|
||||
Returns:
|
||||
List of OHLCV data
|
||||
"""
|
||||
return self.exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
|
||||
|
||||
def get_balance(self) -> dict:
|
||||
"""
|
||||
Get account balance.
|
||||
|
||||
Returns:
|
||||
Balance dictionary with 'total' and 'free' USDT amounts
|
||||
"""
|
||||
balance = self.exchange.fetch_balance()
|
||||
return {
|
||||
'total': balance.get('USDT', {}).get('total', 0),
|
||||
'free': balance.get('USDT', {}).get('free', 0),
|
||||
}
|
||||
|
||||
def get_positions(self) -> list:
|
||||
"""
|
||||
Get open positions.
|
||||
|
||||
Returns:
|
||||
List of open position dictionaries
|
||||
"""
|
||||
positions = self.exchange.fetch_positions()
|
||||
return [p for p in positions if float(p.get('contracts', 0)) != 0]
|
||||
|
||||
def get_position(self, symbol: str) -> Optional[dict]:
|
||||
"""
|
||||
Get position for a specific symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
|
||||
Returns:
|
||||
Position dictionary or None if no position
|
||||
"""
|
||||
positions = self.get_positions()
|
||||
for pos in positions:
|
||||
if pos.get('symbol') == symbol:
|
||||
return pos
|
||||
return None
|
||||
|
||||
def place_market_order(
|
||||
self,
|
||||
symbol: str,
|
||||
side: str,
|
||||
amount: float,
|
||||
reduce_only: bool = False
|
||||
) -> dict:
|
||||
"""
|
||||
Place a market order.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
side: "buy" or "sell"
|
||||
amount: Order amount in base currency
|
||||
reduce_only: If True, only reduce existing position
|
||||
|
||||
Returns:
|
||||
Order result dictionary
|
||||
"""
|
||||
params = {
|
||||
'tdMode': self.trading_config.margin_mode,
|
||||
}
|
||||
if reduce_only:
|
||||
params['reduceOnly'] = True
|
||||
|
||||
order = self.exchange.create_market_order(
|
||||
symbol, side, amount, params=params
|
||||
)
|
||||
logger.info(
|
||||
f"Market {side.upper()} order placed: {amount} {symbol} "
|
||||
f"@ market price, order_id={order['id']}"
|
||||
)
|
||||
return order
|
||||
|
||||
def place_limit_order(
|
||||
self,
|
||||
symbol: str,
|
||||
side: str,
|
||||
amount: float,
|
||||
price: float,
|
||||
reduce_only: bool = False
|
||||
) -> dict:
|
||||
"""
|
||||
Place a limit order.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
side: "buy" or "sell"
|
||||
amount: Order amount in base currency
|
||||
price: Limit price
|
||||
reduce_only: If True, only reduce existing position
|
||||
|
||||
Returns:
|
||||
Order result dictionary
|
||||
"""
|
||||
params = {
|
||||
'tdMode': self.trading_config.margin_mode,
|
||||
}
|
||||
if reduce_only:
|
||||
params['reduceOnly'] = True
|
||||
|
||||
order = self.exchange.create_limit_order(
|
||||
symbol, side, amount, price, params=params
|
||||
)
|
||||
logger.info(
|
||||
f"Limit {side.upper()} order placed: {amount} {symbol} "
|
||||
f"@ {price}, order_id={order['id']}"
|
||||
)
|
||||
return order
|
||||
|
||||
def set_stop_loss_take_profit(
|
||||
self,
|
||||
symbol: str,
|
||||
side: str,
|
||||
amount: float,
|
||||
stop_loss_price: float,
|
||||
take_profit_price: float
|
||||
) -> tuple:
|
||||
"""
|
||||
Set stop-loss and take-profit orders.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
side: Position side ("long" or "short")
|
||||
amount: Position size
|
||||
stop_loss_price: Stop-loss trigger price
|
||||
take_profit_price: Take-profit trigger price
|
||||
|
||||
Returns:
|
||||
Tuple of (sl_order, tp_order)
|
||||
"""
|
||||
# For long position: SL sells, TP sells
|
||||
# For short position: SL buys, TP buys
|
||||
close_side = "sell" if side == "long" else "buy"
|
||||
|
||||
# Stop-loss order
|
||||
sl_params = {
|
||||
'tdMode': self.trading_config.margin_mode,
|
||||
'reduceOnly': True,
|
||||
'stopLossPrice': stop_loss_price,
|
||||
}
|
||||
|
||||
try:
|
||||
sl_order = self.exchange.create_order(
|
||||
symbol, 'market', close_side, amount,
|
||||
params={
|
||||
'tdMode': self.trading_config.margin_mode,
|
||||
'reduceOnly': True,
|
||||
'slTriggerPx': str(stop_loss_price),
|
||||
'slOrdPx': '-1', # Market price
|
||||
}
|
||||
)
|
||||
logger.info(f"Stop-loss set at {stop_loss_price}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not set stop-loss: {e}")
|
||||
sl_order = None
|
||||
|
||||
# Take-profit order
|
||||
try:
|
||||
tp_order = self.exchange.create_order(
|
||||
symbol, 'market', close_side, amount,
|
||||
params={
|
||||
'tdMode': self.trading_config.margin_mode,
|
||||
'reduceOnly': True,
|
||||
'tpTriggerPx': str(take_profit_price),
|
||||
'tpOrdPx': '-1', # Market price
|
||||
}
|
||||
)
|
||||
logger.info(f"Take-profit set at {take_profit_price}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not set take-profit: {e}")
|
||||
tp_order = None
|
||||
|
||||
return sl_order, tp_order
|
||||
|
||||
def close_position(self, symbol: str) -> Optional[dict]:
|
||||
"""
|
||||
Close an open position.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
|
||||
Returns:
|
||||
Order result or None if no position
|
||||
"""
|
||||
position = self.get_position(symbol)
|
||||
if not position:
|
||||
logger.info(f"No open position for {symbol}")
|
||||
return None
|
||||
|
||||
contracts = abs(float(position.get('contracts', 0)))
|
||||
if contracts == 0:
|
||||
return None
|
||||
|
||||
side = position.get('side', 'long')
|
||||
close_side = "sell" if side == "long" else "buy"
|
||||
|
||||
order = self.place_market_order(
|
||||
symbol, close_side, contracts, reduce_only=True
|
||||
)
|
||||
logger.info(f"Position closed for {symbol}")
|
||||
return order
|
||||
|
||||
def get_ticker(self, symbol: str) -> dict:
|
||||
"""
|
||||
Get current ticker/price for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
|
||||
Returns:
|
||||
Ticker dictionary with 'last', 'bid', 'ask' prices
|
||||
"""
|
||||
return self.exchange.fetch_ticker(symbol)
|
||||
|
||||
def get_funding_rate(self, symbol: str) -> float:
|
||||
"""
|
||||
Get current funding rate for a perpetual symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
|
||||
Returns:
|
||||
Current funding rate as decimal
|
||||
"""
|
||||
try:
|
||||
funding = self.exchange.fetch_funding_rate(symbol)
|
||||
return float(funding.get('fundingRate', 0))
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch funding rate: {e}")
|
||||
return 0.0
|
||||
369
live_trading/position_manager.py
Normal file
369
live_trading/position_manager.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""
|
||||
Position Manager for Live Trading.
|
||||
|
||||
Tracks open positions, manages risk, and handles SL/TP logic.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
from .okx_client import OKXClient
|
||||
from .config import TradingConfig, PathConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Position:
|
||||
"""Represents an open trading position."""
|
||||
trade_id: str
|
||||
symbol: str
|
||||
side: str # "long" or "short"
|
||||
entry_price: float
|
||||
entry_time: str # ISO format
|
||||
size: float # Amount in base currency (e.g., ETH)
|
||||
size_usdt: float # Notional value in USDT
|
||||
stop_loss_price: float
|
||||
take_profit_price: float
|
||||
current_price: float = 0.0
|
||||
unrealized_pnl: float = 0.0
|
||||
unrealized_pnl_pct: float = 0.0
|
||||
order_id: str = "" # Entry order ID from exchange
|
||||
|
||||
def update_pnl(self, current_price: float) -> None:
|
||||
"""Update unrealized PnL based on current price."""
|
||||
self.current_price = current_price
|
||||
|
||||
if self.side == "long":
|
||||
self.unrealized_pnl = (current_price - self.entry_price) * self.size
|
||||
self.unrealized_pnl_pct = (current_price / self.entry_price - 1) * 100
|
||||
else: # short
|
||||
self.unrealized_pnl = (self.entry_price - current_price) * self.size
|
||||
self.unrealized_pnl_pct = (1 - current_price / self.entry_price) * 100
|
||||
|
||||
def should_stop_loss(self, current_price: float) -> bool:
|
||||
"""Check if stop-loss should trigger."""
|
||||
if self.side == "long":
|
||||
return current_price <= self.stop_loss_price
|
||||
return current_price >= self.stop_loss_price
|
||||
|
||||
def should_take_profit(self, current_price: float) -> bool:
|
||||
"""Check if take-profit should trigger."""
|
||||
if self.side == "long":
|
||||
return current_price >= self.take_profit_price
|
||||
return current_price <= self.take_profit_price
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'Position':
|
||||
"""Create Position from dictionary."""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class PositionManager:
|
||||
"""
|
||||
Manages trading positions with persistence.
|
||||
|
||||
Tracks open positions, enforces risk limits, and handles
|
||||
position lifecycle (open, update, close).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
okx_client: OKXClient,
|
||||
trading_config: TradingConfig,
|
||||
path_config: PathConfig
|
||||
):
|
||||
self.client = okx_client
|
||||
self.config = trading_config
|
||||
self.paths = path_config
|
||||
self.positions: dict[str, Position] = {}
|
||||
self.trade_log: list[dict] = []
|
||||
self._load_positions()
|
||||
|
||||
def _load_positions(self) -> None:
|
||||
"""Load positions from file."""
|
||||
if self.paths.positions_file.exists():
|
||||
try:
|
||||
with open(self.paths.positions_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
for trade_id, pos_data in data.items():
|
||||
self.positions[trade_id] = Position.from_dict(pos_data)
|
||||
logger.info(f"Loaded {len(self.positions)} positions from file")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load positions: {e}")
|
||||
|
||||
def save_positions(self) -> None:
|
||||
"""Save positions to file."""
|
||||
try:
|
||||
data = {
|
||||
trade_id: pos.to_dict()
|
||||
for trade_id, pos in self.positions.items()
|
||||
}
|
||||
with open(self.paths.positions_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
logger.debug(f"Saved {len(self.positions)} positions")
|
||||
except Exception as e:
|
||||
logger.error(f"Could not save positions: {e}")
|
||||
|
||||
def can_open_position(self) -> bool:
|
||||
"""Check if we can open a new position."""
|
||||
return len(self.positions) < self.config.max_concurrent_positions
|
||||
|
||||
def get_position_for_symbol(self, symbol: str) -> Optional[Position]:
|
||||
"""Get position for a specific symbol."""
|
||||
for pos in self.positions.values():
|
||||
if pos.symbol == symbol:
|
||||
return pos
|
||||
return None
|
||||
|
||||
def open_position(
|
||||
self,
|
||||
symbol: str,
|
||||
side: str,
|
||||
entry_price: float,
|
||||
size: float,
|
||||
stop_loss_price: float,
|
||||
take_profit_price: float,
|
||||
order_id: str = ""
|
||||
) -> Optional[Position]:
|
||||
"""
|
||||
Open a new position.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
side: "long" or "short"
|
||||
entry_price: Entry price
|
||||
size: Position size in base currency
|
||||
stop_loss_price: Stop-loss price
|
||||
take_profit_price: Take-profit price
|
||||
order_id: Entry order ID from exchange
|
||||
|
||||
Returns:
|
||||
Position object or None if failed
|
||||
"""
|
||||
if not self.can_open_position():
|
||||
logger.warning("Cannot open position: max concurrent positions reached")
|
||||
return None
|
||||
|
||||
# Check if already have position for this symbol
|
||||
existing = self.get_position_for_symbol(symbol)
|
||||
if existing:
|
||||
logger.warning(f"Already have position for {symbol}")
|
||||
return None
|
||||
|
||||
# Generate trade ID
|
||||
now = datetime.now(timezone.utc)
|
||||
trade_id = f"{symbol}_{now.strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
position = Position(
|
||||
trade_id=trade_id,
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
entry_price=entry_price,
|
||||
entry_time=now.isoformat(),
|
||||
size=size,
|
||||
size_usdt=entry_price * size,
|
||||
stop_loss_price=stop_loss_price,
|
||||
take_profit_price=take_profit_price,
|
||||
current_price=entry_price,
|
||||
order_id=order_id,
|
||||
)
|
||||
|
||||
self.positions[trade_id] = position
|
||||
self.save_positions()
|
||||
|
||||
logger.info(
|
||||
f"Opened {side.upper()} position: {size} {symbol} @ {entry_price}, "
|
||||
f"SL={stop_loss_price}, TP={take_profit_price}"
|
||||
)
|
||||
|
||||
return position
|
||||
|
||||
def close_position(
|
||||
self,
|
||||
trade_id: str,
|
||||
exit_price: float,
|
||||
reason: str = "manual",
|
||||
exit_order_id: str = ""
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Close a position and record the trade.
|
||||
|
||||
Args:
|
||||
trade_id: Position trade ID
|
||||
exit_price: Exit price
|
||||
reason: Reason for closing (e.g., "stop_loss", "take_profit", "signal")
|
||||
exit_order_id: Exit order ID from exchange
|
||||
|
||||
Returns:
|
||||
Trade record dictionary
|
||||
"""
|
||||
if trade_id not in self.positions:
|
||||
logger.warning(f"Position {trade_id} not found")
|
||||
return None
|
||||
|
||||
position = self.positions[trade_id]
|
||||
position.update_pnl(exit_price)
|
||||
|
||||
# Calculate final PnL
|
||||
entry_time = datetime.fromisoformat(position.entry_time)
|
||||
exit_time = datetime.now(timezone.utc)
|
||||
hold_duration = (exit_time - entry_time).total_seconds() / 3600 # hours
|
||||
|
||||
trade_record = {
|
||||
'trade_id': trade_id,
|
||||
'symbol': position.symbol,
|
||||
'side': position.side,
|
||||
'entry_price': position.entry_price,
|
||||
'exit_price': exit_price,
|
||||
'size': position.size,
|
||||
'size_usdt': position.size_usdt,
|
||||
'pnl_usd': position.unrealized_pnl,
|
||||
'pnl_pct': position.unrealized_pnl_pct,
|
||||
'entry_time': position.entry_time,
|
||||
'exit_time': exit_time.isoformat(),
|
||||
'hold_duration_hours': hold_duration,
|
||||
'reason': reason,
|
||||
'order_id_entry': position.order_id,
|
||||
'order_id_exit': exit_order_id,
|
||||
}
|
||||
|
||||
self.trade_log.append(trade_record)
|
||||
del self.positions[trade_id]
|
||||
self.save_positions()
|
||||
self._append_trade_log(trade_record)
|
||||
|
||||
logger.info(
|
||||
f"Closed {position.side.upper()} position: {position.size} {position.symbol} "
|
||||
f"@ {exit_price}, PnL=${position.unrealized_pnl:.2f} ({position.unrealized_pnl_pct:.2f}%), "
|
||||
f"reason={reason}"
|
||||
)
|
||||
|
||||
return trade_record
|
||||
|
||||
def _append_trade_log(self, trade_record: dict) -> None:
|
||||
"""Append trade record to CSV log file."""
|
||||
import csv
|
||||
|
||||
file_exists = self.paths.trade_log_file.exists()
|
||||
|
||||
with open(self.paths.trade_log_file, 'a', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
writer.writerow(trade_record)
|
||||
|
||||
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
|
||||
"""
|
||||
Update all positions with current prices and check SL/TP.
|
||||
|
||||
Args:
|
||||
current_prices: Dictionary of symbol -> current price
|
||||
|
||||
Returns:
|
||||
List of closed trade records
|
||||
"""
|
||||
closed_trades = []
|
||||
|
||||
for trade_id in list(self.positions.keys()):
|
||||
position = self.positions[trade_id]
|
||||
|
||||
if position.symbol not in current_prices:
|
||||
continue
|
||||
|
||||
current_price = current_prices[position.symbol]
|
||||
position.update_pnl(current_price)
|
||||
|
||||
# Check stop-loss
|
||||
if position.should_stop_loss(current_price):
|
||||
logger.warning(
|
||||
f"Stop-loss triggered for {trade_id} at {current_price}"
|
||||
)
|
||||
# Close position on exchange
|
||||
exit_order_id = ""
|
||||
try:
|
||||
exit_order = self.client.close_position(position.symbol)
|
||||
exit_order_id = exit_order.get('id', '') if exit_order else ''
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close position on exchange: {e}")
|
||||
|
||||
record = self.close_position(trade_id, current_price, "stop_loss", exit_order_id)
|
||||
if record:
|
||||
closed_trades.append(record)
|
||||
continue
|
||||
|
||||
# Check take-profit
|
||||
if position.should_take_profit(current_price):
|
||||
logger.info(
|
||||
f"Take-profit triggered for {trade_id} at {current_price}"
|
||||
)
|
||||
# Close position on exchange
|
||||
exit_order_id = ""
|
||||
try:
|
||||
exit_order = self.client.close_position(position.symbol)
|
||||
exit_order_id = exit_order.get('id', '') if exit_order else ''
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close position on exchange: {e}")
|
||||
|
||||
record = self.close_position(trade_id, current_price, "take_profit", exit_order_id)
|
||||
if record:
|
||||
closed_trades.append(record)
|
||||
|
||||
self.save_positions()
|
||||
return closed_trades
|
||||
|
||||
def sync_with_exchange(self) -> None:
|
||||
"""
|
||||
Sync local positions with exchange positions.
|
||||
|
||||
Reconciles any discrepancies between local tracking
|
||||
and actual exchange positions.
|
||||
"""
|
||||
try:
|
||||
exchange_positions = self.client.get_positions()
|
||||
exchange_symbols = {p['symbol'] for p in exchange_positions}
|
||||
|
||||
# Check for positions we have locally but not on exchange
|
||||
for trade_id in list(self.positions.keys()):
|
||||
pos = self.positions[trade_id]
|
||||
if pos.symbol not in exchange_symbols:
|
||||
logger.warning(
|
||||
f"Position {trade_id} not found on exchange, removing"
|
||||
)
|
||||
# Get last price and close
|
||||
try:
|
||||
ticker = self.client.get_ticker(pos.symbol)
|
||||
exit_price = ticker['last']
|
||||
except Exception:
|
||||
exit_price = pos.current_price
|
||||
|
||||
self.close_position(trade_id, exit_price, "sync_removed")
|
||||
|
||||
logger.info(f"Position sync complete: {len(self.positions)} local positions")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Position sync failed: {e}")
|
||||
|
||||
def get_portfolio_summary(self) -> dict:
|
||||
"""
|
||||
Get portfolio summary.
|
||||
|
||||
Returns:
|
||||
Dictionary with portfolio statistics
|
||||
"""
|
||||
total_exposure = sum(p.size_usdt for p in self.positions.values())
|
||||
total_unrealized_pnl = sum(p.unrealized_pnl for p in self.positions.values())
|
||||
|
||||
return {
|
||||
'open_positions': len(self.positions),
|
||||
'total_exposure_usdt': total_exposure,
|
||||
'total_unrealized_pnl': total_unrealized_pnl,
|
||||
'positions': [p.to_dict() for p in self.positions.values()],
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def write_trade_log(trades: list[dict], path: Path) -> None:
|
||||
if not trades:
|
||||
return
|
||||
df = pd.DataFrame(trades)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_csv(path, index=False)
|
||||
10
main.py
Normal file
10
main.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Lowkey Backtest CLI - VectorBT Edition
|
||||
|
||||
A backtesting framework supporting multiple market types (spot, perpetual)
|
||||
with realistic trading simulation including leverage, funding, and shorts.
|
||||
"""
|
||||
from engine.cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,11 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
TAKER_FEE_BPS_DEFAULT = 10.0 # 0.10%
|
||||
|
||||
|
||||
def okx_fee(fee_bps: float, notional_usd: float) -> float:
|
||||
return notional_usd * (fee_bps / 1e4)
|
||||
|
||||
|
||||
def estimate_slippage_rate(slippage_bps: float, notional_usd: float) -> float:
|
||||
return notional_usd * (slippage_bps / 1e4)
|
||||
54
metrics.py
54
metrics.py
@@ -1,54 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@dataclass
|
||||
class Perf:
|
||||
total_return: float
|
||||
max_drawdown: float
|
||||
sharpe_ratio: float
|
||||
win_rate: float
|
||||
num_trades: int
|
||||
final_equity: float
|
||||
initial_equity: float
|
||||
num_stop_losses: int
|
||||
total_fees: float
|
||||
total_slippage_usd: float
|
||||
avg_slippage_bps: float
|
||||
|
||||
|
||||
def compute_metrics(equity_curve: pd.Series, trades: list[dict]) -> Perf:
|
||||
ret = equity_curve.pct_change().fillna(0.0)
|
||||
total_return = equity_curve.iat[-1] / equity_curve.iat[0] - 1.0
|
||||
cummax = equity_curve.cummax()
|
||||
dd = (equity_curve / cummax - 1.0).min()
|
||||
max_drawdown = dd
|
||||
|
||||
if ret.std(ddof=0) > 0:
|
||||
sharpe = (ret.mean() / ret.std(ddof=0)) * np.sqrt(252 * 24 * 60) # minute bars -> annualized
|
||||
else:
|
||||
sharpe = 0.0
|
||||
|
||||
closes = [t for t in trades if t.get("side") == "SELL"]
|
||||
wins = [t for t in closes if t.get("pnl", 0.0) > 0]
|
||||
win_rate = (len(wins) / len(closes)) if closes else 0.0
|
||||
|
||||
fees = sum(t.get("fee", 0.0) for t in trades)
|
||||
slip = sum(t.get("slippage", 0.0) for t in trades)
|
||||
slippage_bps = [t.get("slippage_bps", 0.0) for t in trades if "slippage_bps" in t]
|
||||
|
||||
return Perf(
|
||||
total_return=total_return,
|
||||
max_drawdown=max_drawdown,
|
||||
sharpe_ratio=sharpe,
|
||||
win_rate=win_rate,
|
||||
num_trades=len(closes),
|
||||
final_equity=float(equity_curve.iat[-1]),
|
||||
initial_equity=float(equity_curve.iat[0]),
|
||||
num_stop_losses=sum(1 for t in closes if t.get("reason") == "stop"),
|
||||
total_fees=fees,
|
||||
total_slippage_usd=slip,
|
||||
avg_slippage_bps=float(np.mean(slippage_bps)) if slippage_bps else 0.0,
|
||||
)
|
||||
@@ -5,5 +5,29 @@ description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"ccxt>=4.5.32",
|
||||
"numpy>=2.3.2",
|
||||
"pandas>=2.3.1",
|
||||
"ta>=0.11.0",
|
||||
"vectorbt>=0.28.2",
|
||||
"scikit-learn>=1.6.0",
|
||||
"matplotlib>=3.10.0",
|
||||
"plotly>=5.24.0",
|
||||
"requests>=2.32.5",
|
||||
"python-dotenv>=1.2.1",
|
||||
# API dependencies
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.34.0",
|
||||
"sqlalchemy>=2.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["."]
|
||||
markers = [
|
||||
"network: marks tests as requiring network access",
|
||||
]
|
||||
|
||||
342
research/regime_detection.py
Normal file
342
research/regime_detection.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Regime Detection Research Script with Walk-Forward Training.
|
||||
|
||||
Tests multiple holding horizons to find optimal parameters
|
||||
without look-ahead bias.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import ta
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import classification_report, f1_score
|
||||
|
||||
from engine.data_manager import DataManager
|
||||
from engine.market import MarketType
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Configuration
|
||||
TRAIN_RATIO = 0.7 # 70% train, 30% test
|
||||
PROFIT_THRESHOLD = 0.005 # 0.5% profit target
|
||||
Z_WINDOW = 24
|
||||
FEE_RATE = 0.001 # 0.1% round-trip fee
|
||||
|
||||
|
||||
def load_data():
|
||||
"""Load and align BTC/ETH data."""
|
||||
dm = DataManager()
|
||||
|
||||
df_btc = dm.load_data("okx", "BTC-USDT", "1h", MarketType.SPOT)
|
||||
df_eth = dm.load_data("okx", "ETH-USDT", "1h", MarketType.SPOT)
|
||||
|
||||
# Filter to Oct-Dec 2025
|
||||
start = pd.Timestamp("2025-10-01", tz="UTC")
|
||||
end = pd.Timestamp("2025-12-31", tz="UTC")
|
||||
|
||||
df_btc = df_btc[(df_btc.index >= start) & (df_btc.index <= end)]
|
||||
df_eth = df_eth[(df_eth.index >= start) & (df_eth.index <= end)]
|
||||
|
||||
# Align indices
|
||||
common = df_btc.index.intersection(df_eth.index)
|
||||
df_btc = df_btc.loc[common]
|
||||
df_eth = df_eth.loc[common]
|
||||
|
||||
logger.info(f"Loaded {len(common)} aligned hourly bars")
|
||||
return df_btc, df_eth
|
||||
|
||||
|
||||
def load_cryptoquant_data():
|
||||
"""Load CryptoQuant on-chain data if available."""
|
||||
try:
|
||||
cq_path = "data/cq_training_data.csv"
|
||||
cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True)
|
||||
if cq_df.index.tz is None:
|
||||
cq_df.index = cq_df.index.tz_localize('UTC')
|
||||
logger.info(f"Loaded CryptoQuant data: {len(cq_df)} rows")
|
||||
return cq_df
|
||||
except Exception as e:
|
||||
logger.warning(f"CryptoQuant data not available: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def calculate_features(df_btc, df_eth, cq_df=None):
|
||||
"""Calculate all features for the model."""
|
||||
spread = df_eth['close'] / df_btc['close']
|
||||
|
||||
# Z-Score
|
||||
rolling_mean = spread.rolling(window=Z_WINDOW).mean()
|
||||
rolling_std = spread.rolling(window=Z_WINDOW).std()
|
||||
z_score = (spread - rolling_mean) / rolling_std
|
||||
|
||||
# Technicals
|
||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||
spread_roc = spread.pct_change(periods=5) * 100
|
||||
spread_change_1h = spread.pct_change(periods=1)
|
||||
|
||||
# Volume
|
||||
vol_ratio = df_eth['volume'] / df_btc['volume']
|
||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||
|
||||
# Volatility
|
||||
ret_btc = df_btc['close'].pct_change()
|
||||
ret_eth = df_eth['close'].pct_change()
|
||||
vol_btc = ret_btc.rolling(window=Z_WINDOW).std()
|
||||
vol_eth = ret_eth.rolling(window=Z_WINDOW).std()
|
||||
vol_spread_ratio = vol_eth / vol_btc
|
||||
|
||||
features = pd.DataFrame(index=spread.index)
|
||||
features['spread'] = spread
|
||||
features['z_score'] = z_score
|
||||
features['spread_rsi'] = spread_rsi
|
||||
features['spread_roc'] = spread_roc
|
||||
features['spread_change_1h'] = spread_change_1h
|
||||
features['vol_ratio'] = vol_ratio
|
||||
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
|
||||
features['vol_diff_ratio'] = vol_spread_ratio
|
||||
|
||||
# Add CQ features if available
|
||||
if cq_df is not None:
|
||||
cq_aligned = cq_df.reindex(features.index, method='ffill')
|
||||
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
|
||||
cq_aligned['funding_diff'] = cq_aligned['eth_funding'] - cq_aligned['btc_funding']
|
||||
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
|
||||
cq_aligned['inflow_ratio'] = cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
|
||||
features = features.join(cq_aligned)
|
||||
|
||||
return features.dropna()
|
||||
|
||||
|
||||
def calculate_targets(features, horizon):
|
||||
"""Calculate target labels for a given horizon."""
|
||||
spread = features['spread']
|
||||
z_score = features['z_score']
|
||||
|
||||
# For Short (Z > 1): Did spread drop below target?
|
||||
future_min = spread.rolling(window=horizon).min().shift(-horizon)
|
||||
target_short = spread * (1 - PROFIT_THRESHOLD)
|
||||
success_short = (z_score > 1.0) & (future_min < target_short)
|
||||
|
||||
# For Long (Z < -1): Did spread rise above target?
|
||||
future_max = spread.rolling(window=horizon).max().shift(-horizon)
|
||||
target_long = spread * (1 + PROFIT_THRESHOLD)
|
||||
success_long = (z_score < -1.0) & (future_max > target_long)
|
||||
|
||||
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||
|
||||
# Create valid mask (rows with complete future data)
|
||||
valid_mask = future_min.notna() & future_max.notna()
|
||||
|
||||
return targets, valid_mask, future_min, future_max
|
||||
|
||||
|
||||
def calculate_mae(features, predictions, test_idx, horizon):
|
||||
"""Calculate Maximum Adverse Excursion for predicted trades."""
|
||||
test_features = features.loc[test_idx]
|
||||
spread = test_features['spread']
|
||||
z_score = test_features['z_score']
|
||||
|
||||
mae_values = []
|
||||
|
||||
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
|
||||
if pred != 1:
|
||||
continue
|
||||
|
||||
entry_spread = spread.loc[idx]
|
||||
z = z_score.loc[idx]
|
||||
|
||||
# Get future spread values
|
||||
future_idx = features.index.get_loc(idx)
|
||||
future_end = min(future_idx + horizon, len(features))
|
||||
future_spreads = features['spread'].iloc[future_idx:future_end]
|
||||
|
||||
if len(future_spreads) < 2:
|
||||
continue
|
||||
|
||||
if z > 1.0: # Short trade
|
||||
max_adverse = (future_spreads.max() - entry_spread) / entry_spread
|
||||
else: # Long trade
|
||||
max_adverse = (entry_spread - future_spreads.min()) / entry_spread
|
||||
|
||||
mae_values.append(max_adverse * 100) # As percentage
|
||||
|
||||
return np.mean(mae_values) if mae_values else 0.0
|
||||
|
||||
|
||||
def calculate_net_profit(features, predictions, test_idx, horizon):
|
||||
"""Calculate estimated net profit including fees."""
|
||||
test_features = features.loc[test_idx]
|
||||
spread = test_features['spread']
|
||||
z_score = test_features['z_score']
|
||||
|
||||
total_pnl = 0.0
|
||||
n_trades = 0
|
||||
|
||||
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
|
||||
if pred != 1:
|
||||
continue
|
||||
|
||||
entry_spread = spread.loc[idx]
|
||||
z = z_score.loc[idx]
|
||||
|
||||
# Get future spread values
|
||||
future_idx = features.index.get_loc(idx)
|
||||
future_end = min(future_idx + horizon, len(features))
|
||||
future_spreads = features['spread'].iloc[future_idx:future_end]
|
||||
|
||||
if len(future_spreads) < 2:
|
||||
continue
|
||||
|
||||
# Calculate PnL based on direction
|
||||
if z > 1.0: # Short trade - profit if spread drops
|
||||
exit_spread = future_spreads.iloc[-1] # Exit at horizon
|
||||
pnl = (entry_spread - exit_spread) / entry_spread
|
||||
else: # Long trade - profit if spread rises
|
||||
exit_spread = future_spreads.iloc[-1]
|
||||
pnl = (exit_spread - entry_spread) / entry_spread
|
||||
|
||||
# Subtract fees
|
||||
net_pnl = pnl - FEE_RATE
|
||||
total_pnl += net_pnl
|
||||
n_trades += 1
|
||||
|
||||
return total_pnl, n_trades
|
||||
|
||||
|
||||
def test_horizon(features, horizon):
|
||||
"""Test a single horizon with walk-forward training."""
|
||||
# Calculate targets
|
||||
targets, valid_mask, _, _ = calculate_targets(features, horizon)
|
||||
|
||||
# Walk-forward split
|
||||
n_samples = len(features)
|
||||
train_size = int(n_samples * TRAIN_RATIO)
|
||||
|
||||
train_features = features.iloc[:train_size]
|
||||
test_features = features.iloc[train_size:]
|
||||
|
||||
train_targets = targets[:train_size]
|
||||
test_targets = targets[train_size:]
|
||||
|
||||
train_valid = valid_mask.iloc[:train_size]
|
||||
test_valid = valid_mask.iloc[train_size:]
|
||||
|
||||
# Prepare training data (only valid rows)
|
||||
exclude = ['spread']
|
||||
cols = [c for c in features.columns if c not in exclude]
|
||||
|
||||
X_train = train_features[cols].fillna(0).replace([np.inf, -np.inf], 0)
|
||||
X_train_valid = X_train[train_valid]
|
||||
y_train_valid = train_targets[train_valid]
|
||||
|
||||
if len(X_train_valid) < 50:
|
||||
return None # Not enough training data
|
||||
|
||||
# Train model
|
||||
model = RandomForestClassifier(
|
||||
n_estimators=300, max_depth=5, min_samples_leaf=30,
|
||||
class_weight={0: 1, 1: 3}, random_state=42
|
||||
)
|
||||
model.fit(X_train_valid, y_train_valid)
|
||||
|
||||
# Predict on test set
|
||||
X_test = test_features[cols].fillna(0).replace([np.inf, -np.inf], 0)
|
||||
predictions = model.predict(X_test)
|
||||
|
||||
# Only evaluate on valid test rows (those with complete future data)
|
||||
test_valid_mask = test_valid.values
|
||||
y_test_valid = test_targets[test_valid_mask]
|
||||
pred_valid = predictions[test_valid_mask]
|
||||
|
||||
if len(y_test_valid) < 10:
|
||||
return None
|
||||
|
||||
# Calculate metrics
|
||||
f1 = f1_score(y_test_valid, pred_valid, zero_division=0)
|
||||
|
||||
# Calculate MAE and Net Profit on ALL test predictions (not just valid targets)
|
||||
test_idx = test_features.index
|
||||
avg_mae = calculate_mae(features, predictions, test_idx, horizon)
|
||||
net_pnl, n_trades = calculate_net_profit(features, predictions, test_idx, horizon)
|
||||
|
||||
return {
|
||||
'horizon': horizon,
|
||||
'f1_score': f1,
|
||||
'avg_mae': avg_mae,
|
||||
'net_pnl': net_pnl,
|
||||
'n_trades': n_trades,
|
||||
'train_samples': len(X_train_valid),
|
||||
'test_samples': len(X_test)
|
||||
}
|
||||
|
||||
|
||||
def test_horizons(features, horizons):
|
||||
"""Test multiple horizons and return comparison."""
|
||||
results = []
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("WALK-FORWARD HORIZON OPTIMIZATION")
|
||||
print(f"Train Ratio: {TRAIN_RATIO*100:.0f}% | Profit Target: {PROFIT_THRESHOLD*100:.1f}% | Fee Rate: {FEE_RATE*100:.2f}%")
|
||||
print("=" * 80)
|
||||
|
||||
for h in horizons:
|
||||
result = test_horizon(features, h)
|
||||
if result:
|
||||
results.append(result)
|
||||
print(f"Horizon {h:3d}h: F1={result['f1_score']:.3f}, "
|
||||
f"MAE={result['avg_mae']:.2f}%, "
|
||||
f"Net PnL={result['net_pnl']*100:.2f}%, "
|
||||
f"Trades={result['n_trades']}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""Main research function."""
|
||||
# Load data
|
||||
df_btc, df_eth = load_data()
|
||||
cq_df = load_cryptoquant_data()
|
||||
|
||||
# Calculate features
|
||||
features = calculate_features(df_btc, df_eth, cq_df)
|
||||
logger.info(f"Calculated {len(features)} feature rows with {len(features.columns)} columns")
|
||||
|
||||
# Test horizons from 6h to 150h
|
||||
horizons = list(range(6, 151, 6)) # 6, 12, 18, ..., 150
|
||||
|
||||
results = test_horizons(features, horizons)
|
||||
|
||||
if not results:
|
||||
print("No valid results!")
|
||||
return
|
||||
|
||||
# Find best by different metrics
|
||||
results_df = pd.DataFrame(results)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("BEST HORIZONS BY METRIC")
|
||||
print("=" * 80)
|
||||
|
||||
best_f1 = results_df.loc[results_df['f1_score'].idxmax()]
|
||||
print(f"Best F1 Score: {best_f1['horizon']:.0f}h (F1={best_f1['f1_score']:.3f})")
|
||||
|
||||
best_pnl = results_df.loc[results_df['net_pnl'].idxmax()]
|
||||
print(f"Best Net PnL: {best_pnl['horizon']:.0f}h (PnL={best_pnl['net_pnl']*100:.2f}%)")
|
||||
|
||||
lowest_mae = results_df.loc[results_df['avg_mae'].idxmin()]
|
||||
print(f"Lowest MAE: {lowest_mae['horizon']:.0f}h (MAE={lowest_mae['avg_mae']:.2f}%)")
|
||||
|
||||
# Save results
|
||||
output_path = "research/horizon_optimization_results.csv"
|
||||
results_df.to_csv(output_path, index=False)
|
||||
print(f"\nResults saved to {output_path}")
|
||||
|
||||
return results_df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
research/regime_results.png
Normal file
BIN
research/regime_results.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 289 KiB |
47
scripts/download_multi_pair_data.py
Normal file
47
scripts/download_multi_pair_data.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Download historical data for Multi-Pair Divergence Strategy.
|
||||
|
||||
Downloads 1h OHLCV data for top 10 cryptocurrencies from OKX.
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
from engine.data_manager import DataManager
|
||||
from engine.market import MarketType
|
||||
from engine.logging_config import setup_logging, get_logger
|
||||
from strategies.multi_pair import MultiPairConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
"""Download data for all configured assets."""
|
||||
setup_logging()
|
||||
|
||||
config = MultiPairConfig()
|
||||
dm = DataManager()
|
||||
|
||||
logger.info("Downloading data for %d assets...", len(config.assets))
|
||||
|
||||
for symbol in config.assets:
|
||||
logger.info("Downloading %s perpetual 1h data...", symbol)
|
||||
try:
|
||||
df = dm.download_data(
|
||||
exchange_id=config.exchange_id,
|
||||
symbol=symbol,
|
||||
timeframe=config.timeframe,
|
||||
market_type=MarketType.PERPETUAL
|
||||
)
|
||||
if df is not None:
|
||||
logger.info("Downloaded %d candles for %s", len(df), symbol)
|
||||
else:
|
||||
logger.warning("No data downloaded for %s", symbol)
|
||||
except Exception as e:
|
||||
logger.error("Failed to download %s: %s", symbol, e)
|
||||
|
||||
logger.info("Download complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
156
scripts/run_multi_pair_backtest.py
Normal file
156
scripts/run_multi_pair_backtest.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run Multi-Pair Divergence Strategy backtest and compare with baseline.
|
||||
|
||||
Compares the multi-pair strategy against the single-pair BTC/ETH regime strategy.
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
from engine.backtester import Backtester
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import setup_logging, get_logger
|
||||
from engine.reporting import Reporter
|
||||
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
|
||||
from strategies.regime_strategy import RegimeReversionStrategy
|
||||
from engine.market import MarketType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_baseline():
|
||||
"""Run baseline BTC/ETH regime strategy."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("BASELINE: BTC/ETH Regime Reversion Strategy")
|
||||
logger.info("=" * 60)
|
||||
|
||||
dm = DataManager()
|
||||
bt = Backtester(dm)
|
||||
|
||||
strategy = RegimeReversionStrategy()
|
||||
|
||||
result = bt.run_strategy(
|
||||
strategy,
|
||||
'okx',
|
||||
'ETH-USDT',
|
||||
timeframe='1h',
|
||||
init_cash=10000
|
||||
)
|
||||
|
||||
logger.info("Baseline Results:")
|
||||
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
|
||||
logger.info(" Total Trades: %d", result.portfolio.trades.count())
|
||||
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def run_multi_pair(assets: list[str] | None = None):
|
||||
"""Run multi-pair divergence strategy."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("MULTI-PAIR: Divergence Selection Strategy")
|
||||
logger.info("=" * 60)
|
||||
|
||||
dm = DataManager()
|
||||
bt = Backtester(dm)
|
||||
|
||||
# Use provided assets or default
|
||||
if assets:
|
||||
config = MultiPairConfig(assets=assets)
|
||||
else:
|
||||
config = MultiPairConfig()
|
||||
|
||||
logger.info("Configured %d assets, %d pairs", len(config.assets), config.get_pair_count())
|
||||
|
||||
strategy = MultiPairDivergenceStrategy(config=config)
|
||||
|
||||
result = bt.run_strategy(
|
||||
strategy,
|
||||
'okx',
|
||||
'ETH-USDT', # Reference asset (not used for trading, just index alignment)
|
||||
timeframe='1h',
|
||||
init_cash=10000
|
||||
)
|
||||
|
||||
logger.info("Multi-Pair Results:")
|
||||
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
|
||||
logger.info(" Total Trades: %d", result.portfolio.trades.count())
|
||||
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def compare_results(baseline, multi_pair):
|
||||
"""Compare and display results."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("COMPARISON")
|
||||
logger.info("=" * 60)
|
||||
|
||||
baseline_return = baseline.portfolio.total_return() * 100
|
||||
multi_return = multi_pair.portfolio.total_return() * 100
|
||||
|
||||
improvement = multi_return - baseline_return
|
||||
|
||||
logger.info("Baseline Return: %.2f%%", baseline_return)
|
||||
logger.info("Multi-Pair Return: %.2f%%", multi_return)
|
||||
logger.info("Improvement: %.2f%% (%.1fx)",
|
||||
improvement,
|
||||
multi_return / baseline_return if baseline_return != 0 else 0)
|
||||
|
||||
baseline_trades = baseline.portfolio.trades.count()
|
||||
multi_trades = multi_pair.portfolio.trades.count()
|
||||
|
||||
logger.info("Baseline Trades: %d", baseline_trades)
|
||||
logger.info("Multi-Pair Trades: %d", multi_trades)
|
||||
|
||||
return {
|
||||
'baseline_return': baseline_return,
|
||||
'multi_pair_return': multi_return,
|
||||
'improvement': improvement,
|
||||
'baseline_trades': baseline_trades,
|
||||
'multi_pair_trades': multi_trades
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
setup_logging()
|
||||
|
||||
# Check available assets
|
||||
dm = DataManager()
|
||||
available = []
|
||||
|
||||
for symbol in MultiPairConfig().assets:
|
||||
try:
|
||||
dm.load_data('okx', symbol, '1h', market_type=MarketType.PERPETUAL)
|
||||
available.append(symbol)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
if len(available) < 2:
|
||||
logger.error(
|
||||
"Need at least 2 assets to run multi-pair strategy. "
|
||||
"Run: uv run python scripts/download_multi_pair_data.py"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Found data for %d assets: %s", len(available), available)
|
||||
|
||||
# Run baseline
|
||||
baseline_result = run_baseline()
|
||||
|
||||
# Run multi-pair
|
||||
multi_result = run_multi_pair(available)
|
||||
|
||||
# Compare
|
||||
comparison = compare_results(baseline_result, multi_result)
|
||||
|
||||
# Save reports
|
||||
reporter = Reporter()
|
||||
reporter.save_reports(multi_result, "multi_pair_divergence")
|
||||
|
||||
logger.info("Reports saved to backtest_logs/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
80
strategies/base.py
Normal file
80
strategies/base.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Base strategy class for all trading strategies.
|
||||
|
||||
Strategies should inherit from BaseStrategy and implement the run() method.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from engine.market import MarketType
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""
|
||||
Abstract base class for trading strategies.
|
||||
|
||||
Class Attributes:
|
||||
default_market_type: Default market type for this strategy
|
||||
default_leverage: Default leverage (only applies to perpetuals)
|
||||
default_sl_stop: Default stop-loss percentage
|
||||
default_tp_stop: Default take-profit percentage
|
||||
default_sl_trail: Whether stop-loss is trailing by default
|
||||
"""
|
||||
# Market configuration defaults
|
||||
default_market_type: MarketType = MarketType.SPOT
|
||||
default_leverage: int = 1
|
||||
|
||||
# Risk management defaults (can be overridden per strategy)
|
||||
default_sl_stop: float | None = None
|
||||
default_tp_stop: float | None = None
|
||||
default_sl_trail: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.params = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Run the strategy logic.
|
||||
|
||||
Args:
|
||||
close: Price series (can be multiple columns for grid search)
|
||||
**kwargs: Additional data (high, low, open, volume) and parameters
|
||||
|
||||
Returns:
|
||||
Tuple of 4 DataFrames/Series:
|
||||
- long_entries: Boolean signals to open long positions
|
||||
- long_exits: Boolean signals to close long positions
|
||||
- short_entries: Boolean signals to open short positions
|
||||
- short_exits: Boolean signals to close short positions
|
||||
|
||||
Note:
|
||||
For spot markets, short signals will be ignored.
|
||||
For backward compatibility, strategies can return 2-tuple (entries, exits)
|
||||
which will be interpreted as long-only signals.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_indicator(self, ind_cls, *args, **kwargs):
|
||||
"""Helper to run a vectorbt indicator."""
|
||||
return ind_cls.run(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_empty_signals(reference: pd.Series | pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Create an empty (all False) signal DataFrame matching the reference shape.
|
||||
|
||||
Args:
|
||||
reference: Series or DataFrame to match shape/index
|
||||
|
||||
Returns:
|
||||
DataFrame of False values with same shape as reference
|
||||
"""
|
||||
if isinstance(reference, pd.DataFrame):
|
||||
return pd.DataFrame(False, index=reference.index, columns=reference.columns)
|
||||
return pd.Series(False, index=reference.index)
|
||||
97
strategies/examples.py
Normal file
97
strategies/examples.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Example trading strategies for backtesting.
|
||||
|
||||
These are simple strategies demonstrating the framework usage.
|
||||
"""
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.market import MarketType
|
||||
from strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class RsiStrategy(BaseStrategy):
|
||||
"""
|
||||
RSI mean-reversion strategy.
|
||||
|
||||
Long entry when RSI crosses below oversold level.
|
||||
Long exit when RSI crosses above overbought level.
|
||||
"""
|
||||
default_market_type = MarketType.SPOT
|
||||
default_leverage = 1
|
||||
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
period: int = 14,
|
||||
rsi_lower: int = 30,
|
||||
rsi_upper: int = 70,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Generate RSI-based trading signals.
|
||||
|
||||
Args:
|
||||
close: Price series
|
||||
period: RSI calculation period
|
||||
rsi_lower: Oversold threshold (buy signal)
|
||||
rsi_upper: Overbought threshold (sell signal)
|
||||
|
||||
Returns:
|
||||
4-tuple of (long_entries, long_exits, short_entries, short_exits)
|
||||
"""
|
||||
# Calculate RSI
|
||||
rsi = vbt.RSI.run(close, window=period)
|
||||
|
||||
# Long signals: buy oversold, sell overbought
|
||||
long_entries = rsi.rsi_crossed_below(rsi_lower)
|
||||
long_exits = rsi.rsi_crossed_above(rsi_upper)
|
||||
|
||||
# No short signals for this strategy (spot-focused)
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
|
||||
|
||||
class MaCrossStrategy(BaseStrategy):
|
||||
"""
|
||||
Moving Average crossover strategy.
|
||||
|
||||
Long entry when fast MA crosses above slow MA.
|
||||
Long exit when fast MA crosses below slow MA.
|
||||
"""
|
||||
default_market_type = MarketType.SPOT
|
||||
default_leverage = 1
|
||||
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
fast_window: int = 10,
|
||||
slow_window: int = 20,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Generate MA crossover trading signals.
|
||||
|
||||
Args:
|
||||
close: Price series
|
||||
fast_window: Fast MA period
|
||||
slow_window: Slow MA period
|
||||
|
||||
Returns:
|
||||
4-tuple of (long_entries, long_exits, short_entries, short_exits)
|
||||
"""
|
||||
# Calculate Moving Averages
|
||||
fast_ma = vbt.MA.run(close, window=fast_window)
|
||||
slow_ma = vbt.MA.run(close, window=slow_window)
|
||||
|
||||
# Long signals
|
||||
long_entries = fast_ma.ma_crossed_above(slow_ma)
|
||||
long_exits = fast_ma.ma_crossed_below(slow_ma)
|
||||
|
||||
# No short signals for this strategy
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
164
strategies/factory.py
Normal file
164
strategies/factory.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Strategy factory for creating strategy instances with their parameters.
|
||||
|
||||
Centralizes strategy creation and parameter configuration.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from strategies.base import BaseStrategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyConfig:
|
||||
"""
|
||||
Configuration for a strategy including default and grid parameters.
|
||||
|
||||
Attributes:
|
||||
strategy_class: The strategy class to instantiate
|
||||
default_params: Parameters for single backtest runs
|
||||
grid_params: Parameters for grid search optimization
|
||||
"""
|
||||
strategy_class: type[BaseStrategy]
|
||||
default_params: dict[str, Any] = field(default_factory=dict)
|
||||
grid_params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _build_registry() -> dict[str, StrategyConfig]:
|
||||
"""
|
||||
Build the strategy registry lazily to avoid circular imports.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping strategy names to their configurations
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from strategies.examples import MaCrossStrategy, RsiStrategy
|
||||
from strategies.supertrend import MetaSupertrendStrategy
|
||||
from strategies.regime_strategy import RegimeReversionStrategy
|
||||
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
|
||||
|
||||
return {
|
||||
"rsi": StrategyConfig(
|
||||
strategy_class=RsiStrategy,
|
||||
default_params={
|
||||
'period': 14,
|
||||
'rsi_lower': 30,
|
||||
'rsi_upper': 70
|
||||
},
|
||||
grid_params={
|
||||
'period': np.arange(10, 25, 2),
|
||||
'rsi_lower': [20, 30, 40],
|
||||
'rsi_upper': [60, 70, 80]
|
||||
}
|
||||
),
|
||||
"macross": StrategyConfig(
|
||||
strategy_class=MaCrossStrategy,
|
||||
default_params={
|
||||
'fast_window': 10,
|
||||
'slow_window': 20
|
||||
},
|
||||
grid_params={
|
||||
'fast_window': np.arange(5, 20, 5),
|
||||
'slow_window': np.arange(20, 60, 10)
|
||||
}
|
||||
),
|
||||
"meta_st": StrategyConfig(
|
||||
strategy_class=MetaSupertrendStrategy,
|
||||
default_params={
|
||||
'period1': 12, 'multiplier1': 3.0,
|
||||
'period2': 10, 'multiplier2': 1.0,
|
||||
'period3': 11, 'multiplier3': 2.0
|
||||
},
|
||||
grid_params={
|
||||
'multiplier1': [2.0, 3.0, 4.0],
|
||||
'period1': [10, 12, 14],
|
||||
'period2': 11, 'multiplier2': 2.0,
|
||||
'period3': 12, 'multiplier3': 1.0
|
||||
}
|
||||
),
|
||||
"regime": StrategyConfig(
|
||||
strategy_class=RegimeReversionStrategy,
|
||||
default_params={
|
||||
# Optimal from walk-forward research (research/horizon_optimization_results.csv)
|
||||
'horizon': 102, # 4.25 days - best Net PnL
|
||||
'z_window': 24, # 24h rolling Z-score window
|
||||
'z_entry_threshold': 1.0, # Enter when |Z| > 1.0
|
||||
'profit_target': 0.005, # 0.5% target for ML labels
|
||||
'stop_loss': 0.06, # 6% stop loss
|
||||
'take_profit': 0.05, # 5% take profit
|
||||
'train_ratio': 0.7, # 70% train / 30% test
|
||||
'trend_window': 0, # Disabled SMA filter
|
||||
'use_funding_filter': True, # Enabled Funding filter
|
||||
'funding_threshold': 0.005 # 0.005% threshold (Proven profitable)
|
||||
},
|
||||
grid_params={
|
||||
'horizon': [84, 96, 102, 108, 120],
|
||||
'z_entry_threshold': [0.8, 1.0, 1.2],
|
||||
'stop_loss': [0.04, 0.06, 0.08],
|
||||
'funding_threshold': [0.005, 0.01, 0.02]
|
||||
}
|
||||
),
|
||||
"multi_pair": StrategyConfig(
|
||||
strategy_class=MultiPairDivergenceStrategy,
|
||||
default_params={
|
||||
# Multi-pair divergence strategy uses config object
|
||||
# Parameters passed here will override MultiPairConfig defaults
|
||||
},
|
||||
grid_params={
|
||||
'z_entry_threshold': [0.8, 1.0, 1.2],
|
||||
'prob_threshold': [0.4, 0.5, 0.6],
|
||||
'correlation_threshold': [0.75, 0.85, 0.95]
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
# Module-level cache for the registry
|
||||
_REGISTRY_CACHE: dict[str, StrategyConfig] | None = None
|
||||
|
||||
|
||||
def get_registry() -> dict[str, StrategyConfig]:
|
||||
"""Get the strategy registry, building it on first access."""
|
||||
global _REGISTRY_CACHE
|
||||
if _REGISTRY_CACHE is None:
|
||||
_REGISTRY_CACHE = _build_registry()
|
||||
return _REGISTRY_CACHE
|
||||
|
||||
|
||||
def get_strategy_names() -> list[str]:
|
||||
"""
|
||||
Get list of available strategy names.
|
||||
|
||||
Returns:
|
||||
List of strategy name strings
|
||||
"""
|
||||
return list(get_registry().keys())
|
||||
|
||||
|
||||
def get_strategy(name: str, is_grid: bool = False) -> tuple[BaseStrategy, dict[str, Any]]:
|
||||
"""
|
||||
Create a strategy instance with appropriate parameters.
|
||||
|
||||
Args:
|
||||
name: Strategy identifier (e.g., 'rsi', 'macross', 'meta_st')
|
||||
is_grid: If True, return grid search parameters
|
||||
|
||||
Returns:
|
||||
Tuple of (strategy instance, parameters dict)
|
||||
|
||||
Raises:
|
||||
KeyError: If strategy name is not found in registry
|
||||
"""
|
||||
registry = get_registry()
|
||||
|
||||
if name not in registry:
|
||||
available = ", ".join(registry.keys())
|
||||
raise KeyError(f"Unknown strategy '{name}'. Available: {available}")
|
||||
|
||||
config = registry[name]
|
||||
strategy = config.strategy_class()
|
||||
params = config.grid_params if is_grid else config.default_params
|
||||
|
||||
return strategy, params.copy()
|
||||
24
strategies/multi_pair/__init__.py
Normal file
24
strategies/multi_pair/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Multi-Pair Divergence Selection Strategy.
|
||||
|
||||
Extends regime detection to multiple cryptocurrency pairs and dynamically
|
||||
selects the most divergent pair for trading.
|
||||
"""
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import PairScanner, TradingPair
|
||||
from .correlation import CorrelationFilter
|
||||
from .feature_engine import MultiPairFeatureEngine
|
||||
from .divergence_scorer import DivergenceScorer
|
||||
from .strategy import MultiPairDivergenceStrategy
|
||||
from .funding import FundingRateFetcher
|
||||
|
||||
__all__ = [
|
||||
"MultiPairConfig",
|
||||
"PairScanner",
|
||||
"TradingPair",
|
||||
"CorrelationFilter",
|
||||
"MultiPairFeatureEngine",
|
||||
"DivergenceScorer",
|
||||
"MultiPairDivergenceStrategy",
|
||||
"FundingRateFetcher",
|
||||
]
|
||||
88
strategies/multi_pair/config.py
Normal file
88
strategies/multi_pair/config.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Configuration for Multi-Pair Divergence Strategy.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiPairConfig:
|
||||
"""
|
||||
Configuration parameters for multi-pair divergence strategy.
|
||||
|
||||
Attributes:
|
||||
assets: List of asset symbols to analyze (top 10 by market cap)
|
||||
z_window: Rolling window for Z-Score calculation (hours)
|
||||
z_entry_threshold: Minimum |Z-Score| to consider for entry
|
||||
prob_threshold: Minimum ML probability to consider for entry
|
||||
correlation_threshold: Max correlation to allow between pairs
|
||||
correlation_window: Rolling window for correlation (hours)
|
||||
atr_period: ATR lookback period for dynamic stops
|
||||
sl_atr_multiplier: Stop-loss as multiple of ATR
|
||||
tp_atr_multiplier: Take-profit as multiple of ATR
|
||||
train_ratio: Walk-forward train/test split ratio
|
||||
horizon: Look-ahead horizon for target calculation (hours)
|
||||
profit_target: Minimum profit threshold for target labels
|
||||
funding_threshold: Funding rate threshold for filtering
|
||||
"""
|
||||
# Asset Universe
|
||||
assets: list[str] = field(default_factory=lambda: [
|
||||
"BTC-USDT", "ETH-USDT", "SOL-USDT", "XRP-USDT", "BNB-USDT",
|
||||
"DOGE-USDT", "ADA-USDT", "AVAX-USDT", "LINK-USDT", "DOT-USDT"
|
||||
])
|
||||
|
||||
# Z-Score Thresholds
|
||||
z_window: int = 24
|
||||
z_entry_threshold: float = 1.0
|
||||
|
||||
# ML Thresholds
|
||||
prob_threshold: float = 0.5
|
||||
train_ratio: float = 0.7
|
||||
horizon: int = 102
|
||||
profit_target: float = 0.005
|
||||
|
||||
# Correlation Filtering
|
||||
correlation_threshold: float = 0.85
|
||||
correlation_window: int = 168 # 7 days in hours
|
||||
|
||||
# Risk Management - ATR-Based Stops
|
||||
# SL/TP are calculated as multiples of ATR
|
||||
# Mean ATR for crypto is ~0.6% per hour, so:
|
||||
# - 10x ATR = ~6% SL (matches previous fixed 6%)
|
||||
# - 8x ATR = ~5% TP (matches previous fixed 5%)
|
||||
atr_period: int = 14 # ATR lookback period (hours for 1h timeframe)
|
||||
sl_atr_multiplier: float = 10.0 # Stop-loss = entry +/- (ATR * multiplier)
|
||||
tp_atr_multiplier: float = 8.0 # Take-profit = entry +/- (ATR * multiplier)
|
||||
|
||||
# Fallback fixed percentages (used if ATR is unavailable)
|
||||
base_sl_pct: float = 0.06
|
||||
base_tp_pct: float = 0.05
|
||||
|
||||
# ATR bounds to prevent extreme stops
|
||||
min_sl_pct: float = 0.02 # Minimum 2% stop-loss
|
||||
max_sl_pct: float = 0.10 # Maximum 10% stop-loss
|
||||
min_tp_pct: float = 0.02 # Minimum 2% take-profit
|
||||
max_tp_pct: float = 0.15 # Maximum 15% take-profit
|
||||
|
||||
volatility_window: int = 24
|
||||
|
||||
# Funding Rate Filter
|
||||
# OKX funding rates are typically 0.0001 (0.01%) per 8h
|
||||
# Extreme funding is > 0.0005 (0.05%) which indicates crowded trade
|
||||
funding_threshold: float = 0.0005 # 0.05% - filter extreme funding
|
||||
|
||||
# Trade Management
|
||||
# Note: Setting min_hold_bars=0 and z_exit_threshold=0 gives best results
|
||||
# The mean-reversion exit at Z=0 is the primary profit driver
|
||||
min_hold_bars: int = 0 # Disabled - let mean reversion drive exits
|
||||
switch_threshold: float = 999.0 # Disabled - don't switch mid-trade
|
||||
cooldown_bars: int = 0 # Disabled - enter when signal appears
|
||||
z_exit_threshold: float = 0.0 # Exit at Z=0 (mean reversion complete)
|
||||
|
||||
# Exchange
|
||||
exchange_id: str = "okx"
|
||||
timeframe: str = "1h"
|
||||
|
||||
def get_pair_count(self) -> int:
|
||||
"""Calculate number of unique pairs from asset list."""
|
||||
n = len(self.assets)
|
||||
return n * (n - 1) // 2
|
||||
173
strategies/multi_pair/correlation.py
Normal file
173
strategies/multi_pair/correlation.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Correlation Filter for Multi-Pair Divergence Strategy.
|
||||
|
||||
Calculates rolling correlation matrix and filters pairs
|
||||
to avoid highly correlated positions.
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import TradingPair
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CorrelationFilter:
|
||||
"""
|
||||
Calculates and filters based on asset correlations.
|
||||
|
||||
Uses rolling correlation of returns to identify assets
|
||||
moving together, avoiding redundant positions.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiPairConfig):
|
||||
self.config = config
|
||||
self._correlation_matrix: pd.DataFrame | None = None
|
||||
self._last_update_idx: int = -1
|
||||
|
||||
def calculate_correlation_matrix(
|
||||
self,
|
||||
price_data: dict[str, pd.Series],
|
||||
current_idx: int | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate rolling correlation matrix between all assets.
|
||||
|
||||
Args:
|
||||
price_data: Dictionary mapping asset symbols to price series
|
||||
current_idx: Current bar index (for caching)
|
||||
|
||||
Returns:
|
||||
Correlation matrix DataFrame
|
||||
"""
|
||||
# Use cached if recent
|
||||
if (
|
||||
current_idx is not None
|
||||
and self._correlation_matrix is not None
|
||||
and current_idx - self._last_update_idx < 24 # Update every 24 bars
|
||||
):
|
||||
return self._correlation_matrix
|
||||
|
||||
# Calculate returns
|
||||
returns = {}
|
||||
for symbol, prices in price_data.items():
|
||||
returns[symbol] = prices.pct_change()
|
||||
|
||||
returns_df = pd.DataFrame(returns)
|
||||
|
||||
# Rolling correlation
|
||||
window = self.config.correlation_window
|
||||
|
||||
# Get latest correlation (last row of rolling correlation)
|
||||
if len(returns_df) >= window:
|
||||
rolling_corr = returns_df.rolling(window=window).corr()
|
||||
# Extract last timestamp correlation matrix
|
||||
last_idx = returns_df.index[-1]
|
||||
corr_matrix = rolling_corr.loc[last_idx]
|
||||
else:
|
||||
# Fallback to full-period correlation if not enough data
|
||||
corr_matrix = returns_df.corr()
|
||||
|
||||
self._correlation_matrix = corr_matrix
|
||||
if current_idx is not None:
|
||||
self._last_update_idx = current_idx
|
||||
|
||||
return corr_matrix
|
||||
|
||||
def filter_pairs(
|
||||
self,
|
||||
pairs: list[TradingPair],
|
||||
current_position_asset: str | None,
|
||||
price_data: dict[str, pd.Series],
|
||||
current_idx: int | None = None
|
||||
) -> list[TradingPair]:
|
||||
"""
|
||||
Filter pairs based on correlation with current position.
|
||||
|
||||
If we have an open position in an asset, exclude pairs where
|
||||
either asset is highly correlated with the held asset.
|
||||
|
||||
Args:
|
||||
pairs: List of candidate pairs
|
||||
current_position_asset: Currently held asset (or None)
|
||||
price_data: Dictionary of price series by symbol
|
||||
current_idx: Current bar index for caching
|
||||
|
||||
Returns:
|
||||
Filtered list of pairs
|
||||
"""
|
||||
if current_position_asset is None:
|
||||
return pairs
|
||||
|
||||
corr_matrix = self.calculate_correlation_matrix(price_data, current_idx)
|
||||
threshold = self.config.correlation_threshold
|
||||
|
||||
filtered = []
|
||||
for pair in pairs:
|
||||
# Check correlation of base and quote with held asset
|
||||
base_corr = self._get_correlation(
|
||||
corr_matrix, pair.base_asset, current_position_asset
|
||||
)
|
||||
quote_corr = self._get_correlation(
|
||||
corr_matrix, pair.quote_asset, current_position_asset
|
||||
)
|
||||
|
||||
# Filter if either asset highly correlated with position
|
||||
if abs(base_corr) > threshold or abs(quote_corr) > threshold:
|
||||
logger.debug(
|
||||
"Filtered %s: base_corr=%.2f, quote_corr=%.2f (held: %s)",
|
||||
pair.name, base_corr, quote_corr, current_position_asset
|
||||
)
|
||||
continue
|
||||
|
||||
filtered.append(pair)
|
||||
|
||||
if len(filtered) < len(pairs):
|
||||
logger.info(
|
||||
"Correlation filter: %d/%d pairs remaining (held: %s)",
|
||||
len(filtered), len(pairs), current_position_asset
|
||||
)
|
||||
|
||||
return filtered
|
||||
|
||||
def _get_correlation(
|
||||
self,
|
||||
corr_matrix: pd.DataFrame,
|
||||
asset1: str,
|
||||
asset2: str
|
||||
) -> float:
|
||||
"""
|
||||
Get correlation between two assets from matrix.
|
||||
|
||||
Args:
|
||||
corr_matrix: Correlation matrix
|
||||
asset1: First asset symbol
|
||||
asset2: Second asset symbol
|
||||
|
||||
Returns:
|
||||
Correlation coefficient (-1 to 1), or 0 if not found
|
||||
"""
|
||||
if asset1 == asset2:
|
||||
return 1.0
|
||||
|
||||
try:
|
||||
return corr_matrix.loc[asset1, asset2]
|
||||
except KeyError:
|
||||
return 0.0
|
||||
|
||||
def get_correlation_report(
|
||||
self,
|
||||
price_data: dict[str, pd.Series]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Generate a readable correlation report.
|
||||
|
||||
Args:
|
||||
price_data: Dictionary of price series
|
||||
|
||||
Returns:
|
||||
Correlation matrix as DataFrame
|
||||
"""
|
||||
return self.calculate_correlation_matrix(price_data)
|
||||
311
strategies/multi_pair/divergence_scorer.py
Normal file
311
strategies/multi_pair/divergence_scorer.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Divergence Scorer for Multi-Pair Strategy.
|
||||
|
||||
Ranks pairs by divergence score and selects the best candidate.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import TradingPair
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DivergenceSignal:
|
||||
"""
|
||||
Signal for a divergent pair.
|
||||
|
||||
Attributes:
|
||||
pair: Trading pair
|
||||
z_score: Current Z-Score of the spread
|
||||
probability: ML model probability of profitable reversion
|
||||
divergence_score: Combined score (|z_score| * probability)
|
||||
direction: 'long' or 'short' (relative to base asset)
|
||||
base_price: Current price of base asset
|
||||
quote_price: Current price of quote asset
|
||||
atr: Average True Range in price units
|
||||
atr_pct: ATR as percentage of price
|
||||
"""
|
||||
pair: TradingPair
|
||||
z_score: float
|
||||
probability: float
|
||||
divergence_score: float
|
||||
direction: str
|
||||
base_price: float
|
||||
quote_price: float
|
||||
atr: float
|
||||
atr_pct: float
|
||||
timestamp: pd.Timestamp
|
||||
|
||||
|
||||
class DivergenceScorer:
|
||||
"""
|
||||
Scores and ranks pairs by divergence potential.
|
||||
|
||||
Uses ML model predictions combined with Z-Score magnitude
|
||||
to identify the most promising mean-reversion opportunity.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiPairConfig, model_path: str = "data/multi_pair_model.pkl"):
|
||||
self.config = config
|
||||
self.model_path = Path(model_path)
|
||||
self.model: RandomForestClassifier | None = None
|
||||
self.feature_cols: list[str] | None = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Load pre-trained model if available."""
|
||||
if self.model_path.exists():
|
||||
try:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
saved = pickle.load(f)
|
||||
self.model = saved['model']
|
||||
self.feature_cols = saved['feature_cols']
|
||||
logger.info("Loaded model from %s", self.model_path)
|
||||
except Exception as e:
|
||||
logger.warning("Could not load model: %s", e)
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save trained model."""
|
||||
if self.model is None:
|
||||
return
|
||||
|
||||
self.model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.model_path, 'wb') as f:
|
||||
pickle.dump({
|
||||
'model': self.model,
|
||||
'feature_cols': self.feature_cols,
|
||||
}, f)
|
||||
logger.info("Saved model to %s", self.model_path)
|
||||
|
||||
def train_model(
|
||||
self,
|
||||
combined_features: pd.DataFrame,
|
||||
pair_features: dict[str, pd.DataFrame]
|
||||
) -> None:
|
||||
"""
|
||||
Train universal model on all pairs.
|
||||
|
||||
Args:
|
||||
combined_features: Combined feature DataFrame from all pairs
|
||||
pair_features: Individual pair feature DataFrames (for target calculation)
|
||||
"""
|
||||
logger.info("Training universal model on %d samples...", len(combined_features))
|
||||
|
||||
z_thresh = self.config.z_entry_threshold
|
||||
horizon = self.config.horizon
|
||||
profit_target = self.config.profit_target
|
||||
|
||||
# Calculate targets for each pair
|
||||
all_targets = []
|
||||
all_features = []
|
||||
|
||||
for pair_id, features in pair_features.items():
|
||||
if len(features) < horizon + 50:
|
||||
continue
|
||||
|
||||
spread = features['spread']
|
||||
z_score = features['z_score']
|
||||
|
||||
# Future price movements
|
||||
future_min = spread.rolling(window=horizon).min().shift(-horizon)
|
||||
future_max = spread.rolling(window=horizon).max().shift(-horizon)
|
||||
|
||||
# Target labels
|
||||
target_short = spread * (1 - profit_target)
|
||||
target_long = spread * (1 + profit_target)
|
||||
|
||||
success_short = (z_score > z_thresh) & (future_min < target_short)
|
||||
success_long = (z_score < -z_thresh) & (future_max > target_long)
|
||||
|
||||
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||
|
||||
# Valid mask (exclude rows without complete future data)
|
||||
valid_mask = future_min.notna() & future_max.notna()
|
||||
|
||||
# Collect valid samples
|
||||
valid_features = features[valid_mask]
|
||||
valid_targets = targets[valid_mask.values]
|
||||
|
||||
if len(valid_features) > 0:
|
||||
all_features.append(valid_features)
|
||||
all_targets.extend(valid_targets)
|
||||
|
||||
if not all_features:
|
||||
logger.warning("No valid training samples")
|
||||
return
|
||||
|
||||
# Combine all training data
|
||||
X_df = pd.concat(all_features, ignore_index=True)
|
||||
y = np.array(all_targets)
|
||||
|
||||
# Get feature columns
|
||||
exclude_cols = [
|
||||
'pair_id', 'base_asset', 'quote_asset',
|
||||
'spread', 'base_close', 'quote_close', 'base_volume'
|
||||
]
|
||||
self.feature_cols = [c for c in X_df.columns if c not in exclude_cols]
|
||||
|
||||
# Prepare features
|
||||
X = X_df[self.feature_cols].fillna(0)
|
||||
X = X.replace([np.inf, -np.inf], 0)
|
||||
|
||||
# Train model
|
||||
self.model = RandomForestClassifier(
|
||||
n_estimators=300,
|
||||
max_depth=5,
|
||||
min_samples_leaf=30,
|
||||
class_weight={0: 1, 1: 3},
|
||||
random_state=42
|
||||
)
|
||||
self.model.fit(X, y)
|
||||
|
||||
logger.info(
|
||||
"Model trained on %d samples, %d features, %.1f%% positive class",
|
||||
len(X), len(self.feature_cols), y.mean() * 100
|
||||
)
|
||||
self.save_model()
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pair_features: dict[str, pd.DataFrame],
|
||||
pairs: list[TradingPair],
|
||||
timestamp: pd.Timestamp | None = None
|
||||
) -> list[DivergenceSignal]:
|
||||
"""
|
||||
Score all pairs and return ranked signals.
|
||||
|
||||
Args:
|
||||
pair_features: Feature DataFrames by pair_id
|
||||
pairs: List of TradingPair objects
|
||||
timestamp: Current timestamp for feature extraction
|
||||
|
||||
Returns:
|
||||
List of DivergenceSignal sorted by score (descending)
|
||||
"""
|
||||
if self.model is None:
|
||||
logger.warning("Model not trained, returning empty signals")
|
||||
return []
|
||||
|
||||
signals = []
|
||||
pair_map = {p.pair_id: p for p in pairs}
|
||||
|
||||
for pair_id, features in pair_features.items():
|
||||
if pair_id not in pair_map:
|
||||
continue
|
||||
|
||||
pair = pair_map[pair_id]
|
||||
|
||||
# Get latest features
|
||||
if timestamp is not None:
|
||||
valid = features[features.index <= timestamp]
|
||||
if len(valid) == 0:
|
||||
continue
|
||||
latest = valid.iloc[-1]
|
||||
ts = valid.index[-1]
|
||||
else:
|
||||
latest = features.iloc[-1]
|
||||
ts = features.index[-1]
|
||||
|
||||
z_score = latest['z_score']
|
||||
|
||||
# Skip if Z-score below threshold
|
||||
if abs(z_score) < self.config.z_entry_threshold:
|
||||
continue
|
||||
|
||||
# Prepare features for prediction
|
||||
feature_row = latest[self.feature_cols].fillna(0).infer_objects(copy=False)
|
||||
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
||||
X = pd.DataFrame([feature_row.values], columns=self.feature_cols)
|
||||
|
||||
# Predict probability
|
||||
prob = self.model.predict_proba(X)[0, 1]
|
||||
|
||||
# Skip if probability below threshold
|
||||
if prob < self.config.prob_threshold:
|
||||
continue
|
||||
|
||||
# Apply funding rate filter
|
||||
# Block trades where funding opposes our direction
|
||||
base_funding = latest.get('base_funding', 0) or 0
|
||||
funding_thresh = self.config.funding_threshold
|
||||
|
||||
if z_score > 0: # Short signal
|
||||
# High negative funding = shorts are paying -> skip
|
||||
if base_funding < -funding_thresh:
|
||||
logger.debug(
|
||||
"Skipping %s short: funding too negative (%.4f)",
|
||||
pair.name, base_funding
|
||||
)
|
||||
continue
|
||||
else: # Long signal
|
||||
# High positive funding = longs are paying -> skip
|
||||
if base_funding > funding_thresh:
|
||||
logger.debug(
|
||||
"Skipping %s long: funding too positive (%.4f)",
|
||||
pair.name, base_funding
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate divergence score
|
||||
divergence_score = abs(z_score) * prob
|
||||
|
||||
# Determine direction
|
||||
# Z > 0: Spread high (base expensive vs quote) -> Short base
|
||||
# Z < 0: Spread low (base cheap vs quote) -> Long base
|
||||
direction = 'short' if z_score > 0 else 'long'
|
||||
|
||||
signal = DivergenceSignal(
|
||||
pair=pair,
|
||||
z_score=z_score,
|
||||
probability=prob,
|
||||
divergence_score=divergence_score,
|
||||
direction=direction,
|
||||
base_price=latest['base_close'],
|
||||
quote_price=latest['quote_close'],
|
||||
atr=latest.get('atr_base', 0),
|
||||
atr_pct=latest.get('atr_pct_base', 0.02),
|
||||
timestamp=ts
|
||||
)
|
||||
signals.append(signal)
|
||||
|
||||
# Sort by divergence score (highest first)
|
||||
signals.sort(key=lambda s: s.divergence_score, reverse=True)
|
||||
|
||||
if signals:
|
||||
logger.debug(
|
||||
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f)",
|
||||
len(signals),
|
||||
signals[0].pair.name,
|
||||
signals[0].divergence_score,
|
||||
signals[0].z_score,
|
||||
signals[0].probability
|
||||
)
|
||||
|
||||
return signals
|
||||
|
||||
def select_best_pair(
|
||||
self,
|
||||
signals: list[DivergenceSignal]
|
||||
) -> DivergenceSignal | None:
|
||||
"""
|
||||
Select the best pair from scored signals.
|
||||
|
||||
Args:
|
||||
signals: List of DivergenceSignal (pre-sorted by score)
|
||||
|
||||
Returns:
|
||||
Best signal or None if no valid candidates
|
||||
"""
|
||||
if not signals:
|
||||
return None
|
||||
return signals[0]
|
||||
433
strategies/multi_pair/feature_engine.py
Normal file
433
strategies/multi_pair/feature_engine.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Feature Engineering for Multi-Pair Divergence Strategy.
|
||||
|
||||
Calculates features for all pairs in the universe, including
|
||||
spread technicals, volatility, and on-chain data.
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import ta
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from engine.data_manager import DataManager
|
||||
from engine.market import MarketType
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import TradingPair
|
||||
from .funding import FundingRateFetcher
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MultiPairFeatureEngine:
|
||||
"""
|
||||
Calculates features for multiple trading pairs.
|
||||
|
||||
Generates consistent feature sets across all pairs for
|
||||
the universal ML model.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiPairConfig):
|
||||
self.config = config
|
||||
self.dm = DataManager()
|
||||
self.funding_fetcher = FundingRateFetcher()
|
||||
self._funding_data: pd.DataFrame | None = None
|
||||
|
||||
def load_all_assets(
|
||||
self,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Load OHLCV data for all assets in the universe.
|
||||
|
||||
Args:
|
||||
start_date: Start date filter (YYYY-MM-DD)
|
||||
end_date: End date filter (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to OHLCV DataFrame
|
||||
"""
|
||||
data = {}
|
||||
market_type = MarketType.PERPETUAL
|
||||
|
||||
for symbol in self.config.assets:
|
||||
try:
|
||||
df = self.dm.load_data(
|
||||
self.config.exchange_id,
|
||||
symbol,
|
||||
self.config.timeframe,
|
||||
market_type
|
||||
)
|
||||
|
||||
# Apply date filters
|
||||
if start_date:
|
||||
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
|
||||
if end_date:
|
||||
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
|
||||
|
||||
if len(df) >= 200: # Minimum data requirement
|
||||
data[symbol] = df
|
||||
logger.debug("Loaded %s: %d bars", symbol, len(df))
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping %s: insufficient data (%d bars)",
|
||||
symbol, len(df)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logger.warning("Data not found for %s", symbol)
|
||||
except Exception as e:
|
||||
logger.error("Error loading %s: %s", symbol, e)
|
||||
|
||||
logger.info("Loaded %d/%d assets", len(data), len(self.config.assets))
|
||||
return data
|
||||
|
||||
def load_funding_data(
|
||||
self,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
use_cache: bool = True
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Load funding rate data for all assets.
|
||||
|
||||
Args:
|
||||
start_date: Start date filter
|
||||
end_date: End date filter
|
||||
use_cache: Whether to use cached data
|
||||
|
||||
Returns:
|
||||
DataFrame with funding rates for all assets
|
||||
"""
|
||||
self._funding_data = self.funding_fetcher.get_funding_data(
|
||||
self.config.assets,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
if self._funding_data is not None and not self._funding_data.empty:
|
||||
logger.info(
|
||||
"Loaded funding data: %d rows, %d assets",
|
||||
len(self._funding_data),
|
||||
len(self._funding_data.columns)
|
||||
)
|
||||
else:
|
||||
logger.warning("No funding data available")
|
||||
|
||||
return self._funding_data
|
||||
|
||||
def calculate_pair_features(
|
||||
self,
|
||||
pair: TradingPair,
|
||||
asset_data: dict[str, pd.DataFrame],
|
||||
on_chain_data: pd.DataFrame | None = None
|
||||
) -> pd.DataFrame | None:
|
||||
"""
|
||||
Calculate features for a single pair.
|
||||
|
||||
Args:
|
||||
pair: Trading pair
|
||||
asset_data: Dictionary of OHLCV DataFrames by symbol
|
||||
on_chain_data: Optional on-chain data (funding, inflows)
|
||||
|
||||
Returns:
|
||||
DataFrame with features, or None if insufficient data
|
||||
"""
|
||||
base = pair.base_asset
|
||||
quote = pair.quote_asset
|
||||
|
||||
if base not in asset_data or quote not in asset_data:
|
||||
return None
|
||||
|
||||
df_base = asset_data[base]
|
||||
df_quote = asset_data[quote]
|
||||
|
||||
# Align indices
|
||||
common_idx = df_base.index.intersection(df_quote.index)
|
||||
if len(common_idx) < 200:
|
||||
logger.debug("Pair %s: insufficient aligned data", pair.name)
|
||||
return None
|
||||
|
||||
df_a = df_base.loc[common_idx]
|
||||
df_b = df_quote.loc[common_idx]
|
||||
|
||||
# Calculate spread (base / quote)
|
||||
spread = df_a['close'] / df_b['close']
|
||||
|
||||
# Z-Score
|
||||
z_window = self.config.z_window
|
||||
rolling_mean = spread.rolling(window=z_window).mean()
|
||||
rolling_std = spread.rolling(window=z_window).std()
|
||||
z_score = (spread - rolling_mean) / rolling_std
|
||||
|
||||
# Spread Technicals
|
||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||
spread_roc = spread.pct_change(periods=5) * 100
|
||||
spread_change_1h = spread.pct_change(periods=1)
|
||||
|
||||
# Volume Analysis
|
||||
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
|
||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
|
||||
|
||||
# Volatility
|
||||
ret_a = df_a['close'].pct_change()
|
||||
ret_b = df_b['close'].pct_change()
|
||||
vol_a = ret_a.rolling(window=z_window).std()
|
||||
vol_b = ret_b.rolling(window=z_window).std()
|
||||
vol_spread_ratio = vol_a / (vol_b + 1e-10)
|
||||
|
||||
# Realized Volatility (for dynamic SL/TP)
|
||||
realized_vol_a = ret_a.rolling(window=self.config.volatility_window).std()
|
||||
realized_vol_b = ret_b.rolling(window=self.config.volatility_window).std()
|
||||
|
||||
# ATR (Average True Range) for dynamic stops
|
||||
# ATR = average of max(high-low, |high-prev_close|, |low-prev_close|)
|
||||
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
|
||||
high_b, low_b, close_b = df_b['high'], df_b['low'], df_b['close']
|
||||
|
||||
# True Range for base asset
|
||||
tr_a = pd.concat([
|
||||
high_a - low_a,
|
||||
(high_a - close_a.shift(1)).abs(),
|
||||
(low_a - close_a.shift(1)).abs()
|
||||
], axis=1).max(axis=1)
|
||||
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
|
||||
|
||||
# True Range for quote asset
|
||||
tr_b = pd.concat([
|
||||
high_b - low_b,
|
||||
(high_b - close_b.shift(1)).abs(),
|
||||
(low_b - close_b.shift(1)).abs()
|
||||
], axis=1).max(axis=1)
|
||||
atr_b = tr_b.rolling(window=self.config.atr_period).mean()
|
||||
|
||||
# ATR as percentage of price (normalized)
|
||||
atr_pct_a = atr_a / close_a
|
||||
atr_pct_b = atr_b / close_b
|
||||
|
||||
# Build feature DataFrame
|
||||
features = pd.DataFrame(index=common_idx)
|
||||
features['pair_id'] = pair.pair_id
|
||||
features['base_asset'] = base
|
||||
features['quote_asset'] = quote
|
||||
|
||||
# Price data (for reference, not features)
|
||||
features['spread'] = spread
|
||||
features['base_close'] = df_a['close']
|
||||
features['quote_close'] = df_b['close']
|
||||
features['base_volume'] = df_a['volume']
|
||||
|
||||
# Core Features
|
||||
features['z_score'] = z_score
|
||||
features['spread_rsi'] = spread_rsi
|
||||
features['spread_roc'] = spread_roc
|
||||
features['spread_change_1h'] = spread_change_1h
|
||||
features['vol_ratio'] = vol_ratio
|
||||
features['vol_ratio_rel'] = vol_ratio_rel
|
||||
features['vol_diff_ratio'] = vol_spread_ratio
|
||||
|
||||
# Volatility for SL/TP
|
||||
features['realized_vol_base'] = realized_vol_a
|
||||
features['realized_vol_quote'] = realized_vol_b
|
||||
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
|
||||
|
||||
# ATR for dynamic stops (in price units and as percentage)
|
||||
features['atr_base'] = atr_a
|
||||
features['atr_quote'] = atr_b
|
||||
features['atr_pct_base'] = atr_pct_a
|
||||
features['atr_pct_quote'] = atr_pct_b
|
||||
features['atr_pct_avg'] = (atr_pct_a + atr_pct_b) / 2
|
||||
|
||||
# Pair encoding (for universal model)
|
||||
# Using base and quote indices for hierarchical encoding
|
||||
assets = self.config.assets
|
||||
features['base_idx'] = assets.index(base) if base in assets else -1
|
||||
features['quote_idx'] = assets.index(quote) if quote in assets else -1
|
||||
|
||||
# Add funding and on-chain features
|
||||
# Funding data is always added from self._funding_data (OKX, all 10 assets)
|
||||
# On-chain data is optional (CryptoQuant, BTC/ETH only)
|
||||
features = self._add_on_chain_features(
|
||||
features, on_chain_data, base, quote
|
||||
)
|
||||
|
||||
# Drop rows with NaN in core features only (not funding/on-chain)
|
||||
core_cols = [
|
||||
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
|
||||
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
|
||||
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
|
||||
'atr_base', 'atr_pct_base' # ATR is core for SL/TP
|
||||
]
|
||||
features = features.dropna(subset=core_cols)
|
||||
|
||||
# Fill missing funding/on-chain features with 0 (neutral)
|
||||
optional_cols = [
|
||||
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg',
|
||||
'base_inflow', 'quote_inflow', 'inflow_ratio'
|
||||
]
|
||||
for col in optional_cols:
|
||||
if col in features.columns:
|
||||
features[col] = features[col].fillna(0)
|
||||
|
||||
return features
|
||||
|
||||
def calculate_all_pair_features(
|
||||
self,
|
||||
pairs: list[TradingPair],
|
||||
asset_data: dict[str, pd.DataFrame],
|
||||
on_chain_data: pd.DataFrame | None = None
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Calculate features for all pairs.
|
||||
|
||||
Args:
|
||||
pairs: List of trading pairs
|
||||
asset_data: Dictionary of OHLCV DataFrames
|
||||
on_chain_data: Optional on-chain data
|
||||
|
||||
Returns:
|
||||
Dictionary mapping pair_id to feature DataFrame
|
||||
"""
|
||||
all_features = {}
|
||||
|
||||
for pair in pairs:
|
||||
features = self.calculate_pair_features(
|
||||
pair, asset_data, on_chain_data
|
||||
)
|
||||
if features is not None and len(features) > 0:
|
||||
all_features[pair.pair_id] = features
|
||||
|
||||
logger.info(
|
||||
"Calculated features for %d/%d pairs",
|
||||
len(all_features), len(pairs)
|
||||
)
|
||||
|
||||
return all_features
|
||||
|
||||
def get_combined_features(
|
||||
self,
|
||||
pair_features: dict[str, pd.DataFrame],
|
||||
timestamp: pd.Timestamp | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Combine all pair features into a single DataFrame.
|
||||
|
||||
Useful for batch model prediction across all pairs.
|
||||
|
||||
Args:
|
||||
pair_features: Dictionary of feature DataFrames by pair_id
|
||||
timestamp: Optional specific timestamp to filter to
|
||||
|
||||
Returns:
|
||||
Combined DataFrame with all pairs as rows
|
||||
"""
|
||||
if not pair_features:
|
||||
return pd.DataFrame()
|
||||
|
||||
if timestamp is not None:
|
||||
# Get latest row from each pair at or before timestamp
|
||||
rows = []
|
||||
for pair_id, features in pair_features.items():
|
||||
valid = features[features.index <= timestamp]
|
||||
if len(valid) > 0:
|
||||
row = valid.iloc[-1:].copy()
|
||||
rows.append(row)
|
||||
|
||||
if rows:
|
||||
return pd.concat(rows, ignore_index=False)
|
||||
return pd.DataFrame()
|
||||
|
||||
# Combine all features (for training)
|
||||
return pd.concat(pair_features.values(), ignore_index=False)
|
||||
|
||||
def _add_on_chain_features(
|
||||
self,
|
||||
features: pd.DataFrame,
|
||||
on_chain_data: pd.DataFrame | None,
|
||||
base_asset: str,
|
||||
quote_asset: str
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Add on-chain and funding rate features for the pair.
|
||||
|
||||
Uses funding data from OKX (all 10 assets) and on-chain data
|
||||
from CryptoQuant (BTC/ETH only for inflows).
|
||||
"""
|
||||
base_short = base_asset.replace('-USDT', '').lower()
|
||||
quote_short = quote_asset.replace('-USDT', '').lower()
|
||||
|
||||
# Add funding rates from cached funding data
|
||||
if self._funding_data is not None and not self._funding_data.empty:
|
||||
funding_aligned = self._funding_data.reindex(
|
||||
features.index, method='ffill'
|
||||
)
|
||||
|
||||
base_funding_col = f'{base_short}_funding'
|
||||
quote_funding_col = f'{quote_short}_funding'
|
||||
|
||||
if base_funding_col in funding_aligned.columns:
|
||||
features['base_funding'] = funding_aligned[base_funding_col]
|
||||
if quote_funding_col in funding_aligned.columns:
|
||||
features['quote_funding'] = funding_aligned[quote_funding_col]
|
||||
|
||||
# Funding difference (positive = base has higher funding)
|
||||
if 'base_funding' in features.columns and 'quote_funding' in features.columns:
|
||||
features['funding_diff'] = (
|
||||
features['base_funding'] - features['quote_funding']
|
||||
)
|
||||
|
||||
# Funding sentiment: average of both assets
|
||||
features['funding_avg'] = (
|
||||
features['base_funding'] + features['quote_funding']
|
||||
) / 2
|
||||
|
||||
# Add on-chain features from CryptoQuant (BTC/ETH only)
|
||||
if on_chain_data is not None and not on_chain_data.empty:
|
||||
cq_aligned = on_chain_data.reindex(features.index, method='ffill')
|
||||
|
||||
# Inflows (only available for BTC/ETH)
|
||||
base_inflow_col = f'{base_short}_inflow'
|
||||
quote_inflow_col = f'{quote_short}_inflow'
|
||||
|
||||
if base_inflow_col in cq_aligned.columns:
|
||||
features['base_inflow'] = cq_aligned[base_inflow_col]
|
||||
if quote_inflow_col in cq_aligned.columns:
|
||||
features['quote_inflow'] = cq_aligned[quote_inflow_col]
|
||||
|
||||
if 'base_inflow' in features.columns and 'quote_inflow' in features.columns:
|
||||
features['inflow_ratio'] = (
|
||||
features['base_inflow'] /
|
||||
(features['quote_inflow'] + 1)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
def get_feature_columns(self) -> list[str]:
|
||||
"""
|
||||
Get list of feature columns for ML model.
|
||||
|
||||
Excludes metadata and target-related columns.
|
||||
|
||||
Returns:
|
||||
List of feature column names
|
||||
"""
|
||||
# Core features (always present)
|
||||
core_features = [
|
||||
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
|
||||
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
|
||||
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
|
||||
'base_idx', 'quote_idx'
|
||||
]
|
||||
|
||||
# Funding features (now available for all 10 assets via OKX)
|
||||
funding_features = [
|
||||
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg'
|
||||
]
|
||||
|
||||
# On-chain features (BTC/ETH only via CryptoQuant)
|
||||
onchain_features = [
|
||||
'base_inflow', 'quote_inflow', 'inflow_ratio'
|
||||
]
|
||||
|
||||
return core_features + funding_features + onchain_features
|
||||
272
strategies/multi_pair/funding.py
Normal file
272
strategies/multi_pair/funding.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Funding Rate Fetcher for Multi-Pair Strategy.
|
||||
|
||||
Fetches historical funding rates from OKX for all assets.
|
||||
CryptoQuant only supports BTC/ETH, so we use OKX for the full universe.
|
||||
"""
|
||||
import time
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import ccxt
|
||||
import pandas as pd
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FundingRateFetcher:
|
||||
"""
|
||||
Fetches and caches funding rate data from OKX.
|
||||
|
||||
OKX funding rates are settled every 8 hours (00:00, 08:00, 16:00 UTC).
|
||||
This fetcher retrieves historical funding rate data and aligns it
|
||||
to hourly candles for use in the multi-pair strategy.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = "data/funding"):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.exchange: ccxt.okx | None = None
|
||||
|
||||
def _init_exchange(self) -> None:
|
||||
"""Initialize OKX exchange connection."""
|
||||
if self.exchange is None:
|
||||
self.exchange = ccxt.okx({
|
||||
'enableRateLimit': True,
|
||||
'options': {'defaultType': 'swap'}
|
||||
})
|
||||
self.exchange.load_markets()
|
||||
|
||||
def fetch_funding_history(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 100
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch historical funding rates for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Asset symbol (e.g., 'BTC-USDT')
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
limit: Max records per request
|
||||
|
||||
Returns:
|
||||
DataFrame with funding rate history
|
||||
"""
|
||||
self._init_exchange()
|
||||
|
||||
# Convert symbol format
|
||||
base = symbol.replace('-USDT', '')
|
||||
okx_symbol = f"{base}/USDT:USDT"
|
||||
|
||||
try:
|
||||
# OKX funding rate history endpoint
|
||||
# Uses fetch_funding_rate_history if available
|
||||
all_funding = []
|
||||
|
||||
# Parse dates
|
||||
if start_date:
|
||||
since = self.exchange.parse8601(f"{start_date}T00:00:00Z")
|
||||
else:
|
||||
# Default to 1 year ago
|
||||
since = self.exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
|
||||
|
||||
if end_date:
|
||||
until = self.exchange.parse8601(f"{end_date}T23:59:59Z")
|
||||
else:
|
||||
until = self.exchange.milliseconds()
|
||||
|
||||
# Fetch in batches
|
||||
current_since = since
|
||||
while current_since < until:
|
||||
try:
|
||||
funding = self.exchange.fetch_funding_rate_history(
|
||||
okx_symbol,
|
||||
since=current_since,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if not funding:
|
||||
break
|
||||
|
||||
all_funding.extend(funding)
|
||||
|
||||
# Move to next batch
|
||||
last_ts = funding[-1]['timestamp']
|
||||
if last_ts <= current_since:
|
||||
break
|
||||
current_since = last_ts + 1
|
||||
|
||||
time.sleep(0.1) # Rate limit
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error fetching funding batch for %s: %s",
|
||||
symbol, str(e)[:50]
|
||||
)
|
||||
break
|
||||
|
||||
if not all_funding:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(all_funding)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
df = df[['fundingRate']].rename(columns={'fundingRate': 'funding_rate'})
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
# Remove duplicates
|
||||
df = df[~df.index.duplicated(keep='first')]
|
||||
|
||||
logger.info("Fetched %d funding records for %s", len(df), symbol)
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to fetch funding for %s: %s", symbol, e)
|
||||
return pd.DataFrame()
|
||||
|
||||
def fetch_all_assets(
|
||||
self,
|
||||
assets: list[str],
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch funding rates for all assets and combine.
|
||||
|
||||
Args:
|
||||
assets: List of asset symbols (e.g., ['BTC-USDT', 'ETH-USDT'])
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Returns:
|
||||
Combined DataFrame with columns like 'btc_funding', 'eth_funding', etc.
|
||||
"""
|
||||
combined = pd.DataFrame()
|
||||
|
||||
for symbol in assets:
|
||||
df = self.fetch_funding_history(symbol, start_date, end_date)
|
||||
|
||||
if df.empty:
|
||||
continue
|
||||
|
||||
# Rename column
|
||||
asset_name = symbol.replace('-USDT', '').lower()
|
||||
col_name = f"{asset_name}_funding"
|
||||
df = df.rename(columns={'funding_rate': col_name})
|
||||
|
||||
if combined.empty:
|
||||
combined = df
|
||||
else:
|
||||
combined = combined.join(df, how='outer')
|
||||
|
||||
time.sleep(0.2) # Be nice to API
|
||||
|
||||
# Forward fill to hourly (funding is every 8h)
|
||||
if not combined.empty:
|
||||
combined = combined.sort_index()
|
||||
combined = combined.ffill()
|
||||
|
||||
return combined
|
||||
|
||||
def save_to_cache(self, df: pd.DataFrame, filename: str = "funding_rates.csv") -> None:
|
||||
"""Save funding data to cache file."""
|
||||
path = self.cache_dir / filename
|
||||
df.to_csv(path)
|
||||
logger.info("Saved funding rates to %s", path)
|
||||
|
||||
def load_from_cache(self, filename: str = "funding_rates.csv") -> pd.DataFrame | None:
|
||||
"""Load funding data from cache if available."""
|
||||
path = self.cache_dir / filename
|
||||
if path.exists():
|
||||
df = pd.read_csv(path, index_col='timestamp', parse_dates=True)
|
||||
logger.info("Loaded funding rates from cache: %d rows", len(df))
|
||||
return df
|
||||
return None
|
||||
|
||||
def get_funding_data(
|
||||
self,
|
||||
assets: list[str],
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Get funding data, using cache if available.
|
||||
|
||||
Args:
|
||||
assets: List of asset symbols
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
use_cache: Whether to use cached data
|
||||
force_refresh: Force refresh even if cache exists
|
||||
|
||||
Returns:
|
||||
DataFrame with funding rates for all assets
|
||||
"""
|
||||
cache_file = "funding_rates.csv"
|
||||
|
||||
# Try cache first
|
||||
if use_cache and not force_refresh:
|
||||
cached = self.load_from_cache(cache_file)
|
||||
if cached is not None:
|
||||
# Check if cache covers requested range
|
||||
if start_date and end_date:
|
||||
start_ts = pd.Timestamp(start_date, tz='UTC')
|
||||
end_ts = pd.Timestamp(end_date, tz='UTC')
|
||||
|
||||
if cached.index.min() <= start_ts and cached.index.max() >= end_ts:
|
||||
# Filter to requested range
|
||||
return cached[(cached.index >= start_ts) & (cached.index <= end_ts)]
|
||||
|
||||
# Fetch fresh data
|
||||
logger.info("Fetching fresh funding rate data...")
|
||||
df = self.fetch_all_assets(assets, start_date, end_date)
|
||||
|
||||
if not df.empty and use_cache:
|
||||
self.save_to_cache(df, cache_file)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def download_funding_data():
|
||||
"""Download funding data for all multi-pair assets."""
|
||||
from strategies.multi_pair.config import MultiPairConfig
|
||||
|
||||
config = MultiPairConfig()
|
||||
fetcher = FundingRateFetcher()
|
||||
|
||||
# Fetch last year of data
|
||||
end_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
start_date = (datetime.now(timezone.utc) - pd.Timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
|
||||
logger.info("Downloading funding rates for %d assets...", len(config.assets))
|
||||
logger.info("Date range: %s to %s", start_date, end_date)
|
||||
|
||||
df = fetcher.get_funding_data(
|
||||
config.assets,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
force_refresh=True
|
||||
)
|
||||
|
||||
if not df.empty:
|
||||
logger.info("Downloaded %d funding rate records", len(df))
|
||||
logger.info("Columns: %s", list(df.columns))
|
||||
else:
|
||||
logger.warning("No funding data downloaded")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from engine.logging_config import setup_logging
|
||||
setup_logging()
|
||||
download_funding_data()
|
||||
168
strategies/multi_pair/pair_scanner.py
Normal file
168
strategies/multi_pair/pair_scanner.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Pair Scanner for Multi-Pair Divergence Strategy.
|
||||
|
||||
Generates all possible pairs from asset universe and checks tradeability.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from itertools import combinations
|
||||
from typing import Optional
|
||||
|
||||
import ccxt
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from .config import MultiPairConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradingPair:
|
||||
"""
|
||||
Represents a tradeable pair for spread analysis.
|
||||
|
||||
Attributes:
|
||||
base_asset: First asset in the pair (numerator)
|
||||
quote_asset: Second asset in the pair (denominator)
|
||||
pair_id: Unique identifier for the pair
|
||||
is_direct: Whether pair can be traded directly on exchange
|
||||
exchange_symbol: Symbol for direct trading (if available)
|
||||
"""
|
||||
base_asset: str
|
||||
quote_asset: str
|
||||
pair_id: str
|
||||
is_direct: bool = False
|
||||
exchange_symbol: Optional[str] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable pair name."""
|
||||
return f"{self.base_asset}/{self.quote_asset}"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.pair_id)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, TradingPair):
|
||||
return False
|
||||
return self.pair_id == other.pair_id
|
||||
|
||||
|
||||
class PairScanner:
|
||||
"""
|
||||
Scans and generates tradeable pairs from asset universe.
|
||||
|
||||
Checks OKX for directly tradeable cross-pairs and generates
|
||||
synthetic pairs via USDT for others.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiPairConfig):
|
||||
self.config = config
|
||||
self.exchange: Optional[ccxt.Exchange] = None
|
||||
self._available_markets: set[str] = set()
|
||||
|
||||
def _init_exchange(self) -> None:
|
||||
"""Initialize exchange connection for market lookup."""
|
||||
if self.exchange is None:
|
||||
exchange_class = getattr(ccxt, self.config.exchange_id)
|
||||
self.exchange = exchange_class({'enableRateLimit': True})
|
||||
self.exchange.load_markets()
|
||||
self._available_markets = set(self.exchange.symbols)
|
||||
logger.info(
|
||||
"Loaded %d markets from %s",
|
||||
len(self._available_markets),
|
||||
self.config.exchange_id
|
||||
)
|
||||
|
||||
def generate_pairs(self, check_exchange: bool = True) -> list[TradingPair]:
|
||||
"""
|
||||
Generate all unique pairs from asset universe.
|
||||
|
||||
Args:
|
||||
check_exchange: Whether to check OKX for direct trading
|
||||
|
||||
Returns:
|
||||
List of TradingPair objects
|
||||
"""
|
||||
if check_exchange:
|
||||
self._init_exchange()
|
||||
|
||||
pairs = []
|
||||
assets = self.config.assets
|
||||
|
||||
for base, quote in combinations(assets, 2):
|
||||
pair_id = f"{base}__{quote}"
|
||||
|
||||
# Check if directly tradeable as cross-pair on OKX
|
||||
is_direct = False
|
||||
exchange_symbol = None
|
||||
|
||||
if check_exchange:
|
||||
# Check perpetual cross-pair (e.g., ETH/BTC:BTC)
|
||||
# OKX perpetuals are typically quoted in USDT
|
||||
# Cross-pairs like ETH/BTC are less common
|
||||
cross_symbol = f"{base.replace('-USDT', '')}/{quote.replace('-USDT', '')}:USDT"
|
||||
if cross_symbol in self._available_markets:
|
||||
is_direct = True
|
||||
exchange_symbol = cross_symbol
|
||||
|
||||
pair = TradingPair(
|
||||
base_asset=base,
|
||||
quote_asset=quote,
|
||||
pair_id=pair_id,
|
||||
is_direct=is_direct,
|
||||
exchange_symbol=exchange_symbol
|
||||
)
|
||||
pairs.append(pair)
|
||||
|
||||
# Log summary
|
||||
direct_count = sum(1 for p in pairs if p.is_direct)
|
||||
logger.info(
|
||||
"Generated %d pairs: %d direct, %d synthetic",
|
||||
len(pairs), direct_count, len(pairs) - direct_count
|
||||
)
|
||||
|
||||
return pairs
|
||||
|
||||
def get_required_symbols(self, pairs: list[TradingPair]) -> list[str]:
|
||||
"""
|
||||
Get list of symbols needed to calculate all pair spreads.
|
||||
|
||||
For synthetic pairs, we need both USDT pairs.
|
||||
For direct pairs, we still load USDT pairs for simplicity.
|
||||
|
||||
Args:
|
||||
pairs: List of trading pairs
|
||||
|
||||
Returns:
|
||||
List of unique symbols to load (e.g., ['BTC-USDT', 'ETH-USDT'])
|
||||
"""
|
||||
symbols = set()
|
||||
for pair in pairs:
|
||||
symbols.add(pair.base_asset)
|
||||
symbols.add(pair.quote_asset)
|
||||
return list(symbols)
|
||||
|
||||
def filter_by_assets(
|
||||
self,
|
||||
pairs: list[TradingPair],
|
||||
exclude_assets: list[str]
|
||||
) -> list[TradingPair]:
|
||||
"""
|
||||
Filter pairs that contain any of the excluded assets.
|
||||
|
||||
Args:
|
||||
pairs: List of trading pairs
|
||||
exclude_assets: Assets to exclude
|
||||
|
||||
Returns:
|
||||
Filtered list of pairs
|
||||
"""
|
||||
if not exclude_assets:
|
||||
return pairs
|
||||
|
||||
exclude_set = set(exclude_assets)
|
||||
return [
|
||||
p for p in pairs
|
||||
if p.base_asset not in exclude_set
|
||||
and p.quote_asset not in exclude_set
|
||||
]
|
||||
525
strategies/multi_pair/strategy.py
Normal file
525
strategies/multi_pair/strategy.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""
|
||||
Multi-Pair Divergence Selection Strategy.
|
||||
|
||||
Main strategy class that orchestrates pair scanning, feature calculation,
|
||||
model training, and signal generation for backtesting.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from strategies.base import BaseStrategy
|
||||
from engine.market import MarketType
|
||||
from engine.logging_config import get_logger
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import PairScanner, TradingPair
|
||||
from .correlation import CorrelationFilter
|
||||
from .feature_engine import MultiPairFeatureEngine
|
||||
from .divergence_scorer import DivergenceScorer, DivergenceSignal
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PositionState:
|
||||
"""Tracks current position state."""
|
||||
pair: TradingPair | None = None
|
||||
direction: str | None = None # 'long' or 'short'
|
||||
entry_price: float = 0.0
|
||||
entry_idx: int = -1
|
||||
stop_loss: float = 0.0
|
||||
take_profit: float = 0.0
|
||||
atr: float = 0.0 # ATR at entry for reference
|
||||
last_exit_idx: int = -100 # For cooldown tracking
|
||||
|
||||
|
||||
class MultiPairDivergenceStrategy(BaseStrategy):
|
||||
"""
|
||||
Multi-Pair Divergence Selection Strategy.
|
||||
|
||||
Scans multiple cryptocurrency pairs for spread divergence,
|
||||
selects the most divergent pair using ML-enhanced scoring,
|
||||
and trades mean-reversion opportunities.
|
||||
|
||||
Key Features:
|
||||
- Universal ML model across all pairs
|
||||
- Correlation-based pair filtering
|
||||
- Dynamic SL/TP based on volatility
|
||||
- Walk-forward training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MultiPairConfig | None = None,
|
||||
model_path: str = "data/multi_pair_model.pkl"
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config or MultiPairConfig()
|
||||
|
||||
# Initialize components
|
||||
self.pair_scanner = PairScanner(self.config)
|
||||
self.correlation_filter = CorrelationFilter(self.config)
|
||||
self.feature_engine = MultiPairFeatureEngine(self.config)
|
||||
self.divergence_scorer = DivergenceScorer(self.config, model_path)
|
||||
|
||||
# Strategy configuration
|
||||
self.default_market_type = MarketType.PERPETUAL
|
||||
self.default_leverage = 1
|
||||
|
||||
# Runtime state
|
||||
self.pairs: list[TradingPair] = []
|
||||
self.asset_data: dict[str, pd.DataFrame] = {}
|
||||
self.pair_features: dict[str, pd.DataFrame] = {}
|
||||
self.position = PositionState()
|
||||
self.train_end_idx: int = 0
|
||||
|
||||
def run(self, close: pd.Series, **kwargs) -> tuple:
|
||||
"""
|
||||
Execute the multi-pair divergence strategy.
|
||||
|
||||
This method is called by the backtester with the primary asset's
|
||||
close prices. For multi-pair, we load all assets internally.
|
||||
|
||||
Args:
|
||||
close: Primary close prices (used for index alignment)
|
||||
**kwargs: Additional data (high, low, volume)
|
||||
|
||||
Returns:
|
||||
Tuple of (long_entries, long_exits, short_entries, short_exits, size)
|
||||
"""
|
||||
logger.info("Starting Multi-Pair Divergence Strategy")
|
||||
|
||||
# 1. Load all asset data
|
||||
start_date = close.index.min().strftime("%Y-%m-%d")
|
||||
end_date = close.index.max().strftime("%Y-%m-%d")
|
||||
|
||||
self.asset_data = self.feature_engine.load_all_assets(
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# 1b. Load funding rate data for all assets
|
||||
self.feature_engine.load_funding_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
if len(self.asset_data) < 2:
|
||||
logger.error("Insufficient assets loaded, need at least 2")
|
||||
return self._empty_signals(close)
|
||||
|
||||
# 2. Generate pairs
|
||||
self.pairs = self.pair_scanner.generate_pairs(check_exchange=False)
|
||||
|
||||
# Filter to pairs with available data
|
||||
available_assets = set(self.asset_data.keys())
|
||||
self.pairs = [
|
||||
p for p in self.pairs
|
||||
if p.base_asset in available_assets
|
||||
and p.quote_asset in available_assets
|
||||
]
|
||||
|
||||
logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data))
|
||||
|
||||
# 3. Calculate features for all pairs
|
||||
self.pair_features = self.feature_engine.calculate_all_pair_features(
|
||||
self.pairs, self.asset_data
|
||||
)
|
||||
|
||||
if not self.pair_features:
|
||||
logger.error("No pair features calculated")
|
||||
return self._empty_signals(close)
|
||||
|
||||
# 4. Align to common index
|
||||
common_index = self._get_common_index()
|
||||
if len(common_index) < 200:
|
||||
logger.error("Insufficient common data across pairs")
|
||||
return self._empty_signals(close)
|
||||
|
||||
# 5. Walk-forward split
|
||||
n_samples = len(common_index)
|
||||
train_size = int(n_samples * self.config.train_ratio)
|
||||
self.train_end_idx = train_size
|
||||
|
||||
train_end_date = common_index[train_size - 1]
|
||||
test_start_date = common_index[train_size]
|
||||
|
||||
logger.info(
|
||||
"Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)",
|
||||
train_size, train_end_date.strftime('%Y-%m-%d'),
|
||||
n_samples - train_size, test_start_date.strftime('%Y-%m-%d')
|
||||
)
|
||||
|
||||
# 6. Train model on training period
|
||||
if self.divergence_scorer.model is None:
|
||||
train_features = {
|
||||
pid: feat[feat.index <= train_end_date]
|
||||
for pid, feat in self.pair_features.items()
|
||||
}
|
||||
combined = self.feature_engine.get_combined_features(train_features)
|
||||
self.divergence_scorer.train_model(combined, train_features)
|
||||
|
||||
# 7. Generate signals for test period
|
||||
return self._generate_signals(common_index, train_size, close)
|
||||
|
||||
def _generate_signals(
|
||||
self,
|
||||
index: pd.DatetimeIndex,
|
||||
train_size: int,
|
||||
reference_close: pd.Series
|
||||
) -> tuple:
|
||||
"""
|
||||
Generate entry/exit signals for the test period.
|
||||
|
||||
Iterates through each bar in the test period, scoring pairs
|
||||
and generating signals based on divergence scores.
|
||||
"""
|
||||
# Initialize signal arrays aligned to reference close
|
||||
long_entries = pd.Series(False, index=reference_close.index)
|
||||
long_exits = pd.Series(False, index=reference_close.index)
|
||||
short_entries = pd.Series(False, index=reference_close.index)
|
||||
short_exits = pd.Series(False, index=reference_close.index)
|
||||
size = pd.Series(1.0, index=reference_close.index)
|
||||
|
||||
# Track position state
|
||||
self.position = PositionState()
|
||||
|
||||
# Price data for correlation calculation
|
||||
price_data = {
|
||||
symbol: df['close'] for symbol, df in self.asset_data.items()
|
||||
}
|
||||
|
||||
# Iterate through test period
|
||||
test_indices = index[train_size:]
|
||||
|
||||
trade_count = 0
|
||||
|
||||
for i, timestamp in enumerate(test_indices):
|
||||
current_idx = train_size + i
|
||||
|
||||
# Check exit conditions first
|
||||
if self.position.pair is not None:
|
||||
# Enforce minimum hold period
|
||||
bars_held = current_idx - self.position.entry_idx
|
||||
if bars_held < self.config.min_hold_bars:
|
||||
# Only allow SL/TP exits during min hold period
|
||||
should_exit, exit_reason = self._check_sl_tp_only(timestamp)
|
||||
else:
|
||||
should_exit, exit_reason = self._check_exit(timestamp)
|
||||
|
||||
if should_exit:
|
||||
# Map exit signal to reference index
|
||||
if timestamp in reference_close.index:
|
||||
if self.position.direction == 'long':
|
||||
long_exits.loc[timestamp] = True
|
||||
else:
|
||||
short_exits.loc[timestamp] = True
|
||||
|
||||
logger.debug(
|
||||
"Exit %s %s at %s: %s (held %d bars)",
|
||||
self.position.direction,
|
||||
self.position.pair.name,
|
||||
timestamp.strftime('%Y-%m-%d %H:%M'),
|
||||
exit_reason,
|
||||
bars_held
|
||||
)
|
||||
self.position = PositionState(last_exit_idx=current_idx)
|
||||
|
||||
# Score pairs (with correlation filter if position exists)
|
||||
held_asset = None
|
||||
if self.position.pair is not None:
|
||||
held_asset = self.position.pair.base_asset
|
||||
|
||||
# Filter pairs by correlation
|
||||
candidate_pairs = self.correlation_filter.filter_pairs(
|
||||
self.pairs,
|
||||
held_asset,
|
||||
price_data,
|
||||
current_idx
|
||||
)
|
||||
|
||||
# Get candidate features
|
||||
candidate_features = {
|
||||
pid: feat for pid, feat in self.pair_features.items()
|
||||
if any(p.pair_id == pid for p in candidate_pairs)
|
||||
}
|
||||
|
||||
# Score pairs
|
||||
signals = self.divergence_scorer.score_pairs(
|
||||
candidate_features,
|
||||
candidate_pairs,
|
||||
timestamp
|
||||
)
|
||||
|
||||
# Get best signal
|
||||
best = self.divergence_scorer.select_best_pair(signals)
|
||||
|
||||
if best is None:
|
||||
continue
|
||||
|
||||
# Check if we should switch positions or enter new
|
||||
should_enter = False
|
||||
|
||||
# Check cooldown
|
||||
bars_since_exit = current_idx - self.position.last_exit_idx
|
||||
in_cooldown = bars_since_exit < self.config.cooldown_bars
|
||||
|
||||
if self.position.pair is None and not in_cooldown:
|
||||
# No position and not in cooldown, can enter
|
||||
should_enter = True
|
||||
elif self.position.pair is not None:
|
||||
# Check if we should switch (requires min hold + significant improvement)
|
||||
bars_held = current_idx - self.position.entry_idx
|
||||
current_score = self._get_current_score(timestamp)
|
||||
|
||||
if (bars_held >= self.config.min_hold_bars and
|
||||
best.divergence_score > current_score * self.config.switch_threshold):
|
||||
# New opportunity is significantly better
|
||||
if timestamp in reference_close.index:
|
||||
if self.position.direction == 'long':
|
||||
long_exits.loc[timestamp] = True
|
||||
else:
|
||||
short_exits.loc[timestamp] = True
|
||||
self.position = PositionState(last_exit_idx=current_idx)
|
||||
should_enter = True
|
||||
|
||||
if should_enter:
|
||||
# Calculate ATR-based dynamic SL/TP
|
||||
sl_price, tp_price = self._calculate_sl_tp(
|
||||
best.base_price,
|
||||
best.direction,
|
||||
best.atr,
|
||||
best.atr_pct
|
||||
)
|
||||
|
||||
# Set position
|
||||
self.position = PositionState(
|
||||
pair=best.pair,
|
||||
direction=best.direction,
|
||||
entry_price=best.base_price,
|
||||
entry_idx=current_idx,
|
||||
stop_loss=sl_price,
|
||||
take_profit=tp_price,
|
||||
atr=best.atr
|
||||
)
|
||||
|
||||
# Calculate position size based on divergence
|
||||
pos_size = self._calculate_size(best.divergence_score)
|
||||
|
||||
# Generate entry signal
|
||||
if timestamp in reference_close.index:
|
||||
if best.direction == 'long':
|
||||
long_entries.loc[timestamp] = True
|
||||
else:
|
||||
short_entries.loc[timestamp] = True
|
||||
size.loc[timestamp] = pos_size
|
||||
|
||||
trade_count += 1
|
||||
logger.debug(
|
||||
"Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f",
|
||||
best.direction,
|
||||
best.pair.name,
|
||||
timestamp.strftime('%Y-%m-%d %H:%M'),
|
||||
best.z_score,
|
||||
best.probability,
|
||||
best.divergence_score
|
||||
)
|
||||
|
||||
logger.info("Generated %d trades in test period", trade_count)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits, size
|
||||
|
||||
def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
|
||||
"""
|
||||
Check if current position should be exited.
|
||||
|
||||
Exit conditions:
|
||||
1. Z-Score reverted to mean (|Z| < threshold)
|
||||
2. Stop-loss hit
|
||||
3. Take-profit hit
|
||||
|
||||
Returns:
|
||||
Tuple of (should_exit, reason)
|
||||
"""
|
||||
if self.position.pair is None:
|
||||
return False, ""
|
||||
|
||||
pair_id = self.position.pair.pair_id
|
||||
if pair_id not in self.pair_features:
|
||||
return True, "pair_data_missing"
|
||||
|
||||
features = self.pair_features[pair_id]
|
||||
valid = features[features.index <= timestamp]
|
||||
|
||||
if len(valid) == 0:
|
||||
return True, "no_data"
|
||||
|
||||
latest = valid.iloc[-1]
|
||||
z_score = latest['z_score']
|
||||
current_price = latest['base_close']
|
||||
|
||||
# Check mean reversion (primary exit)
|
||||
if abs(z_score) < self.config.z_exit_threshold:
|
||||
return True, f"mean_reversion (z={z_score:.2f})"
|
||||
|
||||
# Check SL/TP
|
||||
return self._check_sl_tp(current_price)
|
||||
|
||||
def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
|
||||
"""
|
||||
Check only stop-loss and take-profit conditions.
|
||||
Used during minimum hold period.
|
||||
"""
|
||||
if self.position.pair is None:
|
||||
return False, ""
|
||||
|
||||
pair_id = self.position.pair.pair_id
|
||||
if pair_id not in self.pair_features:
|
||||
return True, "pair_data_missing"
|
||||
|
||||
features = self.pair_features[pair_id]
|
||||
valid = features[features.index <= timestamp]
|
||||
|
||||
if len(valid) == 0:
|
||||
return True, "no_data"
|
||||
|
||||
latest = valid.iloc[-1]
|
||||
current_price = latest['base_close']
|
||||
|
||||
return self._check_sl_tp(current_price)
|
||||
|
||||
def _check_sl_tp(self, current_price: float) -> tuple[bool, str]:
|
||||
"""Check stop-loss and take-profit levels."""
|
||||
if self.position.direction == 'long':
|
||||
if current_price <= self.position.stop_loss:
|
||||
return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})"
|
||||
if current_price >= self.position.take_profit:
|
||||
return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})"
|
||||
else: # short
|
||||
if current_price >= self.position.stop_loss:
|
||||
return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})"
|
||||
if current_price <= self.position.take_profit:
|
||||
return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})"
|
||||
|
||||
return False, ""
|
||||
|
||||
def _get_current_score(self, timestamp: pd.Timestamp) -> float:
|
||||
"""Get current position's divergence score for comparison."""
|
||||
if self.position.pair is None:
|
||||
return 0.0
|
||||
|
||||
pair_id = self.position.pair.pair_id
|
||||
if pair_id not in self.pair_features:
|
||||
return 0.0
|
||||
|
||||
features = self.pair_features[pair_id]
|
||||
valid = features[features.index <= timestamp]
|
||||
|
||||
if len(valid) == 0:
|
||||
return 0.0
|
||||
|
||||
latest = valid.iloc[-1]
|
||||
z_score = abs(latest['z_score'])
|
||||
|
||||
# Re-score with model
|
||||
if self.divergence_scorer.model is not None:
|
||||
feature_row = latest[self.divergence_scorer.feature_cols].fillna(0)
|
||||
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
||||
X = pd.DataFrame(
|
||||
[feature_row.values],
|
||||
columns=self.divergence_scorer.feature_cols
|
||||
)
|
||||
prob = self.divergence_scorer.model.predict_proba(X)[0, 1]
|
||||
return z_score * prob
|
||||
|
||||
return z_score * 0.5
|
||||
|
||||
def _calculate_sl_tp(
|
||||
self,
|
||||
entry_price: float,
|
||||
direction: str,
|
||||
atr: float,
|
||||
atr_pct: float
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate ATR-based dynamic stop-loss and take-profit prices.
|
||||
|
||||
Uses ATR (Average True Range) to set stops that adapt to
|
||||
each asset's volatility. More volatile assets get wider stops.
|
||||
|
||||
Args:
|
||||
entry_price: Entry price
|
||||
direction: 'long' or 'short'
|
||||
atr: ATR in price units
|
||||
atr_pct: ATR as percentage of price
|
||||
|
||||
Returns:
|
||||
Tuple of (stop_loss_price, take_profit_price)
|
||||
"""
|
||||
# Calculate SL/TP as ATR multiples
|
||||
if atr > 0 and atr_pct > 0:
|
||||
# ATR-based calculation
|
||||
sl_distance = atr * self.config.sl_atr_multiplier
|
||||
tp_distance = atr * self.config.tp_atr_multiplier
|
||||
|
||||
# Convert to percentage for bounds checking
|
||||
sl_pct = sl_distance / entry_price
|
||||
tp_pct = tp_distance / entry_price
|
||||
else:
|
||||
# Fallback to fixed percentages if ATR unavailable
|
||||
sl_pct = self.config.base_sl_pct
|
||||
tp_pct = self.config.base_tp_pct
|
||||
|
||||
# Apply bounds to prevent extreme stops
|
||||
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
|
||||
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
|
||||
|
||||
# Calculate actual prices
|
||||
if direction == 'long':
|
||||
stop_loss = entry_price * (1 - sl_pct)
|
||||
take_profit = entry_price * (1 + tp_pct)
|
||||
else: # short
|
||||
stop_loss = entry_price * (1 + sl_pct)
|
||||
take_profit = entry_price * (1 - tp_pct)
|
||||
|
||||
return stop_loss, take_profit
|
||||
|
||||
def _calculate_size(self, divergence_score: float) -> float:
|
||||
"""
|
||||
Calculate position size based on divergence score.
|
||||
|
||||
Higher divergence = larger position (up to 2x).
|
||||
"""
|
||||
# Base score threshold (Z=1.0, prob=0.5 -> score=0.5)
|
||||
base_threshold = 0.5
|
||||
|
||||
# Scale factor
|
||||
if divergence_score <= base_threshold:
|
||||
return 1.0
|
||||
|
||||
# Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold
|
||||
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
|
||||
return min(scale, 2.0)
|
||||
|
||||
def _get_common_index(self) -> pd.DatetimeIndex:
|
||||
"""Get the intersection of all pair feature indices."""
|
||||
if not self.pair_features:
|
||||
return pd.DatetimeIndex([])
|
||||
|
||||
common = None
|
||||
for features in self.pair_features.values():
|
||||
if common is None:
|
||||
common = features.index
|
||||
else:
|
||||
common = common.intersection(features.index)
|
||||
|
||||
return common.sort_values()
|
||||
|
||||
def _empty_signals(self, close: pd.Series) -> tuple:
|
||||
"""Return empty signal arrays."""
|
||||
empty = self.create_empty_signals(close)
|
||||
size = pd.Series(1.0, index=close.index)
|
||||
return empty, empty, empty, empty, size
|
||||
365
strategies/regime_strategy.py
Normal file
365
strategies/regime_strategy.py
Normal file
@@ -0,0 +1,365 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import ta
|
||||
import vectorbt as vbt
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
from strategies.base import BaseStrategy
|
||||
from engine.market import MarketType
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class RegimeReversionStrategy(BaseStrategy):
|
||||
"""
|
||||
ML-Based Regime Detection & Mean Reversion Strategy.
|
||||
|
||||
Logic:
|
||||
1. Tracks the BTC/ETH Spread and its Z-Score (24h window).
|
||||
2. Uses a Random Forest model to predict if an extreme Z-Score will revert profitably.
|
||||
3. Features: Spread Technicals (RSI, ROC) + On-Chain Flows (Inflow, Funding).
|
||||
4. Entry: When Model Probability > 0.5.
|
||||
5. Exit: Z-Score reversion to 0 or SL/TP.
|
||||
|
||||
Walk-Forward Training:
|
||||
- Trains on first `train_ratio` of data (default 70%)
|
||||
- Generates signals only for remaining test period (30%)
|
||||
- Eliminates look-ahead bias for realistic backtest results
|
||||
"""
|
||||
|
||||
# Optimal parameters from walk-forward research (2025-10 to 2025-12)
|
||||
# Research: research/horizon_optimization_results.csv
|
||||
OPTIMAL_HORIZON = 102 # 4.25 days - best Net PnL (+232%)
|
||||
OPTIMAL_Z_WINDOW = 24 # 24h rolling window for spread Z-score
|
||||
OPTIMAL_TRAIN_RATIO = 0.7 # 70% train / 30% test split
|
||||
OPTIMAL_PROFIT_TARGET = 0.005 # 0.5% profit threshold for target definition
|
||||
OPTIMAL_Z_ENTRY = 1.0 # Enter when |Z| > 1.0
|
||||
|
||||
def __init__(self,
|
||||
model_path: str = "data/regime_model.pkl",
|
||||
horizon: int = OPTIMAL_HORIZON,
|
||||
z_window: int = OPTIMAL_Z_WINDOW,
|
||||
z_entry_threshold: float = OPTIMAL_Z_ENTRY,
|
||||
profit_target: float = OPTIMAL_PROFIT_TARGET,
|
||||
stop_loss: float = 0.06, # 6% - accommodates 1.95% avg MAE
|
||||
take_profit: float = 0.05, # 5% swing target
|
||||
train_ratio: float = OPTIMAL_TRAIN_RATIO,
|
||||
trend_window: int = 0, # Disable SMA filter
|
||||
use_funding_filter: bool = True, # Enable Funding Rate filter
|
||||
funding_threshold: float = 0.005 # Tightened to 0.005%
|
||||
):
|
||||
super().__init__()
|
||||
self.model_path = model_path
|
||||
self.horizon = horizon
|
||||
self.z_window = z_window
|
||||
self.z_entry_threshold = z_entry_threshold
|
||||
self.profit_target = profit_target
|
||||
self.stop_loss = stop_loss
|
||||
self.take_profit = take_profit
|
||||
self.train_ratio = train_ratio
|
||||
self.trend_window = trend_window
|
||||
self.use_funding_filter = use_funding_filter
|
||||
self.funding_threshold = funding_threshold
|
||||
|
||||
# Default Strategy Config
|
||||
self.default_market_type = MarketType.PERPETUAL
|
||||
self.default_leverage = 1
|
||||
|
||||
self.dm = DataManager()
|
||||
self.model = None
|
||||
self.feature_cols = None
|
||||
self.train_end_idx = None # Will store the training cutoff point
|
||||
|
||||
def run(self, close, **kwargs):
|
||||
"""
|
||||
Execute the strategy logic.
|
||||
We assume this strategy is run on ETH-USDT (the active asset).
|
||||
We will fetch BTC-USDT internally to calculate the spread.
|
||||
"""
|
||||
# 1. Identify Context
|
||||
# We need BTC data aligned with the incoming ETH 'close' series
|
||||
start_date = close.index.min()
|
||||
end_date = close.index.max()
|
||||
|
||||
logger.info("Fetching BTC context data...")
|
||||
try:
|
||||
# Load BTC data (Context) - Must match the timeframe of the backtest
|
||||
# Research was done on 1h candles, so strategy should be run on 1h
|
||||
# Use PERPETUAL data to match the trading instrument (ETH Perp)
|
||||
df_btc = self.dm.load_data("okx", "BTC-USDT", "1h", MarketType.PERPETUAL)
|
||||
|
||||
# Align BTC to ETH (close)
|
||||
df_btc = df_btc.reindex(close.index, method='ffill')
|
||||
btc_close = df_btc['close']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load BTC context: {e}")
|
||||
empty = self.create_empty_signals(close)
|
||||
return empty, empty, empty, empty
|
||||
|
||||
# 2. Construct DataFrames for Feature Engineering
|
||||
# We need volume/high/low for features, but 'run' signature primarily gives 'close'.
|
||||
# kwargs might have high/low/volume if passed by Backtester.run_strategy
|
||||
eth_vol = kwargs.get('volume')
|
||||
|
||||
if eth_vol is None:
|
||||
logger.warning("Volume data missing. Feature calculation might fail.")
|
||||
# Fallback or error handling
|
||||
eth_vol = pd.Series(0, index=close.index)
|
||||
|
||||
# Construct dummy dfs for prepare_features
|
||||
# We only really need Close and Volume for the current feature set
|
||||
df_a = pd.DataFrame({'close': btc_close, 'volume': df_btc['volume']})
|
||||
df_b = pd.DataFrame({'close': close, 'volume': eth_vol})
|
||||
|
||||
# 3. Load On-Chain Data (CryptoQuant)
|
||||
# We use the saved CSV for training/inference
|
||||
# In a live setting, this would query the API for recent data
|
||||
cq_df = None
|
||||
try:
|
||||
cq_path = "data/cq_training_data.csv"
|
||||
cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True)
|
||||
if cq_df.index.tz is None:
|
||||
cq_df.index = cq_df.index.tz_localize('UTC')
|
||||
except Exception:
|
||||
logger.warning("CryptoQuant data not found. Running without on-chain features.")
|
||||
|
||||
# 4. Calculate Features
|
||||
features = self.prepare_features(df_a, df_b, cq_df)
|
||||
|
||||
# 5. Walk-Forward Split
|
||||
# Train on first `train_ratio` of data, test on remainder
|
||||
n_samples = len(features)
|
||||
train_size = int(n_samples * self.train_ratio)
|
||||
|
||||
train_features = features.iloc[:train_size]
|
||||
test_features = features.iloc[train_size:]
|
||||
|
||||
train_end_date = train_features.index[-1]
|
||||
test_start_date = test_features.index[0]
|
||||
|
||||
logger.info(
|
||||
f"Walk-Forward Split: Train={len(train_features)} bars "
|
||||
f"(until {train_end_date.strftime('%Y-%m-%d')}), "
|
||||
f"Test={len(test_features)} bars "
|
||||
f"(from {test_start_date.strftime('%Y-%m-%d')})"
|
||||
)
|
||||
|
||||
# 6. Train Model on Training Period ONLY
|
||||
if self.model is None:
|
||||
logger.info("Training Regime Model on training period only...")
|
||||
self.model, self.feature_cols = self.train_model(train_features)
|
||||
|
||||
# 7. Predict on TEST Period ONLY
|
||||
# Use valid columns only
|
||||
X_test = test_features[self.feature_cols].fillna(0)
|
||||
X_test = X_test.replace([np.inf, -np.inf], 0)
|
||||
|
||||
# Predict Probabilities for test period
|
||||
probs = self.model.predict_proba(X_test)[:, 1]
|
||||
|
||||
# 8. Generate Entry Signals (TEST period only)
|
||||
# If Z > threshold (Spread High, ETH Expensive) -> Short ETH
|
||||
# If Z < -threshold (Spread Low, ETH Cheap) -> Long ETH
|
||||
z_thresh = self.z_entry_threshold
|
||||
|
||||
short_signal_test = (probs > 0.5) & (test_features['z_score'].values > z_thresh)
|
||||
long_signal_test = (probs > 0.5) & (test_features['z_score'].values < -z_thresh)
|
||||
|
||||
# 8b. Apply Trend Filter (Macro Regime)
|
||||
# Rule: Long only if BTC > SMA (Bull), Short only if BTC < SMA (Bear)
|
||||
if self.trend_window > 0:
|
||||
# Calculate SMA on full BTC history first
|
||||
btc_sma = btc_close.rolling(window=self.trend_window).mean()
|
||||
|
||||
# Align with test period
|
||||
test_btc_close = btc_close.reindex(test_features.index)
|
||||
test_btc_sma = btc_sma.reindex(test_features.index)
|
||||
|
||||
# Define Regimes
|
||||
is_bull = (test_btc_close > test_btc_sma).values
|
||||
is_bear = (test_btc_close < test_btc_sma).values
|
||||
|
||||
# Apply Filter
|
||||
long_signal_test = long_signal_test & is_bull
|
||||
short_signal_test = short_signal_test & is_bear
|
||||
|
||||
# 8c. Apply Funding Rate Filter
|
||||
# Rule: If Funding > Threshold (Greedy) -> No Longs.
|
||||
# If Funding < -Threshold (Fearful) -> No Shorts.
|
||||
if self.use_funding_filter and 'btc_funding' in test_features.columns:
|
||||
funding = test_features['btc_funding'].values
|
||||
thresh = self.funding_threshold
|
||||
|
||||
# Greedy Market (High Positive Funding) -> Risk of Long Squeeze -> Block Longs
|
||||
# (Or implies trend is up? Actually for Mean Reversion, high funding often marks tops)
|
||||
# We block Longs because we don't want to buy into an overheated market?
|
||||
# Actually, "Greedy" means Longs are paying Shorts.
|
||||
# If we Long, we pay funding.
|
||||
# If we Short, we receive funding.
|
||||
# So High Funding = Good for Shorts (receive yield + reversion).
|
||||
# Bad for Longs (pay yield + likely top).
|
||||
|
||||
is_overheated = funding > thresh
|
||||
is_oversold = funding < -thresh
|
||||
|
||||
# Block Longs if Overheated
|
||||
long_signal_test = long_signal_test & (~is_overheated)
|
||||
|
||||
# Block Shorts if Oversold (Negative Funding) -> Risk of Short Squeeze
|
||||
short_signal_test = short_signal_test & (~is_oversold)
|
||||
|
||||
n_blocked_long = (is_overheated & (probs > 0.5) & (test_features['z_score'].values < -z_thresh)).sum()
|
||||
n_blocked_short = (is_oversold & (probs > 0.5) & (test_features['z_score'].values > z_thresh)).sum()
|
||||
|
||||
if n_blocked_long > 0 or n_blocked_short > 0:
|
||||
logger.info(f"Funding Filter: Blocked {n_blocked_long} Longs, {n_blocked_short} Shorts")
|
||||
|
||||
# 9. Calculate Position Sizing (Probability-Based)
|
||||
# Base size = 1.0 (100% of equity)
|
||||
# Scale: 1.0 + (Prob - 0.5) * 2
|
||||
# Example: Prob=0.6 -> Size=1.2, Prob=0.8 -> Size=1.6
|
||||
|
||||
# Align probabilities to close index
|
||||
probs_series = pd.Series(0.0, index=test_features.index)
|
||||
probs_series[:] = probs
|
||||
probs_aligned = probs_series.reindex(close.index, fill_value=0.0)
|
||||
|
||||
# Calculate dynamic size
|
||||
dynamic_size = 1.0 + (probs_aligned - 0.5) * 2.0
|
||||
# Cap leverage between 1x and 2x
|
||||
size = dynamic_size.clip(lower=1.0, upper=2.0)
|
||||
|
||||
# Create full-length signal series (False for training period)
|
||||
long_entries = pd.Series(False, index=close.index)
|
||||
short_entries = pd.Series(False, index=close.index)
|
||||
|
||||
# Map test signals to their correct indices
|
||||
test_idx = test_features.index
|
||||
for i, idx in enumerate(test_idx):
|
||||
if idx in close.index:
|
||||
long_entries.loc[idx] = bool(long_signal_test[i])
|
||||
short_entries.loc[idx] = bool(short_signal_test[i])
|
||||
|
||||
# 9. Generate Exits
|
||||
# Exit when Z-Score crosses back through 0 (mean reversion complete)
|
||||
z_reindexed = features['z_score'].reindex(close.index, fill_value=0)
|
||||
|
||||
# Exit Long when Z > 0, Exit Short when Z < 0
|
||||
long_exits = z_reindexed > 0
|
||||
short_exits = z_reindexed < 0
|
||||
|
||||
# Log signal counts for verification
|
||||
n_long = long_entries.sum()
|
||||
n_short = short_entries.sum()
|
||||
logger.info(f"Generated {n_long} long signals, {n_short} short signals (test period only)")
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits, size
|
||||
|
||||
def prepare_features(self, df_btc, df_eth, cq_df=None):
|
||||
"""Replicate research feature engineering"""
|
||||
# Align
|
||||
common = df_btc.index.intersection(df_eth.index)
|
||||
df_a = df_btc.loc[common].copy()
|
||||
df_b = df_eth.loc[common].copy()
|
||||
|
||||
# Spread
|
||||
spread = df_b['close'] / df_a['close']
|
||||
|
||||
# Z-Score
|
||||
rolling_mean = spread.rolling(window=self.z_window).mean()
|
||||
rolling_std = spread.rolling(window=self.z_window).std()
|
||||
z_score = (spread - rolling_mean) / rolling_std
|
||||
|
||||
# Technicals
|
||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||
spread_roc = spread.pct_change(periods=5) * 100
|
||||
spread_change_1h = spread.pct_change(periods=1)
|
||||
|
||||
# Volume
|
||||
vol_ratio = df_b['volume'] / df_a['volume']
|
||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||
|
||||
# Volatility
|
||||
ret_a = df_a['close'].pct_change()
|
||||
ret_b = df_b['close'].pct_change()
|
||||
vol_a = ret_a.rolling(window=self.z_window).std()
|
||||
vol_b = ret_b.rolling(window=self.z_window).std()
|
||||
vol_spread_ratio = vol_b / vol_a
|
||||
|
||||
features = pd.DataFrame(index=spread.index)
|
||||
features['spread'] = spread
|
||||
features['z_score'] = z_score
|
||||
features['spread_rsi'] = spread_rsi
|
||||
features['spread_roc'] = spread_roc
|
||||
features['spread_change_1h'] = spread_change_1h
|
||||
features['vol_ratio'] = vol_ratio
|
||||
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
|
||||
features['vol_diff_ratio'] = vol_spread_ratio
|
||||
|
||||
# CQ Merge
|
||||
if cq_df is not None:
|
||||
cq_aligned = cq_df.reindex(features.index, method='ffill')
|
||||
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
|
||||
cq_aligned['funding_diff'] = cq_aligned['eth_funding'] - cq_aligned['btc_funding']
|
||||
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
|
||||
cq_aligned['inflow_ratio'] = cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
|
||||
features = features.join(cq_aligned)
|
||||
|
||||
return features.dropna()
|
||||
|
||||
def train_model(self, train_features):
|
||||
"""
|
||||
Train Random Forest on training data only.
|
||||
|
||||
This method receives ONLY the training subset of features,
|
||||
ensuring no look-ahead bias. The model learns from historical
|
||||
patterns and is then applied to unseen test data.
|
||||
|
||||
Args:
|
||||
train_features: DataFrame containing features for training period only
|
||||
"""
|
||||
threshold = self.profit_target
|
||||
horizon = self.horizon
|
||||
z_thresh = self.z_entry_threshold
|
||||
|
||||
# Define targets using ONLY training data
|
||||
# For Short Spread (Z > threshold): Did spread drop below target within horizon?
|
||||
future_min = train_features['spread'].rolling(window=horizon).min().shift(-horizon)
|
||||
target_short = train_features['spread'] * (1 - threshold)
|
||||
success_short = (train_features['z_score'] > z_thresh) & (future_min < target_short)
|
||||
|
||||
# For Long Spread (Z < -threshold): Did spread rise above target within horizon?
|
||||
future_max = train_features['spread'].rolling(window=horizon).max().shift(-horizon)
|
||||
target_long = train_features['spread'] * (1 + threshold)
|
||||
success_long = (train_features['z_score'] < -z_thresh) & (future_max > target_long)
|
||||
|
||||
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||
|
||||
# Build model
|
||||
model = RandomForestClassifier(
|
||||
n_estimators=300, max_depth=5, min_samples_leaf=30,
|
||||
class_weight={0: 1, 1: 3}, random_state=42
|
||||
)
|
||||
|
||||
# Exclude non-feature columns
|
||||
exclude = ['spread']
|
||||
cols = [c for c in train_features.columns if c not in exclude]
|
||||
|
||||
# Clean features
|
||||
X_train = train_features[cols].fillna(0)
|
||||
X_train = X_train.replace([np.inf, -np.inf], 0)
|
||||
|
||||
# Remove rows with NaN targets (from rolling window at end of training period)
|
||||
valid_mask = ~np.isnan(targets) & ~np.isinf(targets)
|
||||
# Also check for rows where future data doesn't exist (shift created NaNs)
|
||||
valid_mask = valid_mask & (future_min.notna().values) & (future_max.notna().values)
|
||||
|
||||
X_train_clean = X_train[valid_mask]
|
||||
targets_clean = targets[valid_mask]
|
||||
|
||||
logger.info(f"Training on {len(X_train_clean)} valid samples (removed {len(X_train) - len(X_train_clean)} with incomplete future data)")
|
||||
|
||||
model.fit(X_train_clean, targets_clean)
|
||||
return model, cols
|
||||
6
strategies/supertrend/__init__.py
Normal file
6
strategies/supertrend/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Meta Supertrend strategy package.
|
||||
"""
|
||||
from .strategy import MetaSupertrendStrategy
|
||||
|
||||
__all__ = ['MetaSupertrendStrategy']
|
||||
128
strategies/supertrend/indicators.py
Normal file
128
strategies/supertrend/indicators.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Supertrend indicators and helper functions.
|
||||
"""
|
||||
import numpy as np
|
||||
import vectorbt as vbt
|
||||
from numba import njit
|
||||
|
||||
# --- Numba Compiled Helper Functions ---
|
||||
|
||||
@njit(cache=False) # Disable cache to avoid stale compilation issues
|
||||
def get_tr_nb(high, low, close):
|
||||
"""Calculate True Range (Numba compiled)."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
tr = np.empty_like(close)
|
||||
tr[0] = high[0] - low[0]
|
||||
for i in range(1, len(close)):
|
||||
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
|
||||
return tr
|
||||
|
||||
@njit(cache=False)
|
||||
def get_atr_nb(high, low, close, period):
|
||||
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
# Ensure period is native Python int (critical for Numba array indexing)
|
||||
n = len(close)
|
||||
p = int(period)
|
||||
|
||||
tr = get_tr_nb(high, low, close)
|
||||
atr = np.full(n, np.nan, dtype=np.float64)
|
||||
|
||||
if n < p:
|
||||
return atr
|
||||
|
||||
# Initial ATR is simple average of TR
|
||||
sum_tr = 0.0
|
||||
for i in range(p):
|
||||
sum_tr += tr[i]
|
||||
atr[p - 1] = sum_tr / p
|
||||
|
||||
# Subsequent ATR is Wilder's smoothed
|
||||
for i in range(p, n):
|
||||
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
|
||||
|
||||
return atr
|
||||
|
||||
@njit(cache=False)
|
||||
def get_supertrend_nb(high, low, close, period, multiplier):
|
||||
"""Calculate SuperTrend completely in Numba."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
# Ensure params are native Python types (critical for Numba)
|
||||
n = len(close)
|
||||
p = int(period)
|
||||
m = float(multiplier)
|
||||
|
||||
atr = get_atr_nb(high, low, close, p)
|
||||
|
||||
final_upper = np.full(n, np.nan, dtype=np.float64)
|
||||
final_lower = np.full(n, np.nan, dtype=np.float64)
|
||||
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
|
||||
|
||||
# Skip until we have valid ATR
|
||||
start_idx = p
|
||||
if start_idx >= n:
|
||||
return trend
|
||||
|
||||
# Init first valid point
|
||||
hl2 = (high[start_idx] + low[start_idx]) / 2
|
||||
final_upper[start_idx] = hl2 + m * atr[start_idx]
|
||||
final_lower[start_idx] = hl2 - m * atr[start_idx]
|
||||
|
||||
# Loop
|
||||
for i in range(start_idx + 1, n):
|
||||
cur_hl2 = (high[i] + low[i]) / 2
|
||||
cur_atr = atr[i]
|
||||
basic_upper = cur_hl2 + m * cur_atr
|
||||
basic_lower = cur_hl2 - m * cur_atr
|
||||
|
||||
# Upper Band Logic
|
||||
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
|
||||
final_upper[i] = basic_upper
|
||||
else:
|
||||
final_upper[i] = final_upper[i-1]
|
||||
|
||||
# Lower Band Logic
|
||||
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
|
||||
final_lower[i] = basic_lower
|
||||
else:
|
||||
final_lower[i] = final_lower[i-1]
|
||||
|
||||
# Trend Logic
|
||||
if trend[i-1] == 1:
|
||||
if close[i] < final_lower[i-1]:
|
||||
trend[i] = -1
|
||||
else:
|
||||
trend[i] = 1
|
||||
else:
|
||||
if close[i] > final_upper[i-1]:
|
||||
trend[i] = 1
|
||||
else:
|
||||
trend[i] = -1
|
||||
|
||||
return trend
|
||||
|
||||
# --- VectorBT Indicator Factory ---
|
||||
|
||||
SuperTrendIndicator = vbt.IndicatorFactory(
|
||||
class_name='SuperTrend',
|
||||
short_name='st',
|
||||
input_names=['high', 'low', 'close'],
|
||||
param_names=['period', 'multiplier'],
|
||||
output_names=['trend']
|
||||
).from_apply_func(
|
||||
get_supertrend_nb,
|
||||
keep_pd=False, # Disable automatic Pandas wrapping of inputs
|
||||
param_product=True # Enable Cartesian product for list params
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user