9 Commits

Author SHA1 Message Date
1af0aab5fa feat: Add Multi-Pair Divergence Live Trading Module
- Introduced a new module for live trading based on the Multi-Pair Divergence Strategy.
- Implemented configuration classes for OKX API and multi-pair settings.
- Developed data feed functionality to fetch real-time OHLCV and funding data for multiple assets.
- Created a trading bot orchestrator to manage trading cycles, including entry and exit signals based on ML model predictions.
- Added comprehensive logging and error handling for robust operation.
- Included a README with setup instructions and usage guidelines for the new module.
2026-01-15 22:17:13 +08:00
df37366603 feat: Multi-Pair Divergence Selection Strategy
- Extend regime detection to top 10 cryptocurrencies (45 pairs)
- Dynamic pair selection based on divergence score (|z_score| * probability)
- Universal ML model trained on all pairs
- Correlation-based filtering to avoid redundant positions
- Funding rate integration from OKX for all 10 assets
- ATR-based dynamic stop-loss and take-profit
- Walk-forward training with 70/30 split

Performance: +35.69% return (vs +28.66% baseline), 63.6% win rate
2026-01-15 20:47:23 +08:00
7e4a6874a2 Cleaning up PRDs files 2026-01-15 10:51:42 +08:00
c4ecb29d4c Update trading configuration to allow full fund utilization and adjust base size calculation in strategy
- Changed `max_position_usdt` to -1.0 to indicate that all available funds should be used if the value is less than or equal to zero.
- Modified the base size calculation in `LiveRegimeStrategy` to accommodate the new logic for `max_position_usdt`, ensuring it uses all available funds when applicable.
2026-01-15 10:40:43 +08:00
0c82c4f366 Implement FastAPI backend and Vue 3 frontend for Lowkey Backtest UI
- Added FastAPI backend with core API endpoints for strategies, backtests, and data management.
- Introduced Vue 3 frontend with a dark theme, enabling users to run backtests, adjust parameters, and compare results.
- Implemented Pydantic schemas for request/response validation and SQLAlchemy models for database interactions.
- Enhanced project structure with dedicated modules for services, routers, and components.
- Updated dependencies in `pyproject.toml` and `frontend/package.json` to include FastAPI, SQLAlchemy, and Vue-related packages.
- Improved `.gitignore` to exclude unnecessary files and directories.
2026-01-14 21:44:04 +08:00
1e4cb87da3 Add check_symbols.py for ETH perpetuals filtering and enhance backtester with size handling
- Introduced `check_symbols.py` to load and filter ETH perpetual markets from the OKX exchange using CCXT.
- Updated the backtester to normalize signals to a 5-tuple format, incorporating size management for trades.
- Enhanced portfolio functions to support variable size and leverage adjustments based on initial capital.
- Added a new method in `CryptoQuantClient` for chunked historical data fetching to avoid API limits.
- Improved market symbol normalization in `market.py` to handle different formats.
- Updated regime strategy parameters based on recent research findings for optimal performance.
2026-01-14 09:46:51 +08:00
10bb371054 Implement Regime Reversion Strategy and remove regime_detection.py
- Introduced `RegimeReversionStrategy` for ML-based regime detection and mean reversion trading.
- Added feature engineering and model training logic within the new strategy.
- Removed the deprecated `regime_detection.py` file to streamline the codebase.
- Updated the strategy factory to include the new regime strategy configuration.
2026-01-13 21:55:34 +08:00
e6d69ed04d Add CryptoQuant client and regime detection analysis
- Introduced `CryptoQuantClient` for fetching data from the CryptoQuant API.
- Added `regime_detection.py` for advanced regime detection analysis using machine learning.
- Updated dependencies in `pyproject.toml` and `uv.lock` to include `scikit-learn`, `matplotlib`, `plotly`, `requests`, and `python-dotenv`.
- Enhanced `.gitignore` to exclude `regime_results.html` and CSV files.
- Created an interactive HTML plot for regime detection results and saved it as `regime_results.html`.
2026-01-13 16:13:57 +08:00
44fac1ed25 Remove deprecated modules and files related to the backtesting framework, including backtest.py, cli.py, config.py, data.py, intrabar.py, logging_utils.py, market_costs.py, metrics.py, trade.py, and supertrend indicators. Introduce a new structure for the backtesting engine with improved organization and functionality, including a CLI handler, data manager, and reporting capabilities. Update dependencies in pyproject.toml to support the new architecture. 2026-01-12 21:11:39 +08:00
109 changed files with 18222 additions and 394 deletions

7
.gitignore vendored
View File

@@ -169,4 +169,9 @@ cython_debug/
#.idea/ #.idea/
./logs/ ./logs/
*.csv *.csv
research/regime_results.html
data/backtest_runs.db
.gitignore
live_trading/regime_model.pkl
live_trading/positions.json

View File

@@ -1 +0,0 @@
../OHLCVPredictor

3
api/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
FastAPI backend for Lowkey Backtest UI.
"""

47
api/main.py Normal file
View File

@@ -0,0 +1,47 @@
"""
FastAPI application entry point for Lowkey Backtest UI.
Run with: uvicorn api.main:app --reload
"""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from api.models.database import init_db
from api.routers import backtest, data, strategies
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize database on startup."""
init_db()
yield
app = FastAPI(
title="Lowkey Backtest API",
description="API for running and analyzing trading strategy backtests",
version="0.1.0",
lifespan=lifespan,
)
# CORS configuration for local development
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Register routers
app.include_router(strategies.router, prefix="/api", tags=["strategies"])
app.include_router(data.router, prefix="/api", tags=["data"])
app.include_router(backtest.router, prefix="/api", tags=["backtest"])
@app.get("/api/health")
async def health_check():
"""Health check endpoint."""
return {"status": "ok", "service": "lowkey-backtest-api"}

3
api/models/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
Pydantic schemas and database models.
"""

99
api/models/database.py Normal file
View File

@@ -0,0 +1,99 @@
"""
SQLAlchemy database models and session management for backtest run persistence.
"""
import json
from datetime import datetime, timezone
from pathlib import Path
from sqlalchemy import JSON, Column, DateTime, Float, Integer, String, Text, create_engine
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
# Database file location
DB_PATH = Path(__file__).parent.parent.parent / "data" / "backtest_runs.db"
DATABASE_URL = f"sqlite:///{DB_PATH}"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
class Base(DeclarativeBase):
"""Base class for SQLAlchemy models."""
pass
class BacktestRun(Base):
"""
Persisted backtest run record.
Stores all information needed to display and compare runs.
"""
__tablename__ = "backtest_runs"
id = Column(Integer, primary_key=True, autoincrement=True)
run_id = Column(String(36), unique=True, nullable=False, index=True)
# Configuration
strategy = Column(String(50), nullable=False, index=True)
symbol = Column(String(20), nullable=False, index=True)
exchange = Column(String(20), nullable=False, default="okx")
market_type = Column(String(20), nullable=False)
timeframe = Column(String(10), nullable=False)
leverage = Column(Integer, nullable=False, default=1)
params = Column(JSON, nullable=False, default=dict)
# Date range
start_date = Column(String(20), nullable=True)
end_date = Column(String(20), nullable=True)
# Metrics (denormalized for quick listing)
total_return = Column(Float, nullable=False)
benchmark_return = Column(Float, nullable=False, default=0.0)
alpha = Column(Float, nullable=False, default=0.0)
sharpe_ratio = Column(Float, nullable=False)
max_drawdown = Column(Float, nullable=False)
win_rate = Column(Float, nullable=False)
total_trades = Column(Integer, nullable=False)
profit_factor = Column(Float, nullable=True)
total_fees = Column(Float, nullable=False, default=0.0)
total_funding = Column(Float, nullable=False, default=0.0)
liquidation_count = Column(Integer, nullable=False, default=0)
liquidation_loss = Column(Float, nullable=False, default=0.0)
adjusted_return = Column(Float, nullable=True)
# Full data (JSON serialized)
equity_curve = Column(Text, nullable=False) # JSON array
trades = Column(Text, nullable=False) # JSON array
# Metadata
created_at = Column(DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
def set_equity_curve(self, data: list[dict]):
"""Serialize equity curve to JSON string."""
self.equity_curve = json.dumps(data)
def get_equity_curve(self) -> list[dict]:
"""Deserialize equity curve from JSON string."""
return json.loads(self.equity_curve) if self.equity_curve else []
def set_trades(self, data: list[dict]):
"""Serialize trades to JSON string."""
self.trades = json.dumps(data)
def get_trades(self) -> list[dict]:
"""Deserialize trades from JSON string."""
return json.loads(self.trades) if self.trades else []
def init_db():
"""Create database tables if they don't exist."""
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
Base.metadata.create_all(bind=engine)
def get_db() -> Session:
"""Get database session (dependency injection)."""
db = SessionLocal()
try:
yield db
finally:
db.close()

162
api/models/schemas.py Normal file
View File

@@ -0,0 +1,162 @@
"""
Pydantic schemas for API request/response models.
"""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
# --- Strategy Schemas ---
class StrategyParam(BaseModel):
"""Single strategy parameter definition."""
name: str
value: Any
param_type: str = Field(description="Type: int, float, bool, list")
min_value: float | None = None
max_value: float | None = None
description: str | None = None
class StrategyInfo(BaseModel):
"""Strategy information with parameters."""
name: str
display_name: str
market_type: str
default_leverage: int
default_params: dict[str, Any]
grid_params: dict[str, Any]
class StrategiesResponse(BaseModel):
"""Response for GET /api/strategies."""
strategies: list[StrategyInfo]
# --- Symbol/Data Schemas ---
class SymbolInfo(BaseModel):
"""Available symbol information."""
symbol: str
exchange: str
market_type: str
timeframes: list[str]
start_date: str | None = None
end_date: str | None = None
row_count: int = 0
class DataStatusResponse(BaseModel):
"""Response for GET /api/data/status."""
symbols: list[SymbolInfo]
# --- Backtest Schemas ---
class BacktestRequest(BaseModel):
"""Request body for POST /api/backtest."""
strategy: str
symbol: str
exchange: str = "okx"
timeframe: str = "1h"
market_type: str = "perpetual"
start_date: str | None = None
end_date: str | None = None
init_cash: float = 10000.0
leverage: int | None = None
fees: float | None = None
slippage: float = 0.001
sl_stop: float | None = None
tp_stop: float | None = None
sl_trail: bool = False
params: dict[str, Any] = Field(default_factory=dict)
class TradeRecord(BaseModel):
"""Single trade record."""
entry_time: str
exit_time: str | None = None
entry_price: float
exit_price: float | None = None
size: float
direction: str
pnl: float | None = None
return_pct: float | None = None
status: str = "closed"
class EquityPoint(BaseModel):
"""Single point on equity curve."""
timestamp: str
value: float
drawdown: float = 0.0
class BacktestMetrics(BaseModel):
"""Backtest performance metrics."""
total_return: float
benchmark_return: float = 0.0
alpha: float = 0.0
sharpe_ratio: float
max_drawdown: float
win_rate: float
total_trades: int
profit_factor: float | None = None
avg_trade_return: float | None = None
total_fees: float = 0.0
total_funding: float = 0.0
liquidation_count: int = 0
liquidation_loss: float = 0.0
adjusted_return: float | None = None
class BacktestResult(BaseModel):
"""Complete backtest result."""
run_id: str
strategy: str
symbol: str
market_type: str
timeframe: str
start_date: str
end_date: str
leverage: int
params: dict[str, Any]
metrics: BacktestMetrics
equity_curve: list[EquityPoint]
trades: list[TradeRecord]
created_at: str
class BacktestSummary(BaseModel):
"""Summary for backtest list view."""
run_id: str
strategy: str
symbol: str
market_type: str
timeframe: str
total_return: float
sharpe_ratio: float
max_drawdown: float
total_trades: int
created_at: str
params: dict[str, Any]
class BacktestListResponse(BaseModel):
"""Response for GET /api/backtests."""
runs: list[BacktestSummary]
total: int
# --- Comparison Schemas ---
class CompareRequest(BaseModel):
"""Request body for POST /api/compare."""
run_ids: list[str] = Field(min_length=2, max_length=5)
class CompareResult(BaseModel):
"""Comparison of multiple backtest runs."""
runs: list[BacktestResult]
param_diff: dict[str, list[Any]]

3
api/routers/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
API routers for backtest, strategies, and data endpoints.
"""

193
api/routers/backtest.py Normal file
View File

@@ -0,0 +1,193 @@
"""
Backtest execution and history endpoints.
"""
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from api.models.database import get_db
from api.models.schemas import (
BacktestListResponse,
BacktestRequest,
BacktestResult,
CompareRequest,
CompareResult,
)
from api.services.runner import get_runner
from api.services.storage import get_storage
from engine.logging_config import get_logger
router = APIRouter()
logger = get_logger(__name__)
@router.post("/backtest", response_model=BacktestResult)
async def run_backtest(
request: BacktestRequest,
db: Session = Depends(get_db),
):
"""
Execute a backtest with the specified configuration.
Runs the strategy on historical data and returns metrics,
equity curve, and trade records. Results are automatically saved.
"""
runner = get_runner()
storage = get_storage()
try:
# Execute backtest
result = runner.run(request)
# Save to database
storage.save_run(db, result)
logger.info(
"Backtest completed and saved: %s (return=%.2f%%, sharpe=%.2f)",
result.run_id,
result.metrics.total_return,
result.metrics.sharpe_ratio,
)
return result
except KeyError as e:
raise HTTPException(status_code=400, detail=f"Invalid strategy: {e}")
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=f"Data not found: {e}")
except Exception as e:
logger.error("Backtest failed: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/backtests", response_model=BacktestListResponse)
async def list_backtests(
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
strategy: str | None = None,
symbol: str | None = None,
db: Session = Depends(get_db),
):
"""
List saved backtest runs with optional filtering.
Returns summaries for quick display in the history sidebar.
"""
storage = get_storage()
runs, total = storage.list_runs(
db,
limit=limit,
offset=offset,
strategy=strategy,
symbol=symbol,
)
return BacktestListResponse(runs=runs, total=total)
@router.get("/backtest/{run_id}", response_model=BacktestResult)
async def get_backtest(
run_id: str,
db: Session = Depends(get_db),
):
"""
Retrieve a specific backtest run by ID.
Returns full result including equity curve and trades.
"""
storage = get_storage()
result = storage.get_run(db, run_id)
if not result:
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
return result
@router.delete("/backtest/{run_id}")
async def delete_backtest(
run_id: str,
db: Session = Depends(get_db),
):
"""
Delete a backtest run.
"""
storage = get_storage()
deleted = storage.delete_run(db, run_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
return {"status": "deleted", "run_id": run_id}
@router.post("/compare", response_model=CompareResult)
async def compare_runs(
request: CompareRequest,
db: Session = Depends(get_db),
):
"""
Compare multiple backtest runs (2-5 runs).
Returns full results for each run plus parameter differences.
"""
storage = get_storage()
runs = storage.get_runs_by_ids(db, request.run_ids)
if len(runs) != len(request.run_ids):
found_ids = {r.run_id for r in runs}
missing = [rid for rid in request.run_ids if rid not in found_ids]
raise HTTPException(
status_code=404,
detail=f"Runs not found: {missing}"
)
# Calculate parameter differences
param_diff = _calculate_param_diff(runs)
return CompareResult(runs=runs, param_diff=param_diff)
def _calculate_param_diff(runs: list[BacktestResult]) -> dict[str, list[Any]]:
"""
Find parameters that differ between runs.
Returns dict mapping param name to list of values (one per run).
"""
if not runs:
return {}
# Collect all param keys
all_keys: set[str] = set()
for run in runs:
all_keys.update(run.params.keys())
# Also include strategy and key config
all_keys.update(['strategy', 'symbol', 'leverage', 'timeframe'])
diff: dict[str, list[Any]] = {}
for key in sorted(all_keys):
values = []
for run in runs:
if key == 'strategy':
values.append(run.strategy)
elif key == 'symbol':
values.append(run.symbol)
elif key == 'leverage':
values.append(run.leverage)
elif key == 'timeframe':
values.append(run.timeframe)
else:
values.append(run.params.get(key))
# Only include if values differ
if len(set(str(v) for v in values)) > 1:
diff[key] = values
return diff

97
api/routers/data.py Normal file
View File

@@ -0,0 +1,97 @@
"""
Data status and symbol information endpoints.
"""
from pathlib import Path
import pandas as pd
from fastapi import APIRouter
from api.models.schemas import DataStatusResponse, SymbolInfo
router = APIRouter()
# Base path for CCXT data
DATA_BASE = Path(__file__).parent.parent.parent / "data" / "ccxt"
def _scan_available_data() -> list[SymbolInfo]:
"""
Scan the data directory for available symbols and timeframes.
Returns list of SymbolInfo with date ranges and row counts.
"""
symbols = []
if not DATA_BASE.exists():
return symbols
# Structure: data/ccxt/{exchange}/{market_type}/{symbol}/{timeframe}.csv
for exchange_dir in DATA_BASE.iterdir():
if not exchange_dir.is_dir():
continue
exchange = exchange_dir.name
for market_dir in exchange_dir.iterdir():
if not market_dir.is_dir():
continue
market_type = market_dir.name
for symbol_dir in market_dir.iterdir():
if not symbol_dir.is_dir():
continue
symbol = symbol_dir.name
# Find all timeframes
timeframes = []
start_date = None
end_date = None
row_count = 0
for csv_file in symbol_dir.glob("*.csv"):
tf = csv_file.stem
timeframes.append(tf)
# Read first and last rows for date range
try:
df = pd.read_csv(csv_file, parse_dates=['timestamp'])
if not df.empty:
row_count = len(df)
start_date = df['timestamp'].min().strftime("%Y-%m-%d")
end_date = df['timestamp'].max().strftime("%Y-%m-%d")
except Exception:
pass
if timeframes:
symbols.append(SymbolInfo(
symbol=symbol,
exchange=exchange,
market_type=market_type,
timeframes=sorted(timeframes),
start_date=start_date,
end_date=end_date,
row_count=row_count,
))
return symbols
@router.get("/symbols", response_model=DataStatusResponse)
async def get_symbols():
"""
Get list of available symbols with their data ranges.
Scans the local data directory for downloaded OHLCV data.
"""
symbols = _scan_available_data()
return DataStatusResponse(symbols=symbols)
@router.get("/data/status", response_model=DataStatusResponse)
async def get_data_status():
"""
Get detailed data inventory status.
Alias for /symbols with additional metadata.
"""
symbols = _scan_available_data()
return DataStatusResponse(symbols=symbols)

67
api/routers/strategies.py Normal file
View File

@@ -0,0 +1,67 @@
"""
Strategy information endpoints.
"""
from typing import Any
import numpy as np
from fastapi import APIRouter
from api.models.schemas import StrategiesResponse, StrategyInfo
from strategies.factory import get_registry
router = APIRouter()
def _serialize_param_value(value: Any) -> Any:
"""Convert numpy arrays and other types to JSON-serializable format."""
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, (np.integer, np.floating)):
return value.item()
return value
def _get_display_name(name: str) -> str:
"""Convert strategy key to display name."""
display_names = {
"rsi": "RSI Strategy",
"macross": "MA Crossover",
"meta_st": "Meta Supertrend",
"regime": "Regime Reversion (ML)",
}
return display_names.get(name, name.replace("_", " ").title())
@router.get("/strategies", response_model=StrategiesResponse)
async def get_strategies():
"""
Get list of available strategies with their parameters.
Returns strategy names, default parameters, and grid search ranges.
"""
registry = get_registry()
strategies = []
for name, config in registry.items():
strategy_instance = config.strategy_class()
# Serialize parameters (convert numpy arrays to lists)
default_params = {
k: _serialize_param_value(v)
for k, v in config.default_params.items()
}
grid_params = {
k: _serialize_param_value(v)
for k, v in config.grid_params.items()
}
strategies.append(StrategyInfo(
name=name,
display_name=_get_display_name(name),
market_type=strategy_instance.default_market_type.value,
default_leverage=strategy_instance.default_leverage,
default_params=default_params,
grid_params=grid_params,
))
return StrategiesResponse(strategies=strategies)

3
api/services/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
Business logic services for backtest execution and storage.
"""

300
api/services/runner.py Normal file
View File

@@ -0,0 +1,300 @@
"""
Backtest runner service that wraps the existing Backtester engine.
"""
import uuid
from datetime import datetime, timezone
from typing import Any
import pandas as pd
from api.models.schemas import (
BacktestMetrics,
BacktestRequest,
BacktestResult,
EquityPoint,
TradeRecord,
)
from engine.backtester import Backtester
from engine.data_manager import DataManager
from engine.logging_config import get_logger
from engine.market import MarketType
from strategies.factory import get_strategy
logger = get_logger(__name__)
class BacktestRunner:
"""
Service for executing backtests via the API.
Wraps the existing Backtester engine and converts results
to API response format.
"""
def __init__(self):
self.dm = DataManager()
self.bt = Backtester(self.dm)
def run(self, request: BacktestRequest) -> BacktestResult:
"""
Execute a backtest and return structured results.
Args:
request: BacktestRequest with strategy, symbol, and parameters
Returns:
BacktestResult with metrics, equity curve, and trades
"""
# Get strategy instance
strategy, default_params = get_strategy(request.strategy, is_grid=False)
# Merge default params with request params
params = {**default_params, **request.params}
# Convert market type string to enum
market_type = MarketType(request.market_type)
# Override strategy market type if specified
strategy.default_market_type = market_type
logger.info(
"Running backtest: %s on %s (%s), params=%s",
request.strategy, request.symbol, request.timeframe, params
)
# Execute backtest
result = self.bt.run_strategy(
strategy=strategy,
exchange_id=request.exchange,
symbol=request.symbol,
timeframe=request.timeframe,
start_date=request.start_date,
end_date=request.end_date,
init_cash=request.init_cash,
fees=request.fees,
slippage=request.slippage,
sl_stop=request.sl_stop,
tp_stop=request.tp_stop,
sl_trail=request.sl_trail,
leverage=request.leverage,
**params
)
# Extract data from portfolio
portfolio = result.portfolio
# Build trade records
trades = self._build_trade_records(portfolio)
# Build equity curve (trimmed to trading period)
equity_curve = self._build_equity_curve(portfolio)
# Build metrics
metrics = self._build_metrics(result, portfolio)
# Get date range from actual trading period (first trade to end)
idx = portfolio.wrapper.index
end_date = idx[-1].strftime("%Y-%m-%d %H:%M")
# Use first trade time as start if trades exist
trades_df = portfolio.trades.records_readable
if not trades_df.empty:
first_entry_col = 'Entry Timestamp' if 'Entry Timestamp' in trades_df.columns else 'Entry Time'
if first_entry_col in trades_df.columns:
first_trade_time = pd.to_datetime(trades_df[first_entry_col].iloc[0])
start_date = first_trade_time.strftime("%Y-%m-%d %H:%M")
else:
start_date = idx[0].strftime("%Y-%m-%d %H:%M")
else:
start_date = idx[0].strftime("%Y-%m-%d %H:%M")
return BacktestResult(
run_id=str(uuid.uuid4()),
strategy=request.strategy,
symbol=request.symbol,
market_type=result.market_type.value,
timeframe=request.timeframe,
start_date=start_date,
end_date=end_date,
leverage=result.leverage,
params=params,
metrics=metrics,
equity_curve=equity_curve,
trades=trades,
created_at=datetime.now(timezone.utc).isoformat(),
)
def _build_equity_curve(self, portfolio) -> list[EquityPoint]:
"""Extract equity curve with drawdown from portfolio, starting from first trade."""
value_series = portfolio.value()
drawdown_series = portfolio.drawdown()
# Handle multi-column case (from grid search)
if hasattr(value_series, 'columns') and len(value_series.columns) > 1:
value_series = value_series.iloc[:, 0]
drawdown_series = drawdown_series.iloc[:, 0]
elif hasattr(value_series, 'columns'):
value_series = value_series.iloc[:, 0]
drawdown_series = drawdown_series.iloc[:, 0]
# Find first trade time to trim equity curve
first_trade_idx = 0
trades_df = portfolio.trades.records_readable
if not trades_df.empty:
first_entry_col = 'Entry Timestamp' if 'Entry Timestamp' in trades_df.columns else 'Entry Time'
if first_entry_col in trades_df.columns:
first_trade_time = pd.to_datetime(trades_df[first_entry_col].iloc[0])
# Find index in value_series closest to first trade
if hasattr(value_series.index, 'get_indexer'):
first_trade_idx = value_series.index.get_indexer([first_trade_time], method='nearest')[0]
# Start a few bars before first trade for context
first_trade_idx = max(0, first_trade_idx - 5)
# Slice from first trade onwards
value_series = value_series.iloc[first_trade_idx:]
drawdown_series = drawdown_series.iloc[first_trade_idx:]
points = []
for i, (ts, val) in enumerate(value_series.items()):
dd = drawdown_series.iloc[i] if i < len(drawdown_series) else 0.0
points.append(EquityPoint(
timestamp=ts.isoformat(),
value=float(val),
drawdown=float(dd) * 100, # Convert to percentage
))
return points
def _build_trade_records(self, portfolio) -> list[TradeRecord]:
"""Extract trade records from portfolio."""
trades_df = portfolio.trades.records_readable
if trades_df.empty:
return []
records = []
for _, row in trades_df.iterrows():
# Handle different column names in vectorbt
entry_time = row.get('Entry Timestamp', row.get('Entry Time', ''))
exit_time = row.get('Exit Timestamp', row.get('Exit Time', ''))
records.append(TradeRecord(
entry_time=str(entry_time) if pd.notna(entry_time) else "",
exit_time=str(exit_time) if pd.notna(exit_time) else None,
entry_price=float(row.get('Avg Entry Price', row.get('Entry Price', 0))),
exit_price=float(row.get('Avg Exit Price', row.get('Exit Price', 0)))
if pd.notna(row.get('Avg Exit Price', row.get('Exit Price'))) else None,
size=float(row.get('Size', 0)),
direction=str(row.get('Direction', 'Long')),
pnl=float(row.get('PnL', 0)) if pd.notna(row.get('PnL')) else None,
return_pct=float(row.get('Return', 0)) * 100
if pd.notna(row.get('Return')) else None,
status="closed" if pd.notna(exit_time) else "open",
))
return records
def _build_metrics(self, result, portfolio) -> BacktestMetrics:
"""Build metrics from backtest result."""
stats = portfolio.stats()
# Extract values, handling potential multi-column results
def get_stat(key: str, default: float = 0.0) -> float:
val = stats.get(key, default)
if hasattr(val, 'mean'):
return float(val.mean())
return float(val) if pd.notna(val) else default
total_return = portfolio.total_return()
if hasattr(total_return, 'mean'):
total_return = total_return.mean()
# Calculate benchmark return from first trade to end (not full period)
# This gives accurate comparison when strategy has training period
close = portfolio.close
benchmark_return = 0.0
if hasattr(close, 'iloc'):
# Find first trade entry time
trades_df = portfolio.trades.records_readable
if not trades_df.empty:
# Get the first trade entry timestamp
first_entry_col = 'Entry Timestamp' if 'Entry Timestamp' in trades_df.columns else 'Entry Time'
if first_entry_col in trades_df.columns:
first_trade_time = pd.to_datetime(trades_df[first_entry_col].iloc[0])
# Find the price at first trade
if hasattr(close.index, 'get_indexer'):
# Find closest index to first trade time
idx = close.index.get_indexer([first_trade_time], method='nearest')[0]
start_price = close.iloc[idx]
else:
start_price = close.iloc[0]
end_price = close.iloc[-1]
if hasattr(start_price, 'mean'):
start_price = start_price.mean()
if hasattr(end_price, 'mean'):
end_price = end_price.mean()
benchmark_return = ((end_price - start_price) / start_price)
else:
# No trades - use full period
start_price = close.iloc[0]
end_price = close.iloc[-1]
if hasattr(start_price, 'mean'):
start_price = start_price.mean()
if hasattr(end_price, 'mean'):
end_price = end_price.mean()
benchmark_return = ((end_price - start_price) / start_price)
# Alpha = strategy return - benchmark return
alpha = float(total_return) - float(benchmark_return)
sharpe = portfolio.sharpe_ratio()
if hasattr(sharpe, 'mean'):
sharpe = sharpe.mean()
max_dd = portfolio.max_drawdown()
if hasattr(max_dd, 'mean'):
max_dd = max_dd.mean()
win_rate = portfolio.trades.win_rate()
if hasattr(win_rate, 'mean'):
win_rate = win_rate.mean()
trade_count = portfolio.trades.count()
if hasattr(trade_count, 'mean'):
trade_count = int(trade_count.mean())
else:
trade_count = int(trade_count)
return BacktestMetrics(
total_return=float(total_return) * 100,
benchmark_return=float(benchmark_return) * 100,
alpha=float(alpha) * 100,
sharpe_ratio=float(sharpe) if pd.notna(sharpe) else 0.0,
max_drawdown=float(max_dd) * 100,
win_rate=float(win_rate) * 100 if pd.notna(win_rate) else 0.0,
total_trades=trade_count,
profit_factor=get_stat('Profit Factor'),
avg_trade_return=get_stat('Avg Winning Trade [%]'),
total_fees=get_stat('Total Fees Paid'),
total_funding=result.total_funding_paid,
liquidation_count=result.liquidation_count,
liquidation_loss=result.total_liquidation_loss,
adjusted_return=result.adjusted_return,
)
# Singleton instance
_runner: BacktestRunner | None = None
def get_runner() -> BacktestRunner:
"""Get or create the backtest runner instance."""
global _runner
if _runner is None:
_runner = BacktestRunner()
return _runner

225
api/services/storage.py Normal file
View File

@@ -0,0 +1,225 @@
"""
Storage service for persisting and retrieving backtest runs.
"""
from sqlalchemy.orm import Session
from api.models.database import BacktestRun
from api.models.schemas import (
BacktestResult,
BacktestSummary,
EquityPoint,
BacktestMetrics,
TradeRecord,
)
class StorageService:
"""
Service for saving and loading backtest runs from SQLite.
"""
def save_run(self, db: Session, result: BacktestResult) -> BacktestRun:
"""
Persist a backtest result to the database.
Args:
db: Database session
result: BacktestResult to save
Returns:
Created BacktestRun record
"""
run = BacktestRun(
run_id=result.run_id,
strategy=result.strategy,
symbol=result.symbol,
market_type=result.market_type,
timeframe=result.timeframe,
leverage=result.leverage,
params=result.params,
start_date=result.start_date,
end_date=result.end_date,
total_return=result.metrics.total_return,
benchmark_return=result.metrics.benchmark_return,
alpha=result.metrics.alpha,
sharpe_ratio=result.metrics.sharpe_ratio,
max_drawdown=result.metrics.max_drawdown,
win_rate=result.metrics.win_rate,
total_trades=result.metrics.total_trades,
profit_factor=result.metrics.profit_factor,
total_fees=result.metrics.total_fees,
total_funding=result.metrics.total_funding,
liquidation_count=result.metrics.liquidation_count,
liquidation_loss=result.metrics.liquidation_loss,
adjusted_return=result.metrics.adjusted_return,
)
# Serialize complex data
run.set_equity_curve([p.model_dump() for p in result.equity_curve])
run.set_trades([t.model_dump() for t in result.trades])
db.add(run)
db.commit()
db.refresh(run)
return run
def get_run(self, db: Session, run_id: str) -> BacktestResult | None:
"""
Retrieve a backtest run by ID.
Args:
db: Database session
run_id: UUID of the run
Returns:
BacktestResult or None if not found
"""
run = db.query(BacktestRun).filter(BacktestRun.run_id == run_id).first()
if not run:
return None
return self._to_result(run)
def list_runs(
self,
db: Session,
limit: int = 50,
offset: int = 0,
strategy: str | None = None,
symbol: str | None = None,
) -> tuple[list[BacktestSummary], int]:
"""
List backtest runs with optional filtering.
Args:
db: Database session
limit: Maximum number of runs to return
offset: Offset for pagination
strategy: Filter by strategy name
symbol: Filter by symbol
Returns:
Tuple of (list of summaries, total count)
"""
query = db.query(BacktestRun)
if strategy:
query = query.filter(BacktestRun.strategy == strategy)
if symbol:
query = query.filter(BacktestRun.symbol == symbol)
total = query.count()
runs = query.order_by(BacktestRun.created_at.desc()).offset(offset).limit(limit).all()
summaries = [self._to_summary(run) for run in runs]
return summaries, total
def get_runs_by_ids(self, db: Session, run_ids: list[str]) -> list[BacktestResult]:
"""
Retrieve multiple runs by their IDs.
Args:
db: Database session
run_ids: List of run UUIDs
Returns:
List of BacktestResults (preserves order)
"""
runs = db.query(BacktestRun).filter(BacktestRun.run_id.in_(run_ids)).all()
# Create lookup and preserve order
run_map = {run.run_id: run for run in runs}
results = []
for run_id in run_ids:
if run_id in run_map:
results.append(self._to_result(run_map[run_id]))
return results
def delete_run(self, db: Session, run_id: str) -> bool:
"""
Delete a backtest run.
Args:
db: Database session
run_id: UUID of the run to delete
Returns:
True if deleted, False if not found
"""
run = db.query(BacktestRun).filter(BacktestRun.run_id == run_id).first()
if not run:
return False
db.delete(run)
db.commit()
return True
def _to_result(self, run: BacktestRun) -> BacktestResult:
"""Convert database record to BacktestResult schema."""
equity_data = run.get_equity_curve()
trades_data = run.get_trades()
return BacktestResult(
run_id=run.run_id,
strategy=run.strategy,
symbol=run.symbol,
market_type=run.market_type,
timeframe=run.timeframe,
start_date=run.start_date or "",
end_date=run.end_date or "",
leverage=run.leverage,
params=run.params or {},
metrics=BacktestMetrics(
total_return=run.total_return,
benchmark_return=run.benchmark_return or 0.0,
alpha=run.alpha or 0.0,
sharpe_ratio=run.sharpe_ratio,
max_drawdown=run.max_drawdown,
win_rate=run.win_rate,
total_trades=run.total_trades,
profit_factor=run.profit_factor,
total_fees=run.total_fees,
total_funding=run.total_funding,
liquidation_count=run.liquidation_count,
liquidation_loss=run.liquidation_loss,
adjusted_return=run.adjusted_return,
),
equity_curve=[EquityPoint(**p) for p in equity_data],
trades=[TradeRecord(**t) for t in trades_data],
created_at=run.created_at.isoformat() if run.created_at else "",
)
def _to_summary(self, run: BacktestRun) -> BacktestSummary:
"""Convert database record to BacktestSummary schema."""
return BacktestSummary(
run_id=run.run_id,
strategy=run.strategy,
symbol=run.symbol,
market_type=run.market_type,
timeframe=run.timeframe,
total_return=run.total_return,
sharpe_ratio=run.sharpe_ratio,
max_drawdown=run.max_drawdown,
total_trades=run.total_trades,
created_at=run.created_at.isoformat() if run.created_at else "",
params=run.params or {},
)
# Singleton instance
_storage: StorageService | None = None
def get_storage() -> StorageService:
"""Get or create the storage service instance."""
global _storage
if _storage is None:
_storage = StorageService()
return _storage

View File

@@ -1,70 +0,0 @@
from __future__ import annotations
import pandas as pd
from pathlib import Path
from trade import TradeState, enter_long, exit_long, maybe_trailing_stop
from indicators import add_supertrends, compute_meta_trend
from metrics import compute_metrics
from logging_utils import write_trade_log
DEFAULT_ST_SETTINGS = [(12, 3.0), (10, 1.0), (11, 2.0)]
def backtest(
df: pd.DataFrame,
df_1min: pd.DataFrame,
timeframe_minutes: int,
stop_loss: float,
exit_on_bearish_flip: bool,
fee_bps: float,
slippage_bps: float,
log_path: Path | None = None,
):
df = add_supertrends(df, DEFAULT_ST_SETTINGS)
df["meta_bull"] = compute_meta_trend(df, DEFAULT_ST_SETTINGS)
state = TradeState(stop_loss_frac=stop_loss, fee_bps=fee_bps, slippage_bps=slippage_bps)
equity, trades = [], []
for i, row in df.iterrows():
price = float(row["Close"])
ts = pd.Timestamp(row["Timestamp"])
if state.qty <= 0 and row["meta_bull"] == 1:
evt = enter_long(state, price)
if evt:
evt.update({"t": ts.isoformat(), "reason": "bull_flip"})
trades.append(evt)
start = ts
end = df["Timestamp"].iat[i + 1] if i + 1 < len(df) else ts + pd.Timedelta(minutes=timeframe_minutes)
if state.qty > 0:
win = df_1min[(df_1min["Timestamp"] >= start) & (df_1min["Timestamp"] < end)]
for _, m in win.iterrows():
hi = float(m["High"])
lo = float(m["Low"])
state.max_px = max(state.max_px or hi, hi)
trail = state.max_px * (1.0 - state.stop_loss_frac)
if lo <= trail:
evt = exit_long(state, trail)
if evt:
prev = trades[-1]
pnl = (evt["price"] - (prev.get("price") or evt["price"])) * (prev.get("qty") or 0.0)
evt.update({"t": pd.Timestamp(m["Timestamp"]).isoformat(), "reason": "stop", "pnl": pnl})
trades.append(evt)
break
if state.qty > 0 and exit_on_bearish_flip and row["meta_bull"] == 0:
evt = exit_long(state, price)
if evt:
prev = trades[-1]
pnl = (evt["price"] - (prev.get("price") or evt["price"])) * (prev.get("qty") or 0.0)
evt.update({"t": ts.isoformat(), "reason": "bearish_flip", "pnl": pnl})
trades.append(evt)
equity.append(state.cash + state.qty * price)
equity_curve = pd.Series(equity, index=df["Timestamp"])
if log_path:
write_trade_log(trades, log_path)
perf = compute_metrics(equity_curve, trades)
return perf, equity_curve, trades

98
check_demo_account.py Normal file
View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python3
"""
Check OKX demo account positions and recent orders.
Usage:
uv run python check_demo_account.py
"""
import sys
from pathlib import Path
from datetime import datetime, timezone
sys.path.insert(0, str(Path(__file__).parent))
from live_trading.config import OKXConfig
import ccxt
def main():
"""Check demo account status."""
config = OKXConfig()
print(f"\n{'='*60}")
print(f" OKX Demo Account Check")
print(f"{'='*60}")
print(f" Demo Mode: {config.demo_mode}")
print(f" API Key: {config.api_key[:8]}..." if config.api_key else " API Key: NOT SET")
print(f"{'='*60}\n")
exchange = ccxt.okx({
'apiKey': config.api_key,
'secret': config.secret,
'password': config.password,
'sandbox': config.demo_mode,
'options': {'defaultType': 'swap'},
'enableRateLimit': True,
})
# Check balance
print("--- BALANCE ---")
balance = exchange.fetch_balance()
usdt = balance.get('USDT', {})
print(f"USDT Total: {usdt.get('total', 0):.2f}")
print(f"USDT Free: {usdt.get('free', 0):.2f}")
print(f"USDT Used: {usdt.get('used', 0):.2f}")
# Check all balances
print("\n--- ALL NON-ZERO BALANCES ---")
for currency, data in balance.items():
if isinstance(data, dict) and data.get('total', 0) > 0:
print(f"{currency}: total={data.get('total', 0):.6f}, free={data.get('free', 0):.6f}")
# Check open positions
print("\n--- OPEN POSITIONS ---")
positions = exchange.fetch_positions()
open_positions = [p for p in positions if abs(float(p.get('contracts', 0))) > 0]
if open_positions:
for pos in open_positions:
print(f" {pos['symbol']}: {pos['side']} {pos['contracts']} contracts @ {pos.get('entryPrice', 'N/A')}")
print(f" Unrealized PnL: {pos.get('unrealizedPnl', 'N/A')}")
else:
print(" No open positions")
# Check recent orders (last 50)
print("\n--- RECENT ORDERS (last 24h) ---")
try:
# Fetch closed orders for AVAX
orders = exchange.fetch_orders('AVAX/USDT:USDT', limit=20)
if orders:
for order in orders[-10:]: # Last 10
ts = datetime.fromtimestamp(order['timestamp']/1000, tz=timezone.utc)
print(f" [{ts.strftime('%H:%M:%S')}] {order['side'].upper()} {order['amount']} AVAX @ {order.get('average', order.get('price', 'market'))}")
print(f" Status: {order['status']}, Filled: {order.get('filled', 0)}, ID: {order['id']}")
else:
print(" No recent AVAX orders")
except Exception as e:
print(f" Could not fetch orders: {e}")
# Check order history more broadly
print("\n--- ORDER HISTORY (AVAX) ---")
try:
# Try fetching my trades
trades = exchange.fetch_my_trades('AVAX/USDT:USDT', limit=10)
if trades:
for trade in trades[-5:]:
ts = datetime.fromtimestamp(trade['timestamp']/1000, tz=timezone.utc)
print(f" [{ts.strftime('%Y-%m-%d %H:%M:%S')}] {trade['side'].upper()} {trade['amount']} @ {trade['price']}")
print(f" Fee: {trade.get('fee', {}).get('cost', 'N/A')} {trade.get('fee', {}).get('currency', '')}")
else:
print(" No recent AVAX trades")
except Exception as e:
print(f" Could not fetch trades: {e}")
print(f"\n{'='*60}\n")
if __name__ == "__main__":
main()

28
check_symbols.py Normal file
View File

@@ -0,0 +1,28 @@
import ccxt
import sys
def main():
try:
exchange = ccxt.okx()
print("Loading markets...")
markets = exchange.load_markets()
# Filter for ETH perpetuals
eth_perps = [
symbol for symbol, market in markets.items()
if 'ETH' in symbol and 'USDT' in symbol and market.get('swap') and market.get('linear')
]
print(f"\nFound {len(eth_perps)} ETH Linear Perps:")
for symbol in eth_perps:
market = markets[symbol]
print(f" CCXT Symbol: {symbol}")
print(f" Exchange ID: {market['id']}")
print(f" Type: {market['type']}")
print("-" * 30)
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()

80
cli.py
View File

@@ -1,80 +0,0 @@
from __future__ import annotations
import argparse
from pathlib import Path
import pandas as pd
from config import CLIConfig
from data import load_data
from backtest import backtest
def parse_args() -> CLIConfig:
p = argparse.ArgumentParser(prog="bt", description="Simple supertrend backtester")
p.add_argument("start")
p.add_argument("end")
p.add_argument("--timeframe-minutes", type=int, default=15) # single TF
p.add_argument("--timeframes-minutes", nargs="+", type=int) # multi TF: e.g. 5 15 60 240
p.add_argument("--stop-loss", dest="stop_losses", type=float, nargs="+", default=[0.02, 0.05])
p.add_argument("--exit-on-bearish-flip", action="store_true")
p.add_argument("--csv", dest="csv_path", type=Path, required=True)
p.add_argument("--out-csv", type=Path, default=Path("summary.csv"))
p.add_argument("--log-dir", type=Path, default=Path("./logs"))
p.add_argument("--fee-bps", type=float, default=10.0)
p.add_argument("--slippage-bps", type=float, default=2.0)
a = p.parse_args()
return CLIConfig(
start=a.start,
end=a.end,
timeframe_minutes=a.timeframe_minutes,
timeframes_minutes=a.timeframes_minutes,
stop_losses=a.stop_losses,
exit_on_bearish_flip=a.exit_on_bearish_flip,
csv_path=a.csv_path,
out_csv=a.out_csv,
log_dir=a.log_dir,
fee_bps=a.fee_bps,
slippage_bps=a.slippage_bps,
)
def main():
cfg = parse_args()
frames = cfg.timeframes_minutes or [cfg.timeframe_minutes]
rows: list[dict] = []
for tfm in frames:
df_1min, df = load_data(cfg.start, cfg.end, tfm, cfg.csv_path)
for sl in cfg.stop_losses:
log_path = cfg.log_dir / f"{tfm}m_sl{sl:.2%}.csv"
perf, equity, _ = backtest(
df=df,
df_1min=df_1min,
timeframe_minutes=tfm,
stop_loss=sl,
exit_on_bearish_flip=cfg.exit_on_bearish_flip,
fee_bps=cfg.fee_bps,
slippage_bps=cfg.slippage_bps,
log_path=log_path,
)
rows.append({
"timeframe": f"{tfm}min",
"stop_loss": sl,
"exit_on_bearish_flip": cfg.exit_on_bearish_flip,
"total_return": f"{perf.total_return:.2%}",
"max_drawdown": f"{perf.max_drawdown:.2%}",
"sharpe_ratio": f"{perf.sharpe_ratio:.2f}",
"win_rate": f"{perf.win_rate:.2%}",
"num_trades": perf.num_trades,
"final_equity": f"${perf.final_equity:.2f}",
"initial_equity": f"${perf.initial_equity:.2f}",
"num_stop_losses": perf.num_stop_losses,
"total_fees": perf.total_fees,
"total_slippage_usd": perf.total_slippage_usd,
"avg_slippage_bps": perf.avg_slippage_bps,
})
out = pd.DataFrame(rows)
out.to_csv(cfg.out_csv, index=False)
print(out.to_string(index=False))
if __name__ == "__main__":
main()

View File

@@ -1,18 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence
@dataclass
class CLIConfig:
start: str
end: str
timeframe_minutes: int
timeframes_minutes: list[int] | None
stop_losses: Sequence[float]
exit_on_bearish_flip: bool
csv_path: Path | None
out_csv: Path
log_dir: Path
fee_bps: float
slippage_bps: float

24
data.py
View File

@@ -1,24 +0,0 @@
from __future__ import annotations
import pandas as pd
from pathlib import Path
def load_data(start: str, end: str, timeframe_minutes: int, csv_path: Path) -> tuple[pd.DataFrame, pd.DataFrame]:
df_1min = pd.read_csv(csv_path)
df_1min["Timestamp"] = pd.to_datetime(df_1min["Timestamp"], unit="s", utc=True)
df_1min = df_1min[(df_1min["Timestamp"] >= pd.Timestamp(start, tz="UTC")) &
(df_1min["Timestamp"] <= pd.Timestamp(end, tz="UTC"))] \
.sort_values("Timestamp").reset_index(drop=True)
if timeframe_minutes != 1:
g = df_1min.set_index("Timestamp").resample(f"{timeframe_minutes}min")
df = pd.DataFrame({
"Open": g["Open"].first(),
"High": g["High"].max(),
"Low": g["Low"].min(),
"Close": g["Close"].last(),
"Volume": g["Volume"].sum(),
}).dropna().reset_index()
else:
df = df_1min.copy()
return df_1min, df

BIN
data/multi_pair_model.pkl Normal file

Binary file not shown.

377
engine/backtester.py Normal file
View File

@@ -0,0 +1,377 @@
"""
Core backtesting engine for running strategy simulations.
Supports multiple market types with realistic trading conditions.
"""
from dataclasses import dataclass
import pandas as pd
import vectorbt as vbt
from engine.data_manager import DataManager
from engine.logging_config import get_logger
from engine.market import MarketType, get_market_config
from engine.optimizer import WalkForwardOptimizer
from engine.portfolio import run_long_only_portfolio, run_long_short_portfolio
from engine.risk import (
LiquidationEvent,
calculate_funding,
calculate_liquidation_adjustment,
inject_liquidation_exits,
)
from strategies.base import BaseStrategy
logger = get_logger(__name__)
@dataclass
class BacktestResult:
"""
Container for backtest results with market-specific metrics.
Attributes:
portfolio: VectorBT Portfolio object
market_type: Market type used for the backtest
leverage: Effective leverage used
total_funding_paid: Total funding fees paid (perpetuals only)
liquidation_count: Number of positions that were liquidated
liquidation_events: Detailed list of liquidation events
total_liquidation_loss: Total margin lost from liquidations
adjusted_return: Return adjusted for liquidation losses (percentage)
"""
portfolio: vbt.Portfolio
market_type: MarketType
leverage: int
total_funding_paid: float = 0.0
liquidation_count: int = 0
liquidation_events: list[LiquidationEvent] | None = None
total_liquidation_loss: float = 0.0
adjusted_return: float | None = None
class Backtester:
"""
Backtester supporting multiple market types with realistic simulation.
Features:
- Spot and Perpetual market support
- Long and short position handling
- Leverage simulation
- Funding rate calculation (perpetuals)
- Liquidation warnings
"""
def __init__(self, data_manager: DataManager):
self.dm = data_manager
def run_strategy(
self,
strategy: BaseStrategy,
exchange_id: str,
symbol: str,
timeframe: str = '1m',
start_date: str | None = None,
end_date: str | None = None,
init_cash: float = 10000,
fees: float | None = None,
slippage: float = 0.001,
sl_stop: float | None = None,
tp_stop: float | None = None,
sl_trail: bool = False,
leverage: int | None = None,
**strategy_params
) -> BacktestResult:
"""
Run a backtest with market-type-aware simulation.
Args:
strategy: Strategy instance to backtest
exchange_id: Exchange identifier (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Data timeframe (e.g., '1m', '1h', '1d')
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
init_cash: Initial capital (margin for leveraged)
fees: Transaction fee override (uses market default if None)
slippage: Slippage percentage
sl_stop: Stop loss percentage
tp_stop: Take profit percentage
sl_trail: Enable trailing stop loss
leverage: Leverage override (uses strategy default if None)
**strategy_params: Additional strategy parameters
Returns:
BacktestResult with portfolio and market-specific metrics
"""
# Get market configuration from strategy
market_type = strategy.default_market_type
market_config = get_market_config(market_type)
# Resolve leverage and fees
effective_leverage = self._resolve_leverage(leverage, strategy, market_type)
effective_fees = fees if fees is not None else market_config.taker_fee
# Load and filter data
df = self._load_data(
exchange_id, symbol, timeframe, market_type, start_date, end_date
)
close_price = df['close']
high_price = df['high']
low_price = df['low']
open_price = df['open']
volume = df['volume']
# Run strategy logic
signals = strategy.run(
close_price,
high=high_price,
low=low_price,
open=open_price,
volume=volume,
**strategy_params
)
# Normalize signals to 5-tuple format
signals = self._normalize_signals(signals, close_price, market_config)
long_entries, long_exits, short_entries, short_exits, size = signals
# Default size if None
if size is None:
size = 1.0
# Convert leverage multiplier to raw value (USD) for vbt
# This works around "SizeType.Percent reversal" error
# Effectively "Fixed Fractional" sizing based on Initial Capital
# (Does not compound, but safe for backtesting)
if isinstance(size, pd.Series):
size = size * init_cash
else:
size = size * init_cash
# Process liquidations - inject forced exits at liquidation points
liquidation_events: list[LiquidationEvent] = []
if effective_leverage > 1:
long_exits, short_exits, liquidation_events = inject_liquidation_exits(
close_price, high_price, low_price,
long_entries, long_exits,
short_entries, short_exits,
effective_leverage,
market_config.maintenance_margin_rate
)
# Calculate perpetual-specific metrics (after liquidation processing)
total_funding = 0.0
if market_type == MarketType.PERPETUAL:
total_funding = calculate_funding(
close_price,
long_entries, long_exits,
short_entries, short_exits,
market_config,
effective_leverage
)
# Run portfolio simulation with liquidation-aware exits
portfolio = self._run_portfolio(
close_price, market_config,
long_entries, long_exits,
short_entries, short_exits,
init_cash, effective_fees, slippage, timeframe,
sl_stop, tp_stop, sl_trail, effective_leverage,
size=size
)
# Calculate adjusted returns accounting for liquidation losses
total_liq_loss, liq_adjustment = calculate_liquidation_adjustment(
liquidation_events, init_cash, effective_leverage
)
raw_return = portfolio.total_return().mean() * 100
adjusted_return = raw_return - liq_adjustment
if liquidation_events:
logger.info(
"Liquidation impact: %d events, $%.2f margin lost, %.2f%% adjustment",
len(liquidation_events), total_liq_loss, liq_adjustment
)
logger.info(
"Backtest completed: %s market, %dx leverage, fees=%.4f%%",
market_type.value, effective_leverage, effective_fees * 100
)
return BacktestResult(
portfolio=portfolio,
market_type=market_type,
leverage=effective_leverage,
total_funding_paid=total_funding,
liquidation_count=len(liquidation_events),
liquidation_events=liquidation_events,
total_liquidation_loss=total_liq_loss,
adjusted_return=adjusted_return
)
def _resolve_leverage(
self,
leverage: int | None,
strategy: BaseStrategy,
market_type: MarketType
) -> int:
"""Resolve effective leverage from CLI, strategy default, or market type."""
effective = leverage or strategy.default_leverage
if market_type == MarketType.SPOT:
return 1 # Spot cannot have leverage
return effective
def _load_data(
self,
exchange_id: str,
symbol: str,
timeframe: str,
market_type: MarketType,
start_date: str | None,
end_date: str | None
) -> pd.DataFrame:
"""Load and filter OHLCV data."""
try:
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
except FileNotFoundError:
logger.warning("Data not found locally. Attempting download...")
df = self.dm.download_data(
exchange_id, symbol, timeframe,
start_date, end_date, market_type
)
if start_date:
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
if end_date:
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
return df
def _normalize_signals(
self,
signals: tuple,
close: pd.Series,
market_config
) -> tuple:
"""
Normalize strategy signals to 5-tuple format.
Returns:
(long_entries, long_exits, short_entries, short_exits, size)
"""
# Default size is None (will be treated as 1.0 or default later)
size = None
if len(signals) == 2:
long_entries, long_exits = signals
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits, size
if len(signals) == 4:
long_entries, long_exits, short_entries, short_exits = signals
elif len(signals) == 5:
long_entries, long_exits, short_entries, short_exits, size = signals
else:
raise ValueError(
f"Strategy must return 2, 4, or 5 signal arrays, got {len(signals)}"
)
# Warn and clear short signals on spot markets
if not market_config.supports_short:
has_shorts = (
short_entries.any().any()
if hasattr(short_entries, 'any')
else short_entries.any()
)
if has_shorts:
logger.warning(
"Short signals detected but market type is SPOT. "
"Short signals will be ignored."
)
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits, size
def _run_portfolio(
self,
close: pd.Series,
market_config,
long_entries, long_exits,
short_entries, short_exits,
init_cash: float,
fees: float,
slippage: float,
freq: str,
sl_stop: float | None,
tp_stop: float | None,
sl_trail: bool,
leverage: int,
size: pd.Series | float = 1.0
) -> vbt.Portfolio:
"""Select and run appropriate portfolio simulation."""
has_shorts = (
short_entries.any().any()
if hasattr(short_entries, 'any')
else short_entries.any()
)
if market_config.supports_short and has_shorts:
return run_long_short_portfolio(
close,
long_entries, long_exits,
short_entries, short_exits,
init_cash, fees, slippage, freq,
sl_stop, tp_stop, sl_trail, leverage,
size=size
)
return run_long_only_portfolio(
close,
long_entries, long_exits,
init_cash, fees, slippage, freq,
sl_stop, tp_stop, sl_trail, leverage,
# Long-only doesn't support variable size in current implementation
# without modification, but we can add it if needed.
# For now, only regime strategy uses it, which is Long/Short.
)
def run_wfa(
self,
strategy: BaseStrategy,
exchange_id: str,
symbol: str,
param_grid: dict,
n_windows: int = 10,
timeframe: str = '1m'
):
"""
Execute Walk-Forward Analysis.
Args:
strategy: Strategy instance to optimize
exchange_id: Exchange identifier
symbol: Trading pair symbol
param_grid: Parameter grid for optimization
n_windows: Number of walk-forward windows
timeframe: Data timeframe to load
Returns:
Tuple of (results DataFrame, stitched equity curve)
"""
market_type = strategy.default_market_type
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
wfa = WalkForwardOptimizer(self, strategy, param_grid)
results, stitched_curve = wfa.run(
df['close'],
high=df['high'],
low=df['low'],
n_windows=n_windows
)
return results, stitched_curve

243
engine/cli.py Normal file
View File

@@ -0,0 +1,243 @@
"""
CLI handler for Lowkey Backtest.
"""
import argparse
import sys
from engine.backtester import Backtester
from engine.data_manager import DataManager
from engine.logging_config import get_logger, setup_logging
from engine.market import MarketType
from engine.reporting import Reporter
from strategies.factory import get_strategy, get_strategy_names
logger = get_logger(__name__)
def create_parser() -> argparse.ArgumentParser:
"""Create and configure the argument parser."""
parser = argparse.ArgumentParser(
description="Lowkey Backtest CLI (VectorBT Edition)"
)
subparsers = parser.add_subparsers(dest="command", help="Command to run")
_add_download_parser(subparsers)
_add_backtest_parser(subparsers)
_add_wfa_parser(subparsers)
return parser
def _add_download_parser(subparsers) -> None:
"""Add download command parser."""
dl_parser = subparsers.add_parser("download", help="Download historical data")
dl_parser.add_argument("--exchange", "-e", type=str, default="okx")
dl_parser.add_argument("--pair", "-p", type=str, required=True)
dl_parser.add_argument("--timeframe", "-t", type=str, default="1m")
dl_parser.add_argument("--start", type=str, help="Start Date (YYYY-MM-DD)")
dl_parser.add_argument(
"--market", "-m",
type=str,
choices=["spot", "perpetual"],
default="spot"
)
def _add_backtest_parser(subparsers) -> None:
"""Add backtest command parser."""
strategy_choices = get_strategy_names()
bt_parser = subparsers.add_parser("backtest", help="Run a backtest")
bt_parser.add_argument(
"--strategy", "-s",
type=str,
choices=strategy_choices,
required=True
)
bt_parser.add_argument("--exchange", "-e", type=str, default="okx")
bt_parser.add_argument("--pair", "-p", type=str, required=True)
bt_parser.add_argument("--timeframe", "-t", type=str, default="1m")
bt_parser.add_argument("--start", type=str)
bt_parser.add_argument("--end", type=str)
bt_parser.add_argument("--grid", "-g", action="store_true")
bt_parser.add_argument("--plot", action="store_true")
# Risk parameters
bt_parser.add_argument("--sl", type=float, help="Stop Loss %%")
bt_parser.add_argument("--tp", type=float, help="Take Profit %%")
bt_parser.add_argument("--trail", action="store_true")
bt_parser.add_argument("--no-bear-exit", action="store_true")
# Cost parameters
bt_parser.add_argument("--fees", type=float, default=None)
bt_parser.add_argument("--slippage", type=float, default=0.001)
bt_parser.add_argument("--leverage", "-l", type=int, default=None)
def _add_wfa_parser(subparsers) -> None:
"""Add walk-forward analysis command parser."""
strategy_choices = get_strategy_names()
wfa_parser = subparsers.add_parser("wfa", help="Run Walk-Forward Analysis")
wfa_parser.add_argument(
"--strategy", "-s",
type=str,
choices=strategy_choices,
required=True
)
wfa_parser.add_argument("--pair", "-p", type=str, required=True)
wfa_parser.add_argument("--timeframe", "-t", type=str, default="1d")
wfa_parser.add_argument("--windows", "-w", type=int, default=10)
wfa_parser.add_argument("--plot", action="store_true")
def run_download(args) -> None:
"""Execute download command."""
dm = DataManager()
market_type = MarketType(args.market)
dm.download_data(
args.exchange,
args.pair,
args.timeframe,
start_date=args.start,
market_type=market_type
)
def run_backtest(args) -> None:
"""Execute backtest command."""
dm = DataManager()
bt = Backtester(dm)
reporter = Reporter()
strategy, params = get_strategy(args.strategy, args.grid)
# Apply CLI overrides for meta_st strategy
params = _apply_strategy_overrides(args, strategy, params)
if args.grid and args.strategy == "meta_st":
logger.info("Running Grid Search for Meta Supertrend...")
try:
result = bt.run_strategy(
strategy,
args.exchange,
args.pair,
timeframe=args.timeframe,
start_date=args.start,
end_date=args.end,
fees=args.fees,
slippage=args.slippage,
sl_stop=args.sl,
tp_stop=args.tp,
sl_trail=args.trail,
leverage=args.leverage,
**params
)
reporter.print_summary(result)
reporter.save_reports(result, f"{args.strategy}_{args.pair.replace('/','-')}")
if args.plot and not args.grid:
reporter.plot(result.portfolio)
elif args.plot and args.grid:
logger.info("Plotting skipped for Grid Search. Check CSV results.")
except Exception as e:
logger.error("Backtest failed: %s", e, exc_info=True)
def run_wfa(args) -> None:
"""Execute walk-forward analysis command."""
dm = DataManager()
bt = Backtester(dm)
reporter = Reporter()
strategy, params = get_strategy(args.strategy, is_grid=True)
logger.info(
"Running WFA on %s for %s (%s) with %d windows...",
args.strategy, args.pair, args.timeframe, args.windows
)
try:
results, stitched_curve = bt.run_wfa(
strategy,
"okx",
args.pair,
params,
n_windows=args.windows,
timeframe=args.timeframe
)
_log_wfa_results(results)
_save_wfa_results(args, results, stitched_curve, reporter)
except Exception as e:
logger.error("WFA failed: %s", e, exc_info=True)
def _apply_strategy_overrides(args, strategy, params: dict) -> dict:
"""Apply CLI argument overrides to strategy parameters."""
if args.strategy != "meta_st":
return params
if args.no_bear_exit:
params['exit_on_bearish_flip'] = False
if args.sl is None:
args.sl = strategy.default_sl_stop
if not args.trail:
args.trail = strategy.default_sl_trail
return params
def _log_wfa_results(results) -> None:
"""Log WFA results summary."""
logger.info("Walk-Forward Analysis Results:")
if results.empty or 'window' not in results.columns:
logger.warning("No valid WFA results. All windows may have failed.")
return
columns = ['window', 'train_score', 'test_score', 'test_return']
logger.info("\n%s", results[columns].to_string(index=False))
avg_test_sharpe = results['test_score'].mean()
avg_test_return = results['test_return'].mean()
logger.info("Average Test Sharpe: %.2f", avg_test_sharpe)
logger.info("Average Test Return: %.2f%%", avg_test_return * 100)
def _save_wfa_results(args, results, stitched_curve, reporter) -> None:
"""Save WFA results to file and optionally plot."""
if results.empty:
return
output_path = f"backtest_logs/wfa_{args.strategy}_{args.pair.replace('/','-')}.csv"
results.to_csv(output_path)
logger.info("Saved full results to %s", output_path)
if args.plot:
reporter.plot_wfa(results, stitched_curve)
def main():
"""Main entry point."""
setup_logging()
parser = create_parser()
args = parser.parse_args()
commands = {
"download": run_download,
"backtest": run_backtest,
"wfa": run_wfa,
}
if args.command in commands:
commands[args.command](args)
else:
parser.print_help()

201
engine/cryptoquant.py Normal file
View File

@@ -0,0 +1,201 @@
import os
import sys
import time
import requests
import pandas as pd
from datetime import datetime, timedelta
from dotenv import load_dotenv
# Load env vars from .env file
load_dotenv()
# Fix path for direct execution
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from engine.logging_config import get_logger
logger = get_logger(__name__)
class CryptoQuantClient:
"""
Client for fetching data from CryptoQuant API.
"""
BASE_URL = "https://api.cryptoquant.com/v1"
def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.getenv("CRYPTOQUANT_API_KEY")
if not self.api_key:
raise ValueError("CryptoQuant API Key not found. Set CRYPTOQUANT_API_KEY env var.")
self.headers = {
"Authorization": f"Bearer {self.api_key}"
}
def fetch_metric(
self,
metric_path: str,
symbol: str,
start_date: str,
end_date: str,
exchange: str | None = "all_exchange",
window: str = "day"
) -> pd.DataFrame:
"""
Fetch a specific metric from CryptoQuant.
"""
url = f"{self.BASE_URL}/{metric_path}"
params = {
"window": window,
"from": start_date,
"to": end_date,
"limit": 100000
}
if exchange:
params["exchange"] = exchange
logger.info(f"Fetching {metric_path} for {symbol} ({start_date}-{end_date})...")
try:
response = requests.get(url, headers=self.headers, params=params)
response.raise_for_status()
data = response.json()
if 'result' in data and 'data' in data['result']:
df = pd.DataFrame(data['result']['data'])
if not df.empty:
if 'date' in df.columns:
df['timestamp'] = pd.to_datetime(df['date'])
df.set_index('timestamp', inplace=True)
df.sort_index(inplace=True)
return df
return pd.DataFrame()
except Exception as e:
logger.error(f"Error fetching CQ data {metric_path}: {e}")
if 'response' in locals() and hasattr(response, 'text'):
logger.error(f"Response: {response.text}")
return pd.DataFrame()
def fetch_multi_metrics(self, symbols: list[str], metrics: dict, start_date: str, end_date: str):
"""
Fetch multiple metrics for multiple symbols and combine them.
"""
combined_df = pd.DataFrame()
for symbol in symbols:
asset = symbol.lower()
for metric_name, api_path in metrics.items():
full_path = f"{asset}/{api_path}"
# Some metrics (like funding rates) might need specific exchange vs all_exchange
# Defaulting to all_exchange is usually safe for flows, but check specific logic if needed
exchange_param = "all_exchange"
if "funding-rates" in api_path:
# For funding rates, 'all_exchange' might not be valid or might be aggregated
# Let's try 'binance' as a proxy for market sentiment if all fails,
# or keep 'all_exchange' if supported.
# Based on testing, 'all_exchange' is standard for flows.
pass
df = self.fetch_metric(full_path, asset, start_date, end_date, exchange=exchange_param)
if not df.empty:
target_col = None
# Heuristic to find the value column
candidates = ['funding_rate', 'reserve', 'inflow_total', 'outflow_total', 'open_interest', 'ratio', 'value']
for col in df.columns:
if col in candidates:
target_col = col
break
if not target_col:
# Fallback: take first numeric col that isn't date
for col in df.columns:
if col not in ['date', 'datetime', 'timestamp_str', 'block_height']:
target_col = col
break
if target_col:
col_name = f"{asset}_{metric_name}"
subset = df[[target_col]].rename(columns={target_col: col_name})
if combined_df.empty:
combined_df = subset
else:
combined_df = combined_df.join(subset, how='outer')
time.sleep(0.2)
return combined_df
def fetch_history_chunked(
self,
symbols: list[str],
metrics: dict,
start_date: str,
end_date: str,
chunk_months: int = 3
) -> pd.DataFrame:
"""
Fetch historical data in chunks to avoid API limits.
"""
start_dt = datetime.strptime(start_date, "%Y%m%d")
end_dt = datetime.strptime(end_date, "%Y%m%d")
all_data = []
current = start_dt
while current < end_dt:
next_chunk = current + timedelta(days=chunk_months * 30)
if next_chunk > end_dt:
next_chunk = end_dt
s_str = current.strftime("%Y%m%d")
e_str = next_chunk.strftime("%Y%m%d")
logger.info(f"Processing chunk: {s_str} to {e_str}")
chunk_df = self.fetch_multi_metrics(symbols, metrics, s_str, e_str)
if not chunk_df.empty:
all_data.append(chunk_df)
current = next_chunk + timedelta(days=1)
time.sleep(1) # Be nice to API
if not all_data:
return pd.DataFrame()
# Combine all chunks
full_df = pd.concat(all_data)
# Remove duplicates if any overlap
full_df = full_df[~full_df.index.duplicated(keep='first')]
full_df.sort_index(inplace=True)
return full_df
if __name__ == "__main__":
cq = CryptoQuantClient()
# 12 Months Data (Jan 1 2025 - Jan 14 2026)
start = "20250101"
end = "20260114"
metrics = {
"reserves": "exchange-flows/exchange-reserve",
"inflow": "exchange-flows/inflow",
"funding": "market-data/funding-rates"
}
print(f"Fetching training data from {start} to {end}...")
df = cq.fetch_history_chunked(["btc", "eth"], metrics, start, end)
output_file = "data/cq_training_data.csv"
os.makedirs("data", exist_ok=True)
df.to_csv(output_file)
print(f"\nSaved {len(df)} rows to {output_file}")
print(df.head())

209
engine/data_manager.py Normal file
View File

@@ -0,0 +1,209 @@
"""
Data management for OHLCV data download and storage.
Handles data retrieval from exchanges and local file management.
"""
import time
from datetime import datetime, timezone
from pathlib import Path
import ccxt
import pandas as pd
from engine.logging_config import get_logger
from engine.market import MarketType, get_ccxt_symbol
logger = get_logger(__name__)
class DataManager:
"""
Manages OHLCV data download and storage for different market types.
Data is stored in: data/ccxt/{exchange}/{market_type}/{symbol}/{timeframe}.csv
"""
def __init__(self, data_dir: str = "data/ccxt"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.exchanges: dict[str, ccxt.Exchange] = {}
def get_exchange(self, exchange_id: str) -> ccxt.Exchange:
"""Get or create a CCXT exchange instance."""
if exchange_id not in self.exchanges:
exchange_class = getattr(ccxt, exchange_id)
self.exchanges[exchange_id] = exchange_class({
'enableRateLimit': True,
})
return self.exchanges[exchange_id]
def _get_data_path(
self,
exchange_id: str,
symbol: str,
timeframe: str,
market_type: MarketType
) -> Path:
"""
Get the file path for storing/loading data.
Args:
exchange_id: Exchange name (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Candle timeframe (e.g., '1m')
market_type: Market type (spot or perpetual)
Returns:
Path to the CSV file
"""
safe_symbol = symbol.replace('/', '-')
return (
self.data_dir
/ exchange_id
/ market_type.value
/ safe_symbol
/ f"{timeframe}.csv"
)
def download_data(
self,
exchange_id: str,
symbol: str,
timeframe: str = '1m',
start_date: str | None = None,
end_date: str | None = None,
market_type: MarketType = MarketType.SPOT
) -> pd.DataFrame | None:
"""
Download OHLCV data from exchange and save to CSV.
Args:
exchange_id: Exchange name (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Candle timeframe (e.g., '1m')
start_date: Start date string (YYYY-MM-DD)
end_date: End date string (YYYY-MM-DD)
market_type: Market type (spot or perpetual)
Returns:
DataFrame with OHLCV data, or None if download failed
"""
exchange = self.get_exchange(exchange_id)
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
file_path.parent.mkdir(parents=True, exist_ok=True)
ccxt_symbol = get_ccxt_symbol(symbol, market_type)
since, until = self._parse_date_range(exchange, start_date, end_date)
logger.info(
"Downloading %s (%s) from %s...",
symbol, market_type.value, exchange_id
)
all_ohlcv = self._fetch_all_candles(exchange, ccxt_symbol, timeframe, since, until)
if not all_ohlcv:
logger.warning("No data downloaded.")
return None
df = self._convert_to_dataframe(all_ohlcv)
df.to_csv(file_path)
logger.info("Saved %d candles to %s", len(df), file_path)
return df
def load_data(
self,
exchange_id: str,
symbol: str,
timeframe: str = '1m',
market_type: MarketType = MarketType.SPOT
) -> pd.DataFrame:
"""
Load saved OHLCV data for vectorbt.
Args:
exchange_id: Exchange name (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Candle timeframe (e.g., '1m')
market_type: Market type (spot or perpetual)
Returns:
DataFrame with OHLCV data indexed by timestamp
Raises:
FileNotFoundError: If data file does not exist
"""
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
if not file_path.exists():
raise FileNotFoundError(
f"Data not found at {file_path}. "
f"Run: uv run python main.py download --pair {symbol} "
f"--market {market_type.value}"
)
return pd.read_csv(file_path, index_col='timestamp', parse_dates=True)
def _parse_date_range(
self,
exchange: ccxt.Exchange,
start_date: str | None,
end_date: str | None
) -> tuple[int, int]:
"""Parse date strings into millisecond timestamps."""
if start_date:
since = exchange.parse8601(f"{start_date}T00:00:00Z")
else:
since = exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
if end_date:
until = exchange.parse8601(f"{end_date}T23:59:59Z")
else:
until = exchange.milliseconds()
return since, until
def _fetch_all_candles(
self,
exchange: ccxt.Exchange,
symbol: str,
timeframe: str,
since: int,
until: int
) -> list:
"""Fetch all candles in the date range."""
all_ohlcv = []
while since < until:
try:
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since, limit=100)
if not ohlcv:
break
all_ohlcv.extend(ohlcv)
since = ohlcv[-1][0] + 1
current_date = datetime.fromtimestamp(
since/1000, tz=timezone.utc
).strftime('%Y-%m-%d')
logger.debug("Fetched up to %s", current_date)
time.sleep(exchange.rateLimit / 1000)
except Exception as e:
logger.error("Error fetching data: %s", e)
break
return all_ohlcv
def _convert_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
"""Convert OHLCV list to DataFrame."""
df = pd.DataFrame(
ohlcv,
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df.set_index('timestamp', inplace=True)
return df

124
engine/logging_config.py Normal file
View File

@@ -0,0 +1,124 @@
"""
Centralized logging configuration for the backtest engine.
Provides colored console output and rotating file logs.
"""
import logging
import sys
from logging.handlers import RotatingFileHandler
from pathlib import Path
# ANSI color codes for terminal output
class Colors:
"""ANSI escape codes for colored terminal output."""
RESET = "\033[0m"
BOLD = "\033[1m"
# Log level colors
DEBUG = "\033[36m" # Cyan
INFO = "\033[32m" # Green
WARNING = "\033[33m" # Yellow
ERROR = "\033[31m" # Red
CRITICAL = "\033[35m" # Magenta
class ColoredFormatter(logging.Formatter):
"""
Custom formatter that adds colors to log level names in terminal output.
"""
LEVEL_COLORS = {
logging.DEBUG: Colors.DEBUG,
logging.INFO: Colors.INFO,
logging.WARNING: Colors.WARNING,
logging.ERROR: Colors.ERROR,
logging.CRITICAL: Colors.CRITICAL,
}
def __init__(self, fmt: str = None, datefmt: str = None):
super().__init__(fmt, datefmt)
def format(self, record: logging.LogRecord) -> str:
# Save original levelname
original_levelname = record.levelname
# Add color to levelname
color = self.LEVEL_COLORS.get(record.levelno, Colors.RESET)
record.levelname = f"{color}{record.levelname}{Colors.RESET}"
# Format the message
result = super().format(record)
# Restore original levelname
record.levelname = original_levelname
return result
def setup_logging(
log_dir: str = "logs",
log_level: int = logging.INFO,
console_level: int = logging.INFO,
max_bytes: int = 5 * 1024 * 1024, # 5MB
backup_count: int = 3
) -> None:
"""
Configure logging for the application.
Args:
log_dir: Directory for log files
log_level: File logging level
console_level: Console logging level
max_bytes: Max size per log file before rotation
backup_count: Number of backup files to keep
"""
log_path = Path(log_dir)
log_path.mkdir(parents=True, exist_ok=True)
# Get root logger
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG) # Capture all, handlers filter
# Clear existing handlers
root_logger.handlers.clear()
# Console handler with colors
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(console_level)
console_fmt = ColoredFormatter(
fmt="[%(asctime)s] %(levelname)s - %(message)s",
datefmt="%H:%M:%S"
)
console_handler.setFormatter(console_fmt)
root_logger.addHandler(console_handler)
# File handler with rotation
file_handler = RotatingFileHandler(
log_path / "backtest.log",
maxBytes=max_bytes,
backupCount=backup_count,
encoding="utf-8"
)
file_handler.setLevel(log_level)
file_fmt = logging.Formatter(
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
file_handler.setFormatter(file_fmt)
root_logger.addHandler(file_handler)
def get_logger(name: str) -> logging.Logger:
"""
Get a logger instance for the given module name.
Args:
name: Module name (typically __name__)
Returns:
Configured logger instance
"""
return logging.getLogger(name)

168
engine/market.py Normal file
View File

@@ -0,0 +1,168 @@
"""
Market type definitions and configuration for backtesting.
Supports different market types with their specific trading conditions:
- SPOT: No leverage, no funding, long-only
- PERPETUAL: Leverage, funding rates, long/short
"""
from dataclasses import dataclass
from enum import Enum
class MarketType(Enum):
"""Supported market types for backtesting."""
SPOT = "spot"
PERPETUAL = "perpetual"
@dataclass(frozen=True)
class MarketConfig:
"""
Configuration for a specific market type.
Attributes:
market_type: The market type enum value
maker_fee: Maker fee as decimal (e.g., 0.0008 = 0.08%)
taker_fee: Taker fee as decimal (e.g., 0.001 = 0.1%)
max_leverage: Maximum allowed leverage
funding_rate: Funding rate per 8 hours as decimal (perpetuals only)
funding_interval_hours: Hours between funding payments
maintenance_margin_rate: Rate for liquidation calculation
supports_short: Whether short-selling is supported
"""
market_type: MarketType
maker_fee: float
taker_fee: float
max_leverage: int
funding_rate: float
funding_interval_hours: int
maintenance_margin_rate: float
supports_short: bool
# OKX-based default configurations
SPOT_CONFIG = MarketConfig(
market_type=MarketType.SPOT,
maker_fee=0.0008, # 0.08%
taker_fee=0.0010, # 0.10%
max_leverage=1,
funding_rate=0.0,
funding_interval_hours=0,
maintenance_margin_rate=0.0,
supports_short=False,
)
PERPETUAL_CONFIG = MarketConfig(
market_type=MarketType.PERPETUAL,
maker_fee=0.0002, # 0.02%
taker_fee=0.0005, # 0.05%
max_leverage=125,
funding_rate=0.0001, # 0.01% per 8 hours (simplified average)
funding_interval_hours=8,
maintenance_margin_rate=0.004, # 0.4% for BTC on OKX
supports_short=True,
)
def get_market_config(market_type: MarketType) -> MarketConfig:
"""
Get the configuration for a specific market type.
Args:
market_type: The market type to get configuration for
Returns:
MarketConfig with default values for that market type
"""
configs = {
MarketType.SPOT: SPOT_CONFIG,
MarketType.PERPETUAL: PERPETUAL_CONFIG,
}
return configs[market_type]
def get_ccxt_symbol(symbol: str, market_type: MarketType) -> str:
"""
Convert a standard symbol to CCXT format for the given market type.
Args:
symbol: Standard symbol (e.g., 'BTC/USDT')
market_type: The market type
Returns:
CCXT-formatted symbol (e.g., 'BTC/USDT:USDT' for perpetuals)
"""
if market_type == MarketType.PERPETUAL:
# OKX perpetual format: BTC/USDT:USDT
if '/' in symbol:
base, quote = symbol.split('/')
return f"{symbol}:{quote}"
elif '-' in symbol:
base, quote = symbol.split('-')
return f"{base}/{quote}:{quote}"
else:
# Assume base is symbol, quote is USDT default
return f"{symbol}/USDT:USDT"
# For spot, normalize dash to slash for CCXT
if '-' in symbol:
return symbol.replace('-', '/')
return symbol
def calculate_leverage_stop_loss(
leverage: int,
maintenance_margin_rate: float = 0.004
) -> float:
"""
Calculate the implicit stop-loss percentage from leverage.
At a given leverage, liquidation occurs when the position loses
approximately (1/leverage - maintenance_margin_rate) of its value.
Args:
leverage: Position leverage multiplier
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
Returns:
Stop-loss percentage as decimal (e.g., 0.196 for 19.6%)
"""
if leverage <= 1:
return 1.0 # No forced stop for spot
return (1 / leverage) - maintenance_margin_rate
def calculate_liquidation_price(
entry_price: float,
leverage: float,
is_long: bool,
maintenance_margin_rate: float = 0.004
) -> float:
"""
Calculate the liquidation price for a leveraged position.
Args:
entry_price: Position entry price
leverage: Position leverage
is_long: True for long positions, False for short
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
Returns:
Liquidation price
"""
if leverage <= 1:
return 0.0 if is_long else float('inf')
# Simplified liquidation formula
# Long: liq_price = entry * (1 - 1/leverage + maintenance_margin_rate)
# Short: liq_price = entry * (1 + 1/leverage - maintenance_margin_rate)
margin_ratio = 1 / leverage
if is_long:
liq_price = entry_price * (1 - margin_ratio + maintenance_margin_rate)
else:
liq_price = entry_price * (1 + margin_ratio - maintenance_margin_rate)
return liq_price

245
engine/optimizer.py Normal file
View File

@@ -0,0 +1,245 @@
"""
Walk-Forward Analysis optimizer for strategy parameter optimization.
Implements expanding window walk-forward analysis with train/test splits.
"""
import numpy as np
import pandas as pd
import vectorbt as vbt
from engine.logging_config import get_logger
logger = get_logger(__name__)
def create_rolling_windows(
index: pd.Index,
n_windows: int,
train_split: float = 0.7
):
"""
Create rolling train/test split indices using expanding window approach.
Args:
index: DataFrame index to split
n_windows: Number of walk-forward windows
train_split: Unused, kept for API compatibility
Yields:
Tuples of (train_idx, test_idx) numpy arrays
"""
chunks = np.array_split(index, n_windows + 1)
for i in range(n_windows):
train_idx = np.concatenate([c for c in chunks[:i+1]])
test_idx = chunks[i+1]
yield train_idx, test_idx
class WalkForwardOptimizer:
"""
Walk-Forward Analysis optimizer for strategy backtesting.
Optimizes strategy parameters on training windows and validates
on out-of-sample test windows.
"""
def __init__(
self,
backtester,
strategy,
param_grid: dict,
metric: str = 'Sharpe Ratio',
fees: float = 0.001,
freq: str = '1m'
):
"""
Initialize the optimizer.
Args:
backtester: Backtester instance
strategy: Strategy instance to optimize
param_grid: Parameter grid for optimization
metric: Performance metric to optimize
fees: Transaction fees for simulation
freq: Data frequency for portfolio simulation
"""
self.bt = backtester
self.strategy = strategy
self.param_grid = param_grid
self.metric = metric
self.fees = fees
self.freq = freq
# Separate grid params (lists) from fixed params (scalars)
self.grid_keys = []
self.fixed_params = {}
for k, v in param_grid.items():
if isinstance(v, (list, np.ndarray)):
self.grid_keys.append(k)
else:
self.fixed_params[k] = v
def run(
self,
close_price: pd.Series,
high: pd.Series | None = None,
low: pd.Series | None = None,
n_windows: int = 10
) -> tuple[pd.DataFrame, pd.Series | None]:
"""
Execute walk-forward analysis.
Args:
close_price: Close price series
high: High price series (optional)
low: Low price series (optional)
n_windows: Number of walk-forward windows
Returns:
Tuple of (results DataFrame, stitched equity curve)
"""
results = []
equity_curves = []
logger.info(
"Starting Walk-Forward Analysis with %d windows (Expanding Train)...",
n_windows
)
splitter = create_rolling_windows(close_price.index, n_windows)
for i, (train_idx, test_idx) in enumerate(splitter):
logger.info("Processing Window %d/%d...", i + 1, n_windows)
window_result = self._process_window(
i, train_idx, test_idx, close_price, high, low
)
if window_result is not None:
result_dict, eq_curve = window_result
results.append(result_dict)
equity_curves.append(eq_curve)
stitched_series = self._stitch_equity_curves(equity_curves)
return pd.DataFrame(results), stitched_series
def _process_window(
self,
window_idx: int,
train_idx: np.ndarray,
test_idx: np.ndarray,
close_price: pd.Series,
high: pd.Series | None,
low: pd.Series | None
) -> tuple[dict, pd.Series] | None:
"""Process a single WFA window."""
try:
# Slice data for train/test
train_close = close_price.loc[train_idx]
train_high = high.loc[train_idx] if high is not None else None
train_low = low.loc[train_idx] if low is not None else None
# Train phase: find best parameters
best_params, best_score = self._optimize_train(
train_close, train_high, train_low
)
# Test phase: validate with best params
test_close = close_price.loc[test_idx]
test_high = high.loc[test_idx] if high is not None else None
test_low = low.loc[test_idx] if low is not None else None
test_params = {**self.fixed_params, **best_params}
test_score, test_return, eq_curve = self._run_test(
test_close, test_high, test_low, test_params
)
return {
'window': window_idx + 1,
'train_start': train_idx[0],
'train_end': train_idx[-1],
'test_start': test_idx[0],
'test_end': test_idx[-1],
'best_params': best_params,
'train_score': best_score,
'test_score': test_score,
'test_return': test_return
}, eq_curve
except Exception as e:
logger.error("Error in window %d: %s", window_idx + 1, e, exc_info=True)
return None
def _optimize_train(
self,
close: pd.Series,
high: pd.Series | None,
low: pd.Series | None
) -> tuple[dict, float]:
"""Run grid search on training data to find best parameters."""
entries, exits = self.strategy.run(
close, high=high, low=low, **self.param_grid
)
pf_train = vbt.Portfolio.from_signals(
close, entries, exits,
fees=self.fees,
freq=self.freq
)
perf_stats = pf_train.sharpe_ratio()
perf_stats = perf_stats.fillna(-999)
best_idx = perf_stats.idxmax()
best_score = perf_stats.max()
# Extract best params from grid search
if len(self.grid_keys) == 1:
best_params = {self.grid_keys[0]: best_idx}
elif len(self.grid_keys) > 1:
best_params = dict(zip(self.grid_keys, best_idx))
else:
best_params = {}
return best_params, best_score
def _run_test(
self,
close: pd.Series,
high: pd.Series | None,
low: pd.Series | None,
params: dict
) -> tuple[float, float, pd.Series]:
"""Run test phase with given parameters."""
entries, exits = self.strategy.run(
close, high=high, low=low, **params
)
pf_test = vbt.Portfolio.from_signals(
close, entries, exits,
fees=self.fees,
freq=self.freq
)
return pf_test.sharpe_ratio(), pf_test.total_return(), pf_test.value()
def _stitch_equity_curves(
self,
equity_curves: list[pd.Series]
) -> pd.Series | None:
"""Stitch multiple equity curves into a continuous series."""
if not equity_curves:
return None
stitched = [equity_curves[0]]
for j in range(1, len(equity_curves)):
prev_end_val = stitched[-1].iloc[-1]
curr_curve = equity_curves[j]
init_cash = curr_curve.iloc[0]
# Scale curve to continue from previous end value
scaled_curve = (curr_curve / init_cash) * prev_end_val
stitched.append(scaled_curve)
return pd.concat(stitched)

108
engine/portfolio.py Normal file
View File

@@ -0,0 +1,108 @@
"""
Portfolio simulation utilities for backtesting.
Handles long-only and long/short portfolio creation using VectorBT.
"""
import pandas as pd
import vectorbt as vbt
from engine.logging_config import get_logger
logger = get_logger(__name__)
def run_long_only_portfolio(
close: pd.Series,
entries: pd.DataFrame,
exits: pd.DataFrame,
init_cash: float,
fees: float,
slippage: float,
freq: str,
sl_stop: float | None,
tp_stop: float | None,
sl_trail: bool,
leverage: int
) -> vbt.Portfolio:
"""
Run a long-only portfolio simulation.
Args:
close: Close price series
entries: Entry signals
exits: Exit signals
init_cash: Initial capital
fees: Transaction fee percentage
slippage: Slippage percentage
freq: Data frequency string
sl_stop: Stop loss percentage
tp_stop: Take profit percentage
sl_trail: Enable trailing stop loss
leverage: Leverage multiplier
Returns:
VectorBT Portfolio object
"""
effective_cash = init_cash * leverage
return vbt.Portfolio.from_signals(
close=close,
entries=entries,
exits=exits,
init_cash=effective_cash,
fees=fees,
slippage=slippage,
freq=freq,
sl_stop=sl_stop,
tp_stop=tp_stop,
sl_trail=sl_trail,
size=1.0,
size_type='percent',
)
def run_long_short_portfolio(
close: pd.Series,
long_entries: pd.DataFrame,
long_exits: pd.DataFrame,
short_entries: pd.DataFrame,
short_exits: pd.DataFrame,
init_cash: float,
fees: float,
slippage: float,
freq: str,
sl_stop: float | None,
tp_stop: float | None,
sl_trail: bool,
leverage: int,
size: pd.Series | float = 1.0,
size_type: str = 'value' # Changed to 'value' to support reversals/sizing
) -> vbt.Portfolio:
"""
Run a portfolio supporting both long and short positions.
Uses VectorBT's native support for short_entries/short_exits
to simulate a single unified portfolio.
"""
effective_cash = init_cash * leverage
# If size is passed as value (USD), we don't scale it by leverage here
# The backtester has already scaled it by init_cash.
# If using 'value', vbt treats it as "Amount of CASH to use for the trade"
return vbt.Portfolio.from_signals(
close=close,
entries=long_entries,
exits=long_exits,
short_entries=short_entries,
short_exits=short_exits,
init_cash=effective_cash,
fees=fees,
slippage=slippage,
freq=freq,
sl_stop=sl_stop,
tp_stop=tp_stop,
sl_trail=sl_trail,
size=size,
size_type=size_type,
)

228
engine/reporting.py Normal file
View File

@@ -0,0 +1,228 @@
"""
Reporting module for backtest results.
Handles summary printing, CSV exports, and plotting.
"""
from pathlib import Path
import pandas as pd
import plotly.graph_objects as go
import vectorbt as vbt
from plotly.subplots import make_subplots
from engine.logging_config import get_logger
logger = get_logger(__name__)
class Reporter:
"""Reporter for backtest results with market-specific metrics."""
def __init__(self, output_dir: str = "backtest_logs"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
def print_summary(self, result) -> None:
"""
Print backtest summary to console via logger.
Args:
result: BacktestResult or vbt.Portfolio object
"""
(portfolio, market_type, leverage, funding_paid,
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
# Extract period info
idx = portfolio.wrapper.index
start_date = idx[0].strftime("%Y-%m-%d")
end_date = idx[-1].strftime("%Y-%m-%d")
# Extract price info
close = portfolio.close
start_price = close.iloc[0].mean() if hasattr(close.iloc[0], 'mean') else close.iloc[0]
end_price = close.iloc[-1].mean() if hasattr(close.iloc[-1], 'mean') else close.iloc[-1]
price_change = ((end_price - start_price) / start_price) * 100
# Extract fees
stats = portfolio.stats()
total_fees = stats.get('Total Fees Paid', 0)
raw_return = portfolio.total_return().mean() * 100
# Build summary
summary_lines = [
"",
"=" * 50,
"BACKTEST RESULTS",
"=" * 50,
f"Market Type: [{market_type.upper()}]",
f"Leverage: [{leverage}x]",
f"Period: [{start_date} to {end_date}]",
f"Price: [{start_price:,.2f} -> {end_price:,.2f} ({price_change:+.2f}%)]",
]
# Show adjusted return if liquidations occurred
if liq_count > 0 and adjusted_return is not None:
summary_lines.append(f"Raw Return: [%{raw_return:.2f}] (before liq adjustment)")
summary_lines.append(f"Adj Return: [%{adjusted_return:.2f}] (after liq losses)")
else:
summary_lines.append(f"Total Return: [%{raw_return:.2f}]")
summary_lines.extend([
f"Sharpe Ratio: [{portfolio.sharpe_ratio().mean():.2f}]",
f"Max Drawdown: [%{portfolio.max_drawdown().mean() * 100:.2f}]",
f"Total Trades: [{portfolio.trades.count().mean():.0f}]",
f"Win Rate: [%{portfolio.trades.win_rate().mean() * 100:.2f}]",
f"Total Fees: [{total_fees:,.2f}]",
])
if funding_paid != 0:
summary_lines.append(f"Funding Paid: [{funding_paid:,.2f}]")
if liq_count > 0:
summary_lines.append(f"Liquidations: [{liq_count}] (${liq_loss:,.2f} margin lost)")
summary_lines.append("=" * 50)
logger.info("\n".join(summary_lines))
def save_reports(self, result, filename_prefix: str) -> None:
"""
Save trade log, stats, and liquidation events to CSV files.
Args:
result: BacktestResult or vbt.Portfolio object
filename_prefix: Prefix for output filenames
"""
(portfolio, market_type, leverage, funding_paid,
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
# Save trades
self._save_csv(
data=portfolio.trades.records_readable,
path=self.output_dir / f"{filename_prefix}_trades.csv",
description="trade log"
)
# Save stats with market-specific additions
stats = portfolio.stats()
stats['Market Type'] = market_type
stats['Leverage'] = leverage
stats['Total Funding Paid'] = funding_paid
stats['Liquidations'] = liq_count
stats['Liquidation Loss'] = liq_loss
if adjusted_return is not None:
stats['Adjusted Return'] = adjusted_return
self._save_csv(
data=stats,
path=self.output_dir / f"{filename_prefix}_stats.csv",
description="stats"
)
# Save liquidation events if any
if hasattr(result, 'liquidation_events') and result.liquidation_events:
liq_df = pd.DataFrame([
{
'entry_time': e.entry_time,
'entry_price': e.entry_price,
'liquidation_time': e.liquidation_time,
'liquidation_price': e.liquidation_price,
'actual_price': e.actual_price,
'direction': e.direction,
'margin_lost_pct': e.margin_lost_pct
}
for e in result.liquidation_events
])
self._save_csv(
data=liq_df,
path=self.output_dir / f"{filename_prefix}_liquidations.csv",
description="liquidation events"
)
def plot(self, portfolio: vbt.Portfolio, show: bool = True) -> None:
"""Display portfolio plot."""
if show:
portfolio.plot().show()
def plot_wfa(
self,
wfa_results: pd.DataFrame,
stitched_curve: pd.Series | None = None,
show: bool = True
) -> None:
"""
Plot Walk-Forward Analysis results.
Args:
wfa_results: DataFrame with WFA window results
stitched_curve: Stitched out-of-sample equity curve
show: Whether to display the plot
"""
fig = make_subplots(
rows=2, cols=1,
shared_xaxes=False,
vertical_spacing=0.1,
subplot_titles=(
"Walk-Forward Test Scores (Sharpe)",
"Stitched Out-of-Sample Equity"
)
)
fig.add_trace(
go.Bar(
x=wfa_results['window'],
y=wfa_results['test_score'],
name="Test Sharpe"
),
row=1, col=1
)
if stitched_curve is not None:
fig.add_trace(
go.Scatter(
x=stitched_curve.index,
y=stitched_curve.values,
name="Equity",
mode='lines'
),
row=2, col=1
)
fig.update_layout(height=800, title_text="Walk-Forward Analysis Report")
if show:
fig.show()
def _extract_result_data(self, result) -> tuple:
"""
Extract data from BacktestResult or raw Portfolio.
Returns:
Tuple of (portfolio, market_type, leverage, funding_paid, liq_count,
liq_loss, adjusted_return)
"""
if hasattr(result, 'portfolio'):
return (
result.portfolio,
result.market_type.value,
result.leverage,
result.total_funding_paid,
result.liquidation_count,
getattr(result, 'total_liquidation_loss', 0.0),
getattr(result, 'adjusted_return', None)
)
return (result, "unknown", 1, 0.0, 0, 0.0, None)
def _save_csv(self, data, path: Path, description: str) -> None:
"""
Save data to CSV with consistent error handling.
Args:
data: DataFrame or Series to save
path: Output file path
description: Human-readable description for logging
"""
try:
data.to_csv(path)
logger.info("Saved %s to %s", description, path)
except Exception as e:
logger.error("Could not save %s: %s", description, e)

395
engine/risk.py Normal file
View File

@@ -0,0 +1,395 @@
"""
Risk management utilities for backtesting.
Handles funding rate calculations and liquidation detection.
"""
from dataclasses import dataclass
import pandas as pd
from engine.logging_config import get_logger
from engine.market import MarketConfig, calculate_liquidation_price
logger = get_logger(__name__)
@dataclass
class LiquidationEvent:
"""
Record of a liquidation event during backtesting.
Attributes:
entry_time: Timestamp when position was opened
entry_price: Price at position entry
liquidation_time: Timestamp when liquidation occurred
liquidation_price: Calculated liquidation price
actual_price: Actual price that triggered liquidation (high/low)
direction: 'long' or 'short'
margin_lost_pct: Percentage of margin lost (typically 100%)
"""
entry_time: pd.Timestamp
entry_price: float
liquidation_time: pd.Timestamp
liquidation_price: float
actual_price: float
direction: str
margin_lost_pct: float = 1.0
def calculate_funding(
close: pd.Series,
long_entries: pd.DataFrame,
long_exits: pd.DataFrame,
short_entries: pd.DataFrame,
short_exits: pd.DataFrame,
market_config: MarketConfig,
leverage: int
) -> float:
"""
Calculate total funding paid/received for perpetual positions.
Simplified model: applies funding rate every 8 hours to open positions.
Positive rate means longs pay shorts.
Args:
close: Price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
market_config: Market configuration with funding parameters
leverage: Position leverage
Returns:
Total funding paid (positive) or received (negative)
"""
if market_config.funding_interval_hours == 0:
return 0.0
funding_rate = market_config.funding_rate
interval_hours = market_config.funding_interval_hours
# Determine position state at each bar
long_position = long_entries.cumsum() - long_exits.cumsum()
short_position = short_entries.cumsum() - short_exits.cumsum()
# Clamp to 0/1 (either in position or not)
long_position = (long_position > 0).astype(int)
short_position = (short_position > 0).astype(int)
# Find funding timestamps (every 8 hours: 00:00, 08:00, 16:00 UTC)
funding_times = close.index[close.index.hour % interval_hours == 0]
total_funding = 0.0
for ts in funding_times:
if ts not in close.index:
continue
price = close.loc[ts]
# Long pays funding, short receives (when rate > 0)
if isinstance(long_position, pd.DataFrame):
long_open = long_position.loc[ts].any()
short_open = short_position.loc[ts].any()
else:
long_open = long_position.loc[ts] > 0
short_open = short_position.loc[ts] > 0
position_value = price * leverage
if long_open:
total_funding += position_value * funding_rate
if short_open:
total_funding -= position_value * funding_rate
return total_funding
def inject_liquidation_exits(
close: pd.Series,
high: pd.Series,
low: pd.Series,
long_entries: pd.DataFrame | pd.Series,
long_exits: pd.DataFrame | pd.Series,
short_entries: pd.DataFrame | pd.Series,
short_exits: pd.DataFrame | pd.Series,
leverage: int,
maintenance_margin_rate: float
) -> tuple[pd.DataFrame | pd.Series, pd.DataFrame | pd.Series, list[LiquidationEvent]]:
"""
Modify exit signals to force position closure at liquidation points.
This function simulates realistic liquidation behavior by:
1. Finding positions that would be liquidated before their normal exit
2. Injecting forced exit signals at the liquidation bar
3. Recording all liquidation events
Args:
close: Close price series
high: High price series
low: Low price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
leverage: Position leverage
maintenance_margin_rate: Maintenance margin rate for liquidation
Returns:
Tuple of (modified_long_exits, modified_short_exits, liquidation_events)
"""
if leverage <= 1:
return long_exits, short_exits, []
liquidation_events: list[LiquidationEvent] = []
# Convert to DataFrame if Series for consistent handling
is_series = isinstance(long_entries, pd.Series)
if is_series:
long_entries_df = long_entries.to_frame()
long_exits_df = long_exits.to_frame()
short_entries_df = short_entries.to_frame()
short_exits_df = short_exits.to_frame()
else:
long_entries_df = long_entries
long_exits_df = long_exits.copy()
short_entries_df = short_entries
short_exits_df = short_exits.copy()
modified_long_exits = long_exits_df.copy()
modified_short_exits = short_exits_df.copy()
# Process long positions
long_mask = long_entries_df.any(axis=1)
for entry_idx in close.index[long_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=True,
maintenance_margin_rate=maintenance_margin_rate
)
# Find the normal exit for this entry
subsequent_exits = long_exits_df.loc[entry_idx:].any(axis=1)
exit_indices = subsequent_exits[subsequent_exits].index
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
# Check if liquidation occurs before normal exit
price_range = low.loc[entry_idx:normal_exit_idx]
if (price_range < liq_price).any():
liq_bar = price_range[price_range < liq_price].index[0]
# Inject forced exit at liquidation bar
for col in modified_long_exits.columns:
modified_long_exits.loc[liq_bar, col] = True
# Record the liquidation event
liquidation_events.append(LiquidationEvent(
entry_time=entry_idx,
entry_price=entry_price,
liquidation_time=liq_bar,
liquidation_price=liq_price,
actual_price=low.loc[liq_bar],
direction='long',
margin_lost_pct=1.0
))
logger.warning(
"LIQUIDATION (Long): Entry %s ($%.2f) -> Liquidated %s "
"(liq=$%.2f, low=$%.2f)",
entry_idx.strftime('%Y-%m-%d'), entry_price,
liq_bar.strftime('%Y-%m-%d'), liq_price, low.loc[liq_bar]
)
# Process short positions
short_mask = short_entries_df.any(axis=1)
for entry_idx in close.index[short_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=False,
maintenance_margin_rate=maintenance_margin_rate
)
# Find the normal exit for this entry
subsequent_exits = short_exits_df.loc[entry_idx:].any(axis=1)
exit_indices = subsequent_exits[subsequent_exits].index
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
# Check if liquidation occurs before normal exit
price_range = high.loc[entry_idx:normal_exit_idx]
if (price_range > liq_price).any():
liq_bar = price_range[price_range > liq_price].index[0]
# Inject forced exit at liquidation bar
for col in modified_short_exits.columns:
modified_short_exits.loc[liq_bar, col] = True
# Record the liquidation event
liquidation_events.append(LiquidationEvent(
entry_time=entry_idx,
entry_price=entry_price,
liquidation_time=liq_bar,
liquidation_price=liq_price,
actual_price=high.loc[liq_bar],
direction='short',
margin_lost_pct=1.0
))
logger.warning(
"LIQUIDATION (Short): Entry %s ($%.2f) -> Liquidated %s "
"(liq=$%.2f, high=$%.2f)",
entry_idx.strftime('%Y-%m-%d'), entry_price,
liq_bar.strftime('%Y-%m-%d'), liq_price, high.loc[liq_bar]
)
# Convert back to Series if input was Series
if is_series:
modified_long_exits = modified_long_exits.iloc[:, 0]
modified_short_exits = modified_short_exits.iloc[:, 0]
return modified_long_exits, modified_short_exits, liquidation_events
def calculate_liquidation_adjustment(
liquidation_events: list[LiquidationEvent],
init_cash: float,
leverage: int
) -> tuple[float, float]:
"""
Calculate the return adjustment for liquidated positions.
VectorBT calculates trade P&L using close price at exit bar.
For liquidations, the actual loss is 100% of the position margin.
This function calculates the difference between what VectorBT
recorded and what actually would have happened.
In our portfolio setup:
- Long/short each get half the capital (init_cash * leverage / 2)
- Each trade uses 100% of that allocation (size=1.0, percent)
- On liquidation, the margin for that trade is lost entirely
The adjustment is the DIFFERENCE between:
- VectorBT's calculated P&L (exit at close price)
- Actual liquidation P&L (100% margin loss)
Args:
liquidation_events: List of liquidation events
init_cash: Initial portfolio cash (before leverage)
leverage: Position leverage used
Returns:
Tuple of (total_margin_lost, adjustment_pct)
- total_margin_lost: Estimated total margin lost from liquidations
- adjustment_pct: Percentage adjustment to apply to returns
"""
if not liquidation_events:
return 0.0, 0.0
# In our setup, each side (long/short) gets half the capital
# Margin per side = init_cash / 2
margin_per_side = init_cash / 2
# For each liquidation, VectorBT recorded some P&L based on close price
# The actual P&L should be -100% of the margin used for that trade
#
# We estimate the adjustment as:
# - Each liquidation should have resulted in ~-20% loss (at 5x leverage)
# - VectorBT may have recorded a different value
# - The margin loss is (1/leverage) per trade that gets liquidated
# Calculate liquidation loss rate based on leverage
# At 5x leverage, liquidation = ~19.6% adverse move = 100% margin loss
liq_loss_rate = 1.0 / leverage # Approximate loss per trade as % of position
# Count liquidations
n_liquidations = len(liquidation_events)
# Estimate total margin lost:
# Each liquidation on average loses the margin for that trade
# Since VectorBT uses half capital per side, and we trade 100% size,
# each liquidation loses approximately margin_per_side
# But we cap at available capital
total_margin_lost = min(n_liquidations * margin_per_side * liq_loss_rate, init_cash)
# Calculate as percentage of initial capital
adjustment_pct = (total_margin_lost / init_cash) * 100
return total_margin_lost, adjustment_pct
def check_liquidations(
close: pd.Series,
high: pd.Series,
low: pd.Series,
long_entries: pd.DataFrame,
long_exits: pd.DataFrame,
short_entries: pd.DataFrame,
short_exits: pd.DataFrame,
leverage: int,
maintenance_margin_rate: float
) -> int:
"""
Check for liquidation events and log warnings.
Args:
close: Close price series
high: High price series
low: Low price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
leverage: Position leverage
maintenance_margin_rate: Maintenance margin rate for liquidation
Returns:
Count of liquidation warnings
"""
warnings = 0
# For long positions
long_mask = (
long_entries.any(axis=1)
if isinstance(long_entries, pd.DataFrame)
else long_entries
)
for entry_idx in close.index[long_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=True,
maintenance_margin_rate=maintenance_margin_rate
)
subsequent = low.loc[entry_idx:]
if (subsequent < liq_price).any():
liq_bar = subsequent[subsequent < liq_price].index[0]
logger.warning(
"LIQUIDATION WARNING (Long): Entry at %s ($%.2f), "
"would liquidate at %s (liq_price=$%.2f, low=$%.2f)",
entry_idx, entry_price, liq_bar, liq_price, low.loc[liq_bar]
)
warnings += 1
# For short positions
short_mask = (
short_entries.any(axis=1)
if isinstance(short_entries, pd.DataFrame)
else short_entries
)
for entry_idx in close.index[short_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=False,
maintenance_margin_rate=maintenance_margin_rate
)
subsequent = high.loc[entry_idx:]
if (subsequent > liq_price).any():
liq_bar = subsequent[subsequent > liq_price].index[0]
logger.warning(
"LIQUIDATION WARNING (Short): Entry at %s ($%.2f), "
"would liquidate at %s (liq_price=$%.2f, high=$%.2f)",
entry_idx, entry_price, liq_bar, liq_price, high.loc[liq_bar]
)
warnings += 1
return warnings

24
frontend/.gitignore vendored Normal file
View File

@@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

3
frontend/.vscode/extensions.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"recommendations": ["Vue.volar"]
}

5
frontend/README.md Normal file
View File

@@ -0,0 +1,5 @@
# Vue 3 + TypeScript + Vite
This template should help get you started developing with Vue 3 and TypeScript in Vite. The template uses Vue 3 `<script setup>` SFCs, check out the [script setup docs](https://v3.vuejs.org/api/sfc-script-setup.html#sfc-script-setup) to learn more.
Learn more about the recommended Project Setup and IDE Support in the [Vue Docs TypeScript Guide](https://vuejs.org/guide/typescript/overview.html#project-setup).

24
frontend/index.html Normal file
View File

@@ -0,0 +1,24 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Lowkey Backtest</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600;700&family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
<style>
body {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
}
code, pre, .font-mono, input, select {
font-family: 'JetBrains Mono', 'Fira Code', monospace;
}
</style>
</head>
<body>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>

2427
frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

27
frontend/package.json Normal file
View File

@@ -0,0 +1,27 @@
{
"name": "frontend",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vue-tsc -b && vite build",
"preview": "vite preview"
},
"dependencies": {
"axios": "^1.13.2",
"plotly.js-dist-min": "^3.3.1",
"vue": "^3.5.24",
"vue-router": "^4.6.4"
},
"devDependencies": {
"@tailwindcss/vite": "^4.1.18",
"@types/node": "^24.10.1",
"@vitejs/plugin-vue": "^6.0.1",
"@vue/tsconfig": "^0.8.1",
"tailwindcss": "^4.1.18",
"typescript": "~5.9.3",
"vite": "^7.2.4",
"vue-tsc": "^3.1.4"
}
}

1
frontend/public/vite.svg Normal file
View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="31.88" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 257"><defs><linearGradient id="IconifyId1813088fe1fbc01fb466" x1="-.828%" x2="57.636%" y1="7.652%" y2="78.411%"><stop offset="0%" stop-color="#41D1FF"></stop><stop offset="100%" stop-color="#BD34FE"></stop></linearGradient><linearGradient id="IconifyId1813088fe1fbc01fb467" x1="43.376%" x2="50.316%" y1="2.242%" y2="89.03%"><stop offset="0%" stop-color="#FFEA83"></stop><stop offset="8.333%" stop-color="#FFDD35"></stop><stop offset="100%" stop-color="#FFA800"></stop></linearGradient></defs><path fill="url(#IconifyId1813088fe1fbc01fb466)" d="M255.153 37.938L134.897 252.976c-2.483 4.44-8.862 4.466-11.382.048L.875 37.958c-2.746-4.814 1.371-10.646 6.827-9.67l120.385 21.517a6.537 6.537 0 0 0 2.322-.004l117.867-21.483c5.438-.991 9.574 4.796 6.877 9.62Z"></path><path fill="url(#IconifyId1813088fe1fbc01fb467)" d="M185.432.063L96.44 17.501a3.268 3.268 0 0 0-2.634 3.014l-5.474 92.456a3.268 3.268 0 0 0 3.997 3.378l24.777-5.718c2.318-.535 4.413 1.507 3.936 3.838l-7.361 36.047c-.495 2.426 1.782 4.5 4.151 3.78l15.304-4.649c2.372-.72 4.652 1.36 4.15 3.788l-11.698 56.621c-.732 3.542 3.979 5.473 5.943 2.437l1.313-2.028l72.516-144.72c1.215-2.423-.88-5.186-3.54-4.672l-25.505 4.922c-2.396.462-4.435-1.77-3.759-4.114l16.646-57.705c.677-2.35-1.37-4.583-3.769-4.113Z"></path></svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

72
frontend/src/App.vue Normal file
View File

@@ -0,0 +1,72 @@
<script setup lang="ts">
import { ref } from 'vue'
import { RouterLink, RouterView } from 'vue-router'
import RunHistory from '@/components/RunHistory.vue'
const historyOpen = ref(true)
</script>
<template>
<div class="flex h-screen overflow-hidden">
<!-- Sidebar Navigation -->
<aside class="w-16 bg-bg-secondary border-r border-border flex flex-col items-center py-4 gap-4">
<!-- Logo -->
<div class="w-10 h-10 rounded-lg bg-accent-blue flex items-center justify-center text-black font-bold text-lg">
LB
</div>
<!-- Nav Links -->
<nav class="flex flex-col gap-2 mt-4">
<RouterLink
to="/"
class="w-10 h-10 rounded-lg flex items-center justify-center hover:bg-bg-hover transition-colors"
:class="{ 'bg-bg-tertiary': $route.path === '/' }"
title="Dashboard"
>
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
</svg>
</RouterLink>
<RouterLink
to="/compare"
class="w-10 h-10 rounded-lg flex items-center justify-center hover:bg-bg-hover transition-colors"
:class="{ 'bg-bg-tertiary': $route.path === '/compare' }"
title="Compare Runs"
>
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 17V7m0 10a2 2 0 01-2 2H5a2 2 0 01-2-2V7a2 2 0 012-2h2a2 2 0 012 2m0 10a2 2 0 002 2h2a2 2 0 002-2M9 7a2 2 0 012-2h2a2 2 0 012 2m0 10V7m0 10a2 2 0 002 2h2a2 2 0 002-2V7a2 2 0 00-2-2h-2a2 2 0 00-2 2" />
</svg>
</RouterLink>
</nav>
<!-- Spacer -->
<div class="flex-1"></div>
<!-- Toggle History -->
<button
@click="historyOpen = !historyOpen"
class="w-10 h-10 rounded-lg flex items-center justify-center hover:bg-bg-hover transition-colors"
:class="{ 'bg-bg-tertiary': historyOpen }"
title="Toggle Run History"
>
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
</button>
</aside>
<!-- Main Content -->
<main class="flex-1 overflow-auto">
<RouterView />
</main>
<!-- Run History Sidebar -->
<aside
v-if="historyOpen"
class="w-72 bg-bg-secondary border-l border-border overflow-hidden flex flex-col"
>
<RunHistory />
</aside>
</div>
</template>

View File

@@ -0,0 +1,81 @@
/**
* API client for Lowkey Backtest backend.
*/
import axios from 'axios'
import type {
StrategiesResponse,
DataStatusResponse,
BacktestRequest,
BacktestResult,
BacktestListResponse,
CompareResult,
} from './types'
const api = axios.create({
baseURL: '/api',
headers: {
'Content-Type': 'application/json',
},
})
/**
* Get list of available strategies with parameters.
*/
export async function getStrategies(): Promise<StrategiesResponse> {
const response = await api.get<StrategiesResponse>('/strategies')
return response.data
}
/**
* Get list of available symbols with data status.
*/
export async function getSymbols(): Promise<DataStatusResponse> {
const response = await api.get<DataStatusResponse>('/symbols')
return response.data
}
/**
* Run a backtest with the given configuration.
*/
export async function runBacktest(request: BacktestRequest): Promise<BacktestResult> {
const response = await api.post<BacktestResult>('/backtest', request)
return response.data
}
/**
* Get list of saved backtest runs.
*/
export async function getBacktests(params?: {
limit?: number
offset?: number
strategy?: string
symbol?: string
}): Promise<BacktestListResponse> {
const response = await api.get<BacktestListResponse>('/backtests', { params })
return response.data
}
/**
* Get a specific backtest run by ID.
*/
export async function getBacktest(runId: string): Promise<BacktestResult> {
const response = await api.get<BacktestResult>(`/backtest/${runId}`)
return response.data
}
/**
* Delete a backtest run.
*/
export async function deleteBacktest(runId: string): Promise<void> {
await api.delete(`/backtest/${runId}`)
}
/**
* Compare multiple backtest runs.
*/
export async function compareRuns(runIds: string[]): Promise<CompareResult> {
const response = await api.post<CompareResult>('/compare', { run_ids: runIds })
return response.data
}
export default api

131
frontend/src/api/types.ts Normal file
View File

@@ -0,0 +1,131 @@
/**
* TypeScript types matching the FastAPI Pydantic schemas.
*/
// Strategy types
export interface StrategyInfo {
name: string
display_name: string
market_type: string
default_leverage: number
default_params: Record<string, unknown>
grid_params: Record<string, unknown>
}
export interface StrategiesResponse {
strategies: StrategyInfo[]
}
// Symbol/Data types
export interface SymbolInfo {
symbol: string
exchange: string
market_type: string
timeframes: string[]
start_date: string | null
end_date: string | null
row_count: number
}
export interface DataStatusResponse {
symbols: SymbolInfo[]
}
// Backtest types
export interface BacktestRequest {
strategy: string
symbol: string
exchange?: string
timeframe?: string
market_type?: string
start_date?: string | null
end_date?: string | null
init_cash?: number
leverage?: number | null
fees?: number | null
slippage?: number
sl_stop?: number | null
tp_stop?: number | null
sl_trail?: boolean
params?: Record<string, unknown>
}
export interface TradeRecord {
entry_time: string
exit_time: string | null
entry_price: number
exit_price: number | null
size: number
direction: string
pnl: number | null
return_pct: number | null
status: string
}
export interface EquityPoint {
timestamp: string
value: number
drawdown: number
}
export interface BacktestMetrics {
total_return: number
benchmark_return: number
alpha: number
sharpe_ratio: number
max_drawdown: number
win_rate: number
total_trades: number
profit_factor: number | null
avg_trade_return: number | null
total_fees: number
total_funding: number
liquidation_count: number
liquidation_loss: number
adjusted_return: number | null
}
export interface BacktestResult {
run_id: string
strategy: string
symbol: string
market_type: string
timeframe: string
start_date: string
end_date: string
leverage: number
params: Record<string, unknown>
metrics: BacktestMetrics
equity_curve: EquityPoint[]
trades: TradeRecord[]
created_at: string
}
export interface BacktestSummary {
run_id: string
strategy: string
symbol: string
market_type: string
timeframe: string
total_return: number
sharpe_ratio: number
max_drawdown: number
total_trades: number
created_at: string
params: Record<string, unknown>
}
export interface BacktestListResponse {
runs: BacktestSummary[]
total: number
}
// Comparison types
export interface CompareRequest {
run_ids: string[]
}
export interface CompareResult {
runs: BacktestResult[]
param_diff: Record<string, unknown[]>
}

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="37.07" height="36" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 198"><path fill="#41B883" d="M204.8 0H256L128 220.8L0 0h97.92L128 51.2L157.44 0h47.36Z"></path><path fill="#41B883" d="m0 0l128 220.8L256 0h-51.2L128 132.48L50.56 0H0Z"></path><path fill="#35495E" d="M50.56 0L128 133.12L204.8 0h-47.36L128 51.2L97.92 0H50.56Z"></path></svg>

After

Width:  |  Height:  |  Size: 496 B

View File

@@ -0,0 +1,186 @@
<script setup lang="ts">
import { ref, computed, watch, onMounted } from 'vue'
import { useBacktest } from '@/composables/useBacktest'
import type { BacktestRequest } from '@/api/types'
const { strategies, symbols, loading, init, executeBacktest } = useBacktest()
// Form state
const selectedStrategy = ref('')
const selectedSymbol = ref('')
const selectedMarket = ref('perpetual')
const timeframe = ref('1h')
const initCash = ref(10000)
const leverage = ref<number | null>(null)
const slStop = ref<number | null>(null)
const tpStop = ref<number | null>(null)
const params = ref<Record<string, number | boolean>>({})
// Initialize
onMounted(async () => {
await init()
if (strategies.value.length > 0 && strategies.value[0]) {
selectedStrategy.value = strategies.value[0].name
}
})
// Get current strategy config
const currentStrategy = computed(() =>
strategies.value.find(s => s.name === selectedStrategy.value)
)
// Filter symbols by market type
const filteredSymbols = computed(() =>
symbols.value.filter(s => s.market_type === selectedMarket.value)
)
// Update params when strategy changes
watch(selectedStrategy, (name) => {
const strategy = strategies.value.find(s => s.name === name)
if (strategy) {
params.value = { ...strategy.default_params } as Record<string, number | boolean>
selectedMarket.value = strategy.market_type
leverage.value = strategy.default_leverage > 1 ? strategy.default_leverage : null
}
})
// Update symbol when market changes
watch([filteredSymbols, selectedMarket], () => {
const firstSymbol = filteredSymbols.value[0]
if (filteredSymbols.value.length > 0 && firstSymbol && !filteredSymbols.value.find(s => s.symbol === selectedSymbol.value)) {
selectedSymbol.value = firstSymbol.symbol
}
})
async function handleSubmit() {
if (!selectedStrategy.value || !selectedSymbol.value) return
const request: BacktestRequest = {
strategy: selectedStrategy.value,
symbol: selectedSymbol.value,
market_type: selectedMarket.value,
timeframe: timeframe.value,
init_cash: initCash.value,
leverage: leverage.value,
sl_stop: slStop.value,
tp_stop: tpStop.value,
params: params.value,
}
await executeBacktest(request)
}
function getParamType(value: unknown): 'number' | 'boolean' | 'unknown' {
if (typeof value === 'boolean') return 'boolean'
if (typeof value === 'number') return 'number'
return 'unknown'
}
</script>
<template>
<div class="card">
<h2 class="text-lg font-semibold mb-4">Backtest Configuration</h2>
<form @submit.prevent="handleSubmit" class="space-y-4">
<!-- Strategy -->
<div>
<label class="block text-xs text-text-secondary uppercase mb-1">Strategy</label>
<select v-model="selectedStrategy" class="w-full">
<option v-for="s in strategies" :key="s.name" :value="s.name">
{{ s.display_name }}
</option>
</select>
</div>
<!-- Market Type & Symbol -->
<div class="grid grid-cols-2 gap-3">
<div>
<label class="block text-xs text-text-secondary uppercase mb-1">Market</label>
<select v-model="selectedMarket" class="w-full">
<option value="spot">Spot</option>
<option value="perpetual">Perpetual</option>
</select>
</div>
<div>
<label class="block text-xs text-text-secondary uppercase mb-1">Symbol</label>
<select v-model="selectedSymbol" class="w-full">
<option v-for="s in filteredSymbols" :key="s.symbol" :value="s.symbol">
{{ s.symbol }}
</option>
</select>
</div>
</div>
<!-- Timeframe & Cash -->
<div class="grid grid-cols-2 gap-3">
<div>
<label class="block text-xs text-text-secondary uppercase mb-1">Timeframe</label>
<select v-model="timeframe" class="w-full">
<option value="1h">1 Hour</option>
<option value="4h">4 Hours</option>
<option value="1d">1 Day</option>
</select>
</div>
<div>
<label class="block text-xs text-text-secondary uppercase mb-1">Initial Cash</label>
<input type="number" v-model.number="initCash" class="w-full" min="100" step="100" />
</div>
</div>
<!-- Leverage (perpetual only) -->
<div v-if="selectedMarket === 'perpetual'" class="grid grid-cols-3 gap-3">
<div>
<label class="block text-xs text-text-secondary uppercase mb-1">Leverage</label>
<input type="number" v-model.number="leverage" class="w-full" min="1" max="100" placeholder="1" />
</div>
<div>
<label class="block text-xs text-text-secondary uppercase mb-1">Stop Loss %</label>
<input type="number" v-model.number="slStop" class="w-full" min="0" max="100" step="0.1" placeholder="None" />
</div>
<div>
<label class="block text-xs text-text-secondary uppercase mb-1">Take Profit %</label>
<input type="number" v-model.number="tpStop" class="w-full" min="0" max="100" step="0.1" placeholder="None" />
</div>
</div>
<!-- Strategy Parameters -->
<div v-if="currentStrategy && Object.keys(params).length > 0">
<h3 class="text-sm font-medium text-text-secondary mb-2">Strategy Parameters</h3>
<div class="grid grid-cols-2 gap-3">
<div v-for="(value, key) in params" :key="key">
<label class="block text-xs text-text-secondary uppercase mb-1">
{{ String(key).replace(/_/g, ' ') }}
</label>
<template v-if="getParamType(value) === 'boolean'">
<input
type="checkbox"
:checked="Boolean(value)"
@change="params[key] = ($event.target as HTMLInputElement).checked"
class="w-5 h-5"
/>
</template>
<template v-else>
<input
type="number"
:value="value"
@input="params[key] = parseFloat(($event.target as HTMLInputElement).value)"
class="w-full"
step="any"
/>
</template>
</div>
</div>
</div>
<!-- Submit -->
<button
type="submit"
class="btn btn-primary w-full"
:disabled="loading || !selectedStrategy || !selectedSymbol"
>
<span v-if="loading" class="spinner"></span>
<span v-else>Run Backtest</span>
</button>
</form>
</div>
</template>

View File

@@ -0,0 +1,88 @@
<script setup lang="ts">
import { ref, watch, onMounted, onUnmounted } from 'vue'
import Plotly from 'plotly.js-dist-min'
import type { EquityPoint } from '@/api/types'
const props = defineProps<{
data: EquityPoint[]
title?: string
}>()
const chartRef = ref<HTMLDivElement | null>(null)
const CHART_COLORS = {
equity: '#58a6ff',
grid: '#30363d',
text: '#8b949e',
}
function renderChart() {
if (!chartRef.value || props.data.length === 0) return
const timestamps = props.data.map(p => p.timestamp)
const values = props.data.map(p => p.value)
const traces: Plotly.Data[] = [
{
x: timestamps,
y: values,
type: 'scatter',
mode: 'lines',
name: 'Portfolio Value',
line: { color: CHART_COLORS.equity, width: 2 },
hovertemplate: '%{x}<br>Value: $%{y:,.2f}<extra></extra>',
},
]
const layout: Partial<Plotly.Layout> = {
title: props.title ? {
text: props.title,
font: { color: CHART_COLORS.text, size: 14 },
} : undefined,
paper_bgcolor: 'transparent',
plot_bgcolor: 'transparent',
margin: { l: 60, r: 20, t: props.title ? 40 : 20, b: 40 },
xaxis: {
showgrid: true,
gridcolor: CHART_COLORS.grid,
tickfont: { color: CHART_COLORS.text, size: 10 },
linecolor: CHART_COLORS.grid,
},
yaxis: {
showgrid: true,
gridcolor: CHART_COLORS.grid,
tickfont: { color: CHART_COLORS.text, size: 10 },
linecolor: CHART_COLORS.grid,
tickprefix: '$',
hoverformat: ',.2f',
},
showlegend: false,
hovermode: 'x unified',
}
const config: Partial<Plotly.Config> = {
responsive: true,
displayModeBar: false,
}
Plotly.react(chartRef.value, traces, layout, config)
}
watch(() => props.data, renderChart, { deep: true })
onMounted(() => {
renderChart()
window.addEventListener('resize', renderChart)
})
onUnmounted(() => {
window.removeEventListener('resize', renderChart)
if (chartRef.value) {
Plotly.purge(chartRef.value)
}
})
</script>
<template>
<div ref="chartRef" class="w-full h-full min-h-[300px]"></div>
</template>

View File

@@ -0,0 +1,41 @@
<script setup lang="ts">
import { ref } from 'vue'
defineProps<{ msg: string }>()
const count = ref(0)
</script>
<template>
<h1>{{ msg }}</h1>
<div class="card">
<button type="button" @click="count++">count is {{ count }}</button>
<p>
Edit
<code>components/HelloWorld.vue</code> to test HMR
</p>
</div>
<p>
Check out
<a href="https://vuejs.org/guide/quick-start.html#local" target="_blank"
>create-vue</a
>, the official Vue + Vite starter
</p>
<p>
Learn more about IDE Support for Vue in the
<a
href="https://vuejs.org/guide/scaling-up/tooling.html#ide-support"
target="_blank"
>Vue Docs Scaling up Guide</a
>.
</p>
<p class="read-the-docs">Click on the Vite and Vue logos to learn more</p>
</template>
<style scoped>
.read-the-docs {
color: #888;
}
</style>

View File

@@ -0,0 +1,144 @@
<script setup lang="ts">
import type { BacktestMetrics } from '@/api/types'
const props = defineProps<{
metrics: BacktestMetrics
leverage?: number
marketType?: string
}>()
function formatPercent(val: number): string {
return (val >= 0 ? '+' : '') + val.toFixed(2) + '%'
}
function formatNumber(val: number | null | undefined, decimals = 2): string {
if (val === null || val === undefined) return '-'
return val.toFixed(decimals)
}
function formatCurrency(val: number): string {
return '$' + val.toLocaleString('en-US', { minimumFractionDigits: 2, maximumFractionDigits: 2 })
}
</script>
<template>
<div class="grid grid-cols-2 md:grid-cols-4 gap-4">
<!-- Total Return -->
<div class="card">
<div class="metric-label">Strategy Return</div>
<div
class="metric-value"
:class="metrics.total_return >= 0 ? 'profit' : 'loss'"
>
{{ formatPercent(metrics.total_return) }}
</div>
<div v-if="metrics.adjusted_return !== null && metrics.adjusted_return !== metrics.total_return" class="text-xs text-text-muted mt-1">
Adj: {{ formatPercent(metrics.adjusted_return) }}
</div>
</div>
<!-- Benchmark Return -->
<div class="card">
<div class="metric-label">Benchmark (B&H)</div>
<div
class="metric-value"
:class="metrics.benchmark_return >= 0 ? 'profit' : 'loss'"
>
{{ formatPercent(metrics.benchmark_return) }}
</div>
<div class="text-xs text-text-muted mt-1">
Market change
</div>
</div>
<!-- Alpha -->
<div class="card">
<div class="metric-label">Alpha</div>
<div
class="metric-value"
:class="metrics.alpha >= 0 ? 'profit' : 'loss'"
>
{{ formatPercent(metrics.alpha) }}
</div>
<div class="text-xs text-text-muted mt-1">
vs Buy & Hold
</div>
</div>
<!-- Sharpe Ratio -->
<div class="card">
<div class="metric-label">Sharpe Ratio</div>
<div
class="metric-value"
:class="metrics.sharpe_ratio >= 1 ? 'profit' : metrics.sharpe_ratio < 0 ? 'loss' : ''"
>
{{ formatNumber(metrics.sharpe_ratio) }}
</div>
</div>
<!-- Max Drawdown -->
<div class="card">
<div class="metric-label">Max Drawdown</div>
<div class="metric-value loss">
{{ formatPercent(metrics.max_drawdown) }}
</div>
</div>
<!-- Win Rate -->
<div class="card">
<div class="metric-label">Win Rate</div>
<div
class="metric-value"
:class="metrics.win_rate >= 50 ? 'profit' : 'loss'"
>
{{ formatNumber(metrics.win_rate, 1) }}%
</div>
</div>
<!-- Total Trades -->
<div class="card">
<div class="metric-label">Total Trades</div>
<div class="metric-value">
{{ metrics.total_trades }}
</div>
</div>
<!-- Profit Factor -->
<div class="card">
<div class="metric-label">Profit Factor</div>
<div
class="metric-value"
:class="(metrics.profit_factor || 0) >= 1 ? 'profit' : 'loss'"
>
{{ formatNumber(metrics.profit_factor) }}
</div>
</div>
<!-- Total Fees -->
<div class="card">
<div class="metric-label">Total Fees</div>
<div class="metric-value text-warning">
{{ formatCurrency(metrics.total_fees) }}
</div>
</div>
<!-- Funding (perpetual only) -->
<div v-if="marketType === 'perpetual'" class="card">
<div class="metric-label">Funding Paid</div>
<div class="metric-value text-warning">
{{ formatCurrency(metrics.total_funding) }}
</div>
</div>
<!-- Liquidations (if any) -->
<div v-if="metrics.liquidation_count > 0" class="card">
<div class="metric-label">Liquidations</div>
<div class="metric-value loss">
{{ metrics.liquidation_count }}
</div>
<div class="text-xs text-text-muted mt-1">
Lost: {{ formatCurrency(metrics.liquidation_loss) }}
</div>
</div>
</div>
</template>

View File

@@ -0,0 +1,141 @@
<script setup lang="ts">
import { onMounted } from 'vue'
import { useBacktest } from '@/composables/useBacktest'
import { useRouter } from 'vue-router'
const router = useRouter()
const {
runs,
currentResult,
selectedRuns,
refreshRuns,
loadRun,
removeRun,
toggleRunSelection
} = useBacktest()
onMounted(() => {
refreshRuns()
})
function formatDate(iso: string): string {
const d = new Date(iso)
return d.toLocaleDateString('en-US', {
month: 'short',
day: 'numeric',
hour: '2-digit',
minute: '2-digit'
})
}
function formatReturn(val: number): string {
return (val >= 0 ? '+' : '') + val.toFixed(2) + '%'
}
async function handleClick(runId: string) {
await loadRun(runId)
router.push('/')
}
function handleCheckbox(e: Event, runId: string) {
e.stopPropagation()
toggleRunSelection(runId)
}
function handleDelete(e: Event, runId: string) {
e.stopPropagation()
if (confirm('Delete this run?')) {
removeRun(runId)
}
}
</script>
<template>
<div class="flex flex-col h-full">
<!-- Header -->
<div class="p-4 border-b border-border">
<h2 class="text-sm font-semibold text-text-secondary uppercase tracking-wide">
Run History
</h2>
<p class="text-xs text-text-muted mt-1">
{{ runs.length }} runs | {{ selectedRuns.length }} selected
</p>
</div>
<!-- Run List -->
<div class="flex-1 overflow-y-auto">
<div
v-for="run in runs"
:key="run.run_id"
@click="handleClick(run.run_id)"
class="p-3 border-b border-border-muted cursor-pointer hover:bg-bg-hover transition-colors"
:class="{ 'bg-bg-tertiary': currentResult?.run_id === run.run_id }"
>
<div class="flex items-start gap-2">
<!-- Checkbox for comparison -->
<input
type="checkbox"
:checked="selectedRuns.includes(run.run_id)"
@click="handleCheckbox($event, run.run_id)"
class="mt-1 w-4 h-4 rounded border-border bg-bg-tertiary"
/>
<div class="flex-1 min-w-0">
<!-- Strategy & Symbol -->
<div class="flex items-center gap-2">
<span class="font-medium text-sm truncate">{{ run.strategy }}</span>
<span class="text-xs px-1.5 py-0.5 rounded bg-bg-tertiary text-text-secondary">
{{ run.symbol }}
</span>
</div>
<!-- Metrics -->
<div class="flex items-center gap-3 mt-1">
<span
class="text-sm font-mono"
:class="run.total_return >= 0 ? 'profit' : 'loss'"
>
{{ formatReturn(run.total_return) }}
</span>
<span class="text-xs text-text-muted">
SR {{ run.sharpe_ratio.toFixed(2) }}
</span>
</div>
<!-- Date -->
<div class="text-xs text-text-muted mt-1">
{{ formatDate(run.created_at) }}
</div>
</div>
<!-- Delete button -->
<button
@click="handleDelete($event, run.run_id)"
class="p-1 rounded hover:bg-loss/20 text-text-muted hover:text-loss transition-colors"
title="Delete run"
>
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
</svg>
</button>
</div>
</div>
<!-- Empty state -->
<div v-if="runs.length === 0" class="p-8 text-center text-text-muted">
<p>No runs yet.</p>
<p class="text-xs mt-1">Run a backtest to see results here.</p>
</div>
</div>
<!-- Compare Button -->
<div v-if="selectedRuns.length >= 2" class="p-4 border-t border-border">
<router-link
to="/compare"
class="btn btn-primary w-full"
>
Compare {{ selectedRuns.length }} Runs
</router-link>
</div>
</div>
</template>

View File

@@ -0,0 +1,157 @@
<script setup lang="ts">
import { ref, computed } from 'vue'
import type { TradeRecord } from '@/api/types'
const props = defineProps<{
trades: TradeRecord[]
}>()
type SortKey = 'entry_time' | 'pnl' | 'return_pct' | 'size'
const sortKey = ref<SortKey>('entry_time')
const sortDesc = ref(true)
const sortedTrades = computed(() => {
return [...props.trades].sort((a, b) => {
let aVal: number | string = 0
let bVal: number | string = 0
switch (sortKey.value) {
case 'entry_time':
aVal = a.entry_time
bVal = b.entry_time
break
case 'pnl':
aVal = a.pnl ?? 0
bVal = b.pnl ?? 0
break
case 'return_pct':
aVal = a.return_pct ?? 0
bVal = b.return_pct ?? 0
break
case 'size':
aVal = a.size
bVal = b.size
break
}
if (aVal < bVal) return sortDesc.value ? 1 : -1
if (aVal > bVal) return sortDesc.value ? -1 : 1
return 0
})
})
function toggleSort(key: SortKey) {
if (sortKey.value === key) {
sortDesc.value = !sortDesc.value
} else {
sortKey.value = key
sortDesc.value = true
}
}
function formatDate(iso: string): string {
if (!iso) return '-'
const d = new Date(iso)
return d.toLocaleDateString('en-US', {
month: 'short',
day: 'numeric',
hour: '2-digit',
minute: '2-digit',
})
}
function formatPrice(val: number | null): string {
if (val === null) return '-'
return val.toLocaleString('en-US', { minimumFractionDigits: 2, maximumFractionDigits: 2 })
}
function formatPnL(val: number | null): string {
if (val === null) return '-'
const sign = val >= 0 ? '+' : ''
return sign + val.toLocaleString('en-US', { minimumFractionDigits: 2, maximumFractionDigits: 2 })
}
function formatReturn(val: number | null): string {
if (val === null) return '-'
return (val >= 0 ? '+' : '') + val.toFixed(2) + '%'
}
</script>
<template>
<div class="card overflow-hidden">
<div class="flex items-center justify-between mb-4">
<h3 class="text-sm font-semibold text-text-secondary uppercase tracking-wide">
Trade Log
</h3>
<span class="text-xs text-text-muted">{{ trades.length }} trades</span>
</div>
<div class="overflow-x-auto max-h-[400px] overflow-y-auto">
<table class="min-w-full">
<thead class="sticky top-0 bg-bg-card">
<tr>
<th
@click="toggleSort('entry_time')"
class="cursor-pointer hover:text-text-primary"
>
Entry Time
<span v-if="sortKey === 'entry_time'">{{ sortDesc ? ' v' : ' ^' }}</span>
</th>
<th>Exit Time</th>
<th>Direction</th>
<th>Entry</th>
<th>Exit</th>
<th
@click="toggleSort('size')"
class="cursor-pointer hover:text-text-primary"
>
Size
<span v-if="sortKey === 'size'">{{ sortDesc ? ' v' : ' ^' }}</span>
</th>
<th
@click="toggleSort('pnl')"
class="cursor-pointer hover:text-text-primary"
>
PnL
<span v-if="sortKey === 'pnl'">{{ sortDesc ? ' v' : ' ^' }}</span>
</th>
<th
@click="toggleSort('return_pct')"
class="cursor-pointer hover:text-text-primary"
>
Return
<span v-if="sortKey === 'return_pct'">{{ sortDesc ? ' v' : ' ^' }}</span>
</th>
</tr>
</thead>
<tbody>
<tr v-for="(trade, idx) in sortedTrades" :key="idx">
<td class="text-text-secondary">{{ formatDate(trade.entry_time) }}</td>
<td class="text-text-secondary">{{ formatDate(trade.exit_time || '') }}</td>
<td>
<span
class="px-2 py-0.5 rounded text-xs font-medium"
:class="trade.direction === 'Long' ? 'bg-profit/20 text-profit' : 'bg-loss/20 text-loss'"
>
{{ trade.direction }}
</span>
</td>
<td>${{ formatPrice(trade.entry_price) }}</td>
<td>${{ formatPrice(trade.exit_price) }}</td>
<td>{{ trade.size.toFixed(4) }}</td>
<td :class="(trade.pnl ?? 0) >= 0 ? 'profit' : 'loss'">
${{ formatPnL(trade.pnl) }}
</td>
<td :class="(trade.return_pct ?? 0) >= 0 ? 'profit' : 'loss'">
{{ formatReturn(trade.return_pct) }}
</td>
</tr>
</tbody>
</table>
<div v-if="trades.length === 0" class="p-8 text-center text-text-muted">
No trades executed.
</div>
</div>
</div>
</template>

View File

@@ -0,0 +1,150 @@
/**
* Composable for managing backtest state across components.
*/
import { ref, computed } from 'vue'
import type { BacktestResult, BacktestSummary, StrategyInfo, SymbolInfo } from '@/api/types'
import { getStrategies, getSymbols, getBacktests, getBacktest, runBacktest, deleteBacktest } from '@/api/client'
import type { BacktestRequest } from '@/api/types'
// Shared state
const strategies = ref<StrategyInfo[]>([])
const symbols = ref<SymbolInfo[]>([])
const runs = ref<BacktestSummary[]>([])
const currentResult = ref<BacktestResult | null>(null)
const selectedRuns = ref<string[]>([])
const loading = ref(false)
const error = ref<string | null>(null)
// Computed
const symbolsByMarket = computed(() => {
const grouped: Record<string, SymbolInfo[]> = {}
for (const s of symbols.value) {
const key = `${s.market_type}`
if (!grouped[key]) grouped[key] = []
grouped[key].push(s)
}
return grouped
})
export function useBacktest() {
/**
* Load strategies and symbols on app init.
*/
async function init() {
try {
const [stratRes, symRes] = await Promise.all([
getStrategies(),
getSymbols(),
])
strategies.value = stratRes.strategies
symbols.value = symRes.symbols
} catch (e) {
error.value = `Failed to load initial data: ${e}`
}
}
/**
* Refresh the run history list.
*/
async function refreshRuns() {
try {
const res = await getBacktests({ limit: 100 })
runs.value = res.runs
} catch (e) {
error.value = `Failed to load runs: ${e}`
}
}
/**
* Execute a new backtest.
*/
async function executeBacktest(request: BacktestRequest) {
loading.value = true
error.value = null
try {
const result = await runBacktest(request)
currentResult.value = result
await refreshRuns()
return result
} catch (e: unknown) {
const msg = e instanceof Error ? e.message : String(e)
error.value = `Backtest failed: ${msg}`
throw e
} finally {
loading.value = false
}
}
/**
* Load a specific run by ID.
*/
async function loadRun(runId: string) {
loading.value = true
error.value = null
try {
const result = await getBacktest(runId)
currentResult.value = result
return result
} catch (e) {
error.value = `Failed to load run: ${e}`
throw e
} finally {
loading.value = false
}
}
/**
* Delete a run.
*/
async function removeRun(runId: string) {
try {
await deleteBacktest(runId)
await refreshRuns()
if (currentResult.value?.run_id === runId) {
currentResult.value = null
}
selectedRuns.value = selectedRuns.value.filter(id => id !== runId)
} catch (e) {
error.value = `Failed to delete run: ${e}`
}
}
/**
* Toggle run selection for comparison.
*/
function toggleRunSelection(runId: string) {
const idx = selectedRuns.value.indexOf(runId)
if (idx >= 0) {
selectedRuns.value.splice(idx, 1)
} else if (selectedRuns.value.length < 5) {
selectedRuns.value.push(runId)
}
}
/**
* Clear all selections.
*/
function clearSelections() {
selectedRuns.value = []
}
return {
// State
strategies,
symbols,
symbolsByMarket,
runs,
currentResult,
selectedRuns,
loading,
error,
// Actions
init,
refreshRuns,
executeBacktest,
loadRun,
removeRun,
toggleRunSelection,
clearSelections,
}
}

8
frontend/src/main.ts Normal file
View File

@@ -0,0 +1,8 @@
import { createApp } from 'vue'
import App from './App.vue'
import router from './router'
import './style.css'
const app = createApp(App)
app.use(router)
app.mount('#app')

View File

@@ -0,0 +1,21 @@
import { createRouter, createWebHistory } from 'vue-router'
import DashboardView from '@/views/DashboardView.vue'
import CompareView from '@/views/CompareView.vue'
const router = createRouter({
history: createWebHistory(import.meta.env.BASE_URL),
routes: [
{
path: '/',
name: 'dashboard',
component: DashboardView,
},
{
path: '/compare',
name: 'compare',
component: CompareView,
},
],
})
export default router

198
frontend/src/style.css Normal file
View File

@@ -0,0 +1,198 @@
@import "tailwindcss";
/* QuantConnect-inspired dark theme */
@theme {
/* Background colors */
--color-bg-primary: #0d1117;
--color-bg-secondary: #161b22;
--color-bg-tertiary: #21262d;
--color-bg-card: #1c2128;
--color-bg-hover: #30363d;
/* Text colors */
--color-text-primary: #e6edf3;
--color-text-secondary: #8b949e;
--color-text-muted: #6e7681;
/* Accent colors */
--color-accent-blue: #58a6ff;
--color-accent-purple: #a371f7;
--color-accent-cyan: #39d4e8;
/* Status colors */
--color-profit: #3fb950;
--color-loss: #f85149;
--color-warning: #d29922;
/* Border colors */
--color-border: #30363d;
--color-border-muted: #21262d;
/* Chart colors for comparison */
--color-chart-1: #58a6ff;
--color-chart-2: #a371f7;
--color-chart-3: #39d4e8;
--color-chart-4: #f0883e;
--color-chart-5: #db61a2;
}
/* Base styles */
body {
background-color: var(--color-bg-primary);
color: var(--color-text-primary);
font-family: 'JetBrains Mono', 'Fira Code', 'SF Mono', Consolas, monospace;
}
/* Scrollbar styling */
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: var(--color-bg-secondary);
}
::-webkit-scrollbar-thumb {
background: var(--color-bg-hover);
border-radius: 4px;
}
::-webkit-scrollbar-thumb:hover {
background: var(--color-border);
}
/* Input styling */
input[type="number"],
input[type="text"],
select {
background-color: var(--color-bg-tertiary);
border: 1px solid var(--color-border);
color: var(--color-text-primary);
border-radius: 6px;
padding: 0.5rem 0.75rem;
font-size: 0.875rem;
transition: border-color 0.2s, box-shadow 0.2s;
}
input[type="number"]:focus,
input[type="text"]:focus,
select:focus {
outline: none;
border-color: var(--color-accent-blue);
box-shadow: 0 0 0 3px rgba(88, 166, 255, 0.2);
}
/* Button base */
.btn {
display: inline-flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
padding: 0.5rem 1rem;
font-size: 0.875rem;
font-weight: 500;
border-radius: 6px;
transition: all 0.2s;
cursor: pointer;
border: 1px solid transparent;
}
.btn-primary {
background-color: var(--color-accent-blue);
color: #000;
}
.btn-primary:hover {
background-color: #79b8ff;
}
.btn-primary:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.btn-secondary {
background-color: var(--color-bg-tertiary);
border-color: var(--color-border);
color: var(--color-text-primary);
}
.btn-secondary:hover {
background-color: var(--color-bg-hover);
}
/* Card styling */
.card {
background-color: var(--color-bg-card);
border: 1px solid var(--color-border-muted);
border-radius: 8px;
padding: 1rem;
}
/* Profit/Loss coloring */
.profit {
color: var(--color-profit);
}
.loss {
color: var(--color-loss);
}
/* Metric value styling */
.metric-value {
font-size: 1.5rem;
font-weight: 600;
font-variant-numeric: tabular-nums;
}
.metric-label {
font-size: 0.75rem;
color: var(--color-text-secondary);
text-transform: uppercase;
letter-spacing: 0.05em;
}
/* Table styling */
table {
width: 100%;
border-collapse: collapse;
}
th {
text-align: left;
padding: 0.75rem;
font-size: 0.75rem;
font-weight: 500;
color: var(--color-text-secondary);
text-transform: uppercase;
letter-spacing: 0.05em;
border-bottom: 1px solid var(--color-border);
}
td {
padding: 0.75rem;
font-size: 0.875rem;
border-bottom: 1px solid var(--color-border-muted);
font-variant-numeric: tabular-nums;
}
tr:hover td {
background-color: var(--color-bg-hover);
}
/* Loading spinner */
.spinner {
width: 20px;
height: 20px;
border: 2px solid var(--color-border);
border-top-color: var(--color-accent-blue);
border-radius: 50%;
animation: spin 0.8s linear infinite;
}
@keyframes spin {
to {
transform: rotate(360deg);
}
}

4
frontend/src/types/plotly.d.ts vendored Normal file
View File

@@ -0,0 +1,4 @@
declare module 'plotly.js-dist-min' {
import Plotly from 'plotly.js'
export default Plotly
}

View File

@@ -0,0 +1,362 @@
<script setup lang="ts">
import { ref, watch, onMounted, onUnmounted } from 'vue'
import { useRouter } from 'vue-router'
import Plotly from 'plotly.js-dist-min'
import { useBacktest } from '@/composables/useBacktest'
import { compareRuns } from '@/api/client'
import type { BacktestResult, CompareResult } from '@/api/types'
const router = useRouter()
const { selectedRuns, clearSelections } = useBacktest()
const chartRef = ref<HTMLDivElement | null>(null)
const loading = ref(false)
const error = ref<string | null>(null)
const compareResult = ref<CompareResult | null>(null)
const CHART_COLORS = [
'#58a6ff', // blue
'#a371f7', // purple
'#39d4e8', // cyan
'#f0883e', // orange
'#db61a2', // pink
]
async function loadComparison() {
if (selectedRuns.value.length < 2) {
router.push('/')
return
}
loading.value = true
error.value = null
try {
compareResult.value = await compareRuns(selectedRuns.value)
renderChart()
} catch (e) {
error.value = `Failed to load comparison: ${e}`
} finally {
loading.value = false
}
}
function renderChart() {
if (!chartRef.value || !compareResult.value) return
const traces: Plotly.Data[] = compareResult.value.runs.map((run, idx) => {
// Normalize equity curves to start at 100 for comparison
const startValue = run.equity_curve[0]?.value || 1
const normalizedValues = run.equity_curve.map(p => (p.value / startValue) * 100)
return {
x: run.equity_curve.map(p => p.timestamp),
y: normalizedValues,
type: 'scatter',
mode: 'lines',
name: `${run.strategy} (${run.params.period || ''})`,
line: { color: CHART_COLORS[idx % CHART_COLORS.length], width: 2 },
hovertemplate: `%{x}<br>${run.strategy}: %{y:.2f}<extra></extra>`,
}
})
const layout: Partial<Plotly.Layout> = {
title: {
text: 'Normalized Equity Comparison (Base 100)',
font: { color: '#8b949e', size: 14 },
},
paper_bgcolor: 'transparent',
plot_bgcolor: 'transparent',
margin: { l: 60, r: 20, t: 50, b: 40 },
xaxis: {
showgrid: true,
gridcolor: '#30363d',
tickfont: { color: '#8b949e', size: 10 },
linecolor: '#30363d',
},
yaxis: {
showgrid: true,
gridcolor: '#30363d',
tickfont: { color: '#8b949e', size: 10 },
linecolor: '#30363d',
title: { text: 'Normalized Value', font: { color: '#8b949e' } },
},
legend: {
orientation: 'h',
yanchor: 'bottom',
y: 1.02,
xanchor: 'left',
x: 0,
font: { color: '#8b949e' },
},
hovermode: 'x unified',
}
Plotly.react(chartRef.value, traces, layout, { responsive: true, displayModeBar: false })
}
function formatPercent(val: number): string {
return (val >= 0 ? '+' : '') + val.toFixed(2) + '%'
}
function formatNumber(val: number | null): string {
if (val === null) return '-'
return val.toFixed(2)
}
function getBestIndex(runs: BacktestResult[], metric: keyof BacktestResult['metrics'], higher = true): number {
let bestIdx = 0
let bestVal = runs[0]?.metrics[metric] ?? 0
for (let i = 1; i < runs.length; i++) {
const val = runs[i]?.metrics[metric] ?? 0
const isBetter = higher ? (val as number) > (bestVal as number) : (val as number) < (bestVal as number)
if (isBetter) {
bestIdx = i
bestVal = val
}
}
return bestIdx
}
function handleClearAndBack() {
clearSelections()
router.push('/')
}
onMounted(() => {
loadComparison()
window.addEventListener('resize', renderChart)
})
onUnmounted(() => {
window.removeEventListener('resize', renderChart)
if (chartRef.value) {
Plotly.purge(chartRef.value)
}
})
watch(selectedRuns, () => {
if (selectedRuns.value.length >= 2) {
loadComparison()
}
})
</script>
<template>
<div class="p-6 space-y-6">
<!-- Header -->
<div class="flex items-center justify-between">
<div>
<h1 class="text-2xl font-bold">Compare Runs</h1>
<p class="text-text-secondary text-sm mt-1">
Comparing {{ selectedRuns.length }} backtest runs
</p>
</div>
<button @click="handleClearAndBack" class="btn btn-secondary">
Clear & Back
</button>
</div>
<!-- Error -->
<div v-if="error" class="p-4 rounded-lg bg-loss/10 border border-loss/30 text-loss">
{{ error }}
</div>
<!-- Loading -->
<div v-if="loading" class="card flex items-center justify-center h-[400px]">
<div class="spinner" style="width: 40px; height: 40px;"></div>
</div>
<!-- Comparison Results -->
<template v-else-if="compareResult">
<!-- Equity Curve Comparison -->
<div class="card">
<div ref="chartRef" class="h-[400px]"></div>
</div>
<!-- Metrics Comparison Table -->
<div class="card overflow-x-auto">
<h3 class="text-sm font-semibold text-text-secondary uppercase tracking-wide mb-4">
Metrics Comparison
</h3>
<table class="min-w-full">
<thead>
<tr>
<th>Metric</th>
<th
v-for="(run, idx) in compareResult.runs"
:key="run.run_id"
class="text-center"
>
<span
class="inline-block w-3 h-3 rounded-full mr-2"
:style="{ backgroundColor: CHART_COLORS[idx] }"
></span>
{{ run.strategy }}
</th>
</tr>
</thead>
<tbody>
<!-- Total Return -->
<tr>
<td class="font-medium">Strategy Return</td>
<td
v-for="(run, idx) in compareResult.runs"
:key="run.run_id"
class="text-center"
:class="[
run.metrics.total_return >= 0 ? 'profit' : 'loss',
idx === getBestIndex(compareResult.runs, 'total_return') ? 'font-bold' : ''
]"
>
{{ formatPercent(run.metrics.total_return) }}
</td>
</tr>
<!-- Benchmark Return -->
<tr>
<td class="font-medium">Benchmark (B&H)</td>
<td
v-for="run in compareResult.runs"
:key="run.run_id"
class="text-center"
:class="run.metrics.benchmark_return >= 0 ? 'profit' : 'loss'"
>
{{ formatPercent(run.metrics.benchmark_return) }}
</td>
</tr>
<!-- Alpha -->
<tr>
<td class="font-medium">Alpha</td>
<td
v-for="(run, idx) in compareResult.runs"
:key="run.run_id"
class="text-center"
:class="[
run.metrics.alpha >= 0 ? 'profit' : 'loss',
idx === getBestIndex(compareResult.runs, 'alpha') ? 'font-bold' : ''
]"
>
{{ formatPercent(run.metrics.alpha) }}
</td>
</tr>
<!-- Sharpe Ratio -->
<tr>
<td class="font-medium">Sharpe Ratio</td>
<td
v-for="(run, idx) in compareResult.runs"
:key="run.run_id"
class="text-center"
:class="idx === getBestIndex(compareResult.runs, 'sharpe_ratio') ? 'font-bold text-accent-blue' : ''"
>
{{ formatNumber(run.metrics.sharpe_ratio) }}
</td>
</tr>
<!-- Max Drawdown -->
<tr>
<td class="font-medium">Max Drawdown</td>
<td
v-for="(run, idx) in compareResult.runs"
:key="run.run_id"
class="text-center loss"
:class="idx === getBestIndex(compareResult.runs, 'max_drawdown', false) ? 'font-bold' : ''"
>
{{ formatPercent(run.metrics.max_drawdown) }}
</td>
</tr>
<!-- Win Rate -->
<tr>
<td class="font-medium">Win Rate</td>
<td
v-for="(run, idx) in compareResult.runs"
:key="run.run_id"
class="text-center"
:class="idx === getBestIndex(compareResult.runs, 'win_rate') ? 'font-bold text-profit' : ''"
>
{{ formatNumber(run.metrics.win_rate) }}%
</td>
</tr>
<!-- Total Trades -->
<tr>
<td class="font-medium">Total Trades</td>
<td
v-for="run in compareResult.runs"
:key="run.run_id"
class="text-center"
>
{{ run.metrics.total_trades }}
</td>
</tr>
<!-- Profit Factor -->
<tr>
<td class="font-medium">Profit Factor</td>
<td
v-for="(run, idx) in compareResult.runs"
:key="run.run_id"
class="text-center"
:class="idx === getBestIndex(compareResult.runs, 'profit_factor') ? 'font-bold text-profit' : ''"
>
{{ formatNumber(run.metrics.profit_factor) }}
</td>
</tr>
</tbody>
</table>
</div>
<!-- Parameter Differences -->
<div v-if="Object.keys(compareResult.param_diff).length > 0" class="card">
<h3 class="text-sm font-semibold text-text-secondary uppercase tracking-wide mb-4">
Parameter Differences
</h3>
<table class="min-w-full">
<thead>
<tr>
<th>Parameter</th>
<th
v-for="(run, idx) in compareResult.runs"
:key="run.run_id"
class="text-center"
>
<span
class="inline-block w-3 h-3 rounded-full mr-2"
:style="{ backgroundColor: CHART_COLORS[idx] }"
></span>
Run {{ idx + 1 }}
</th>
</tr>
</thead>
<tbody>
<tr v-for="(values, key) in compareResult.param_diff" :key="key">
<td class="font-medium font-mono">{{ key }}</td>
<td
v-for="(val, idx) in values"
:key="idx"
class="text-center font-mono"
>
{{ val ?? '-' }}
</td>
</tr>
</tbody>
</table>
</div>
</template>
<!-- No Selection -->
<div v-else class="card flex items-center justify-center h-[400px]">
<div class="text-center text-text-muted">
<p>Select at least 2 runs from the history to compare.</p>
<button @click="router.push('/')" class="btn btn-secondary mt-4">
Go to Dashboard
</button>
</div>
</div>
</div>
</template>

View File

@@ -0,0 +1,136 @@
<script setup lang="ts">
import { useBacktest } from '@/composables/useBacktest'
import BacktestConfig from '@/components/BacktestConfig.vue'
import EquityCurve from '@/components/EquityCurve.vue'
import MetricsPanel from '@/components/MetricsPanel.vue'
import TradeLog from '@/components/TradeLog.vue'
const { currentResult, loading, error } = useBacktest()
</script>
<template>
<div class="p-6 space-y-6">
<!-- Header -->
<div class="flex items-center justify-between">
<div>
<h1 class="text-2xl font-bold">Lowkey Backtest</h1>
<p class="text-text-secondary text-sm mt-1">
Run and analyze trading strategy backtests
</p>
</div>
</div>
<!-- Error Banner -->
<div
v-if="error"
class="p-4 rounded-lg bg-loss/10 border border-loss/30 text-loss"
>
{{ error }}
</div>
<!-- Main Grid -->
<div class="grid grid-cols-1 lg:grid-cols-3 gap-6">
<!-- Config Panel (Left) -->
<div class="lg:col-span-1">
<BacktestConfig />
</div>
<!-- Results (Right) -->
<div class="lg:col-span-2 space-y-6">
<!-- Loading State -->
<div
v-if="loading"
class="card flex items-center justify-center h-[400px]"
>
<div class="text-center">
<div class="spinner mx-auto mb-4" style="width: 40px; height: 40px;"></div>
<p class="text-text-secondary">Running backtest...</p>
</div>
</div>
<!-- Results Display -->
<template v-else-if="currentResult">
<!-- Result Header -->
<div class="card">
<div class="flex items-center justify-between">
<div>
<h2 class="text-lg font-semibold">
{{ currentResult.strategy }} on {{ currentResult.symbol }}
</h2>
<p class="text-sm text-text-secondary mt-1">
{{ currentResult.start_date }} - {{ currentResult.end_date }}
<span class="mx-2">|</span>
{{ currentResult.market_type.toUpperCase() }}
<span v-if="currentResult.leverage > 1" class="ml-2">
{{ currentResult.leverage }}x
</span>
</p>
</div>
<div class="text-right">
<div
class="text-2xl font-bold"
:class="currentResult.metrics.total_return >= 0 ? 'profit' : 'loss'"
>
{{ currentResult.metrics.total_return >= 0 ? '+' : '' }}{{ currentResult.metrics.total_return.toFixed(2) }}%
</div>
<div class="text-sm text-text-secondary">
Sharpe: {{ currentResult.metrics.sharpe_ratio.toFixed(2) }}
</div>
</div>
</div>
</div>
<!-- Equity Curve -->
<div class="card">
<h3 class="text-sm font-semibold text-text-secondary uppercase tracking-wide mb-4">
Equity Curve
</h3>
<div class="h-[350px]">
<EquityCurve :data="currentResult.equity_curve" />
</div>
</div>
<!-- Metrics -->
<MetricsPanel
:metrics="currentResult.metrics"
:leverage="currentResult.leverage"
:market-type="currentResult.market_type"
/>
<!-- Trade Log -->
<TradeLog :trades="currentResult.trades" />
<!-- Parameters Used -->
<div class="card">
<h3 class="text-sm font-semibold text-text-secondary uppercase tracking-wide mb-3">
Parameters
</h3>
<div class="flex flex-wrap gap-2">
<span
v-for="(value, key) in currentResult.params"
:key="key"
class="px-2 py-1 rounded bg-bg-tertiary text-sm font-mono"
>
{{ key }}: {{ value }}
</span>
</div>
</div>
</template>
<!-- Empty State -->
<div
v-else
class="card flex items-center justify-center h-[400px]"
>
<div class="text-center text-text-muted">
<svg class="w-16 h-16 mx-auto mb-4 opacity-50" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="1" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
</svg>
<p>Configure and run a backtest to see results.</p>
<p class="text-xs mt-2">Or select a run from history.</p>
</div>
</div>
</div>
</div>
</div>
</template>

View File

@@ -0,0 +1,20 @@
{
"extends": "@vue/tsconfig/tsconfig.dom.json",
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
"types": ["vite/client"],
"baseUrl": ".",
"paths": {
"@/*": ["./src/*"]
},
/* Linting */
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true,
"noUncheckedSideEffectImports": true
},
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"]
}

7
frontend/tsconfig.json Normal file
View File

@@ -0,0 +1,7 @@
{
"files": [],
"references": [
{ "path": "./tsconfig.app.json" },
{ "path": "./tsconfig.node.json" }
]
}

View File

@@ -0,0 +1,26 @@
{
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
"target": "ES2023",
"lib": ["ES2023"],
"module": "ESNext",
"types": ["node"],
"skipLibCheck": true,
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"verbatimModuleSyntax": true,
"moduleDetection": "force",
"noEmit": true,
/* Linting */
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true,
"noUncheckedSideEffectImports": true
},
"include": ["vite.config.ts"]
}

22
frontend/vite.config.ts Normal file
View File

@@ -0,0 +1,22 @@
import { defineConfig } from 'vite'
import vue from '@vitejs/plugin-vue'
import tailwindcss from '@tailwindcss/vite'
import { fileURLToPath, URL } from 'node:url'
// https://vite.dev/config/
export default defineConfig({
plugins: [vue(), tailwindcss()],
resolve: {
alias: {
'@': fileURLToPath(new URL('./src', import.meta.url))
}
},
server: {
proxy: {
'/api': {
target: 'http://127.0.0.1:8000',
changeOrigin: true
}
}
}
})

View File

@@ -1,3 +0,0 @@
from .supertrend import add_supertrends, compute_meta_trend
__all__ = ["add_supertrends", "compute_meta_trend"]

View File

@@ -1,58 +0,0 @@
from __future__ import annotations
import pandas as pd
import numpy as np
def _atr(high: pd.Series, low: pd.Series, close: pd.Series, period: int) -> pd.Series:
hl = (high - low).abs()
hc = (high - close.shift()).abs()
lc = (low - close.shift()).abs()
tr = pd.concat([hl, hc, lc], axis=1).max(axis=1)
return tr.rolling(period, min_periods=period).mean()
def supertrend_series(df: pd.DataFrame, length: int, multiplier: float) -> pd.Series:
atr = _atr(df["High"], df["Low"], df["Close"], length)
hl2 = (df["High"] + df["Low"]) / 2
upper = hl2 + multiplier * atr
lower = hl2 - multiplier * atr
trend = pd.Series(index=df.index, dtype=float)
dir_up = True
prev_upper = np.nan
prev_lower = np.nan
for i in range(len(df)):
if i == 0 or pd.isna(atr.iat[i]):
trend.iat[i] = np.nan
prev_upper = upper.iat[i]
prev_lower = lower.iat[i]
continue
cu = min(upper.iat[i], prev_upper) if dir_up else upper.iat[i]
cl = max(lower.iat[i], prev_lower) if not dir_up else lower.iat[i]
if df["Close"].iat[i] > cu:
dir_up = True
elif df["Close"].iat[i] < cl:
dir_up = False
prev_upper = cu if dir_up else upper.iat[i]
prev_lower = lower.iat[i] if dir_up else cl
trend.iat[i] = cl if dir_up else cu
return trend
def add_supertrends(df: pd.DataFrame, settings: list[tuple[int, float]]) -> pd.DataFrame:
out = df.copy()
for length, mult in settings:
col = f"supertrend_{length}_{mult}"
out[col] = supertrend_series(out, length, mult)
out[f"bull_{length}_{mult}"] = (out["Close"] >= out[col]).astype(int)
return out
def compute_meta_trend(df: pd.DataFrame, settings: list[tuple[int, float]]) -> pd.Series:
bull_cols = [f"bull_{l}_{m}" for l, m in settings]
return (df[bull_cols].sum(axis=1) == len(bull_cols)).astype(int)

View File

@@ -1,10 +0,0 @@
from __future__ import annotations
import pandas as pd
def precompute_slices(df: pd.DataFrame) -> pd.DataFrame:
return df # hook for future use
def entry_slippage_row(price: float, qty: float, slippage_bps: float) -> float:
return price + price * (slippage_bps / 1e4)

135
live_trading/README.md Normal file
View File

@@ -0,0 +1,135 @@
# Live Trading - Regime Reversion Strategy
This module implements live trading for the ML-based regime detection and mean reversion strategy on OKX perpetual futures.
## Overview
The strategy trades ETH perpetual futures based on:
1. **BTC/ETH Spread Z-Score**: Identifies when ETH is cheap or expensive relative to BTC
2. **Random Forest ML Model**: Predicts probability of successful mean reversion
3. **Funding Rate Filter**: Avoids trades in overheated/oversold market conditions
## Setup
### 1. API Keys
The bot loads OKX API credentials from `../BTC_spot_MVRV/.env`.
**IMPORTANT: OKX uses SEPARATE API keys for live vs demo trading!**
#### Option A: Demo Trading (Recommended for Testing)
1. Go to [OKX Demo Trading](https://www.okx.com/demo-trading)
2. Create a demo account if you haven't
3. Generate API keys from the demo environment
4. Set in `.env`:
```env
OKX_API_KEY=your_demo_api_key
OKX_SECRET=your_demo_secret
OKX_PASSWORD=your_demo_passphrase
OKX_DEMO_MODE=true
```
#### Option B: Live Trading (Real Funds)
Use your existing live API keys with:
```env
OKX_API_KEY=your_live_api_key
OKX_SECRET=your_live_secret
OKX_PASSWORD=your_live_passphrase
OKX_DEMO_MODE=false
```
**Note:** You cannot use live API keys with `OKX_DEMO_MODE=true` or vice versa.
OKX will return error `50101: APIKey does not match current environment`.
### 2. Dependencies
All dependencies are already in the project's `pyproject.toml`. No additional installation needed.
## Usage
### Run with Demo Account (Recommended First)
```bash
cd /path/to/lowkey_backtest
uv run python -m live_trading.main
```
### Command Line Options
```bash
# Custom position size
uv run python -m live_trading.main --max-position 500
# Custom leverage
uv run python -m live_trading.main --leverage 2
# Custom cycle interval (in seconds)
uv run python -m live_trading.main --interval 1800
# Combine options
uv run python -m live_trading.main --max-position 1000 --leverage 3 --interval 3600
```
### Live Trading (Use with Caution)
```bash
# Requires OKX_DEMO_MODE=false in .env
uv run python -m live_trading.main --live
```
## Architecture
```
live_trading/
__init__.py # Module initialization
config.py # Configuration loading
okx_client.py # OKX API wrapper
data_feed.py # Real-time OHLCV data
position_manager.py # Position tracking
live_regime_strategy.py # Strategy logic
main.py # Entry point
.env.example # Environment template
README.md # This file
```
## Strategy Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `z_entry_threshold` | 1.0 | Enter when \|Z-Score\| > threshold |
| `z_window` | 24 | Rolling window for Z-Score (hours) |
| `model_prob_threshold` | 0.5 | ML probability threshold for entry |
| `funding_threshold` | 0.005 | Funding rate filter threshold |
| `stop_loss_pct` | 6% | Stop-loss percentage |
| `take_profit_pct` | 5% | Take-profit percentage |
## Files Generated
- `live_trading/positions.json` - Open positions persistence
- `live_trading/trade_log.csv` - Trade history
- `live_trading/regime_model.pkl` - Trained ML model
- `logs/live_trading.log` - Trading logs
## Risk Warning
This is experimental trading software. Use at your own risk:
- Always start with demo trading
- Never risk more than you can afford to lose
- Monitor the bot regularly
- Have a kill switch ready (Ctrl+C)
## Troubleshooting
### API Key Issues
- Ensure API keys have trading permissions
- For demo trading, use demo-specific API keys
- Check that passphrase matches exactly
### No Signals Generated
- The strategy requires the ML model to be trained
- Need at least 200 candles of data
- Model trains automatically on first run
### Position Sync Issues
- The bot syncs with exchange positions on each cycle
- If positions are closed manually, the bot will detect this

6
live_trading/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""
Live Trading Module for Regime Reversion Strategy on OKX.
This module implements live trading using the ML-based regime detection
and mean reversion strategy on OKX perpetual futures.
"""

114
live_trading/config.py Normal file
View File

@@ -0,0 +1,114 @@
"""
Configuration for Live Trading.
Loads OKX API credentials from environment variables.
Uses demo/sandbox mode by default for paper trading.
"""
import os
from pathlib import Path
from dataclasses import dataclass, field
from dotenv import load_dotenv
load_dotenv()
@dataclass
class OKXConfig:
"""OKX API configuration."""
api_key: str = field(default_factory=lambda: "")
secret: str = field(default_factory=lambda: "")
password: str = field(default_factory=lambda: "")
demo_mode: bool = field(default_factory=lambda: True)
def __post_init__(self):
"""Load credentials based on demo mode setting."""
# Check demo mode first
self.demo_mode = os.getenv("OKX_DEMO_MODE", "true").lower() in ("true", "1", "yes")
if self.demo_mode:
# Load demo-specific credentials if available
self.api_key = os.getenv("OKX_DEMO_API_KEY", os.getenv("OKX_API_KEY", ""))
self.secret = os.getenv("OKX_DEMO_SECRET", os.getenv("OKX_SECRET", ""))
self.password = os.getenv("OKX_DEMO_PASSWORD", os.getenv("OKX_PASSWORD", ""))
else:
# Load live credentials
self.api_key = os.getenv("OKX_API_KEY", "")
self.secret = os.getenv("OKX_SECRET", "")
self.password = os.getenv("OKX_PASSWORD", "")
def validate(self) -> None:
"""Validate that required credentials are present."""
mode = "demo" if self.demo_mode else "live"
if not self.api_key:
raise ValueError(f"OKX API key not set for {mode} mode")
if not self.secret:
raise ValueError(f"OKX secret not set for {mode} mode")
if not self.password:
raise ValueError(f"OKX password not set for {mode} mode")
@dataclass
class TradingConfig:
"""Trading parameters configuration."""
# Trading pairs
eth_symbol: str = "ETH/USDT:USDT" # ETH perpetual (primary trading asset)
btc_symbol: str = "BTC/USDT:USDT" # BTC perpetual (context asset)
# Timeframe
timeframe: str = "1h"
candles_to_fetch: int = 500 # Enough for feature calculation
# Position sizing
max_position_usdt: float = -1.0 # Max position size in USDT. If <= 0, use all available funds
min_position_usdt: float = 10.0 # Min position size in USDT
leverage: int = 1 # Leverage (1x = no leverage)
margin_mode: str = "cross" # "cross" or "isolated"
# Risk management
stop_loss_pct: float = 0.06 # 6% stop loss
take_profit_pct: float = 0.05 # 5% take profit
max_concurrent_positions: int = 1 # Max open positions
# Strategy parameters (from regime_strategy.py)
z_entry_threshold: float = 1.0 # Enter when |Z| > 1.0
z_window: int = 24 # 24h rolling Z-score window
model_prob_threshold: float = 0.5 # ML model probability threshold
funding_threshold: float = 0.005 # Funding rate filter threshold
# Execution
sleep_seconds: int = 3600 # Run every hour (1h candles)
slippage_pct: float = 0.001 # 0.1% slippage buffer
@dataclass
class PathConfig:
"""File paths configuration."""
base_dir: Path = field(
default_factory=lambda: Path(__file__).parent.parent
)
data_dir: Path = field(default=None)
logs_dir: Path = field(default=None)
model_path: Path = field(default=None)
positions_file: Path = field(default=None)
trade_log_file: Path = field(default=None)
cq_data_path: Path = field(default=None)
def __post_init__(self):
self.data_dir = self.base_dir / "data"
self.logs_dir = self.base_dir / "logs"
self.model_path = self.base_dir / "live_trading" / "regime_model.pkl"
self.positions_file = self.base_dir / "live_trading" / "positions.json"
self.trade_log_file = self.base_dir / "live_trading" / "trade_log.csv"
self.cq_data_path = self.data_dir / "cq_training_data.csv"
# Ensure directories exist
self.data_dir.mkdir(parents=True, exist_ok=True)
self.logs_dir.mkdir(parents=True, exist_ok=True)
def get_config():
"""Get all configuration objects."""
okx = OKXConfig()
trading = TradingConfig()
paths = PathConfig()
return okx, trading, paths

216
live_trading/data_feed.py Normal file
View File

@@ -0,0 +1,216 @@
"""
Data Feed for Live Trading.
Fetches real-time OHLCV data from OKX and prepares features
for the regime strategy.
"""
import logging
from datetime import datetime, timezone
from typing import Optional
import pandas as pd
import numpy as np
import ta
from .okx_client import OKXClient
from .config import TradingConfig, PathConfig
logger = logging.getLogger(__name__)
class DataFeed:
"""
Real-time data feed for the regime strategy.
Fetches BTC and ETH OHLCV data from OKX and calculates
the spread-based features required by the ML model.
"""
def __init__(
self,
okx_client: OKXClient,
trading_config: TradingConfig,
path_config: PathConfig
):
self.client = okx_client
self.config = trading_config
self.paths = path_config
self.cq_data: Optional[pd.DataFrame] = None
self._load_cq_data()
def _load_cq_data(self) -> None:
"""Load CryptoQuant on-chain data if available."""
try:
if self.paths.cq_data_path.exists():
self.cq_data = pd.read_csv(
self.paths.cq_data_path,
index_col='timestamp',
parse_dates=True
)
if self.cq_data.index.tz is None:
self.cq_data.index = self.cq_data.index.tz_localize('UTC')
logger.info(f"Loaded CryptoQuant data: {len(self.cq_data)} rows")
except Exception as e:
logger.warning(f"Could not load CryptoQuant data: {e}")
self.cq_data = None
def fetch_ohlcv_data(self) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Fetch OHLCV data for BTC and ETH.
Returns:
Tuple of (btc_df, eth_df) DataFrames
"""
# Fetch BTC data
btc_ohlcv = self.client.fetch_ohlcv(
self.config.btc_symbol,
self.config.timeframe,
self.config.candles_to_fetch
)
btc_df = self._ohlcv_to_dataframe(btc_ohlcv)
# Fetch ETH data
eth_ohlcv = self.client.fetch_ohlcv(
self.config.eth_symbol,
self.config.timeframe,
self.config.candles_to_fetch
)
eth_df = self._ohlcv_to_dataframe(eth_ohlcv)
logger.info(
f"Fetched {len(btc_df)} BTC candles and {len(eth_df)} ETH candles"
)
return btc_df, eth_df
def _ohlcv_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
"""Convert OHLCV list to DataFrame."""
df = pd.DataFrame(
ohlcv,
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df.set_index('timestamp', inplace=True)
return df
def calculate_features(
self,
btc_df: pd.DataFrame,
eth_df: pd.DataFrame
) -> pd.DataFrame:
"""
Calculate spread-based features for the regime strategy.
Args:
btc_df: BTC OHLCV DataFrame
eth_df: ETH OHLCV DataFrame
Returns:
DataFrame with calculated features
"""
# Align indices
common_idx = btc_df.index.intersection(eth_df.index)
df_btc = btc_df.loc[common_idx].copy()
df_eth = eth_df.loc[common_idx].copy()
# Calculate spread (ETH/BTC ratio)
spread = df_eth['close'] / df_btc['close']
# Z-Score of spread
z_window = self.config.z_window
rolling_mean = spread.rolling(window=z_window).mean()
rolling_std = spread.rolling(window=z_window).std()
z_score = (spread - rolling_mean) / rolling_std
# Spread technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
spread_change_1h = spread.pct_change(periods=1)
# Volume ratio
vol_ratio = df_eth['volume'] / df_btc['volume']
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
# Volatility
ret_btc = df_btc['close'].pct_change()
ret_eth = df_eth['close'].pct_change()
vol_btc = ret_btc.rolling(window=z_window).std()
vol_eth = ret_eth.rolling(window=z_window).std()
vol_spread_ratio = vol_eth / vol_btc
# Build features DataFrame
features = pd.DataFrame(index=spread.index)
features['spread'] = spread
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['spread_change_1h'] = spread_change_1h
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
features['vol_diff_ratio'] = vol_spread_ratio
# Add price data for reference
features['btc_close'] = df_btc['close']
features['eth_close'] = df_eth['close']
features['eth_volume'] = df_eth['volume']
# Merge CryptoQuant data if available
if self.cq_data is not None:
cq_aligned = self.cq_data.reindex(features.index, method='ffill')
# Calculate derived features
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
cq_aligned['funding_diff'] = (
cq_aligned['eth_funding'] - cq_aligned['btc_funding']
)
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
cq_aligned['inflow_ratio'] = (
cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
)
features = features.join(cq_aligned)
return features.dropna()
def get_latest_data(self) -> Optional[pd.DataFrame]:
"""
Fetch and process latest market data.
Returns:
DataFrame with features or None on error
"""
try:
btc_df, eth_df = self.fetch_ohlcv_data()
features = self.calculate_features(btc_df, eth_df)
if features.empty:
logger.warning("No valid features calculated")
return None
logger.info(
f"Latest data: ETH={features['eth_close'].iloc[-1]:.2f}, "
f"BTC={features['btc_close'].iloc[-1]:.2f}, "
f"Z-Score={features['z_score'].iloc[-1]:.3f}"
)
return features
except Exception as e:
logger.error(f"Error fetching market data: {e}", exc_info=True)
return None
def get_current_funding_rates(self) -> dict:
"""
Get current funding rates for BTC and ETH.
Returns:
Dictionary with 'btc_funding' and 'eth_funding' rates
"""
btc_funding = self.client.get_funding_rate(self.config.btc_symbol)
eth_funding = self.client.get_funding_rate(self.config.eth_symbol)
return {
'btc_funding': btc_funding,
'eth_funding': eth_funding,
'funding_diff': eth_funding - btc_funding,
}

15
live_trading/env.template Normal file
View File

@@ -0,0 +1,15 @@
# OKX API Credentials Template
# Copy this file to .env and fill in your credentials
# For demo trading, use your OKX demo account API keys
# Generate keys at: https://www.okx.com/account/my-api (Demo Trading section)
OKX_API_KEY=your_api_key_here
OKX_SECRET=your_secret_key_here
OKX_PASSWORD=your_passphrase_here
# Demo Mode: Set to "true" for paper trading (sandbox)
# Set to "false" for live trading with real funds
OKX_DEMO_MODE=true
# CryptoQuant API (optional, for on-chain features)
CRYPTOQUANT_API_KEY=your_cryptoquant_api_key_here

View File

@@ -0,0 +1,287 @@
"""
Live Regime Reversion Strategy.
Adapts the backtest regime strategy for live trading.
Uses a pre-trained ML model or trains on historical data.
"""
import logging
import pickle
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from .config import TradingConfig, PathConfig
logger = logging.getLogger(__name__)
class LiveRegimeStrategy:
"""
Live trading implementation of the ML-based regime detection
and mean reversion strategy.
Logic:
1. Calculates BTC/ETH spread Z-Score
2. Uses Random Forest to predict reversion probability
3. Applies funding rate filter
4. Generates long/short signals on ETH perpetual
"""
def __init__(
self,
trading_config: TradingConfig,
path_config: PathConfig
):
self.config = trading_config
self.paths = path_config
self.model: Optional[RandomForestClassifier] = None
self.feature_cols: Optional[list] = None
self._load_or_train_model()
def _load_or_train_model(self) -> None:
"""Load pre-trained model or train a new one."""
if self.paths.model_path.exists():
try:
with open(self.paths.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
logger.info(f"Loaded model from {self.paths.model_path}")
return
except Exception as e:
logger.warning(f"Could not load model: {e}")
logger.info("No pre-trained model found. Will train on first data batch.")
def save_model(self) -> None:
"""Save trained model to file."""
if self.model is None:
return
try:
with open(self.paths.model_path, 'wb') as f:
pickle.dump({
'model': self.model,
'feature_cols': self.feature_cols,
}, f)
logger.info(f"Saved model to {self.paths.model_path}")
except Exception as e:
logger.error(f"Could not save model: {e}")
def train_model(self, features: pd.DataFrame) -> None:
"""
Train the Random Forest model on historical data.
Args:
features: DataFrame with calculated features
"""
logger.info(f"Training model on {len(features)} samples...")
z_thresh = self.config.z_entry_threshold
horizon = 102 # Optimal horizon from research
profit_target = 0.005 # 0.5% profit threshold
# Define targets
future_min = features['spread'].rolling(window=horizon).min().shift(-horizon)
future_max = features['spread'].rolling(window=horizon).max().shift(-horizon)
target_short = features['spread'] * (1 - profit_target)
target_long = features['spread'] * (1 + profit_target)
success_short = (features['z_score'] > z_thresh) & (future_min < target_short)
success_long = (features['z_score'] < -z_thresh) & (future_max > target_long)
targets = np.select([success_short, success_long], [1, 1], default=0)
# Exclude non-feature columns
exclude = ['spread', 'btc_close', 'eth_close', 'eth_volume']
self.feature_cols = [c for c in features.columns if c not in exclude]
# Clean features
X = features[self.feature_cols].fillna(0)
X = X.replace([np.inf, -np.inf], 0)
# Remove rows with invalid targets
valid_mask = ~np.isnan(targets) & future_min.notna().values & future_max.notna().values
X_clean = X[valid_mask]
y_clean = targets[valid_mask]
if len(X_clean) < 100:
logger.warning("Not enough data to train model")
return
# Train model
self.model = RandomForestClassifier(
n_estimators=300,
max_depth=5,
min_samples_leaf=30,
class_weight={0: 1, 1: 3},
random_state=42
)
self.model.fit(X_clean, y_clean)
logger.info(f"Model trained on {len(X_clean)} samples")
self.save_model()
def generate_signal(
self,
features: pd.DataFrame,
current_funding: dict
) -> dict:
"""
Generate trading signal from latest features.
Args:
features: DataFrame with calculated features
current_funding: Dictionary with funding rate data
Returns:
Signal dictionary with action, side, confidence, etc.
"""
if self.model is None:
# Train model if not available
if len(features) >= 200:
self.train_model(features)
else:
return {'action': 'hold', 'reason': 'model_not_trained'}
if self.model is None:
return {'action': 'hold', 'reason': 'insufficient_data_for_training'}
# Get latest row
latest = features.iloc[-1]
z_score = latest['z_score']
eth_price = latest['eth_close']
btc_price = latest['btc_close']
# Prepare features for prediction
X = features[self.feature_cols].iloc[[-1]].fillna(0)
X = X.replace([np.inf, -np.inf], 0)
# Get prediction probability
prob = self.model.predict_proba(X)[0, 1]
# Apply thresholds
z_thresh = self.config.z_entry_threshold
prob_thresh = self.config.model_prob_threshold
# Determine signal direction
signal = {
'action': 'hold',
'side': None,
'probability': prob,
'z_score': z_score,
'eth_price': eth_price,
'btc_price': btc_price,
'reason': '',
}
# Check for entry conditions
if prob > prob_thresh:
if z_score > z_thresh:
# Spread high (ETH expensive relative to BTC) -> Short ETH
signal['action'] = 'entry'
signal['side'] = 'short'
signal['reason'] = f'z_score={z_score:.2f}>threshold, prob={prob:.2f}'
elif z_score < -z_thresh:
# Spread low (ETH cheap relative to BTC) -> Long ETH
signal['action'] = 'entry'
signal['side'] = 'long'
signal['reason'] = f'z_score={z_score:.2f}<-threshold, prob={prob:.2f}'
else:
signal['reason'] = f'z_score={z_score:.2f} within threshold'
else:
signal['reason'] = f'prob={prob:.2f}<threshold'
# Apply funding rate filter
if signal['action'] == 'entry':
btc_funding = current_funding.get('btc_funding', 0)
funding_thresh = self.config.funding_threshold
if signal['side'] == 'long' and btc_funding > funding_thresh:
# High positive funding = overheated, don't go long
signal['action'] = 'hold'
signal['reason'] = f'funding_filter_blocked_long (funding={btc_funding:.4f})'
elif signal['side'] == 'short' and btc_funding < -funding_thresh:
# High negative funding = oversold, don't go short
signal['action'] = 'hold'
signal['reason'] = f'funding_filter_blocked_short (funding={btc_funding:.4f})'
# Check for exit conditions (mean reversion complete)
if signal['action'] == 'hold':
# Z-score crossed back through 0
if abs(z_score) < 0.3:
signal['action'] = 'check_exit'
signal['reason'] = f'z_score_reverted_to_mean ({z_score:.2f})'
logger.info(
f"Signal: {signal['action']} {signal['side'] or ''} "
f"(prob={prob:.2f}, z={z_score:.2f}, reason={signal['reason']})"
)
return signal
def calculate_position_size(
self,
signal: dict,
available_usdt: float
) -> float:
"""
Calculate position size based on signal confidence.
Args:
signal: Signal dictionary with probability
available_usdt: Available USDT balance
Returns:
Position size in USDT
"""
prob = signal.get('probability', 0.5)
# Base size: if max_position_usdt <= 0, use all available funds
if self.config.max_position_usdt <= 0:
base_size = available_usdt
else:
base_size = min(available_usdt, self.config.max_position_usdt)
# Scale by probability (1.0x at 0.5 prob, up to 1.6x at 0.8 prob)
scale = 1.0 + (prob - 0.5) * 2.0
scale = max(1.0, min(scale, 2.0)) # Clamp between 1x and 2x
size = base_size * scale
# Ensure minimum position size
if size < self.config.min_position_usdt:
return 0.0
return min(size, available_usdt * 0.95) # Leave 5% buffer
def calculate_sl_tp(
self,
entry_price: float,
side: str
) -> tuple[float, float]:
"""
Calculate stop-loss and take-profit prices.
Args:
entry_price: Entry price
side: "long" or "short"
Returns:
Tuple of (stop_loss_price, take_profit_price)
"""
sl_pct = self.config.stop_loss_pct
tp_pct = self.config.take_profit_pct
if side == "long":
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else: # short
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
return stop_loss, take_profit

390
live_trading/main.py Normal file
View File

@@ -0,0 +1,390 @@
#!/usr/bin/env python3
"""
Live Trading Bot for Regime Reversion Strategy on OKX.
This script runs the regime-based mean reversion strategy
on ETH perpetual futures using OKX exchange.
Usage:
# Run with demo account (default)
uv run python -m live_trading.main
# Run with specific settings
uv run python -m live_trading.main --max-position 500 --leverage 2
"""
import argparse
import logging
import signal
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from live_trading.config import get_config, OKXConfig, TradingConfig, PathConfig
from live_trading.okx_client import OKXClient
from live_trading.data_feed import DataFeed
from live_trading.position_manager import PositionManager
from live_trading.live_regime_strategy import LiveRegimeStrategy
def setup_logging(log_dir: Path) -> logging.Logger:
"""Configure logging for the trading bot."""
log_file = log_dir / "live_trading.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
force=True
)
return logging.getLogger(__name__)
class LiveTradingBot:
"""
Main trading bot orchestrator.
Coordinates data fetching, signal generation, and order execution
in a continuous loop.
"""
def __init__(
self,
okx_config: OKXConfig,
trading_config: TradingConfig,
path_config: PathConfig
):
self.okx_config = okx_config
self.trading_config = trading_config
self.path_config = path_config
self.logger = logging.getLogger(__name__)
self.running = True
# Initialize components
self.logger.info("Initializing trading bot components...")
self.okx_client = OKXClient(okx_config, trading_config)
self.data_feed = DataFeed(self.okx_client, trading_config, path_config)
self.position_manager = PositionManager(
self.okx_client, trading_config, path_config
)
self.strategy = LiveRegimeStrategy(trading_config, path_config)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, self._handle_shutdown)
signal.signal(signal.SIGTERM, self._handle_shutdown)
self._print_startup_banner()
def _print_startup_banner(self) -> None:
"""Print startup information."""
mode = "DEMO/SANDBOX" if self.okx_config.demo_mode else "LIVE"
print("=" * 60)
print(f" Regime Reversion Strategy - Live Trading Bot")
print("=" * 60)
print(f" Mode: {mode}")
print(f" Trading Pair: {self.trading_config.eth_symbol}")
print(f" Context Pair: {self.trading_config.btc_symbol}")
print(f" Timeframe: {self.trading_config.timeframe}")
print(f" Max Position: ${self.trading_config.max_position_usdt if self.trading_config.max_position_usdt > 0 else 'All available'}")
print(f" Leverage: {self.trading_config.leverage}x")
print(f" Stop Loss: {self.trading_config.stop_loss_pct * 100:.1f}%")
print(f" Take Profit: {self.trading_config.take_profit_pct * 100:.1f}%")
print(f" Cycle Interval: {self.trading_config.sleep_seconds // 60} minutes")
print("=" * 60)
if not self.okx_config.demo_mode:
print("\n *** WARNING: LIVE TRADING MODE - REAL FUNDS AT RISK ***\n")
def _handle_shutdown(self, signum, frame) -> None:
"""Handle shutdown signals gracefully."""
self.logger.info("Shutdown signal received, stopping...")
self.running = False
def run_trading_cycle(self) -> None:
"""
Execute one trading cycle.
1. Fetch latest market data
2. Update open positions
3. Generate trading signal
4. Execute trades if signal triggers
"""
cycle_start = datetime.now(timezone.utc)
self.logger.info(f"--- Trading Cycle Start: {cycle_start.isoformat()} ---")
try:
# 1. Fetch market data
features = self.data_feed.get_latest_data()
if features is None or features.empty:
self.logger.warning("No market data available, skipping cycle")
return
# Get current prices
eth_price = features['eth_close'].iloc[-1]
btc_price = features['btc_close'].iloc[-1]
current_prices = {
self.trading_config.eth_symbol: eth_price,
self.trading_config.btc_symbol: btc_price,
}
# 2. Update existing positions (check SL/TP)
closed_trades = self.position_manager.update_positions(current_prices)
if closed_trades:
for trade in closed_trades:
self.logger.info(
f"Trade closed: {trade['trade_id']} "
f"PnL=${trade['pnl_usd']:.2f} ({trade['reason']})"
)
# 3. Sync with exchange positions
self.position_manager.sync_with_exchange()
# 4. Get current funding rates
funding = self.data_feed.get_current_funding_rates()
# 5. Generate trading signal
signal = self.strategy.generate_signal(features, funding)
# 6. Execute trades based on signal
if signal['action'] == 'entry':
self._execute_entry(signal, eth_price)
elif signal['action'] == 'check_exit':
self._execute_exit(signal)
# 7. Log portfolio summary
summary = self.position_manager.get_portfolio_summary()
self.logger.info(
f"Portfolio: {summary['open_positions']} positions, "
f"exposure=${summary['total_exposure_usdt']:.2f}, "
f"unrealized_pnl=${summary['total_unrealized_pnl']:.2f}"
)
except Exception as e:
self.logger.error(f"Trading cycle error: {e}", exc_info=True)
# Save positions on error
self.position_manager.save_positions()
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
self.logger.info(f"--- Cycle completed in {cycle_duration:.1f}s ---")
def _execute_entry(self, signal: dict, current_price: float) -> None:
"""Execute entry trade."""
symbol = self.trading_config.eth_symbol
side = signal['side']
# Check if we can open a position
if not self.position_manager.can_open_position():
self.logger.info("Cannot open position: max positions reached")
return
# Get account balance
balance = self.okx_client.get_balance()
available_usdt = balance['free']
# Calculate position size
size_usdt = self.strategy.calculate_position_size(signal, available_usdt)
if size_usdt <= 0:
self.logger.info("Position size too small, skipping entry")
return
size_eth = size_usdt / current_price
# Calculate SL/TP
stop_loss, take_profit = self.strategy.calculate_sl_tp(current_price, side)
self.logger.info(
f"Executing {side.upper()} entry: {size_eth:.4f} ETH @ {current_price:.2f} "
f"(${size_usdt:.2f}), SL={stop_loss:.2f}, TP={take_profit:.2f}"
)
try:
# Place market order
order_side = "buy" if side == "long" else "sell"
order = self.okx_client.place_market_order(symbol, order_side, size_eth)
# Get filled price (handle None values from OKX response)
filled_price = order.get('average') or order.get('price') or current_price
filled_amount = order.get('filled') or order.get('amount') or size_eth
# Ensure we have valid numeric values
if filled_price is None or filled_price == 0:
self.logger.warning(f"No fill price in order response, using current price: {current_price}")
filled_price = current_price
if filled_amount is None or filled_amount == 0:
self.logger.warning(f"No fill amount in order response, using requested: {size_eth}")
filled_amount = size_eth
# Recalculate SL/TP with filled price
stop_loss, take_profit = self.strategy.calculate_sl_tp(filled_price, side)
# Get order ID from response
order_id = order.get('id', '')
# Record position locally
position = self.position_manager.open_position(
symbol=symbol,
side=side,
entry_price=filled_price,
size=filled_amount,
stop_loss_price=stop_loss,
take_profit_price=take_profit,
order_id=order_id,
)
if position:
self.logger.info(
f"Position opened: {position.trade_id}, "
f"{filled_amount:.4f} ETH @ {filled_price:.2f}"
)
# Try to set SL/TP on exchange
try:
self.okx_client.set_stop_loss_take_profit(
symbol, side, filled_amount, stop_loss, take_profit
)
except Exception as e:
self.logger.warning(f"Could not set SL/TP on exchange: {e}")
except Exception as e:
self.logger.error(f"Order execution failed: {e}", exc_info=True)
def _execute_exit(self, signal: dict) -> None:
"""Execute exit based on mean reversion signal."""
symbol = self.trading_config.eth_symbol
# Get position for ETH
position = self.position_manager.get_position_for_symbol(symbol)
if not position:
return
current_price = signal.get('eth_price', position.current_price)
self.logger.info(
f"Mean reversion exit signal: closing {position.trade_id} "
f"@ {current_price:.2f}"
)
try:
# Close position on exchange
exit_order = self.okx_client.close_position(symbol)
exit_order_id = exit_order.get('id', '') if exit_order else ''
# Record closure locally
self.position_manager.close_position(
position.trade_id,
current_price,
reason="mean_reversion_complete",
exit_order_id=exit_order_id,
)
except Exception as e:
self.logger.error(f"Exit execution failed: {e}", exc_info=True)
def run(self) -> None:
"""Main trading loop."""
self.logger.info("Starting trading loop...")
while self.running:
try:
self.run_trading_cycle()
if self.running:
sleep_seconds = self.trading_config.sleep_seconds
minutes = sleep_seconds // 60
self.logger.info(f"Sleeping for {minutes} minutes...")
# Sleep in smaller chunks to allow faster shutdown
for _ in range(sleep_seconds):
if not self.running:
break
time.sleep(1)
except KeyboardInterrupt:
self.logger.info("Keyboard interrupt received")
break
except Exception as e:
self.logger.error(f"Unexpected error in main loop: {e}", exc_info=True)
time.sleep(60) # Wait before retry
# Cleanup
self.logger.info("Shutting down...")
self.position_manager.save_positions()
self.logger.info("Shutdown complete")
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Live Trading Bot for Regime Reversion Strategy"
)
parser.add_argument(
"--max-position",
type=float,
default=None,
help="Maximum position size in USDT"
)
parser.add_argument(
"--leverage",
type=int,
default=None,
help="Trading leverage (1-125)"
)
parser.add_argument(
"--interval",
type=int,
default=None,
help="Trading cycle interval in seconds"
)
parser.add_argument(
"--live",
action="store_true",
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
)
return parser.parse_args()
def main():
"""Main entry point."""
args = parse_args()
# Load configuration
okx_config, trading_config, path_config = get_config()
# Apply command line overrides
if args.max_position is not None:
trading_config.max_position_usdt = args.max_position
if args.leverage is not None:
trading_config.leverage = args.leverage
if args.interval is not None:
trading_config.sleep_seconds = args.interval
if args.live:
okx_config.demo_mode = False
# Setup logging
logger = setup_logging(path_config.logs_dir)
try:
# Create and run bot
bot = LiveTradingBot(okx_config, trading_config, path_config)
bot.run()
except ValueError as e:
logger.error(f"Configuration error: {e}")
sys.exit(1)
except Exception as e:
logger.error(f"Fatal error: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,145 @@
# Multi-Pair Divergence Live Trading
This module implements live trading for the Multi-Pair Divergence Selection Strategy on OKX perpetual futures.
## Overview
The strategy scans 10 cryptocurrency pairs for spread divergence opportunities:
1. **Pair Universe**: Top 10 assets by market cap (BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT)
2. **Spread Z-Score**: Identifies when pairs are divergent from their historical mean
3. **Universal ML Model**: Predicts probability of successful mean reversion
4. **Dynamic Selection**: Trades the pair with highest divergence score
## Prerequisites
Before running live trading, you must train the model via backtesting:
```bash
uv run python scripts/run_multi_pair_backtest.py
```
This creates `data/multi_pair_model.pkl` which the live trading bot requires.
## Setup
### 1. API Keys
Same as single-pair trading. Set in `.env`:
```env
OKX_API_KEY=your_api_key
OKX_SECRET=your_secret
OKX_PASSWORD=your_passphrase
OKX_DEMO_MODE=true # Use demo for testing
```
### 2. Dependencies
All dependencies are in `pyproject.toml`. No additional installation needed.
## Usage
### Run with Demo Account (Recommended First)
```bash
uv run python -m live_trading.multi_pair.main
```
### Command Line Options
```bash
# Custom position size
uv run python -m live_trading.multi_pair.main --max-position 500
# Custom leverage
uv run python -m live_trading.multi_pair.main --leverage 2
# Custom cycle interval (in seconds)
uv run python -m live_trading.multi_pair.main --interval 1800
# Combine options
uv run python -m live_trading.multi_pair.main --max-position 1000 --leverage 3 --interval 3600
```
### Live Trading (Use with Caution)
```bash
uv run python -m live_trading.multi_pair.main --live
```
## How It Works
### Each Trading Cycle
1. **Fetch Data**: Gets OHLCV for all 10 assets from OKX
2. **Calculate Features**: Computes Z-Score, RSI, volatility for all 45 pair combinations
3. **Score Pairs**: Uses ML model to rank pairs by divergence score (|Z| x probability)
4. **Check Exits**: If holding, check mean reversion or SL/TP
5. **Enter Best**: If no position, enter the highest-scoring divergent pair
### Entry Conditions
- |Z-Score| > 1.0 (spread diverged from mean)
- ML probability > 0.5 (model predicts successful reversion)
- Funding rate filter passes (avoid crowded trades)
### Exit Conditions
- Mean reversion: |Z-Score| returns to ~0
- Stop-loss: ATR-based (default ~6%)
- Take-profit: ATR-based (default ~5%)
## Strategy Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `z_entry_threshold` | 1.0 | Enter when \|Z-Score\| > threshold |
| `z_exit_threshold` | 0.0 | Exit when Z reverts to mean |
| `z_window` | 24 | Rolling window for Z-Score (hours) |
| `prob_threshold` | 0.5 | ML probability threshold for entry |
| `funding_threshold` | 0.0005 | Funding rate filter (0.05%) |
| `sl_atr_multiplier` | 10.0 | Stop-loss as ATR multiple |
| `tp_atr_multiplier` | 8.0 | Take-profit as ATR multiple |
## Files
### Input
- `data/multi_pair_model.pkl` - Pre-trained ML model (required)
### Output
- `logs/multi_pair_live.log` - Trading logs
- `live_trading/multi_pair_positions.json` - Position persistence
- `live_trading/multi_pair_trade_log.csv` - Trade history
## Architecture
```
live_trading/multi_pair/
__init__.py # Module exports
config.py # Configuration classes
data_feed.py # Multi-asset OHLCV fetcher
strategy.py # ML scoring and signal generation
main.py # Bot orchestrator
README.md # This file
```
## Differences from Single-Pair
| Aspect | Single-Pair | Multi-Pair |
|--------|-------------|------------|
| Assets | ETH only (BTC context) | 10 assets, 45 pairs |
| Model | ETH-specific | Universal across pairs |
| Selection | Fixed pair | Dynamic best pair |
| Stops | Fixed 6%/5% | ATR-based dynamic |
## Risk Warning
This is experimental trading software. Use at your own risk:
- Always start with demo trading
- Never risk more than you can afford to lose
- Monitor the bot regularly
- The model was trained on historical data and may not predict future performance

View File

@@ -0,0 +1,11 @@
"""Multi-Pair Divergence Live Trading Module."""
from .config import MultiPairLiveConfig, get_multi_pair_config
from .data_feed import MultiPairDataFeed
from .strategy import LiveMultiPairStrategy
__all__ = [
"MultiPairLiveConfig",
"get_multi_pair_config",
"MultiPairDataFeed",
"LiveMultiPairStrategy",
]

View File

@@ -0,0 +1,145 @@
"""
Configuration for Multi-Pair Live Trading.
Extends the base live trading config with multi-pair specific settings.
"""
import os
from pathlib import Path
from dataclasses import dataclass, field
from dotenv import load_dotenv
load_dotenv()
@dataclass
class OKXConfig:
"""OKX API configuration."""
api_key: str = field(default_factory=lambda: "")
secret: str = field(default_factory=lambda: "")
password: str = field(default_factory=lambda: "")
demo_mode: bool = field(default_factory=lambda: True)
def __post_init__(self):
"""Load credentials based on demo mode setting."""
self.demo_mode = os.getenv("OKX_DEMO_MODE", "true").lower() in ("true", "1", "yes")
if self.demo_mode:
self.api_key = os.getenv("OKX_DEMO_API_KEY", os.getenv("OKX_API_KEY", ""))
self.secret = os.getenv("OKX_DEMO_SECRET", os.getenv("OKX_SECRET", ""))
self.password = os.getenv("OKX_DEMO_PASSWORD", os.getenv("OKX_PASSWORD", ""))
else:
self.api_key = os.getenv("OKX_API_KEY", "")
self.secret = os.getenv("OKX_SECRET", "")
self.password = os.getenv("OKX_PASSWORD", "")
def validate(self) -> None:
"""Validate that required credentials are present."""
mode = "demo" if self.demo_mode else "live"
if not self.api_key:
raise ValueError(f"OKX API key not set for {mode} mode")
if not self.secret:
raise ValueError(f"OKX secret not set for {mode} mode")
if not self.password:
raise ValueError(f"OKX password not set for {mode} mode")
@dataclass
class MultiPairLiveConfig:
"""
Configuration for multi-pair live trading.
Combines trading parameters, strategy settings, and risk management.
"""
# Asset Universe (top 10 by market cap perpetuals)
assets: list[str] = field(default_factory=lambda: [
"BTC/USDT:USDT", "ETH/USDT:USDT", "SOL/USDT:USDT", "XRP/USDT:USDT",
"BNB/USDT:USDT", "DOGE/USDT:USDT", "ADA/USDT:USDT", "AVAX/USDT:USDT",
"LINK/USDT:USDT", "DOT/USDT:USDT"
])
# Timeframe
timeframe: str = "1h"
candles_to_fetch: int = 500 # Enough for feature calculation
# Z-Score Thresholds
z_window: int = 24
z_entry_threshold: float = 1.0
z_exit_threshold: float = 0.0 # Exit at mean reversion
# ML Thresholds
prob_threshold: float = 0.5
# Position sizing
max_position_usdt: float = -1.0 # If <= 0, use all available funds
min_position_usdt: float = 10.0
leverage: int = 1
margin_mode: str = "cross"
max_concurrent_positions: int = 1 # Trade one pair at a time
# Risk Management - ATR-Based Stops
atr_period: int = 14
sl_atr_multiplier: float = 10.0
tp_atr_multiplier: float = 8.0
# Fallback fixed percentages
base_sl_pct: float = 0.06
base_tp_pct: float = 0.05
# ATR bounds
min_sl_pct: float = 0.02
max_sl_pct: float = 0.10
min_tp_pct: float = 0.02
max_tp_pct: float = 0.15
# Funding Rate Filter
funding_threshold: float = 0.0005 # 0.05%
# Trade Management
min_hold_bars: int = 0
cooldown_bars: int = 0
# Execution
sleep_seconds: int = 3600 # Run every hour
slippage_pct: float = 0.001
def get_asset_short_name(self, symbol: str) -> str:
"""Convert symbol to short name (e.g., BTC/USDT:USDT -> btc)."""
return symbol.split("/")[0].lower()
def get_pair_count(self) -> int:
"""Calculate number of unique pairs from asset list."""
n = len(self.assets)
return n * (n - 1) // 2
@dataclass
class PathConfig:
"""File paths configuration."""
base_dir: Path = field(
default_factory=lambda: Path(__file__).parent.parent.parent
)
data_dir: Path = field(default=None)
logs_dir: Path = field(default=None)
model_path: Path = field(default=None)
positions_file: Path = field(default=None)
trade_log_file: Path = field(default=None)
def __post_init__(self):
self.data_dir = self.base_dir / "data"
self.logs_dir = self.base_dir / "logs"
# Use the same model as backtesting
self.model_path = self.base_dir / "data" / "multi_pair_model.pkl"
self.positions_file = self.base_dir / "live_trading" / "multi_pair_positions.json"
self.trade_log_file = self.base_dir / "live_trading" / "multi_pair_trade_log.csv"
# Ensure directories exist
self.data_dir.mkdir(parents=True, exist_ok=True)
self.logs_dir.mkdir(parents=True, exist_ok=True)
def get_multi_pair_config() -> tuple[OKXConfig, MultiPairLiveConfig, PathConfig]:
"""Get all configuration objects for multi-pair trading."""
okx = OKXConfig()
trading = MultiPairLiveConfig()
paths = PathConfig()
return okx, trading, paths

View File

@@ -0,0 +1,336 @@
"""
Multi-Pair Data Feed for Live Trading.
Fetches real-time OHLCV and funding data for all assets in the universe.
"""
import logging
from itertools import combinations
import pandas as pd
import numpy as np
import ta
from live_trading.okx_client import OKXClient
from .config import MultiPairLiveConfig, PathConfig
logger = logging.getLogger(__name__)
class TradingPair:
"""
Represents a tradeable pair for spread analysis.
Attributes:
base_asset: First asset symbol (e.g., ETH/USDT:USDT)
quote_asset: Second asset symbol (e.g., BTC/USDT:USDT)
pair_id: Unique identifier
"""
def __init__(self, base_asset: str, quote_asset: str):
self.base_asset = base_asset
self.quote_asset = quote_asset
self.pair_id = f"{base_asset}__{quote_asset}"
@property
def name(self) -> str:
"""Human-readable pair name."""
base = self.base_asset.split("/")[0]
quote = self.quote_asset.split("/")[0]
return f"{base}/{quote}"
def __hash__(self):
return hash(self.pair_id)
def __eq__(self, other):
if not isinstance(other, TradingPair):
return False
return self.pair_id == other.pair_id
class MultiPairDataFeed:
"""
Real-time data feed for multi-pair strategy.
Fetches OHLCV data for all assets and calculates spread features
for all pair combinations.
"""
def __init__(
self,
okx_client: OKXClient,
config: MultiPairLiveConfig,
path_config: PathConfig
):
self.client = okx_client
self.config = config
self.paths = path_config
# Cache for asset data
self._asset_data: dict[str, pd.DataFrame] = {}
self._funding_rates: dict[str, float] = {}
self._pairs: list[TradingPair] = []
# Generate pairs
self._generate_pairs()
def _generate_pairs(self) -> None:
"""Generate all unique pairs from asset universe."""
self._pairs = []
for base, quote in combinations(self.config.assets, 2):
pair = TradingPair(base_asset=base, quote_asset=quote)
self._pairs.append(pair)
logger.info("Generated %d pairs from %d assets",
len(self._pairs), len(self.config.assets))
@property
def pairs(self) -> list[TradingPair]:
"""Get list of trading pairs."""
return self._pairs
def fetch_all_ohlcv(self) -> dict[str, pd.DataFrame]:
"""
Fetch OHLCV data for all assets.
Returns:
Dictionary mapping symbol to OHLCV DataFrame
"""
self._asset_data = {}
for symbol in self.config.assets:
try:
ohlcv = self.client.fetch_ohlcv(
symbol,
self.config.timeframe,
self.config.candles_to_fetch
)
df = self._ohlcv_to_dataframe(ohlcv)
if len(df) >= 200:
self._asset_data[symbol] = df
logger.debug("Fetched %s: %d candles", symbol, len(df))
else:
logger.warning("Skipping %s: insufficient data (%d)",
symbol, len(df))
except Exception as e:
logger.error("Error fetching %s: %s", symbol, e)
logger.info("Fetched data for %d/%d assets",
len(self._asset_data), len(self.config.assets))
return self._asset_data
def _ohlcv_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
"""Convert OHLCV list to DataFrame."""
df = pd.DataFrame(
ohlcv,
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df.set_index('timestamp', inplace=True)
return df
def fetch_all_funding_rates(self) -> dict[str, float]:
"""
Fetch current funding rates for all assets.
Returns:
Dictionary mapping symbol to funding rate
"""
self._funding_rates = {}
for symbol in self.config.assets:
try:
rate = self.client.get_funding_rate(symbol)
self._funding_rates[symbol] = rate
except Exception as e:
logger.warning("Could not get funding for %s: %s", symbol, e)
self._funding_rates[symbol] = 0.0
return self._funding_rates
def calculate_pair_features(
self,
pair: TradingPair
) -> pd.DataFrame | None:
"""
Calculate features for a single pair.
Args:
pair: Trading pair
Returns:
DataFrame with features, or None if insufficient data
"""
base = pair.base_asset
quote = pair.quote_asset
if base not in self._asset_data or quote not in self._asset_data:
return None
df_base = self._asset_data[base]
df_quote = self._asset_data[quote]
# Align indices
common_idx = df_base.index.intersection(df_quote.index)
if len(common_idx) < 200:
return None
df_a = df_base.loc[common_idx]
df_b = df_quote.loc[common_idx]
# Calculate spread (base / quote)
spread = df_a['close'] / df_b['close']
# Z-Score
z_window = self.config.z_window
rolling_mean = spread.rolling(window=z_window).mean()
rolling_std = spread.rolling(window=z_window).std()
z_score = (spread - rolling_mean) / rolling_std
# Spread Technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
spread_change_1h = spread.pct_change(periods=1)
# Volume Analysis
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
# Volatility
ret_a = df_a['close'].pct_change()
ret_b = df_b['close'].pct_change()
vol_a = ret_a.rolling(window=z_window).std()
vol_b = ret_b.rolling(window=z_window).std()
vol_spread_ratio = vol_a / (vol_b + 1e-10)
# Realized Volatility
realized_vol_a = ret_a.rolling(window=24).std()
realized_vol_b = ret_b.rolling(window=24).std()
# ATR (Average True Range)
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
tr_a = pd.concat([
high_a - low_a,
(high_a - close_a.shift(1)).abs(),
(low_a - close_a.shift(1)).abs()
], axis=1).max(axis=1)
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
atr_pct_a = atr_a / close_a
# Build feature DataFrame
features = pd.DataFrame(index=common_idx)
features['pair_id'] = pair.pair_id
features['base_asset'] = base
features['quote_asset'] = quote
# Price data
features['spread'] = spread
features['base_close'] = df_a['close']
features['quote_close'] = df_b['close']
features['base_volume'] = df_a['volume']
# Core Features
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['spread_change_1h'] = spread_change_1h
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio_rel
features['vol_diff_ratio'] = vol_spread_ratio
# Volatility
features['realized_vol_base'] = realized_vol_a
features['realized_vol_quote'] = realized_vol_b
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
# ATR
features['atr_base'] = atr_a
features['atr_pct_base'] = atr_pct_a
# Pair encoding
assets = self.config.assets
features['base_idx'] = assets.index(base) if base in assets else -1
features['quote_idx'] = assets.index(quote) if quote in assets else -1
# Funding rates
base_funding = self._funding_rates.get(base, 0.0)
quote_funding = self._funding_rates.get(quote, 0.0)
features['base_funding'] = base_funding
features['quote_funding'] = quote_funding
features['funding_diff'] = base_funding - quote_funding
features['funding_avg'] = (base_funding + quote_funding) / 2
# Drop NaN rows in core features
core_cols = [
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
'realized_vol_base', 'atr_base', 'atr_pct_base'
]
features = features.dropna(subset=core_cols)
return features
def calculate_all_pair_features(self) -> dict[str, pd.DataFrame]:
"""
Calculate features for all pairs.
Returns:
Dictionary mapping pair_id to feature DataFrame
"""
all_features = {}
for pair in self._pairs:
features = self.calculate_pair_features(pair)
if features is not None and len(features) > 0:
all_features[pair.pair_id] = features
logger.info("Calculated features for %d/%d pairs",
len(all_features), len(self._pairs))
return all_features
def get_latest_data(self) -> dict[str, pd.DataFrame] | None:
"""
Fetch and process latest market data for all pairs.
Returns:
Dictionary of pair features or None on error
"""
try:
# Fetch OHLCV for all assets
self.fetch_all_ohlcv()
if len(self._asset_data) < 2:
logger.warning("Insufficient assets fetched")
return None
# Fetch funding rates
self.fetch_all_funding_rates()
# Calculate features for all pairs
pair_features = self.calculate_all_pair_features()
if not pair_features:
logger.warning("No pair features calculated")
return None
logger.info("Processed %d pairs with valid features", len(pair_features))
return pair_features
except Exception as e:
logger.error("Error fetching market data: %s", e, exc_info=True)
return None
def get_pair_by_id(self, pair_id: str) -> TradingPair | None:
"""Get pair object by ID."""
for pair in self._pairs:
if pair.pair_id == pair_id:
return pair
return None
def get_current_price(self, symbol: str) -> float | None:
"""Get current price for a symbol."""
if symbol in self._asset_data:
return self._asset_data[symbol]['close'].iloc[-1]
return None

View File

@@ -0,0 +1,609 @@
#!/usr/bin/env python3
"""
Multi-Pair Divergence Live Trading Bot.
Trades the top 10 cryptocurrency pairs based on spread divergence
using a universal ML model for signal generation.
Usage:
# Run with demo account (default)
uv run python -m live_trading.multi_pair.main
# Run with specific settings
uv run python -m live_trading.multi_pair.main --max-position 500 --leverage 2
"""
import argparse
import logging
import signal
import sys
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from live_trading.okx_client import OKXClient
from live_trading.position_manager import PositionManager
from live_trading.multi_pair.config import (
OKXConfig, MultiPairLiveConfig, PathConfig, get_multi_pair_config
)
from live_trading.multi_pair.data_feed import MultiPairDataFeed, TradingPair
from live_trading.multi_pair.strategy import LiveMultiPairStrategy
def setup_logging(log_dir: Path) -> logging.Logger:
"""Configure logging for the trading bot."""
log_file = log_dir / "multi_pair_live.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
force=True
)
return logging.getLogger(__name__)
@dataclass
class PositionState:
"""Track current position state for multi-pair."""
pair: TradingPair | None = None
pair_id: str | None = None
direction: str | None = None
entry_price: float = 0.0
size: float = 0.0
stop_loss: float = 0.0
take_profit: float = 0.0
entry_time: datetime | None = None
class MultiPairLiveTradingBot:
"""
Main trading bot for multi-pair divergence strategy.
Coordinates data fetching, pair scoring, and order execution.
"""
def __init__(
self,
okx_config: OKXConfig,
trading_config: MultiPairLiveConfig,
path_config: PathConfig
):
self.okx_config = okx_config
self.trading_config = trading_config
self.path_config = path_config
self.logger = logging.getLogger(__name__)
self.running = True
# Initialize components
self.logger.info("Initializing multi-pair trading bot...")
# Create OKX client with adapted config
self._adapted_trading_config = self._adapt_config_for_okx_client()
self.okx_client = OKXClient(okx_config, self._adapted_trading_config)
# Initialize data feed
self.data_feed = MultiPairDataFeed(
self.okx_client, trading_config, path_config
)
# Initialize position manager (reuse from single-pair)
self.position_manager = PositionManager(
self.okx_client, self._adapted_trading_config, path_config
)
# Initialize strategy
self.strategy = LiveMultiPairStrategy(trading_config, path_config)
# Current position state
self.position = PositionState()
# Register signal handlers
signal.signal(signal.SIGINT, self._handle_shutdown)
signal.signal(signal.SIGTERM, self._handle_shutdown)
self._print_startup_banner()
# Sync with exchange positions on startup
self._sync_position_from_exchange()
def _adapt_config_for_okx_client(self):
"""Create config compatible with OKXClient."""
# OKXClient expects specific attributes
@dataclass
class AdaptedConfig:
eth_symbol: str = "ETH/USDT:USDT"
btc_symbol: str = "BTC/USDT:USDT"
timeframe: str = "1h"
candles_to_fetch: int = 500
max_position_usdt: float = -1.0
min_position_usdt: float = 10.0
leverage: int = 1
margin_mode: str = "cross"
stop_loss_pct: float = 0.06
take_profit_pct: float = 0.05
max_concurrent_positions: int = 1
z_entry_threshold: float = 1.0
z_window: int = 24
model_prob_threshold: float = 0.5
funding_threshold: float = 0.0005
sleep_seconds: int = 3600
slippage_pct: float = 0.001
adapted = AdaptedConfig()
adapted.timeframe = self.trading_config.timeframe
adapted.candles_to_fetch = self.trading_config.candles_to_fetch
adapted.max_position_usdt = self.trading_config.max_position_usdt
adapted.min_position_usdt = self.trading_config.min_position_usdt
adapted.leverage = self.trading_config.leverage
adapted.margin_mode = self.trading_config.margin_mode
adapted.max_concurrent_positions = self.trading_config.max_concurrent_positions
adapted.sleep_seconds = self.trading_config.sleep_seconds
adapted.slippage_pct = self.trading_config.slippage_pct
return adapted
def _print_startup_banner(self) -> None:
"""Print startup information."""
mode = "DEMO/SANDBOX" if self.okx_config.demo_mode else "LIVE"
print("=" * 60)
print(" Multi-Pair Divergence Strategy - Live Trading Bot")
print("=" * 60)
print(f" Mode: {mode}")
print(f" Assets: {len(self.trading_config.assets)} assets")
print(f" Pairs: {self.trading_config.get_pair_count()} pairs")
print(f" Timeframe: {self.trading_config.timeframe}")
print(f" Max Position: ${self.trading_config.max_position_usdt if self.trading_config.max_position_usdt > 0 else 'All available'}")
print(f" Leverage: {self.trading_config.leverage}x")
print(f" Z-Entry: > {self.trading_config.z_entry_threshold}")
print(f" Prob Threshold: > {self.trading_config.prob_threshold}")
print(f" Cycle Interval: {self.trading_config.sleep_seconds // 60} minutes")
print("=" * 60)
print(f" Assets: {', '.join([a.split('/')[0] for a in self.trading_config.assets])}")
print("=" * 60)
if not self.okx_config.demo_mode:
print("\n *** WARNING: LIVE TRADING MODE - REAL FUNDS AT RISK ***\n")
def _handle_shutdown(self, signum, frame) -> None:
"""Handle shutdown signals gracefully."""
self.logger.info("Shutdown signal received, stopping...")
self.running = False
def _sync_position_from_exchange(self) -> bool:
"""
Sync internal position state with exchange positions.
Checks for existing open positions on the exchange and updates
internal state to match. This prevents stacking positions when
the bot is restarted.
Returns:
True if a position was synced, False otherwise
"""
try:
positions = self.okx_client.get_positions()
if not positions:
if self.position.pair is not None:
# Position was closed externally (e.g., SL/TP hit)
self.logger.info(
"Position %s was closed externally, resetting state",
self.position.pair.name if self.position.pair else "unknown"
)
self.position = PositionState()
return False
# Check each position against our tradeable assets
our_assets = set(self.trading_config.assets)
for pos in positions:
pos_symbol = pos.get('symbol', '')
contracts = abs(float(pos.get('contracts', 0)))
if contracts == 0:
continue
# Check if this position is for one of our assets
if pos_symbol not in our_assets:
continue
# Found a position for one of our assets
side = pos.get('side', 'long')
entry_price = float(pos.get('entryPrice', 0))
unrealized_pnl = float(pos.get('unrealizedPnl', 0))
# If we already track this position, just update
if (self.position.pair is not None and
self.position.pair.base_asset == pos_symbol):
self.logger.debug(
"Position already tracked: %s %s %.2f contracts",
side, pos_symbol, contracts
)
return True
# New position found - sync it
# Find or create a TradingPair for this position
matched_pair = None
for pair in self.data_feed.pairs:
if pair.base_asset == pos_symbol:
matched_pair = pair
break
if matched_pair is None:
# Create a placeholder pair (we don't know the quote asset)
matched_pair = TradingPair(
base_asset=pos_symbol,
quote_asset="UNKNOWN"
)
# Calculate approximate SL/TP based on config defaults
sl_pct = self.trading_config.base_sl_pct
tp_pct = self.trading_config.base_tp_pct
if side == 'long':
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else:
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
self.position = PositionState(
pair=matched_pair,
pair_id=matched_pair.pair_id,
direction=side,
entry_price=entry_price,
size=contracts,
stop_loss=stop_loss,
take_profit=take_profit,
entry_time=None # Unknown for synced positions
)
self.logger.info(
"Synced existing position from exchange: %s %s %.4f @ %.4f (PnL: %.2f)",
side.upper(),
pos_symbol,
contracts,
entry_price,
unrealized_pnl
)
return True
# No matching positions found
if self.position.pair is not None:
self.logger.info(
"Position %s no longer exists on exchange, resetting state",
self.position.pair.name
)
self.position = PositionState()
return False
except Exception as e:
self.logger.error("Failed to sync position from exchange: %s", e)
return False
def run_trading_cycle(self) -> None:
"""
Execute one trading cycle.
1. Sync position state with exchange
2. Fetch latest market data for all assets
3. Calculate features for all pairs
4. Score pairs and find best opportunity
5. Check exit conditions for current position
6. Execute trades if needed
"""
cycle_start = datetime.now(timezone.utc)
self.logger.info("--- Trading Cycle Start: %s ---", cycle_start.isoformat())
try:
# 1. Sync position state with exchange (detect SL/TP closures)
self._sync_position_from_exchange()
# 2. Fetch all market data
pair_features = self.data_feed.get_latest_data()
if pair_features is None:
self.logger.warning("No market data available, skipping cycle")
return
# 2. Check exit conditions for current position
if self.position.pair is not None:
exit_signal = self.strategy.check_exit_signal(
pair_features,
self.position.pair_id
)
if exit_signal['action'] == 'exit':
self._execute_exit(exit_signal)
else:
# Check SL/TP
current_price = self.data_feed.get_current_price(
self.position.pair.base_asset
)
if current_price:
sl_tp_exit = self._check_sl_tp(current_price)
if sl_tp_exit:
self._execute_exit({'reason': sl_tp_exit})
# 3. Generate entry signal if no position
if self.position.pair is None:
entry_signal = self.strategy.generate_signal(
pair_features,
self.data_feed.pairs
)
if entry_signal['action'] == 'entry':
self._execute_entry(entry_signal)
# 4. Log status
if self.position.pair:
self.logger.info(
"Position: %s %s, entry=%.4f, current PnL check pending",
self.position.direction,
self.position.pair.name,
self.position.entry_price
)
else:
self.logger.info("No open position")
except Exception as e:
self.logger.error("Trading cycle error: %s", e, exc_info=True)
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
self.logger.info("--- Cycle completed in %.1fs ---", cycle_duration)
def _check_sl_tp(self, current_price: float) -> str | None:
"""Check stop-loss and take-profit levels."""
if self.position.direction == 'long':
if current_price <= self.position.stop_loss:
return f"stop_loss ({current_price:.4f} <= {self.position.stop_loss:.4f})"
if current_price >= self.position.take_profit:
return f"take_profit ({current_price:.4f} >= {self.position.take_profit:.4f})"
else: # short
if current_price >= self.position.stop_loss:
return f"stop_loss ({current_price:.4f} >= {self.position.stop_loss:.4f})"
if current_price <= self.position.take_profit:
return f"take_profit ({current_price:.4f} <= {self.position.take_profit:.4f})"
return None
def _execute_entry(self, signal: dict) -> None:
"""Execute entry trade."""
pair = signal['pair']
symbol = pair.base_asset # Trade the base asset
direction = signal['direction']
self.logger.info(
"Entry signal: %s %s (z=%.2f, p=%.2f, score=%.3f)",
direction.upper(),
pair.name,
signal['z_score'],
signal['probability'],
signal['divergence_score']
)
# Get account balance
try:
balance = self.okx_client.get_balance()
available_usdt = balance['free']
except Exception as e:
self.logger.error("Could not get balance: %s", e)
return
# Calculate position size
size_usdt = self.strategy.calculate_position_size(
signal['divergence_score'],
available_usdt
)
if size_usdt <= 0:
self.logger.info("Position size too small, skipping entry")
return
current_price = signal['base_price']
size_asset = size_usdt / current_price
# Calculate SL/TP
stop_loss, take_profit = self.strategy.calculate_sl_tp(
current_price,
direction,
signal['atr'],
signal['atr_pct']
)
self.logger.info(
"Executing %s entry: %.6f %s @ %.4f ($%.2f), SL=%.4f, TP=%.4f",
direction.upper(),
size_asset,
symbol.split('/')[0],
current_price,
size_usdt,
stop_loss,
take_profit
)
try:
# Place market order
order_side = "buy" if direction == "long" else "sell"
order = self.okx_client.place_market_order(symbol, order_side, size_asset)
filled_price = order.get('average') or order.get('price') or current_price
filled_amount = order.get('filled') or order.get('amount') or size_asset
if filled_price is None or filled_price == 0:
filled_price = current_price
if filled_amount is None or filled_amount == 0:
filled_amount = size_asset
# Recalculate SL/TP with filled price
stop_loss, take_profit = self.strategy.calculate_sl_tp(
filled_price, direction, signal['atr'], signal['atr_pct']
)
# Update position state
self.position = PositionState(
pair=pair,
pair_id=pair.pair_id,
direction=direction,
entry_price=filled_price,
size=filled_amount,
stop_loss=stop_loss,
take_profit=take_profit,
entry_time=datetime.now(timezone.utc)
)
self.logger.info(
"Position opened: %s %s %.6f @ %.4f",
direction.upper(),
pair.name,
filled_amount,
filled_price
)
# Try to set SL/TP on exchange
try:
self.okx_client.set_stop_loss_take_profit(
symbol, direction, filled_amount, stop_loss, take_profit
)
except Exception as e:
self.logger.warning("Could not set SL/TP on exchange: %s", e)
except Exception as e:
self.logger.error("Order execution failed: %s", e, exc_info=True)
def _execute_exit(self, signal: dict) -> None:
"""Execute exit trade."""
if self.position.pair is None:
return
symbol = self.position.pair.base_asset
reason = signal.get('reason', 'unknown')
self.logger.info(
"Exit signal: %s %s, reason: %s",
self.position.direction,
self.position.pair.name,
reason
)
try:
# Close position on exchange
self.okx_client.close_position(symbol)
self.logger.info(
"Position closed: %s %s",
self.position.direction,
self.position.pair.name
)
# Reset position state
self.position = PositionState()
except Exception as e:
self.logger.error("Exit execution failed: %s", e, exc_info=True)
def run(self) -> None:
"""Main trading loop."""
self.logger.info("Starting multi-pair trading loop...")
while self.running:
try:
self.run_trading_cycle()
if self.running:
sleep_seconds = self.trading_config.sleep_seconds
minutes = sleep_seconds // 60
self.logger.info("Sleeping for %d minutes...", minutes)
for _ in range(sleep_seconds):
if not self.running:
break
time.sleep(1)
except KeyboardInterrupt:
self.logger.info("Keyboard interrupt received")
break
except Exception as e:
self.logger.error("Unexpected error in main loop: %s", e, exc_info=True)
time.sleep(60)
self.logger.info("Shutting down...")
self.logger.info("Shutdown complete")
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Multi-Pair Divergence Live Trading Bot"
)
parser.add_argument(
"--max-position",
type=float,
default=None,
help="Maximum position size in USDT"
)
parser.add_argument(
"--leverage",
type=int,
default=None,
help="Trading leverage (1-125)"
)
parser.add_argument(
"--interval",
type=int,
default=None,
help="Trading cycle interval in seconds"
)
parser.add_argument(
"--live",
action="store_true",
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
)
return parser.parse_args()
def main():
"""Main entry point."""
args = parse_args()
# Load configuration
okx_config, trading_config, path_config = get_multi_pair_config()
# Apply command line overrides
if args.max_position is not None:
trading_config.max_position_usdt = args.max_position
if args.leverage is not None:
trading_config.leverage = args.leverage
if args.interval is not None:
trading_config.sleep_seconds = args.interval
if args.live:
okx_config.demo_mode = False
# Setup logging
logger = setup_logging(path_config.logs_dir)
try:
# Validate config
okx_config.validate()
# Create and run bot
bot = MultiPairLiveTradingBot(okx_config, trading_config, path_config)
bot.run()
except ValueError as e:
logger.error("Configuration error: %s", e)
sys.exit(1)
except Exception as e:
logger.error("Fatal error: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,396 @@
"""
Live Multi-Pair Divergence Strategy.
Scores all pairs and selects the best divergence opportunity for trading.
Uses the pre-trained universal ML model from backtesting.
"""
import logging
import pickle
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
# Opt-in to future pandas behavior to silence FutureWarning on fillna
pd.set_option('future.no_silent_downcasting', True)
from .config import MultiPairLiveConfig, PathConfig
from .data_feed import TradingPair
logger = logging.getLogger(__name__)
@dataclass
class DivergenceSignal:
"""
Signal for a divergent pair.
Attributes:
pair: Trading pair
z_score: Current Z-Score of the spread
probability: ML model probability of profitable reversion
divergence_score: Combined score (|z_score| * probability)
direction: 'long' or 'short' (relative to base asset)
base_price: Current price of base asset
quote_price: Current price of quote asset
atr: Average True Range in price units
atr_pct: ATR as percentage of price
"""
pair: TradingPair
z_score: float
probability: float
divergence_score: float
direction: str
base_price: float
quote_price: float
atr: float
atr_pct: float
base_funding: float = 0.0
class LiveMultiPairStrategy:
"""
Live trading implementation of multi-pair divergence strategy.
Scores all pairs using the universal ML model and selects
the best opportunity for mean-reversion trading.
"""
def __init__(
self,
config: MultiPairLiveConfig,
path_config: PathConfig
):
self.config = config
self.paths = path_config
self.model: RandomForestClassifier | None = None
self.feature_cols: list[str] | None = None
self._load_model()
def _load_model(self) -> None:
"""Load pre-trained model from backtesting."""
if self.paths.model_path.exists():
try:
with open(self.paths.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
logger.info("Loaded model from %s", self.paths.model_path)
except Exception as e:
logger.error("Could not load model: %s", e)
raise ValueError(
f"Multi-pair model not found at {self.paths.model_path}. "
"Run the backtest first to train the model."
)
else:
raise ValueError(
f"Multi-pair model not found at {self.paths.model_path}. "
"Run the backtest first to train the model."
)
def score_pairs(
self,
pair_features: dict[str, pd.DataFrame],
pairs: list[TradingPair]
) -> list[DivergenceSignal]:
"""
Score all pairs and return ranked signals.
Args:
pair_features: Feature DataFrames by pair_id
pairs: List of TradingPair objects
Returns:
List of DivergenceSignal sorted by score (descending)
"""
if self.model is None:
logger.warning("Model not loaded")
return []
signals = []
pair_map = {p.pair_id: p for p in pairs}
for pair_id, features in pair_features.items():
if pair_id not in pair_map:
continue
pair = pair_map[pair_id]
# Get latest features
if len(features) == 0:
continue
latest = features.iloc[-1]
z_score = latest['z_score']
# Skip if Z-score below threshold
if abs(z_score) < self.config.z_entry_threshold:
continue
# Prepare features for prediction
# Handle missing feature columns gracefully
available_cols = [c for c in self.feature_cols if c in latest.index]
missing_cols = [c for c in self.feature_cols if c not in latest.index]
if missing_cols:
logger.debug("Missing feature columns: %s", missing_cols)
feature_row = latest[available_cols].fillna(0)
feature_row = feature_row.replace([np.inf, -np.inf], 0)
# Create full feature vector with zeros for missing
X_dict = {c: 0 for c in self.feature_cols}
for col in available_cols:
X_dict[col] = feature_row[col]
X = pd.DataFrame([X_dict])
# Predict probability
prob = self.model.predict_proba(X)[0, 1]
# Skip if probability below threshold
if prob < self.config.prob_threshold:
continue
# Apply funding rate filter
base_funding = latest.get('base_funding', 0) or 0
funding_thresh = self.config.funding_threshold
if z_score > 0: # Short signal
if base_funding < -funding_thresh:
logger.debug(
"Skipping %s short: funding too negative (%.4f)",
pair.name, base_funding
)
continue
else: # Long signal
if base_funding > funding_thresh:
logger.debug(
"Skipping %s long: funding too positive (%.4f)",
pair.name, base_funding
)
continue
# Calculate divergence score
divergence_score = abs(z_score) * prob
# Determine direction
direction = 'short' if z_score > 0 else 'long'
signal = DivergenceSignal(
pair=pair,
z_score=z_score,
probability=prob,
divergence_score=divergence_score,
direction=direction,
base_price=latest['base_close'],
quote_price=latest['quote_close'],
atr=latest.get('atr_base', 0),
atr_pct=latest.get('atr_pct_base', 0.02),
base_funding=base_funding
)
signals.append(signal)
# Sort by divergence score (highest first)
signals.sort(key=lambda s: s.divergence_score, reverse=True)
if signals:
logger.info(
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f, dir=%s)",
len(signals),
signals[0].pair.name,
signals[0].divergence_score,
signals[0].z_score,
signals[0].probability,
signals[0].direction
)
else:
logger.info("No pairs meet entry criteria")
return signals
def select_best_pair(
self,
signals: list[DivergenceSignal]
) -> DivergenceSignal | None:
"""
Select the best pair from scored signals.
Args:
signals: List of DivergenceSignal (pre-sorted by score)
Returns:
Best signal or None if no valid candidates
"""
if not signals:
return None
return signals[0]
def generate_signal(
self,
pair_features: dict[str, pd.DataFrame],
pairs: list[TradingPair]
) -> dict:
"""
Generate trading signal from latest features.
Args:
pair_features: Feature DataFrames by pair_id
pairs: List of TradingPair objects
Returns:
Signal dictionary with action, pair, direction, etc.
"""
# Score all pairs
signals = self.score_pairs(pair_features, pairs)
# Select best
best = self.select_best_pair(signals)
if best is None:
return {
'action': 'hold',
'reason': 'no_valid_signals'
}
return {
'action': 'entry',
'pair': best.pair,
'pair_id': best.pair.pair_id,
'direction': best.direction,
'z_score': best.z_score,
'probability': best.probability,
'divergence_score': best.divergence_score,
'base_price': best.base_price,
'quote_price': best.quote_price,
'atr': best.atr,
'atr_pct': best.atr_pct,
'base_funding': best.base_funding,
'reason': f'{best.pair.name} z={best.z_score:.2f} p={best.probability:.2f}'
}
def check_exit_signal(
self,
pair_features: dict[str, pd.DataFrame],
current_pair_id: str
) -> dict:
"""
Check if current position should be exited.
Exit conditions:
1. Z-Score reverted to mean (|Z| < threshold)
Args:
pair_features: Feature DataFrames by pair_id
current_pair_id: Current position's pair ID
Returns:
Signal dictionary with action and reason
"""
if current_pair_id not in pair_features:
return {
'action': 'exit',
'reason': 'pair_data_missing'
}
features = pair_features[current_pair_id]
if len(features) == 0:
return {
'action': 'exit',
'reason': 'no_data'
}
latest = features.iloc[-1]
z_score = latest['z_score']
# Check mean reversion
if abs(z_score) < self.config.z_exit_threshold:
return {
'action': 'exit',
'reason': f'mean_reversion (z={z_score:.2f})'
}
return {
'action': 'hold',
'z_score': z_score,
'reason': f'holding (z={z_score:.2f})'
}
def calculate_sl_tp(
self,
entry_price: float,
direction: str,
atr: float,
atr_pct: float
) -> tuple[float, float]:
"""
Calculate ATR-based dynamic stop-loss and take-profit prices.
Args:
entry_price: Entry price
direction: 'long' or 'short'
atr: ATR in price units
atr_pct: ATR as percentage of price
Returns:
Tuple of (stop_loss_price, take_profit_price)
"""
if atr > 0 and atr_pct > 0:
sl_distance = atr * self.config.sl_atr_multiplier
tp_distance = atr * self.config.tp_atr_multiplier
sl_pct = sl_distance / entry_price
tp_pct = tp_distance / entry_price
else:
sl_pct = self.config.base_sl_pct
tp_pct = self.config.base_tp_pct
# Apply bounds
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
if direction == 'long':
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else:
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
return stop_loss, take_profit
def calculate_position_size(
self,
divergence_score: float,
available_usdt: float
) -> float:
"""
Calculate position size based on divergence score.
Args:
divergence_score: Combined score (|z| * prob)
available_usdt: Available USDT balance
Returns:
Position size in USDT
"""
if self.config.max_position_usdt <= 0:
base_size = available_usdt
else:
base_size = min(available_usdt, self.config.max_position_usdt)
# Scale by divergence (1.0 at 0.5 score, up to 2.0 at 1.0+ score)
base_threshold = 0.5
if divergence_score <= base_threshold:
scale = 1.0
else:
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
scale = min(scale, 2.0)
size = base_size * scale
if size < self.config.min_position_usdt:
return 0.0
return min(size, available_usdt * 0.95)

338
live_trading/okx_client.py Normal file
View File

@@ -0,0 +1,338 @@
"""
OKX Exchange Client for Live Trading.
Handles connection to OKX API, order execution, and account management.
Supports demo/sandbox mode for paper trading.
"""
import logging
from typing import Optional
from datetime import datetime, timezone
import ccxt
from .config import OKXConfig, TradingConfig
logger = logging.getLogger(__name__)
class OKXClient:
"""
OKX Exchange client wrapper using CCXT.
Supports both live and demo (sandbox) trading modes.
Demo mode uses OKX's official sandbox environment.
"""
def __init__(self, okx_config: OKXConfig, trading_config: TradingConfig):
self.okx_config = okx_config
self.trading_config = trading_config
self.exchange: Optional[ccxt.okx] = None
self._setup_exchange()
def _setup_exchange(self) -> None:
"""Initialize CCXT OKX exchange instance."""
self.okx_config.validate()
config = {
'apiKey': self.okx_config.api_key,
'secret': self.okx_config.secret,
'password': self.okx_config.password,
'sandbox': self.okx_config.demo_mode,
'options': {
'defaultType': 'swap', # Perpetual futures
},
'timeout': 30000,
'enableRateLimit': True,
}
self.exchange = ccxt.okx(config)
mode_str = "DEMO/SANDBOX" if self.okx_config.demo_mode else "LIVE"
logger.info(f"OKX Exchange initialized in {mode_str} mode")
# Configure trading settings
self._configure_trading_settings()
def _configure_trading_settings(self) -> None:
"""Configure leverage and margin mode."""
symbol = self.trading_config.eth_symbol
leverage = self.trading_config.leverage
margin_mode = self.trading_config.margin_mode
try:
# Set position mode to one-way (net) first
self.exchange.set_position_mode(False) # False = one-way mode
logger.info("Position mode set to One-Way (Net)")
except Exception as e:
# Position mode might already be set
logger.debug(f"Position mode setting: {e}")
try:
# Set margin mode with leverage parameter (required by OKX)
self.exchange.set_margin_mode(
margin_mode,
symbol,
params={'lever': leverage}
)
logger.info(
f"Margin mode set to {margin_mode} with {leverage}x leverage "
f"for {symbol}"
)
except Exception as e:
logger.warning(f"Could not set margin mode: {e}")
# Try setting leverage separately
try:
self.exchange.set_leverage(leverage, symbol)
logger.info(f"Leverage set to {leverage}x for {symbol}")
except Exception as e2:
logger.warning(f"Could not set leverage: {e2}")
def fetch_ohlcv(
self,
symbol: str,
timeframe: str = "1h",
limit: int = 500
) -> list:
"""
Fetch OHLCV candle data.
Args:
symbol: Trading pair symbol (e.g., "ETH/USDT:USDT")
timeframe: Candle timeframe (e.g., "1h")
limit: Number of candles to fetch
Returns:
List of OHLCV data
"""
return self.exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
def get_balance(self) -> dict:
"""
Get account balance.
Returns:
Balance dictionary with 'total' and 'free' USDT amounts
"""
balance = self.exchange.fetch_balance()
return {
'total': balance.get('USDT', {}).get('total', 0),
'free': balance.get('USDT', {}).get('free', 0),
}
def get_positions(self) -> list:
"""
Get open positions.
Returns:
List of open position dictionaries
"""
positions = self.exchange.fetch_positions()
return [p for p in positions if float(p.get('contracts', 0)) != 0]
def get_position(self, symbol: str) -> Optional[dict]:
"""
Get position for a specific symbol.
Args:
symbol: Trading pair symbol
Returns:
Position dictionary or None if no position
"""
positions = self.get_positions()
for pos in positions:
if pos.get('symbol') == symbol:
return pos
return None
def place_market_order(
self,
symbol: str,
side: str,
amount: float,
reduce_only: bool = False
) -> dict:
"""
Place a market order.
Args:
symbol: Trading pair symbol
side: "buy" or "sell"
amount: Order amount in base currency
reduce_only: If True, only reduce existing position
Returns:
Order result dictionary
"""
params = {
'tdMode': self.trading_config.margin_mode,
}
if reduce_only:
params['reduceOnly'] = True
order = self.exchange.create_market_order(
symbol, side, amount, params=params
)
logger.info(
f"Market {side.upper()} order placed: {amount} {symbol} "
f"@ market price, order_id={order['id']}"
)
return order
def place_limit_order(
self,
symbol: str,
side: str,
amount: float,
price: float,
reduce_only: bool = False
) -> dict:
"""
Place a limit order.
Args:
symbol: Trading pair symbol
side: "buy" or "sell"
amount: Order amount in base currency
price: Limit price
reduce_only: If True, only reduce existing position
Returns:
Order result dictionary
"""
params = {
'tdMode': self.trading_config.margin_mode,
}
if reduce_only:
params['reduceOnly'] = True
order = self.exchange.create_limit_order(
symbol, side, amount, price, params=params
)
logger.info(
f"Limit {side.upper()} order placed: {amount} {symbol} "
f"@ {price}, order_id={order['id']}"
)
return order
def set_stop_loss_take_profit(
self,
symbol: str,
side: str,
amount: float,
stop_loss_price: float,
take_profit_price: float
) -> tuple:
"""
Set stop-loss and take-profit orders.
Args:
symbol: Trading pair symbol
side: Position side ("long" or "short")
amount: Position size
stop_loss_price: Stop-loss trigger price
take_profit_price: Take-profit trigger price
Returns:
Tuple of (sl_order, tp_order)
"""
# For long position: SL sells, TP sells
# For short position: SL buys, TP buys
close_side = "sell" if side == "long" else "buy"
# Stop-loss order
sl_params = {
'tdMode': self.trading_config.margin_mode,
'reduceOnly': True,
'stopLossPrice': stop_loss_price,
}
try:
sl_order = self.exchange.create_order(
symbol, 'market', close_side, amount,
params={
'tdMode': self.trading_config.margin_mode,
'reduceOnly': True,
'slTriggerPx': str(stop_loss_price),
'slOrdPx': '-1', # Market price
}
)
logger.info(f"Stop-loss set at {stop_loss_price}")
except Exception as e:
logger.warning(f"Could not set stop-loss: {e}")
sl_order = None
# Take-profit order
try:
tp_order = self.exchange.create_order(
symbol, 'market', close_side, amount,
params={
'tdMode': self.trading_config.margin_mode,
'reduceOnly': True,
'tpTriggerPx': str(take_profit_price),
'tpOrdPx': '-1', # Market price
}
)
logger.info(f"Take-profit set at {take_profit_price}")
except Exception as e:
logger.warning(f"Could not set take-profit: {e}")
tp_order = None
return sl_order, tp_order
def close_position(self, symbol: str) -> Optional[dict]:
"""
Close an open position.
Args:
symbol: Trading pair symbol
Returns:
Order result or None if no position
"""
position = self.get_position(symbol)
if not position:
logger.info(f"No open position for {symbol}")
return None
contracts = abs(float(position.get('contracts', 0)))
if contracts == 0:
return None
side = position.get('side', 'long')
close_side = "sell" if side == "long" else "buy"
order = self.place_market_order(
symbol, close_side, contracts, reduce_only=True
)
logger.info(f"Position closed for {symbol}")
return order
def get_ticker(self, symbol: str) -> dict:
"""
Get current ticker/price for a symbol.
Args:
symbol: Trading pair symbol
Returns:
Ticker dictionary with 'last', 'bid', 'ask' prices
"""
return self.exchange.fetch_ticker(symbol)
def get_funding_rate(self, symbol: str) -> float:
"""
Get current funding rate for a perpetual symbol.
Args:
symbol: Trading pair symbol
Returns:
Current funding rate as decimal
"""
try:
funding = self.exchange.fetch_funding_rate(symbol)
return float(funding.get('fundingRate', 0))
except Exception as e:
logger.warning(f"Could not fetch funding rate: {e}")
return 0.0

View File

@@ -0,0 +1,369 @@
"""
Position Manager for Live Trading.
Tracks open positions, manages risk, and handles SL/TP logic.
"""
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from dataclasses import dataclass, field, asdict
from .okx_client import OKXClient
from .config import TradingConfig, PathConfig
logger = logging.getLogger(__name__)
@dataclass
class Position:
"""Represents an open trading position."""
trade_id: str
symbol: str
side: str # "long" or "short"
entry_price: float
entry_time: str # ISO format
size: float # Amount in base currency (e.g., ETH)
size_usdt: float # Notional value in USDT
stop_loss_price: float
take_profit_price: float
current_price: float = 0.0
unrealized_pnl: float = 0.0
unrealized_pnl_pct: float = 0.0
order_id: str = "" # Entry order ID from exchange
def update_pnl(self, current_price: float) -> None:
"""Update unrealized PnL based on current price."""
self.current_price = current_price
if self.side == "long":
self.unrealized_pnl = (current_price - self.entry_price) * self.size
self.unrealized_pnl_pct = (current_price / self.entry_price - 1) * 100
else: # short
self.unrealized_pnl = (self.entry_price - current_price) * self.size
self.unrealized_pnl_pct = (1 - current_price / self.entry_price) * 100
def should_stop_loss(self, current_price: float) -> bool:
"""Check if stop-loss should trigger."""
if self.side == "long":
return current_price <= self.stop_loss_price
return current_price >= self.stop_loss_price
def should_take_profit(self, current_price: float) -> bool:
"""Check if take-profit should trigger."""
if self.side == "long":
return current_price >= self.take_profit_price
return current_price <= self.take_profit_price
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'Position':
"""Create Position from dictionary."""
return cls(**data)
class PositionManager:
"""
Manages trading positions with persistence.
Tracks open positions, enforces risk limits, and handles
position lifecycle (open, update, close).
"""
def __init__(
self,
okx_client: OKXClient,
trading_config: TradingConfig,
path_config: PathConfig
):
self.client = okx_client
self.config = trading_config
self.paths = path_config
self.positions: dict[str, Position] = {}
self.trade_log: list[dict] = []
self._load_positions()
def _load_positions(self) -> None:
"""Load positions from file."""
if self.paths.positions_file.exists():
try:
with open(self.paths.positions_file, 'r') as f:
data = json.load(f)
for trade_id, pos_data in data.items():
self.positions[trade_id] = Position.from_dict(pos_data)
logger.info(f"Loaded {len(self.positions)} positions from file")
except Exception as e:
logger.warning(f"Could not load positions: {e}")
def save_positions(self) -> None:
"""Save positions to file."""
try:
data = {
trade_id: pos.to_dict()
for trade_id, pos in self.positions.items()
}
with open(self.paths.positions_file, 'w') as f:
json.dump(data, f, indent=2)
logger.debug(f"Saved {len(self.positions)} positions")
except Exception as e:
logger.error(f"Could not save positions: {e}")
def can_open_position(self) -> bool:
"""Check if we can open a new position."""
return len(self.positions) < self.config.max_concurrent_positions
def get_position_for_symbol(self, symbol: str) -> Optional[Position]:
"""Get position for a specific symbol."""
for pos in self.positions.values():
if pos.symbol == symbol:
return pos
return None
def open_position(
self,
symbol: str,
side: str,
entry_price: float,
size: float,
stop_loss_price: float,
take_profit_price: float,
order_id: str = ""
) -> Optional[Position]:
"""
Open a new position.
Args:
symbol: Trading pair symbol
side: "long" or "short"
entry_price: Entry price
size: Position size in base currency
stop_loss_price: Stop-loss price
take_profit_price: Take-profit price
order_id: Entry order ID from exchange
Returns:
Position object or None if failed
"""
if not self.can_open_position():
logger.warning("Cannot open position: max concurrent positions reached")
return None
# Check if already have position for this symbol
existing = self.get_position_for_symbol(symbol)
if existing:
logger.warning(f"Already have position for {symbol}")
return None
# Generate trade ID
now = datetime.now(timezone.utc)
trade_id = f"{symbol}_{now.strftime('%Y%m%d_%H%M%S')}"
position = Position(
trade_id=trade_id,
symbol=symbol,
side=side,
entry_price=entry_price,
entry_time=now.isoformat(),
size=size,
size_usdt=entry_price * size,
stop_loss_price=stop_loss_price,
take_profit_price=take_profit_price,
current_price=entry_price,
order_id=order_id,
)
self.positions[trade_id] = position
self.save_positions()
logger.info(
f"Opened {side.upper()} position: {size} {symbol} @ {entry_price}, "
f"SL={stop_loss_price}, TP={take_profit_price}"
)
return position
def close_position(
self,
trade_id: str,
exit_price: float,
reason: str = "manual",
exit_order_id: str = ""
) -> Optional[dict]:
"""
Close a position and record the trade.
Args:
trade_id: Position trade ID
exit_price: Exit price
reason: Reason for closing (e.g., "stop_loss", "take_profit", "signal")
exit_order_id: Exit order ID from exchange
Returns:
Trade record dictionary
"""
if trade_id not in self.positions:
logger.warning(f"Position {trade_id} not found")
return None
position = self.positions[trade_id]
position.update_pnl(exit_price)
# Calculate final PnL
entry_time = datetime.fromisoformat(position.entry_time)
exit_time = datetime.now(timezone.utc)
hold_duration = (exit_time - entry_time).total_seconds() / 3600 # hours
trade_record = {
'trade_id': trade_id,
'symbol': position.symbol,
'side': position.side,
'entry_price': position.entry_price,
'exit_price': exit_price,
'size': position.size,
'size_usdt': position.size_usdt,
'pnl_usd': position.unrealized_pnl,
'pnl_pct': position.unrealized_pnl_pct,
'entry_time': position.entry_time,
'exit_time': exit_time.isoformat(),
'hold_duration_hours': hold_duration,
'reason': reason,
'order_id_entry': position.order_id,
'order_id_exit': exit_order_id,
}
self.trade_log.append(trade_record)
del self.positions[trade_id]
self.save_positions()
self._append_trade_log(trade_record)
logger.info(
f"Closed {position.side.upper()} position: {position.size} {position.symbol} "
f"@ {exit_price}, PnL=${position.unrealized_pnl:.2f} ({position.unrealized_pnl_pct:.2f}%), "
f"reason={reason}"
)
return trade_record
def _append_trade_log(self, trade_record: dict) -> None:
"""Append trade record to CSV log file."""
import csv
file_exists = self.paths.trade_log_file.exists()
with open(self.paths.trade_log_file, 'a', newline='') as f:
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
if not file_exists:
writer.writeheader()
writer.writerow(trade_record)
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
"""
Update all positions with current prices and check SL/TP.
Args:
current_prices: Dictionary of symbol -> current price
Returns:
List of closed trade records
"""
closed_trades = []
for trade_id in list(self.positions.keys()):
position = self.positions[trade_id]
if position.symbol not in current_prices:
continue
current_price = current_prices[position.symbol]
position.update_pnl(current_price)
# Check stop-loss
if position.should_stop_loss(current_price):
logger.warning(
f"Stop-loss triggered for {trade_id} at {current_price}"
)
# Close position on exchange
exit_order_id = ""
try:
exit_order = self.client.close_position(position.symbol)
exit_order_id = exit_order.get('id', '') if exit_order else ''
except Exception as e:
logger.error(f"Failed to close position on exchange: {e}")
record = self.close_position(trade_id, current_price, "stop_loss", exit_order_id)
if record:
closed_trades.append(record)
continue
# Check take-profit
if position.should_take_profit(current_price):
logger.info(
f"Take-profit triggered for {trade_id} at {current_price}"
)
# Close position on exchange
exit_order_id = ""
try:
exit_order = self.client.close_position(position.symbol)
exit_order_id = exit_order.get('id', '') if exit_order else ''
except Exception as e:
logger.error(f"Failed to close position on exchange: {e}")
record = self.close_position(trade_id, current_price, "take_profit", exit_order_id)
if record:
closed_trades.append(record)
self.save_positions()
return closed_trades
def sync_with_exchange(self) -> None:
"""
Sync local positions with exchange positions.
Reconciles any discrepancies between local tracking
and actual exchange positions.
"""
try:
exchange_positions = self.client.get_positions()
exchange_symbols = {p['symbol'] for p in exchange_positions}
# Check for positions we have locally but not on exchange
for trade_id in list(self.positions.keys()):
pos = self.positions[trade_id]
if pos.symbol not in exchange_symbols:
logger.warning(
f"Position {trade_id} not found on exchange, removing"
)
# Get last price and close
try:
ticker = self.client.get_ticker(pos.symbol)
exit_price = ticker['last']
except Exception:
exit_price = pos.current_price
self.close_position(trade_id, exit_price, "sync_removed")
logger.info(f"Position sync complete: {len(self.positions)} local positions")
except Exception as e:
logger.error(f"Position sync failed: {e}")
def get_portfolio_summary(self) -> dict:
"""
Get portfolio summary.
Returns:
Dictionary with portfolio statistics
"""
total_exposure = sum(p.size_usdt for p in self.positions.values())
total_unrealized_pnl = sum(p.unrealized_pnl for p in self.positions.values())
return {
'open_positions': len(self.positions),
'total_exposure_usdt': total_exposure,
'total_unrealized_pnl': total_unrealized_pnl,
'positions': [p.to_dict() for p in self.positions.values()],
}

View File

@@ -1,11 +0,0 @@
from __future__ import annotations
from pathlib import Path
import pandas as pd
def write_trade_log(trades: list[dict], path: Path) -> None:
if not trades:
return
df = pd.DataFrame(trades)
path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(path, index=False)

10
main.py Normal file
View File

@@ -0,0 +1,10 @@
"""
Lowkey Backtest CLI - VectorBT Edition
A backtesting framework supporting multiple market types (spot, perpetual)
with realistic trading simulation including leverage, funding, and shorts.
"""
from engine.cli import main
if __name__ == "__main__":
main()

View File

@@ -1,11 +0,0 @@
from __future__ import annotations
TAKER_FEE_BPS_DEFAULT = 10.0 # 0.10%
def okx_fee(fee_bps: float, notional_usd: float) -> float:
return notional_usd * (fee_bps / 1e4)
def estimate_slippage_rate(slippage_bps: float, notional_usd: float) -> float:
return notional_usd * (slippage_bps / 1e4)

View File

@@ -1,54 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import pandas as pd
@dataclass
class Perf:
total_return: float
max_drawdown: float
sharpe_ratio: float
win_rate: float
num_trades: int
final_equity: float
initial_equity: float
num_stop_losses: int
total_fees: float
total_slippage_usd: float
avg_slippage_bps: float
def compute_metrics(equity_curve: pd.Series, trades: list[dict]) -> Perf:
ret = equity_curve.pct_change().fillna(0.0)
total_return = equity_curve.iat[-1] / equity_curve.iat[0] - 1.0
cummax = equity_curve.cummax()
dd = (equity_curve / cummax - 1.0).min()
max_drawdown = dd
if ret.std(ddof=0) > 0:
sharpe = (ret.mean() / ret.std(ddof=0)) * np.sqrt(252 * 24 * 60) # minute bars -> annualized
else:
sharpe = 0.0
closes = [t for t in trades if t.get("side") == "SELL"]
wins = [t for t in closes if t.get("pnl", 0.0) > 0]
win_rate = (len(wins) / len(closes)) if closes else 0.0
fees = sum(t.get("fee", 0.0) for t in trades)
slip = sum(t.get("slippage", 0.0) for t in trades)
slippage_bps = [t.get("slippage_bps", 0.0) for t in trades if "slippage_bps" in t]
return Perf(
total_return=total_return,
max_drawdown=max_drawdown,
sharpe_ratio=sharpe,
win_rate=win_rate,
num_trades=len(closes),
final_equity=float(equity_curve.iat[-1]),
initial_equity=float(equity_curve.iat[0]),
num_stop_losses=sum(1 for t in closes if t.get("reason") == "stop"),
total_fees=fees,
total_slippage_usd=slip,
avg_slippage_bps=float(np.mean(slippage_bps)) if slippage_bps else 0.0,
)

View File

@@ -5,5 +5,29 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"ccxt>=4.5.32",
"numpy>=2.3.2",
"pandas>=2.3.1",
"ta>=0.11.0", "ta>=0.11.0",
"vectorbt>=0.28.2",
"scikit-learn>=1.6.0",
"matplotlib>=3.10.0",
"plotly>=5.24.0",
"requests>=2.32.5",
"python-dotenv>=1.2.1",
# API dependencies
"fastapi>=0.115.0",
"uvicorn[standard]>=0.34.0",
"sqlalchemy>=2.0.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
]
[tool.pytest.ini_options]
pythonpath = ["."]
markers = [
"network: marks tests as requiring network access",
] ]

View File

@@ -0,0 +1,342 @@
"""
Regime Detection Research Script with Walk-Forward Training.
Tests multiple holding horizons to find optimal parameters
without look-ahead bias.
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pandas as pd
import numpy as np
import ta
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, f1_score
from engine.data_manager import DataManager
from engine.market import MarketType
from engine.logging_config import get_logger
logger = get_logger(__name__)
# Configuration
TRAIN_RATIO = 0.7 # 70% train, 30% test
PROFIT_THRESHOLD = 0.005 # 0.5% profit target
Z_WINDOW = 24
FEE_RATE = 0.001 # 0.1% round-trip fee
def load_data():
"""Load and align BTC/ETH data."""
dm = DataManager()
df_btc = dm.load_data("okx", "BTC-USDT", "1h", MarketType.SPOT)
df_eth = dm.load_data("okx", "ETH-USDT", "1h", MarketType.SPOT)
# Filter to Oct-Dec 2025
start = pd.Timestamp("2025-10-01", tz="UTC")
end = pd.Timestamp("2025-12-31", tz="UTC")
df_btc = df_btc[(df_btc.index >= start) & (df_btc.index <= end)]
df_eth = df_eth[(df_eth.index >= start) & (df_eth.index <= end)]
# Align indices
common = df_btc.index.intersection(df_eth.index)
df_btc = df_btc.loc[common]
df_eth = df_eth.loc[common]
logger.info(f"Loaded {len(common)} aligned hourly bars")
return df_btc, df_eth
def load_cryptoquant_data():
"""Load CryptoQuant on-chain data if available."""
try:
cq_path = "data/cq_training_data.csv"
cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True)
if cq_df.index.tz is None:
cq_df.index = cq_df.index.tz_localize('UTC')
logger.info(f"Loaded CryptoQuant data: {len(cq_df)} rows")
return cq_df
except Exception as e:
logger.warning(f"CryptoQuant data not available: {e}")
return None
def calculate_features(df_btc, df_eth, cq_df=None):
"""Calculate all features for the model."""
spread = df_eth['close'] / df_btc['close']
# Z-Score
rolling_mean = spread.rolling(window=Z_WINDOW).mean()
rolling_std = spread.rolling(window=Z_WINDOW).std()
z_score = (spread - rolling_mean) / rolling_std
# Technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
spread_change_1h = spread.pct_change(periods=1)
# Volume
vol_ratio = df_eth['volume'] / df_btc['volume']
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
# Volatility
ret_btc = df_btc['close'].pct_change()
ret_eth = df_eth['close'].pct_change()
vol_btc = ret_btc.rolling(window=Z_WINDOW).std()
vol_eth = ret_eth.rolling(window=Z_WINDOW).std()
vol_spread_ratio = vol_eth / vol_btc
features = pd.DataFrame(index=spread.index)
features['spread'] = spread
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['spread_change_1h'] = spread_change_1h
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
features['vol_diff_ratio'] = vol_spread_ratio
# Add CQ features if available
if cq_df is not None:
cq_aligned = cq_df.reindex(features.index, method='ffill')
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
cq_aligned['funding_diff'] = cq_aligned['eth_funding'] - cq_aligned['btc_funding']
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
cq_aligned['inflow_ratio'] = cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
features = features.join(cq_aligned)
return features.dropna()
def calculate_targets(features, horizon):
"""Calculate target labels for a given horizon."""
spread = features['spread']
z_score = features['z_score']
# For Short (Z > 1): Did spread drop below target?
future_min = spread.rolling(window=horizon).min().shift(-horizon)
target_short = spread * (1 - PROFIT_THRESHOLD)
success_short = (z_score > 1.0) & (future_min < target_short)
# For Long (Z < -1): Did spread rise above target?
future_max = spread.rolling(window=horizon).max().shift(-horizon)
target_long = spread * (1 + PROFIT_THRESHOLD)
success_long = (z_score < -1.0) & (future_max > target_long)
targets = np.select([success_short, success_long], [1, 1], default=0)
# Create valid mask (rows with complete future data)
valid_mask = future_min.notna() & future_max.notna()
return targets, valid_mask, future_min, future_max
def calculate_mae(features, predictions, test_idx, horizon):
"""Calculate Maximum Adverse Excursion for predicted trades."""
test_features = features.loc[test_idx]
spread = test_features['spread']
z_score = test_features['z_score']
mae_values = []
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
if pred != 1:
continue
entry_spread = spread.loc[idx]
z = z_score.loc[idx]
# Get future spread values
future_idx = features.index.get_loc(idx)
future_end = min(future_idx + horizon, len(features))
future_spreads = features['spread'].iloc[future_idx:future_end]
if len(future_spreads) < 2:
continue
if z > 1.0: # Short trade
max_adverse = (future_spreads.max() - entry_spread) / entry_spread
else: # Long trade
max_adverse = (entry_spread - future_spreads.min()) / entry_spread
mae_values.append(max_adverse * 100) # As percentage
return np.mean(mae_values) if mae_values else 0.0
def calculate_net_profit(features, predictions, test_idx, horizon):
"""Calculate estimated net profit including fees."""
test_features = features.loc[test_idx]
spread = test_features['spread']
z_score = test_features['z_score']
total_pnl = 0.0
n_trades = 0
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
if pred != 1:
continue
entry_spread = spread.loc[idx]
z = z_score.loc[idx]
# Get future spread values
future_idx = features.index.get_loc(idx)
future_end = min(future_idx + horizon, len(features))
future_spreads = features['spread'].iloc[future_idx:future_end]
if len(future_spreads) < 2:
continue
# Calculate PnL based on direction
if z > 1.0: # Short trade - profit if spread drops
exit_spread = future_spreads.iloc[-1] # Exit at horizon
pnl = (entry_spread - exit_spread) / entry_spread
else: # Long trade - profit if spread rises
exit_spread = future_spreads.iloc[-1]
pnl = (exit_spread - entry_spread) / entry_spread
# Subtract fees
net_pnl = pnl - FEE_RATE
total_pnl += net_pnl
n_trades += 1
return total_pnl, n_trades
def test_horizon(features, horizon):
"""Test a single horizon with walk-forward training."""
# Calculate targets
targets, valid_mask, _, _ = calculate_targets(features, horizon)
# Walk-forward split
n_samples = len(features)
train_size = int(n_samples * TRAIN_RATIO)
train_features = features.iloc[:train_size]
test_features = features.iloc[train_size:]
train_targets = targets[:train_size]
test_targets = targets[train_size:]
train_valid = valid_mask.iloc[:train_size]
test_valid = valid_mask.iloc[train_size:]
# Prepare training data (only valid rows)
exclude = ['spread']
cols = [c for c in features.columns if c not in exclude]
X_train = train_features[cols].fillna(0).replace([np.inf, -np.inf], 0)
X_train_valid = X_train[train_valid]
y_train_valid = train_targets[train_valid]
if len(X_train_valid) < 50:
return None # Not enough training data
# Train model
model = RandomForestClassifier(
n_estimators=300, max_depth=5, min_samples_leaf=30,
class_weight={0: 1, 1: 3}, random_state=42
)
model.fit(X_train_valid, y_train_valid)
# Predict on test set
X_test = test_features[cols].fillna(0).replace([np.inf, -np.inf], 0)
predictions = model.predict(X_test)
# Only evaluate on valid test rows (those with complete future data)
test_valid_mask = test_valid.values
y_test_valid = test_targets[test_valid_mask]
pred_valid = predictions[test_valid_mask]
if len(y_test_valid) < 10:
return None
# Calculate metrics
f1 = f1_score(y_test_valid, pred_valid, zero_division=0)
# Calculate MAE and Net Profit on ALL test predictions (not just valid targets)
test_idx = test_features.index
avg_mae = calculate_mae(features, predictions, test_idx, horizon)
net_pnl, n_trades = calculate_net_profit(features, predictions, test_idx, horizon)
return {
'horizon': horizon,
'f1_score': f1,
'avg_mae': avg_mae,
'net_pnl': net_pnl,
'n_trades': n_trades,
'train_samples': len(X_train_valid),
'test_samples': len(X_test)
}
def test_horizons(features, horizons):
"""Test multiple horizons and return comparison."""
results = []
print("\n" + "=" * 80)
print("WALK-FORWARD HORIZON OPTIMIZATION")
print(f"Train Ratio: {TRAIN_RATIO*100:.0f}% | Profit Target: {PROFIT_THRESHOLD*100:.1f}% | Fee Rate: {FEE_RATE*100:.2f}%")
print("=" * 80)
for h in horizons:
result = test_horizon(features, h)
if result:
results.append(result)
print(f"Horizon {h:3d}h: F1={result['f1_score']:.3f}, "
f"MAE={result['avg_mae']:.2f}%, "
f"Net PnL={result['net_pnl']*100:.2f}%, "
f"Trades={result['n_trades']}")
return results
def main():
"""Main research function."""
# Load data
df_btc, df_eth = load_data()
cq_df = load_cryptoquant_data()
# Calculate features
features = calculate_features(df_btc, df_eth, cq_df)
logger.info(f"Calculated {len(features)} feature rows with {len(features.columns)} columns")
# Test horizons from 6h to 150h
horizons = list(range(6, 151, 6)) # 6, 12, 18, ..., 150
results = test_horizons(features, horizons)
if not results:
print("No valid results!")
return
# Find best by different metrics
results_df = pd.DataFrame(results)
print("\n" + "=" * 80)
print("BEST HORIZONS BY METRIC")
print("=" * 80)
best_f1 = results_df.loc[results_df['f1_score'].idxmax()]
print(f"Best F1 Score: {best_f1['horizon']:.0f}h (F1={best_f1['f1_score']:.3f})")
best_pnl = results_df.loc[results_df['net_pnl'].idxmax()]
print(f"Best Net PnL: {best_pnl['horizon']:.0f}h (PnL={best_pnl['net_pnl']*100:.2f}%)")
lowest_mae = results_df.loc[results_df['avg_mae'].idxmin()]
print(f"Lowest MAE: {lowest_mae['horizon']:.0f}h (MAE={lowest_mae['avg_mae']:.2f}%)")
# Save results
output_path = "research/horizon_optimization_results.csv"
results_df.to_csv(output_path, index=False)
print(f"\nResults saved to {output_path}")
return results_df
if __name__ == "__main__":
main()

BIN
research/regime_results.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 289 KiB

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
"""
Download historical data for Multi-Pair Divergence Strategy.
Downloads 1h OHLCV data for top 10 cryptocurrencies from OKX.
"""
import sys
sys.path.insert(0, '.')
from engine.data_manager import DataManager
from engine.market import MarketType
from engine.logging_config import setup_logging, get_logger
from strategies.multi_pair import MultiPairConfig
logger = get_logger(__name__)
def main():
"""Download data for all configured assets."""
setup_logging()
config = MultiPairConfig()
dm = DataManager()
logger.info("Downloading data for %d assets...", len(config.assets))
for symbol in config.assets:
logger.info("Downloading %s perpetual 1h data...", symbol)
try:
df = dm.download_data(
exchange_id=config.exchange_id,
symbol=symbol,
timeframe=config.timeframe,
market_type=MarketType.PERPETUAL
)
if df is not None:
logger.info("Downloaded %d candles for %s", len(df), symbol)
else:
logger.warning("No data downloaded for %s", symbol)
except Exception as e:
logger.error("Failed to download %s: %s", symbol, e)
logger.info("Download complete!")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python3
"""
Run Multi-Pair Divergence Strategy backtest and compare with baseline.
Compares the multi-pair strategy against the single-pair BTC/ETH regime strategy.
"""
import sys
sys.path.insert(0, '.')
from engine.backtester import Backtester
from engine.data_manager import DataManager
from engine.logging_config import setup_logging, get_logger
from engine.reporting import Reporter
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
from strategies.regime_strategy import RegimeReversionStrategy
from engine.market import MarketType
logger = get_logger(__name__)
def run_baseline():
"""Run baseline BTC/ETH regime strategy."""
logger.info("=" * 60)
logger.info("BASELINE: BTC/ETH Regime Reversion Strategy")
logger.info("=" * 60)
dm = DataManager()
bt = Backtester(dm)
strategy = RegimeReversionStrategy()
result = bt.run_strategy(
strategy,
'okx',
'ETH-USDT',
timeframe='1h',
init_cash=10000
)
logger.info("Baseline Results:")
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
logger.info(" Total Trades: %d", result.portfolio.trades.count())
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
return result
def run_multi_pair(assets: list[str] | None = None):
"""Run multi-pair divergence strategy."""
logger.info("=" * 60)
logger.info("MULTI-PAIR: Divergence Selection Strategy")
logger.info("=" * 60)
dm = DataManager()
bt = Backtester(dm)
# Use provided assets or default
if assets:
config = MultiPairConfig(assets=assets)
else:
config = MultiPairConfig()
logger.info("Configured %d assets, %d pairs", len(config.assets), config.get_pair_count())
strategy = MultiPairDivergenceStrategy(config=config)
result = bt.run_strategy(
strategy,
'okx',
'ETH-USDT', # Reference asset (not used for trading, just index alignment)
timeframe='1h',
init_cash=10000
)
logger.info("Multi-Pair Results:")
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
logger.info(" Total Trades: %d", result.portfolio.trades.count())
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
return result
def compare_results(baseline, multi_pair):
"""Compare and display results."""
logger.info("=" * 60)
logger.info("COMPARISON")
logger.info("=" * 60)
baseline_return = baseline.portfolio.total_return() * 100
multi_return = multi_pair.portfolio.total_return() * 100
improvement = multi_return - baseline_return
logger.info("Baseline Return: %.2f%%", baseline_return)
logger.info("Multi-Pair Return: %.2f%%", multi_return)
logger.info("Improvement: %.2f%% (%.1fx)",
improvement,
multi_return / baseline_return if baseline_return != 0 else 0)
baseline_trades = baseline.portfolio.trades.count()
multi_trades = multi_pair.portfolio.trades.count()
logger.info("Baseline Trades: %d", baseline_trades)
logger.info("Multi-Pair Trades: %d", multi_trades)
return {
'baseline_return': baseline_return,
'multi_pair_return': multi_return,
'improvement': improvement,
'baseline_trades': baseline_trades,
'multi_pair_trades': multi_trades
}
def main():
"""Main entry point."""
setup_logging()
# Check available assets
dm = DataManager()
available = []
for symbol in MultiPairConfig().assets:
try:
dm.load_data('okx', symbol, '1h', market_type=MarketType.PERPETUAL)
available.append(symbol)
except FileNotFoundError:
pass
if len(available) < 2:
logger.error(
"Need at least 2 assets to run multi-pair strategy. "
"Run: uv run python scripts/download_multi_pair_data.py"
)
return
logger.info("Found data for %d assets: %s", len(available), available)
# Run baseline
baseline_result = run_baseline()
# Run multi-pair
multi_result = run_multi_pair(available)
# Compare
comparison = compare_results(baseline_result, multi_result)
# Save reports
reporter = Reporter()
reporter.save_reports(multi_result, "multi_pair_divergence")
logger.info("Reports saved to backtest_logs/")
if __name__ == "__main__":
main()

80
strategies/base.py Normal file
View File

@@ -0,0 +1,80 @@
"""
Base strategy class for all trading strategies.
Strategies should inherit from BaseStrategy and implement the run() method.
"""
from abc import ABC, abstractmethod
import pandas as pd
from engine.market import MarketType
class BaseStrategy(ABC):
"""
Abstract base class for trading strategies.
Class Attributes:
default_market_type: Default market type for this strategy
default_leverage: Default leverage (only applies to perpetuals)
default_sl_stop: Default stop-loss percentage
default_tp_stop: Default take-profit percentage
default_sl_trail: Whether stop-loss is trailing by default
"""
# Market configuration defaults
default_market_type: MarketType = MarketType.SPOT
default_leverage: int = 1
# Risk management defaults (can be overridden per strategy)
default_sl_stop: float | None = None
default_tp_stop: float | None = None
default_sl_trail: bool = False
def __init__(self, **kwargs):
self.params = kwargs
@abstractmethod
def run(
self,
close: pd.Series,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Run the strategy logic.
Args:
close: Price series (can be multiple columns for grid search)
**kwargs: Additional data (high, low, open, volume) and parameters
Returns:
Tuple of 4 DataFrames/Series:
- long_entries: Boolean signals to open long positions
- long_exits: Boolean signals to close long positions
- short_entries: Boolean signals to open short positions
- short_exits: Boolean signals to close short positions
Note:
For spot markets, short signals will be ignored.
For backward compatibility, strategies can return 2-tuple (entries, exits)
which will be interpreted as long-only signals.
"""
pass
def get_indicator(self, ind_cls, *args, **kwargs):
"""Helper to run a vectorbt indicator."""
return ind_cls.run(*args, **kwargs)
@staticmethod
def create_empty_signals(reference: pd.Series | pd.DataFrame) -> pd.DataFrame:
"""
Create an empty (all False) signal DataFrame matching the reference shape.
Args:
reference: Series or DataFrame to match shape/index
Returns:
DataFrame of False values with same shape as reference
"""
if isinstance(reference, pd.DataFrame):
return pd.DataFrame(False, index=reference.index, columns=reference.columns)
return pd.Series(False, index=reference.index)

97
strategies/examples.py Normal file
View File

@@ -0,0 +1,97 @@
"""
Example trading strategies for backtesting.
These are simple strategies demonstrating the framework usage.
"""
import pandas as pd
import vectorbt as vbt
from engine.market import MarketType
from strategies.base import BaseStrategy
class RsiStrategy(BaseStrategy):
"""
RSI mean-reversion strategy.
Long entry when RSI crosses below oversold level.
Long exit when RSI crosses above overbought level.
"""
default_market_type = MarketType.SPOT
default_leverage = 1
def run(
self,
close: pd.Series,
period: int = 14,
rsi_lower: int = 30,
rsi_upper: int = 70,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Generate RSI-based trading signals.
Args:
close: Price series
period: RSI calculation period
rsi_lower: Oversold threshold (buy signal)
rsi_upper: Overbought threshold (sell signal)
Returns:
4-tuple of (long_entries, long_exits, short_entries, short_exits)
"""
# Calculate RSI
rsi = vbt.RSI.run(close, window=period)
# Long signals: buy oversold, sell overbought
long_entries = rsi.rsi_crossed_below(rsi_lower)
long_exits = rsi.rsi_crossed_above(rsi_upper)
# No short signals for this strategy (spot-focused)
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits
class MaCrossStrategy(BaseStrategy):
"""
Moving Average crossover strategy.
Long entry when fast MA crosses above slow MA.
Long exit when fast MA crosses below slow MA.
"""
default_market_type = MarketType.SPOT
default_leverage = 1
def run(
self,
close: pd.Series,
fast_window: int = 10,
slow_window: int = 20,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Generate MA crossover trading signals.
Args:
close: Price series
fast_window: Fast MA period
slow_window: Slow MA period
Returns:
4-tuple of (long_entries, long_exits, short_entries, short_exits)
"""
# Calculate Moving Averages
fast_ma = vbt.MA.run(close, window=fast_window)
slow_ma = vbt.MA.run(close, window=slow_window)
# Long signals
long_entries = fast_ma.ma_crossed_above(slow_ma)
long_exits = fast_ma.ma_crossed_below(slow_ma)
# No short signals for this strategy
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits

164
strategies/factory.py Normal file
View File

@@ -0,0 +1,164 @@
"""
Strategy factory for creating strategy instances with their parameters.
Centralizes strategy creation and parameter configuration.
"""
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from strategies.base import BaseStrategy
@dataclass
class StrategyConfig:
"""
Configuration for a strategy including default and grid parameters.
Attributes:
strategy_class: The strategy class to instantiate
default_params: Parameters for single backtest runs
grid_params: Parameters for grid search optimization
"""
strategy_class: type[BaseStrategy]
default_params: dict[str, Any] = field(default_factory=dict)
grid_params: dict[str, Any] = field(default_factory=dict)
def _build_registry() -> dict[str, StrategyConfig]:
"""
Build the strategy registry lazily to avoid circular imports.
Returns:
Dictionary mapping strategy names to their configurations
"""
# Import here to avoid circular imports
from strategies.examples import MaCrossStrategy, RsiStrategy
from strategies.supertrend import MetaSupertrendStrategy
from strategies.regime_strategy import RegimeReversionStrategy
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
return {
"rsi": StrategyConfig(
strategy_class=RsiStrategy,
default_params={
'period': 14,
'rsi_lower': 30,
'rsi_upper': 70
},
grid_params={
'period': np.arange(10, 25, 2),
'rsi_lower': [20, 30, 40],
'rsi_upper': [60, 70, 80]
}
),
"macross": StrategyConfig(
strategy_class=MaCrossStrategy,
default_params={
'fast_window': 10,
'slow_window': 20
},
grid_params={
'fast_window': np.arange(5, 20, 5),
'slow_window': np.arange(20, 60, 10)
}
),
"meta_st": StrategyConfig(
strategy_class=MetaSupertrendStrategy,
default_params={
'period1': 12, 'multiplier1': 3.0,
'period2': 10, 'multiplier2': 1.0,
'period3': 11, 'multiplier3': 2.0
},
grid_params={
'multiplier1': [2.0, 3.0, 4.0],
'period1': [10, 12, 14],
'period2': 11, 'multiplier2': 2.0,
'period3': 12, 'multiplier3': 1.0
}
),
"regime": StrategyConfig(
strategy_class=RegimeReversionStrategy,
default_params={
# Optimal from walk-forward research (research/horizon_optimization_results.csv)
'horizon': 102, # 4.25 days - best Net PnL
'z_window': 24, # 24h rolling Z-score window
'z_entry_threshold': 1.0, # Enter when |Z| > 1.0
'profit_target': 0.005, # 0.5% target for ML labels
'stop_loss': 0.06, # 6% stop loss
'take_profit': 0.05, # 5% take profit
'train_ratio': 0.7, # 70% train / 30% test
'trend_window': 0, # Disabled SMA filter
'use_funding_filter': True, # Enabled Funding filter
'funding_threshold': 0.005 # 0.005% threshold (Proven profitable)
},
grid_params={
'horizon': [84, 96, 102, 108, 120],
'z_entry_threshold': [0.8, 1.0, 1.2],
'stop_loss': [0.04, 0.06, 0.08],
'funding_threshold': [0.005, 0.01, 0.02]
}
),
"multi_pair": StrategyConfig(
strategy_class=MultiPairDivergenceStrategy,
default_params={
# Multi-pair divergence strategy uses config object
# Parameters passed here will override MultiPairConfig defaults
},
grid_params={
'z_entry_threshold': [0.8, 1.0, 1.2],
'prob_threshold': [0.4, 0.5, 0.6],
'correlation_threshold': [0.75, 0.85, 0.95]
}
)
}
# Module-level cache for the registry
_REGISTRY_CACHE: dict[str, StrategyConfig] | None = None
def get_registry() -> dict[str, StrategyConfig]:
"""Get the strategy registry, building it on first access."""
global _REGISTRY_CACHE
if _REGISTRY_CACHE is None:
_REGISTRY_CACHE = _build_registry()
return _REGISTRY_CACHE
def get_strategy_names() -> list[str]:
"""
Get list of available strategy names.
Returns:
List of strategy name strings
"""
return list(get_registry().keys())
def get_strategy(name: str, is_grid: bool = False) -> tuple[BaseStrategy, dict[str, Any]]:
"""
Create a strategy instance with appropriate parameters.
Args:
name: Strategy identifier (e.g., 'rsi', 'macross', 'meta_st')
is_grid: If True, return grid search parameters
Returns:
Tuple of (strategy instance, parameters dict)
Raises:
KeyError: If strategy name is not found in registry
"""
registry = get_registry()
if name not in registry:
available = ", ".join(registry.keys())
raise KeyError(f"Unknown strategy '{name}'. Available: {available}")
config = registry[name]
strategy = config.strategy_class()
params = config.grid_params if is_grid else config.default_params
return strategy, params.copy()

View File

@@ -0,0 +1,24 @@
"""
Multi-Pair Divergence Selection Strategy.
Extends regime detection to multiple cryptocurrency pairs and dynamically
selects the most divergent pair for trading.
"""
from .config import MultiPairConfig
from .pair_scanner import PairScanner, TradingPair
from .correlation import CorrelationFilter
from .feature_engine import MultiPairFeatureEngine
from .divergence_scorer import DivergenceScorer
from .strategy import MultiPairDivergenceStrategy
from .funding import FundingRateFetcher
__all__ = [
"MultiPairConfig",
"PairScanner",
"TradingPair",
"CorrelationFilter",
"MultiPairFeatureEngine",
"DivergenceScorer",
"MultiPairDivergenceStrategy",
"FundingRateFetcher",
]

View File

@@ -0,0 +1,88 @@
"""
Configuration for Multi-Pair Divergence Strategy.
"""
from dataclasses import dataclass, field
@dataclass
class MultiPairConfig:
"""
Configuration parameters for multi-pair divergence strategy.
Attributes:
assets: List of asset symbols to analyze (top 10 by market cap)
z_window: Rolling window for Z-Score calculation (hours)
z_entry_threshold: Minimum |Z-Score| to consider for entry
prob_threshold: Minimum ML probability to consider for entry
correlation_threshold: Max correlation to allow between pairs
correlation_window: Rolling window for correlation (hours)
atr_period: ATR lookback period for dynamic stops
sl_atr_multiplier: Stop-loss as multiple of ATR
tp_atr_multiplier: Take-profit as multiple of ATR
train_ratio: Walk-forward train/test split ratio
horizon: Look-ahead horizon for target calculation (hours)
profit_target: Minimum profit threshold for target labels
funding_threshold: Funding rate threshold for filtering
"""
# Asset Universe
assets: list[str] = field(default_factory=lambda: [
"BTC-USDT", "ETH-USDT", "SOL-USDT", "XRP-USDT", "BNB-USDT",
"DOGE-USDT", "ADA-USDT", "AVAX-USDT", "LINK-USDT", "DOT-USDT"
])
# Z-Score Thresholds
z_window: int = 24
z_entry_threshold: float = 1.0
# ML Thresholds
prob_threshold: float = 0.5
train_ratio: float = 0.7
horizon: int = 102
profit_target: float = 0.005
# Correlation Filtering
correlation_threshold: float = 0.85
correlation_window: int = 168 # 7 days in hours
# Risk Management - ATR-Based Stops
# SL/TP are calculated as multiples of ATR
# Mean ATR for crypto is ~0.6% per hour, so:
# - 10x ATR = ~6% SL (matches previous fixed 6%)
# - 8x ATR = ~5% TP (matches previous fixed 5%)
atr_period: int = 14 # ATR lookback period (hours for 1h timeframe)
sl_atr_multiplier: float = 10.0 # Stop-loss = entry +/- (ATR * multiplier)
tp_atr_multiplier: float = 8.0 # Take-profit = entry +/- (ATR * multiplier)
# Fallback fixed percentages (used if ATR is unavailable)
base_sl_pct: float = 0.06
base_tp_pct: float = 0.05
# ATR bounds to prevent extreme stops
min_sl_pct: float = 0.02 # Minimum 2% stop-loss
max_sl_pct: float = 0.10 # Maximum 10% stop-loss
min_tp_pct: float = 0.02 # Minimum 2% take-profit
max_tp_pct: float = 0.15 # Maximum 15% take-profit
volatility_window: int = 24
# Funding Rate Filter
# OKX funding rates are typically 0.0001 (0.01%) per 8h
# Extreme funding is > 0.0005 (0.05%) which indicates crowded trade
funding_threshold: float = 0.0005 # 0.05% - filter extreme funding
# Trade Management
# Note: Setting min_hold_bars=0 and z_exit_threshold=0 gives best results
# The mean-reversion exit at Z=0 is the primary profit driver
min_hold_bars: int = 0 # Disabled - let mean reversion drive exits
switch_threshold: float = 999.0 # Disabled - don't switch mid-trade
cooldown_bars: int = 0 # Disabled - enter when signal appears
z_exit_threshold: float = 0.0 # Exit at Z=0 (mean reversion complete)
# Exchange
exchange_id: str = "okx"
timeframe: str = "1h"
def get_pair_count(self) -> int:
"""Calculate number of unique pairs from asset list."""
n = len(self.assets)
return n * (n - 1) // 2

View File

@@ -0,0 +1,173 @@
"""
Correlation Filter for Multi-Pair Divergence Strategy.
Calculates rolling correlation matrix and filters pairs
to avoid highly correlated positions.
"""
import pandas as pd
import numpy as np
from engine.logging_config import get_logger
from .config import MultiPairConfig
from .pair_scanner import TradingPair
logger = get_logger(__name__)
class CorrelationFilter:
"""
Calculates and filters based on asset correlations.
Uses rolling correlation of returns to identify assets
moving together, avoiding redundant positions.
"""
def __init__(self, config: MultiPairConfig):
self.config = config
self._correlation_matrix: pd.DataFrame | None = None
self._last_update_idx: int = -1
def calculate_correlation_matrix(
self,
price_data: dict[str, pd.Series],
current_idx: int | None = None
) -> pd.DataFrame:
"""
Calculate rolling correlation matrix between all assets.
Args:
price_data: Dictionary mapping asset symbols to price series
current_idx: Current bar index (for caching)
Returns:
Correlation matrix DataFrame
"""
# Use cached if recent
if (
current_idx is not None
and self._correlation_matrix is not None
and current_idx - self._last_update_idx < 24 # Update every 24 bars
):
return self._correlation_matrix
# Calculate returns
returns = {}
for symbol, prices in price_data.items():
returns[symbol] = prices.pct_change()
returns_df = pd.DataFrame(returns)
# Rolling correlation
window = self.config.correlation_window
# Get latest correlation (last row of rolling correlation)
if len(returns_df) >= window:
rolling_corr = returns_df.rolling(window=window).corr()
# Extract last timestamp correlation matrix
last_idx = returns_df.index[-1]
corr_matrix = rolling_corr.loc[last_idx]
else:
# Fallback to full-period correlation if not enough data
corr_matrix = returns_df.corr()
self._correlation_matrix = corr_matrix
if current_idx is not None:
self._last_update_idx = current_idx
return corr_matrix
def filter_pairs(
self,
pairs: list[TradingPair],
current_position_asset: str | None,
price_data: dict[str, pd.Series],
current_idx: int | None = None
) -> list[TradingPair]:
"""
Filter pairs based on correlation with current position.
If we have an open position in an asset, exclude pairs where
either asset is highly correlated with the held asset.
Args:
pairs: List of candidate pairs
current_position_asset: Currently held asset (or None)
price_data: Dictionary of price series by symbol
current_idx: Current bar index for caching
Returns:
Filtered list of pairs
"""
if current_position_asset is None:
return pairs
corr_matrix = self.calculate_correlation_matrix(price_data, current_idx)
threshold = self.config.correlation_threshold
filtered = []
for pair in pairs:
# Check correlation of base and quote with held asset
base_corr = self._get_correlation(
corr_matrix, pair.base_asset, current_position_asset
)
quote_corr = self._get_correlation(
corr_matrix, pair.quote_asset, current_position_asset
)
# Filter if either asset highly correlated with position
if abs(base_corr) > threshold or abs(quote_corr) > threshold:
logger.debug(
"Filtered %s: base_corr=%.2f, quote_corr=%.2f (held: %s)",
pair.name, base_corr, quote_corr, current_position_asset
)
continue
filtered.append(pair)
if len(filtered) < len(pairs):
logger.info(
"Correlation filter: %d/%d pairs remaining (held: %s)",
len(filtered), len(pairs), current_position_asset
)
return filtered
def _get_correlation(
self,
corr_matrix: pd.DataFrame,
asset1: str,
asset2: str
) -> float:
"""
Get correlation between two assets from matrix.
Args:
corr_matrix: Correlation matrix
asset1: First asset symbol
asset2: Second asset symbol
Returns:
Correlation coefficient (-1 to 1), or 0 if not found
"""
if asset1 == asset2:
return 1.0
try:
return corr_matrix.loc[asset1, asset2]
except KeyError:
return 0.0
def get_correlation_report(
self,
price_data: dict[str, pd.Series]
) -> pd.DataFrame:
"""
Generate a readable correlation report.
Args:
price_data: Dictionary of price series
Returns:
Correlation matrix as DataFrame
"""
return self.calculate_correlation_matrix(price_data)

View File

@@ -0,0 +1,311 @@
"""
Divergence Scorer for Multi-Pair Strategy.
Ranks pairs by divergence score and selects the best candidate.
"""
from dataclasses import dataclass
from typing import Optional
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import pickle
from pathlib import Path
from engine.logging_config import get_logger
from .config import MultiPairConfig
from .pair_scanner import TradingPair
logger = get_logger(__name__)
@dataclass
class DivergenceSignal:
"""
Signal for a divergent pair.
Attributes:
pair: Trading pair
z_score: Current Z-Score of the spread
probability: ML model probability of profitable reversion
divergence_score: Combined score (|z_score| * probability)
direction: 'long' or 'short' (relative to base asset)
base_price: Current price of base asset
quote_price: Current price of quote asset
atr: Average True Range in price units
atr_pct: ATR as percentage of price
"""
pair: TradingPair
z_score: float
probability: float
divergence_score: float
direction: str
base_price: float
quote_price: float
atr: float
atr_pct: float
timestamp: pd.Timestamp
class DivergenceScorer:
"""
Scores and ranks pairs by divergence potential.
Uses ML model predictions combined with Z-Score magnitude
to identify the most promising mean-reversion opportunity.
"""
def __init__(self, config: MultiPairConfig, model_path: str = "data/multi_pair_model.pkl"):
self.config = config
self.model_path = Path(model_path)
self.model: RandomForestClassifier | None = None
self.feature_cols: list[str] | None = None
self._load_model()
def _load_model(self) -> None:
"""Load pre-trained model if available."""
if self.model_path.exists():
try:
with open(self.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
logger.info("Loaded model from %s", self.model_path)
except Exception as e:
logger.warning("Could not load model: %s", e)
def save_model(self) -> None:
"""Save trained model."""
if self.model is None:
return
self.model_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.model_path, 'wb') as f:
pickle.dump({
'model': self.model,
'feature_cols': self.feature_cols,
}, f)
logger.info("Saved model to %s", self.model_path)
def train_model(
self,
combined_features: pd.DataFrame,
pair_features: dict[str, pd.DataFrame]
) -> None:
"""
Train universal model on all pairs.
Args:
combined_features: Combined feature DataFrame from all pairs
pair_features: Individual pair feature DataFrames (for target calculation)
"""
logger.info("Training universal model on %d samples...", len(combined_features))
z_thresh = self.config.z_entry_threshold
horizon = self.config.horizon
profit_target = self.config.profit_target
# Calculate targets for each pair
all_targets = []
all_features = []
for pair_id, features in pair_features.items():
if len(features) < horizon + 50:
continue
spread = features['spread']
z_score = features['z_score']
# Future price movements
future_min = spread.rolling(window=horizon).min().shift(-horizon)
future_max = spread.rolling(window=horizon).max().shift(-horizon)
# Target labels
target_short = spread * (1 - profit_target)
target_long = spread * (1 + profit_target)
success_short = (z_score > z_thresh) & (future_min < target_short)
success_long = (z_score < -z_thresh) & (future_max > target_long)
targets = np.select([success_short, success_long], [1, 1], default=0)
# Valid mask (exclude rows without complete future data)
valid_mask = future_min.notna() & future_max.notna()
# Collect valid samples
valid_features = features[valid_mask]
valid_targets = targets[valid_mask.values]
if len(valid_features) > 0:
all_features.append(valid_features)
all_targets.extend(valid_targets)
if not all_features:
logger.warning("No valid training samples")
return
# Combine all training data
X_df = pd.concat(all_features, ignore_index=True)
y = np.array(all_targets)
# Get feature columns
exclude_cols = [
'pair_id', 'base_asset', 'quote_asset',
'spread', 'base_close', 'quote_close', 'base_volume'
]
self.feature_cols = [c for c in X_df.columns if c not in exclude_cols]
# Prepare features
X = X_df[self.feature_cols].fillna(0)
X = X.replace([np.inf, -np.inf], 0)
# Train model
self.model = RandomForestClassifier(
n_estimators=300,
max_depth=5,
min_samples_leaf=30,
class_weight={0: 1, 1: 3},
random_state=42
)
self.model.fit(X, y)
logger.info(
"Model trained on %d samples, %d features, %.1f%% positive class",
len(X), len(self.feature_cols), y.mean() * 100
)
self.save_model()
def score_pairs(
self,
pair_features: dict[str, pd.DataFrame],
pairs: list[TradingPair],
timestamp: pd.Timestamp | None = None
) -> list[DivergenceSignal]:
"""
Score all pairs and return ranked signals.
Args:
pair_features: Feature DataFrames by pair_id
pairs: List of TradingPair objects
timestamp: Current timestamp for feature extraction
Returns:
List of DivergenceSignal sorted by score (descending)
"""
if self.model is None:
logger.warning("Model not trained, returning empty signals")
return []
signals = []
pair_map = {p.pair_id: p for p in pairs}
for pair_id, features in pair_features.items():
if pair_id not in pair_map:
continue
pair = pair_map[pair_id]
# Get latest features
if timestamp is not None:
valid = features[features.index <= timestamp]
if len(valid) == 0:
continue
latest = valid.iloc[-1]
ts = valid.index[-1]
else:
latest = features.iloc[-1]
ts = features.index[-1]
z_score = latest['z_score']
# Skip if Z-score below threshold
if abs(z_score) < self.config.z_entry_threshold:
continue
# Prepare features for prediction
feature_row = latest[self.feature_cols].fillna(0).infer_objects(copy=False)
feature_row = feature_row.replace([np.inf, -np.inf], 0)
X = pd.DataFrame([feature_row.values], columns=self.feature_cols)
# Predict probability
prob = self.model.predict_proba(X)[0, 1]
# Skip if probability below threshold
if prob < self.config.prob_threshold:
continue
# Apply funding rate filter
# Block trades where funding opposes our direction
base_funding = latest.get('base_funding', 0) or 0
funding_thresh = self.config.funding_threshold
if z_score > 0: # Short signal
# High negative funding = shorts are paying -> skip
if base_funding < -funding_thresh:
logger.debug(
"Skipping %s short: funding too negative (%.4f)",
pair.name, base_funding
)
continue
else: # Long signal
# High positive funding = longs are paying -> skip
if base_funding > funding_thresh:
logger.debug(
"Skipping %s long: funding too positive (%.4f)",
pair.name, base_funding
)
continue
# Calculate divergence score
divergence_score = abs(z_score) * prob
# Determine direction
# Z > 0: Spread high (base expensive vs quote) -> Short base
# Z < 0: Spread low (base cheap vs quote) -> Long base
direction = 'short' if z_score > 0 else 'long'
signal = DivergenceSignal(
pair=pair,
z_score=z_score,
probability=prob,
divergence_score=divergence_score,
direction=direction,
base_price=latest['base_close'],
quote_price=latest['quote_close'],
atr=latest.get('atr_base', 0),
atr_pct=latest.get('atr_pct_base', 0.02),
timestamp=ts
)
signals.append(signal)
# Sort by divergence score (highest first)
signals.sort(key=lambda s: s.divergence_score, reverse=True)
if signals:
logger.debug(
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f)",
len(signals),
signals[0].pair.name,
signals[0].divergence_score,
signals[0].z_score,
signals[0].probability
)
return signals
def select_best_pair(
self,
signals: list[DivergenceSignal]
) -> DivergenceSignal | None:
"""
Select the best pair from scored signals.
Args:
signals: List of DivergenceSignal (pre-sorted by score)
Returns:
Best signal or None if no valid candidates
"""
if not signals:
return None
return signals[0]

View File

@@ -0,0 +1,433 @@
"""
Feature Engineering for Multi-Pair Divergence Strategy.
Calculates features for all pairs in the universe, including
spread technicals, volatility, and on-chain data.
"""
import pandas as pd
import numpy as np
import ta
from engine.logging_config import get_logger
from engine.data_manager import DataManager
from engine.market import MarketType
from .config import MultiPairConfig
from .pair_scanner import TradingPair
from .funding import FundingRateFetcher
logger = get_logger(__name__)
class MultiPairFeatureEngine:
"""
Calculates features for multiple trading pairs.
Generates consistent feature sets across all pairs for
the universal ML model.
"""
def __init__(self, config: MultiPairConfig):
self.config = config
self.dm = DataManager()
self.funding_fetcher = FundingRateFetcher()
self._funding_data: pd.DataFrame | None = None
def load_all_assets(
self,
start_date: str | None = None,
end_date: str | None = None
) -> dict[str, pd.DataFrame]:
"""
Load OHLCV data for all assets in the universe.
Args:
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
Returns:
Dictionary mapping symbol to OHLCV DataFrame
"""
data = {}
market_type = MarketType.PERPETUAL
for symbol in self.config.assets:
try:
df = self.dm.load_data(
self.config.exchange_id,
symbol,
self.config.timeframe,
market_type
)
# Apply date filters
if start_date:
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
if end_date:
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
if len(df) >= 200: # Minimum data requirement
data[symbol] = df
logger.debug("Loaded %s: %d bars", symbol, len(df))
else:
logger.warning(
"Skipping %s: insufficient data (%d bars)",
symbol, len(df)
)
except FileNotFoundError:
logger.warning("Data not found for %s", symbol)
except Exception as e:
logger.error("Error loading %s: %s", symbol, e)
logger.info("Loaded %d/%d assets", len(data), len(self.config.assets))
return data
def load_funding_data(
self,
start_date: str | None = None,
end_date: str | None = None,
use_cache: bool = True
) -> pd.DataFrame:
"""
Load funding rate data for all assets.
Args:
start_date: Start date filter
end_date: End date filter
use_cache: Whether to use cached data
Returns:
DataFrame with funding rates for all assets
"""
self._funding_data = self.funding_fetcher.get_funding_data(
self.config.assets,
start_date=start_date,
end_date=end_date,
use_cache=use_cache
)
if self._funding_data is not None and not self._funding_data.empty:
logger.info(
"Loaded funding data: %d rows, %d assets",
len(self._funding_data),
len(self._funding_data.columns)
)
else:
logger.warning("No funding data available")
return self._funding_data
def calculate_pair_features(
self,
pair: TradingPair,
asset_data: dict[str, pd.DataFrame],
on_chain_data: pd.DataFrame | None = None
) -> pd.DataFrame | None:
"""
Calculate features for a single pair.
Args:
pair: Trading pair
asset_data: Dictionary of OHLCV DataFrames by symbol
on_chain_data: Optional on-chain data (funding, inflows)
Returns:
DataFrame with features, or None if insufficient data
"""
base = pair.base_asset
quote = pair.quote_asset
if base not in asset_data or quote not in asset_data:
return None
df_base = asset_data[base]
df_quote = asset_data[quote]
# Align indices
common_idx = df_base.index.intersection(df_quote.index)
if len(common_idx) < 200:
logger.debug("Pair %s: insufficient aligned data", pair.name)
return None
df_a = df_base.loc[common_idx]
df_b = df_quote.loc[common_idx]
# Calculate spread (base / quote)
spread = df_a['close'] / df_b['close']
# Z-Score
z_window = self.config.z_window
rolling_mean = spread.rolling(window=z_window).mean()
rolling_std = spread.rolling(window=z_window).std()
z_score = (spread - rolling_mean) / rolling_std
# Spread Technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
spread_change_1h = spread.pct_change(periods=1)
# Volume Analysis
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
# Volatility
ret_a = df_a['close'].pct_change()
ret_b = df_b['close'].pct_change()
vol_a = ret_a.rolling(window=z_window).std()
vol_b = ret_b.rolling(window=z_window).std()
vol_spread_ratio = vol_a / (vol_b + 1e-10)
# Realized Volatility (for dynamic SL/TP)
realized_vol_a = ret_a.rolling(window=self.config.volatility_window).std()
realized_vol_b = ret_b.rolling(window=self.config.volatility_window).std()
# ATR (Average True Range) for dynamic stops
# ATR = average of max(high-low, |high-prev_close|, |low-prev_close|)
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
high_b, low_b, close_b = df_b['high'], df_b['low'], df_b['close']
# True Range for base asset
tr_a = pd.concat([
high_a - low_a,
(high_a - close_a.shift(1)).abs(),
(low_a - close_a.shift(1)).abs()
], axis=1).max(axis=1)
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
# True Range for quote asset
tr_b = pd.concat([
high_b - low_b,
(high_b - close_b.shift(1)).abs(),
(low_b - close_b.shift(1)).abs()
], axis=1).max(axis=1)
atr_b = tr_b.rolling(window=self.config.atr_period).mean()
# ATR as percentage of price (normalized)
atr_pct_a = atr_a / close_a
atr_pct_b = atr_b / close_b
# Build feature DataFrame
features = pd.DataFrame(index=common_idx)
features['pair_id'] = pair.pair_id
features['base_asset'] = base
features['quote_asset'] = quote
# Price data (for reference, not features)
features['spread'] = spread
features['base_close'] = df_a['close']
features['quote_close'] = df_b['close']
features['base_volume'] = df_a['volume']
# Core Features
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['spread_change_1h'] = spread_change_1h
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio_rel
features['vol_diff_ratio'] = vol_spread_ratio
# Volatility for SL/TP
features['realized_vol_base'] = realized_vol_a
features['realized_vol_quote'] = realized_vol_b
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
# ATR for dynamic stops (in price units and as percentage)
features['atr_base'] = atr_a
features['atr_quote'] = atr_b
features['atr_pct_base'] = atr_pct_a
features['atr_pct_quote'] = atr_pct_b
features['atr_pct_avg'] = (atr_pct_a + atr_pct_b) / 2
# Pair encoding (for universal model)
# Using base and quote indices for hierarchical encoding
assets = self.config.assets
features['base_idx'] = assets.index(base) if base in assets else -1
features['quote_idx'] = assets.index(quote) if quote in assets else -1
# Add funding and on-chain features
# Funding data is always added from self._funding_data (OKX, all 10 assets)
# On-chain data is optional (CryptoQuant, BTC/ETH only)
features = self._add_on_chain_features(
features, on_chain_data, base, quote
)
# Drop rows with NaN in core features only (not funding/on-chain)
core_cols = [
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
'atr_base', 'atr_pct_base' # ATR is core for SL/TP
]
features = features.dropna(subset=core_cols)
# Fill missing funding/on-chain features with 0 (neutral)
optional_cols = [
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg',
'base_inflow', 'quote_inflow', 'inflow_ratio'
]
for col in optional_cols:
if col in features.columns:
features[col] = features[col].fillna(0)
return features
def calculate_all_pair_features(
self,
pairs: list[TradingPair],
asset_data: dict[str, pd.DataFrame],
on_chain_data: pd.DataFrame | None = None
) -> dict[str, pd.DataFrame]:
"""
Calculate features for all pairs.
Args:
pairs: List of trading pairs
asset_data: Dictionary of OHLCV DataFrames
on_chain_data: Optional on-chain data
Returns:
Dictionary mapping pair_id to feature DataFrame
"""
all_features = {}
for pair in pairs:
features = self.calculate_pair_features(
pair, asset_data, on_chain_data
)
if features is not None and len(features) > 0:
all_features[pair.pair_id] = features
logger.info(
"Calculated features for %d/%d pairs",
len(all_features), len(pairs)
)
return all_features
def get_combined_features(
self,
pair_features: dict[str, pd.DataFrame],
timestamp: pd.Timestamp | None = None
) -> pd.DataFrame:
"""
Combine all pair features into a single DataFrame.
Useful for batch model prediction across all pairs.
Args:
pair_features: Dictionary of feature DataFrames by pair_id
timestamp: Optional specific timestamp to filter to
Returns:
Combined DataFrame with all pairs as rows
"""
if not pair_features:
return pd.DataFrame()
if timestamp is not None:
# Get latest row from each pair at or before timestamp
rows = []
for pair_id, features in pair_features.items():
valid = features[features.index <= timestamp]
if len(valid) > 0:
row = valid.iloc[-1:].copy()
rows.append(row)
if rows:
return pd.concat(rows, ignore_index=False)
return pd.DataFrame()
# Combine all features (for training)
return pd.concat(pair_features.values(), ignore_index=False)
def _add_on_chain_features(
self,
features: pd.DataFrame,
on_chain_data: pd.DataFrame | None,
base_asset: str,
quote_asset: str
) -> pd.DataFrame:
"""
Add on-chain and funding rate features for the pair.
Uses funding data from OKX (all 10 assets) and on-chain data
from CryptoQuant (BTC/ETH only for inflows).
"""
base_short = base_asset.replace('-USDT', '').lower()
quote_short = quote_asset.replace('-USDT', '').lower()
# Add funding rates from cached funding data
if self._funding_data is not None and not self._funding_data.empty:
funding_aligned = self._funding_data.reindex(
features.index, method='ffill'
)
base_funding_col = f'{base_short}_funding'
quote_funding_col = f'{quote_short}_funding'
if base_funding_col in funding_aligned.columns:
features['base_funding'] = funding_aligned[base_funding_col]
if quote_funding_col in funding_aligned.columns:
features['quote_funding'] = funding_aligned[quote_funding_col]
# Funding difference (positive = base has higher funding)
if 'base_funding' in features.columns and 'quote_funding' in features.columns:
features['funding_diff'] = (
features['base_funding'] - features['quote_funding']
)
# Funding sentiment: average of both assets
features['funding_avg'] = (
features['base_funding'] + features['quote_funding']
) / 2
# Add on-chain features from CryptoQuant (BTC/ETH only)
if on_chain_data is not None and not on_chain_data.empty:
cq_aligned = on_chain_data.reindex(features.index, method='ffill')
# Inflows (only available for BTC/ETH)
base_inflow_col = f'{base_short}_inflow'
quote_inflow_col = f'{quote_short}_inflow'
if base_inflow_col in cq_aligned.columns:
features['base_inflow'] = cq_aligned[base_inflow_col]
if quote_inflow_col in cq_aligned.columns:
features['quote_inflow'] = cq_aligned[quote_inflow_col]
if 'base_inflow' in features.columns and 'quote_inflow' in features.columns:
features['inflow_ratio'] = (
features['base_inflow'] /
(features['quote_inflow'] + 1)
)
return features
def get_feature_columns(self) -> list[str]:
"""
Get list of feature columns for ML model.
Excludes metadata and target-related columns.
Returns:
List of feature column names
"""
# Core features (always present)
core_features = [
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
'base_idx', 'quote_idx'
]
# Funding features (now available for all 10 assets via OKX)
funding_features = [
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg'
]
# On-chain features (BTC/ETH only via CryptoQuant)
onchain_features = [
'base_inflow', 'quote_inflow', 'inflow_ratio'
]
return core_features + funding_features + onchain_features

View File

@@ -0,0 +1,272 @@
"""
Funding Rate Fetcher for Multi-Pair Strategy.
Fetches historical funding rates from OKX for all assets.
CryptoQuant only supports BTC/ETH, so we use OKX for the full universe.
"""
import time
from pathlib import Path
from datetime import datetime, timezone
import ccxt
import pandas as pd
from engine.logging_config import get_logger
logger = get_logger(__name__)
class FundingRateFetcher:
"""
Fetches and caches funding rate data from OKX.
OKX funding rates are settled every 8 hours (00:00, 08:00, 16:00 UTC).
This fetcher retrieves historical funding rate data and aligns it
to hourly candles for use in the multi-pair strategy.
"""
def __init__(self, cache_dir: str = "data/funding"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.exchange: ccxt.okx | None = None
def _init_exchange(self) -> None:
"""Initialize OKX exchange connection."""
if self.exchange is None:
self.exchange = ccxt.okx({
'enableRateLimit': True,
'options': {'defaultType': 'swap'}
})
self.exchange.load_markets()
def fetch_funding_history(
self,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 100
) -> pd.DataFrame:
"""
Fetch historical funding rates for a symbol.
Args:
symbol: Asset symbol (e.g., 'BTC-USDT')
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
limit: Max records per request
Returns:
DataFrame with funding rate history
"""
self._init_exchange()
# Convert symbol format
base = symbol.replace('-USDT', '')
okx_symbol = f"{base}/USDT:USDT"
try:
# OKX funding rate history endpoint
# Uses fetch_funding_rate_history if available
all_funding = []
# Parse dates
if start_date:
since = self.exchange.parse8601(f"{start_date}T00:00:00Z")
else:
# Default to 1 year ago
since = self.exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
if end_date:
until = self.exchange.parse8601(f"{end_date}T23:59:59Z")
else:
until = self.exchange.milliseconds()
# Fetch in batches
current_since = since
while current_since < until:
try:
funding = self.exchange.fetch_funding_rate_history(
okx_symbol,
since=current_since,
limit=limit
)
if not funding:
break
all_funding.extend(funding)
# Move to next batch
last_ts = funding[-1]['timestamp']
if last_ts <= current_since:
break
current_since = last_ts + 1
time.sleep(0.1) # Rate limit
except Exception as e:
logger.warning(
"Error fetching funding batch for %s: %s",
symbol, str(e)[:50]
)
break
if not all_funding:
return pd.DataFrame()
# Convert to DataFrame
df = pd.DataFrame(all_funding)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df.set_index('timestamp', inplace=True)
df = df[['fundingRate']].rename(columns={'fundingRate': 'funding_rate'})
df.sort_index(inplace=True)
# Remove duplicates
df = df[~df.index.duplicated(keep='first')]
logger.info("Fetched %d funding records for %s", len(df), symbol)
return df
except Exception as e:
logger.error("Failed to fetch funding for %s: %s", symbol, e)
return pd.DataFrame()
def fetch_all_assets(
self,
assets: list[str],
start_date: str | None = None,
end_date: str | None = None
) -> pd.DataFrame:
"""
Fetch funding rates for all assets and combine.
Args:
assets: List of asset symbols (e.g., ['BTC-USDT', 'ETH-USDT'])
start_date: Start date
end_date: End date
Returns:
Combined DataFrame with columns like 'btc_funding', 'eth_funding', etc.
"""
combined = pd.DataFrame()
for symbol in assets:
df = self.fetch_funding_history(symbol, start_date, end_date)
if df.empty:
continue
# Rename column
asset_name = symbol.replace('-USDT', '').lower()
col_name = f"{asset_name}_funding"
df = df.rename(columns={'funding_rate': col_name})
if combined.empty:
combined = df
else:
combined = combined.join(df, how='outer')
time.sleep(0.2) # Be nice to API
# Forward fill to hourly (funding is every 8h)
if not combined.empty:
combined = combined.sort_index()
combined = combined.ffill()
return combined
def save_to_cache(self, df: pd.DataFrame, filename: str = "funding_rates.csv") -> None:
"""Save funding data to cache file."""
path = self.cache_dir / filename
df.to_csv(path)
logger.info("Saved funding rates to %s", path)
def load_from_cache(self, filename: str = "funding_rates.csv") -> pd.DataFrame | None:
"""Load funding data from cache if available."""
path = self.cache_dir / filename
if path.exists():
df = pd.read_csv(path, index_col='timestamp', parse_dates=True)
logger.info("Loaded funding rates from cache: %d rows", len(df))
return df
return None
def get_funding_data(
self,
assets: list[str],
start_date: str | None = None,
end_date: str | None = None,
use_cache: bool = True,
force_refresh: bool = False
) -> pd.DataFrame:
"""
Get funding data, using cache if available.
Args:
assets: List of asset symbols
start_date: Start date
end_date: End date
use_cache: Whether to use cached data
force_refresh: Force refresh even if cache exists
Returns:
DataFrame with funding rates for all assets
"""
cache_file = "funding_rates.csv"
# Try cache first
if use_cache and not force_refresh:
cached = self.load_from_cache(cache_file)
if cached is not None:
# Check if cache covers requested range
if start_date and end_date:
start_ts = pd.Timestamp(start_date, tz='UTC')
end_ts = pd.Timestamp(end_date, tz='UTC')
if cached.index.min() <= start_ts and cached.index.max() >= end_ts:
# Filter to requested range
return cached[(cached.index >= start_ts) & (cached.index <= end_ts)]
# Fetch fresh data
logger.info("Fetching fresh funding rate data...")
df = self.fetch_all_assets(assets, start_date, end_date)
if not df.empty and use_cache:
self.save_to_cache(df, cache_file)
return df
def download_funding_data():
"""Download funding data for all multi-pair assets."""
from strategies.multi_pair.config import MultiPairConfig
config = MultiPairConfig()
fetcher = FundingRateFetcher()
# Fetch last year of data
end_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
start_date = (datetime.now(timezone.utc) - pd.Timedelta(days=365)).strftime("%Y-%m-%d")
logger.info("Downloading funding rates for %d assets...", len(config.assets))
logger.info("Date range: %s to %s", start_date, end_date)
df = fetcher.get_funding_data(
config.assets,
start_date=start_date,
end_date=end_date,
force_refresh=True
)
if not df.empty:
logger.info("Downloaded %d funding rate records", len(df))
logger.info("Columns: %s", list(df.columns))
else:
logger.warning("No funding data downloaded")
return df
if __name__ == "__main__":
from engine.logging_config import setup_logging
setup_logging()
download_funding_data()

View File

@@ -0,0 +1,168 @@
"""
Pair Scanner for Multi-Pair Divergence Strategy.
Generates all possible pairs from asset universe and checks tradeability.
"""
from dataclasses import dataclass
from itertools import combinations
from typing import Optional
import ccxt
from engine.logging_config import get_logger
from .config import MultiPairConfig
logger = get_logger(__name__)
@dataclass
class TradingPair:
"""
Represents a tradeable pair for spread analysis.
Attributes:
base_asset: First asset in the pair (numerator)
quote_asset: Second asset in the pair (denominator)
pair_id: Unique identifier for the pair
is_direct: Whether pair can be traded directly on exchange
exchange_symbol: Symbol for direct trading (if available)
"""
base_asset: str
quote_asset: str
pair_id: str
is_direct: bool = False
exchange_symbol: Optional[str] = None
@property
def name(self) -> str:
"""Human-readable pair name."""
return f"{self.base_asset}/{self.quote_asset}"
def __hash__(self):
return hash(self.pair_id)
def __eq__(self, other):
if not isinstance(other, TradingPair):
return False
return self.pair_id == other.pair_id
class PairScanner:
"""
Scans and generates tradeable pairs from asset universe.
Checks OKX for directly tradeable cross-pairs and generates
synthetic pairs via USDT for others.
"""
def __init__(self, config: MultiPairConfig):
self.config = config
self.exchange: Optional[ccxt.Exchange] = None
self._available_markets: set[str] = set()
def _init_exchange(self) -> None:
"""Initialize exchange connection for market lookup."""
if self.exchange is None:
exchange_class = getattr(ccxt, self.config.exchange_id)
self.exchange = exchange_class({'enableRateLimit': True})
self.exchange.load_markets()
self._available_markets = set(self.exchange.symbols)
logger.info(
"Loaded %d markets from %s",
len(self._available_markets),
self.config.exchange_id
)
def generate_pairs(self, check_exchange: bool = True) -> list[TradingPair]:
"""
Generate all unique pairs from asset universe.
Args:
check_exchange: Whether to check OKX for direct trading
Returns:
List of TradingPair objects
"""
if check_exchange:
self._init_exchange()
pairs = []
assets = self.config.assets
for base, quote in combinations(assets, 2):
pair_id = f"{base}__{quote}"
# Check if directly tradeable as cross-pair on OKX
is_direct = False
exchange_symbol = None
if check_exchange:
# Check perpetual cross-pair (e.g., ETH/BTC:BTC)
# OKX perpetuals are typically quoted in USDT
# Cross-pairs like ETH/BTC are less common
cross_symbol = f"{base.replace('-USDT', '')}/{quote.replace('-USDT', '')}:USDT"
if cross_symbol in self._available_markets:
is_direct = True
exchange_symbol = cross_symbol
pair = TradingPair(
base_asset=base,
quote_asset=quote,
pair_id=pair_id,
is_direct=is_direct,
exchange_symbol=exchange_symbol
)
pairs.append(pair)
# Log summary
direct_count = sum(1 for p in pairs if p.is_direct)
logger.info(
"Generated %d pairs: %d direct, %d synthetic",
len(pairs), direct_count, len(pairs) - direct_count
)
return pairs
def get_required_symbols(self, pairs: list[TradingPair]) -> list[str]:
"""
Get list of symbols needed to calculate all pair spreads.
For synthetic pairs, we need both USDT pairs.
For direct pairs, we still load USDT pairs for simplicity.
Args:
pairs: List of trading pairs
Returns:
List of unique symbols to load (e.g., ['BTC-USDT', 'ETH-USDT'])
"""
symbols = set()
for pair in pairs:
symbols.add(pair.base_asset)
symbols.add(pair.quote_asset)
return list(symbols)
def filter_by_assets(
self,
pairs: list[TradingPair],
exclude_assets: list[str]
) -> list[TradingPair]:
"""
Filter pairs that contain any of the excluded assets.
Args:
pairs: List of trading pairs
exclude_assets: Assets to exclude
Returns:
Filtered list of pairs
"""
if not exclude_assets:
return pairs
exclude_set = set(exclude_assets)
return [
p for p in pairs
if p.base_asset not in exclude_set
and p.quote_asset not in exclude_set
]

View File

@@ -0,0 +1,525 @@
"""
Multi-Pair Divergence Selection Strategy.
Main strategy class that orchestrates pair scanning, feature calculation,
model training, and signal generation for backtesting.
"""
from dataclasses import dataclass
from typing import Optional
import pandas as pd
import numpy as np
from strategies.base import BaseStrategy
from engine.market import MarketType
from engine.logging_config import get_logger
from .config import MultiPairConfig
from .pair_scanner import PairScanner, TradingPair
from .correlation import CorrelationFilter
from .feature_engine import MultiPairFeatureEngine
from .divergence_scorer import DivergenceScorer, DivergenceSignal
logger = get_logger(__name__)
@dataclass
class PositionState:
"""Tracks current position state."""
pair: TradingPair | None = None
direction: str | None = None # 'long' or 'short'
entry_price: float = 0.0
entry_idx: int = -1
stop_loss: float = 0.0
take_profit: float = 0.0
atr: float = 0.0 # ATR at entry for reference
last_exit_idx: int = -100 # For cooldown tracking
class MultiPairDivergenceStrategy(BaseStrategy):
"""
Multi-Pair Divergence Selection Strategy.
Scans multiple cryptocurrency pairs for spread divergence,
selects the most divergent pair using ML-enhanced scoring,
and trades mean-reversion opportunities.
Key Features:
- Universal ML model across all pairs
- Correlation-based pair filtering
- Dynamic SL/TP based on volatility
- Walk-forward training
"""
def __init__(
self,
config: MultiPairConfig | None = None,
model_path: str = "data/multi_pair_model.pkl"
):
super().__init__()
self.config = config or MultiPairConfig()
# Initialize components
self.pair_scanner = PairScanner(self.config)
self.correlation_filter = CorrelationFilter(self.config)
self.feature_engine = MultiPairFeatureEngine(self.config)
self.divergence_scorer = DivergenceScorer(self.config, model_path)
# Strategy configuration
self.default_market_type = MarketType.PERPETUAL
self.default_leverage = 1
# Runtime state
self.pairs: list[TradingPair] = []
self.asset_data: dict[str, pd.DataFrame] = {}
self.pair_features: dict[str, pd.DataFrame] = {}
self.position = PositionState()
self.train_end_idx: int = 0
def run(self, close: pd.Series, **kwargs) -> tuple:
"""
Execute the multi-pair divergence strategy.
This method is called by the backtester with the primary asset's
close prices. For multi-pair, we load all assets internally.
Args:
close: Primary close prices (used for index alignment)
**kwargs: Additional data (high, low, volume)
Returns:
Tuple of (long_entries, long_exits, short_entries, short_exits, size)
"""
logger.info("Starting Multi-Pair Divergence Strategy")
# 1. Load all asset data
start_date = close.index.min().strftime("%Y-%m-%d")
end_date = close.index.max().strftime("%Y-%m-%d")
self.asset_data = self.feature_engine.load_all_assets(
start_date=start_date,
end_date=end_date
)
# 1b. Load funding rate data for all assets
self.feature_engine.load_funding_data(
start_date=start_date,
end_date=end_date,
use_cache=True
)
if len(self.asset_data) < 2:
logger.error("Insufficient assets loaded, need at least 2")
return self._empty_signals(close)
# 2. Generate pairs
self.pairs = self.pair_scanner.generate_pairs(check_exchange=False)
# Filter to pairs with available data
available_assets = set(self.asset_data.keys())
self.pairs = [
p for p in self.pairs
if p.base_asset in available_assets
and p.quote_asset in available_assets
]
logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data))
# 3. Calculate features for all pairs
self.pair_features = self.feature_engine.calculate_all_pair_features(
self.pairs, self.asset_data
)
if not self.pair_features:
logger.error("No pair features calculated")
return self._empty_signals(close)
# 4. Align to common index
common_index = self._get_common_index()
if len(common_index) < 200:
logger.error("Insufficient common data across pairs")
return self._empty_signals(close)
# 5. Walk-forward split
n_samples = len(common_index)
train_size = int(n_samples * self.config.train_ratio)
self.train_end_idx = train_size
train_end_date = common_index[train_size - 1]
test_start_date = common_index[train_size]
logger.info(
"Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)",
train_size, train_end_date.strftime('%Y-%m-%d'),
n_samples - train_size, test_start_date.strftime('%Y-%m-%d')
)
# 6. Train model on training period
if self.divergence_scorer.model is None:
train_features = {
pid: feat[feat.index <= train_end_date]
for pid, feat in self.pair_features.items()
}
combined = self.feature_engine.get_combined_features(train_features)
self.divergence_scorer.train_model(combined, train_features)
# 7. Generate signals for test period
return self._generate_signals(common_index, train_size, close)
def _generate_signals(
self,
index: pd.DatetimeIndex,
train_size: int,
reference_close: pd.Series
) -> tuple:
"""
Generate entry/exit signals for the test period.
Iterates through each bar in the test period, scoring pairs
and generating signals based on divergence scores.
"""
# Initialize signal arrays aligned to reference close
long_entries = pd.Series(False, index=reference_close.index)
long_exits = pd.Series(False, index=reference_close.index)
short_entries = pd.Series(False, index=reference_close.index)
short_exits = pd.Series(False, index=reference_close.index)
size = pd.Series(1.0, index=reference_close.index)
# Track position state
self.position = PositionState()
# Price data for correlation calculation
price_data = {
symbol: df['close'] for symbol, df in self.asset_data.items()
}
# Iterate through test period
test_indices = index[train_size:]
trade_count = 0
for i, timestamp in enumerate(test_indices):
current_idx = train_size + i
# Check exit conditions first
if self.position.pair is not None:
# Enforce minimum hold period
bars_held = current_idx - self.position.entry_idx
if bars_held < self.config.min_hold_bars:
# Only allow SL/TP exits during min hold period
should_exit, exit_reason = self._check_sl_tp_only(timestamp)
else:
should_exit, exit_reason = self._check_exit(timestamp)
if should_exit:
# Map exit signal to reference index
if timestamp in reference_close.index:
if self.position.direction == 'long':
long_exits.loc[timestamp] = True
else:
short_exits.loc[timestamp] = True
logger.debug(
"Exit %s %s at %s: %s (held %d bars)",
self.position.direction,
self.position.pair.name,
timestamp.strftime('%Y-%m-%d %H:%M'),
exit_reason,
bars_held
)
self.position = PositionState(last_exit_idx=current_idx)
# Score pairs (with correlation filter if position exists)
held_asset = None
if self.position.pair is not None:
held_asset = self.position.pair.base_asset
# Filter pairs by correlation
candidate_pairs = self.correlation_filter.filter_pairs(
self.pairs,
held_asset,
price_data,
current_idx
)
# Get candidate features
candidate_features = {
pid: feat for pid, feat in self.pair_features.items()
if any(p.pair_id == pid for p in candidate_pairs)
}
# Score pairs
signals = self.divergence_scorer.score_pairs(
candidate_features,
candidate_pairs,
timestamp
)
# Get best signal
best = self.divergence_scorer.select_best_pair(signals)
if best is None:
continue
# Check if we should switch positions or enter new
should_enter = False
# Check cooldown
bars_since_exit = current_idx - self.position.last_exit_idx
in_cooldown = bars_since_exit < self.config.cooldown_bars
if self.position.pair is None and not in_cooldown:
# No position and not in cooldown, can enter
should_enter = True
elif self.position.pair is not None:
# Check if we should switch (requires min hold + significant improvement)
bars_held = current_idx - self.position.entry_idx
current_score = self._get_current_score(timestamp)
if (bars_held >= self.config.min_hold_bars and
best.divergence_score > current_score * self.config.switch_threshold):
# New opportunity is significantly better
if timestamp in reference_close.index:
if self.position.direction == 'long':
long_exits.loc[timestamp] = True
else:
short_exits.loc[timestamp] = True
self.position = PositionState(last_exit_idx=current_idx)
should_enter = True
if should_enter:
# Calculate ATR-based dynamic SL/TP
sl_price, tp_price = self._calculate_sl_tp(
best.base_price,
best.direction,
best.atr,
best.atr_pct
)
# Set position
self.position = PositionState(
pair=best.pair,
direction=best.direction,
entry_price=best.base_price,
entry_idx=current_idx,
stop_loss=sl_price,
take_profit=tp_price,
atr=best.atr
)
# Calculate position size based on divergence
pos_size = self._calculate_size(best.divergence_score)
# Generate entry signal
if timestamp in reference_close.index:
if best.direction == 'long':
long_entries.loc[timestamp] = True
else:
short_entries.loc[timestamp] = True
size.loc[timestamp] = pos_size
trade_count += 1
logger.debug(
"Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f",
best.direction,
best.pair.name,
timestamp.strftime('%Y-%m-%d %H:%M'),
best.z_score,
best.probability,
best.divergence_score
)
logger.info("Generated %d trades in test period", trade_count)
return long_entries, long_exits, short_entries, short_exits, size
def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
"""
Check if current position should be exited.
Exit conditions:
1. Z-Score reverted to mean (|Z| < threshold)
2. Stop-loss hit
3. Take-profit hit
Returns:
Tuple of (should_exit, reason)
"""
if self.position.pair is None:
return False, ""
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return True, "pair_data_missing"
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return True, "no_data"
latest = valid.iloc[-1]
z_score = latest['z_score']
current_price = latest['base_close']
# Check mean reversion (primary exit)
if abs(z_score) < self.config.z_exit_threshold:
return True, f"mean_reversion (z={z_score:.2f})"
# Check SL/TP
return self._check_sl_tp(current_price)
def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
"""
Check only stop-loss and take-profit conditions.
Used during minimum hold period.
"""
if self.position.pair is None:
return False, ""
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return True, "pair_data_missing"
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return True, "no_data"
latest = valid.iloc[-1]
current_price = latest['base_close']
return self._check_sl_tp(current_price)
def _check_sl_tp(self, current_price: float) -> tuple[bool, str]:
"""Check stop-loss and take-profit levels."""
if self.position.direction == 'long':
if current_price <= self.position.stop_loss:
return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})"
if current_price >= self.position.take_profit:
return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})"
else: # short
if current_price >= self.position.stop_loss:
return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})"
if current_price <= self.position.take_profit:
return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})"
return False, ""
def _get_current_score(self, timestamp: pd.Timestamp) -> float:
"""Get current position's divergence score for comparison."""
if self.position.pair is None:
return 0.0
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return 0.0
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return 0.0
latest = valid.iloc[-1]
z_score = abs(latest['z_score'])
# Re-score with model
if self.divergence_scorer.model is not None:
feature_row = latest[self.divergence_scorer.feature_cols].fillna(0)
feature_row = feature_row.replace([np.inf, -np.inf], 0)
X = pd.DataFrame(
[feature_row.values],
columns=self.divergence_scorer.feature_cols
)
prob = self.divergence_scorer.model.predict_proba(X)[0, 1]
return z_score * prob
return z_score * 0.5
def _calculate_sl_tp(
self,
entry_price: float,
direction: str,
atr: float,
atr_pct: float
) -> tuple[float, float]:
"""
Calculate ATR-based dynamic stop-loss and take-profit prices.
Uses ATR (Average True Range) to set stops that adapt to
each asset's volatility. More volatile assets get wider stops.
Args:
entry_price: Entry price
direction: 'long' or 'short'
atr: ATR in price units
atr_pct: ATR as percentage of price
Returns:
Tuple of (stop_loss_price, take_profit_price)
"""
# Calculate SL/TP as ATR multiples
if atr > 0 and atr_pct > 0:
# ATR-based calculation
sl_distance = atr * self.config.sl_atr_multiplier
tp_distance = atr * self.config.tp_atr_multiplier
# Convert to percentage for bounds checking
sl_pct = sl_distance / entry_price
tp_pct = tp_distance / entry_price
else:
# Fallback to fixed percentages if ATR unavailable
sl_pct = self.config.base_sl_pct
tp_pct = self.config.base_tp_pct
# Apply bounds to prevent extreme stops
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
# Calculate actual prices
if direction == 'long':
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else: # short
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
return stop_loss, take_profit
def _calculate_size(self, divergence_score: float) -> float:
"""
Calculate position size based on divergence score.
Higher divergence = larger position (up to 2x).
"""
# Base score threshold (Z=1.0, prob=0.5 -> score=0.5)
base_threshold = 0.5
# Scale factor
if divergence_score <= base_threshold:
return 1.0
# Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
return min(scale, 2.0)
def _get_common_index(self) -> pd.DatetimeIndex:
"""Get the intersection of all pair feature indices."""
if not self.pair_features:
return pd.DatetimeIndex([])
common = None
for features in self.pair_features.values():
if common is None:
common = features.index
else:
common = common.intersection(features.index)
return common.sort_values()
def _empty_signals(self, close: pd.Series) -> tuple:
"""Return empty signal arrays."""
empty = self.create_empty_signals(close)
size = pd.Series(1.0, index=close.index)
return empty, empty, empty, empty, size

View File

@@ -0,0 +1,365 @@
import pandas as pd
import numpy as np
import ta
import vectorbt as vbt
from sklearn.ensemble import RandomForestClassifier
from strategies.base import BaseStrategy
from engine.market import MarketType
from engine.data_manager import DataManager
from engine.logging_config import get_logger
logger = get_logger(__name__)
class RegimeReversionStrategy(BaseStrategy):
"""
ML-Based Regime Detection & Mean Reversion Strategy.
Logic:
1. Tracks the BTC/ETH Spread and its Z-Score (24h window).
2. Uses a Random Forest model to predict if an extreme Z-Score will revert profitably.
3. Features: Spread Technicals (RSI, ROC) + On-Chain Flows (Inflow, Funding).
4. Entry: When Model Probability > 0.5.
5. Exit: Z-Score reversion to 0 or SL/TP.
Walk-Forward Training:
- Trains on first `train_ratio` of data (default 70%)
- Generates signals only for remaining test period (30%)
- Eliminates look-ahead bias for realistic backtest results
"""
# Optimal parameters from walk-forward research (2025-10 to 2025-12)
# Research: research/horizon_optimization_results.csv
OPTIMAL_HORIZON = 102 # 4.25 days - best Net PnL (+232%)
OPTIMAL_Z_WINDOW = 24 # 24h rolling window for spread Z-score
OPTIMAL_TRAIN_RATIO = 0.7 # 70% train / 30% test split
OPTIMAL_PROFIT_TARGET = 0.005 # 0.5% profit threshold for target definition
OPTIMAL_Z_ENTRY = 1.0 # Enter when |Z| > 1.0
def __init__(self,
model_path: str = "data/regime_model.pkl",
horizon: int = OPTIMAL_HORIZON,
z_window: int = OPTIMAL_Z_WINDOW,
z_entry_threshold: float = OPTIMAL_Z_ENTRY,
profit_target: float = OPTIMAL_PROFIT_TARGET,
stop_loss: float = 0.06, # 6% - accommodates 1.95% avg MAE
take_profit: float = 0.05, # 5% swing target
train_ratio: float = OPTIMAL_TRAIN_RATIO,
trend_window: int = 0, # Disable SMA filter
use_funding_filter: bool = True, # Enable Funding Rate filter
funding_threshold: float = 0.005 # Tightened to 0.005%
):
super().__init__()
self.model_path = model_path
self.horizon = horizon
self.z_window = z_window
self.z_entry_threshold = z_entry_threshold
self.profit_target = profit_target
self.stop_loss = stop_loss
self.take_profit = take_profit
self.train_ratio = train_ratio
self.trend_window = trend_window
self.use_funding_filter = use_funding_filter
self.funding_threshold = funding_threshold
# Default Strategy Config
self.default_market_type = MarketType.PERPETUAL
self.default_leverage = 1
self.dm = DataManager()
self.model = None
self.feature_cols = None
self.train_end_idx = None # Will store the training cutoff point
def run(self, close, **kwargs):
"""
Execute the strategy logic.
We assume this strategy is run on ETH-USDT (the active asset).
We will fetch BTC-USDT internally to calculate the spread.
"""
# 1. Identify Context
# We need BTC data aligned with the incoming ETH 'close' series
start_date = close.index.min()
end_date = close.index.max()
logger.info("Fetching BTC context data...")
try:
# Load BTC data (Context) - Must match the timeframe of the backtest
# Research was done on 1h candles, so strategy should be run on 1h
# Use PERPETUAL data to match the trading instrument (ETH Perp)
df_btc = self.dm.load_data("okx", "BTC-USDT", "1h", MarketType.PERPETUAL)
# Align BTC to ETH (close)
df_btc = df_btc.reindex(close.index, method='ffill')
btc_close = df_btc['close']
except Exception as e:
logger.error(f"Failed to load BTC context: {e}")
empty = self.create_empty_signals(close)
return empty, empty, empty, empty
# 2. Construct DataFrames for Feature Engineering
# We need volume/high/low for features, but 'run' signature primarily gives 'close'.
# kwargs might have high/low/volume if passed by Backtester.run_strategy
eth_vol = kwargs.get('volume')
if eth_vol is None:
logger.warning("Volume data missing. Feature calculation might fail.")
# Fallback or error handling
eth_vol = pd.Series(0, index=close.index)
# Construct dummy dfs for prepare_features
# We only really need Close and Volume for the current feature set
df_a = pd.DataFrame({'close': btc_close, 'volume': df_btc['volume']})
df_b = pd.DataFrame({'close': close, 'volume': eth_vol})
# 3. Load On-Chain Data (CryptoQuant)
# We use the saved CSV for training/inference
# In a live setting, this would query the API for recent data
cq_df = None
try:
cq_path = "data/cq_training_data.csv"
cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True)
if cq_df.index.tz is None:
cq_df.index = cq_df.index.tz_localize('UTC')
except Exception:
logger.warning("CryptoQuant data not found. Running without on-chain features.")
# 4. Calculate Features
features = self.prepare_features(df_a, df_b, cq_df)
# 5. Walk-Forward Split
# Train on first `train_ratio` of data, test on remainder
n_samples = len(features)
train_size = int(n_samples * self.train_ratio)
train_features = features.iloc[:train_size]
test_features = features.iloc[train_size:]
train_end_date = train_features.index[-1]
test_start_date = test_features.index[0]
logger.info(
f"Walk-Forward Split: Train={len(train_features)} bars "
f"(until {train_end_date.strftime('%Y-%m-%d')}), "
f"Test={len(test_features)} bars "
f"(from {test_start_date.strftime('%Y-%m-%d')})"
)
# 6. Train Model on Training Period ONLY
if self.model is None:
logger.info("Training Regime Model on training period only...")
self.model, self.feature_cols = self.train_model(train_features)
# 7. Predict on TEST Period ONLY
# Use valid columns only
X_test = test_features[self.feature_cols].fillna(0)
X_test = X_test.replace([np.inf, -np.inf], 0)
# Predict Probabilities for test period
probs = self.model.predict_proba(X_test)[:, 1]
# 8. Generate Entry Signals (TEST period only)
# If Z > threshold (Spread High, ETH Expensive) -> Short ETH
# If Z < -threshold (Spread Low, ETH Cheap) -> Long ETH
z_thresh = self.z_entry_threshold
short_signal_test = (probs > 0.5) & (test_features['z_score'].values > z_thresh)
long_signal_test = (probs > 0.5) & (test_features['z_score'].values < -z_thresh)
# 8b. Apply Trend Filter (Macro Regime)
# Rule: Long only if BTC > SMA (Bull), Short only if BTC < SMA (Bear)
if self.trend_window > 0:
# Calculate SMA on full BTC history first
btc_sma = btc_close.rolling(window=self.trend_window).mean()
# Align with test period
test_btc_close = btc_close.reindex(test_features.index)
test_btc_sma = btc_sma.reindex(test_features.index)
# Define Regimes
is_bull = (test_btc_close > test_btc_sma).values
is_bear = (test_btc_close < test_btc_sma).values
# Apply Filter
long_signal_test = long_signal_test & is_bull
short_signal_test = short_signal_test & is_bear
# 8c. Apply Funding Rate Filter
# Rule: If Funding > Threshold (Greedy) -> No Longs.
# If Funding < -Threshold (Fearful) -> No Shorts.
if self.use_funding_filter and 'btc_funding' in test_features.columns:
funding = test_features['btc_funding'].values
thresh = self.funding_threshold
# Greedy Market (High Positive Funding) -> Risk of Long Squeeze -> Block Longs
# (Or implies trend is up? Actually for Mean Reversion, high funding often marks tops)
# We block Longs because we don't want to buy into an overheated market?
# Actually, "Greedy" means Longs are paying Shorts.
# If we Long, we pay funding.
# If we Short, we receive funding.
# So High Funding = Good for Shorts (receive yield + reversion).
# Bad for Longs (pay yield + likely top).
is_overheated = funding > thresh
is_oversold = funding < -thresh
# Block Longs if Overheated
long_signal_test = long_signal_test & (~is_overheated)
# Block Shorts if Oversold (Negative Funding) -> Risk of Short Squeeze
short_signal_test = short_signal_test & (~is_oversold)
n_blocked_long = (is_overheated & (probs > 0.5) & (test_features['z_score'].values < -z_thresh)).sum()
n_blocked_short = (is_oversold & (probs > 0.5) & (test_features['z_score'].values > z_thresh)).sum()
if n_blocked_long > 0 or n_blocked_short > 0:
logger.info(f"Funding Filter: Blocked {n_blocked_long} Longs, {n_blocked_short} Shorts")
# 9. Calculate Position Sizing (Probability-Based)
# Base size = 1.0 (100% of equity)
# Scale: 1.0 + (Prob - 0.5) * 2
# Example: Prob=0.6 -> Size=1.2, Prob=0.8 -> Size=1.6
# Align probabilities to close index
probs_series = pd.Series(0.0, index=test_features.index)
probs_series[:] = probs
probs_aligned = probs_series.reindex(close.index, fill_value=0.0)
# Calculate dynamic size
dynamic_size = 1.0 + (probs_aligned - 0.5) * 2.0
# Cap leverage between 1x and 2x
size = dynamic_size.clip(lower=1.0, upper=2.0)
# Create full-length signal series (False for training period)
long_entries = pd.Series(False, index=close.index)
short_entries = pd.Series(False, index=close.index)
# Map test signals to their correct indices
test_idx = test_features.index
for i, idx in enumerate(test_idx):
if idx in close.index:
long_entries.loc[idx] = bool(long_signal_test[i])
short_entries.loc[idx] = bool(short_signal_test[i])
# 9. Generate Exits
# Exit when Z-Score crosses back through 0 (mean reversion complete)
z_reindexed = features['z_score'].reindex(close.index, fill_value=0)
# Exit Long when Z > 0, Exit Short when Z < 0
long_exits = z_reindexed > 0
short_exits = z_reindexed < 0
# Log signal counts for verification
n_long = long_entries.sum()
n_short = short_entries.sum()
logger.info(f"Generated {n_long} long signals, {n_short} short signals (test period only)")
return long_entries, long_exits, short_entries, short_exits, size
def prepare_features(self, df_btc, df_eth, cq_df=None):
"""Replicate research feature engineering"""
# Align
common = df_btc.index.intersection(df_eth.index)
df_a = df_btc.loc[common].copy()
df_b = df_eth.loc[common].copy()
# Spread
spread = df_b['close'] / df_a['close']
# Z-Score
rolling_mean = spread.rolling(window=self.z_window).mean()
rolling_std = spread.rolling(window=self.z_window).std()
z_score = (spread - rolling_mean) / rolling_std
# Technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
spread_change_1h = spread.pct_change(periods=1)
# Volume
vol_ratio = df_b['volume'] / df_a['volume']
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
# Volatility
ret_a = df_a['close'].pct_change()
ret_b = df_b['close'].pct_change()
vol_a = ret_a.rolling(window=self.z_window).std()
vol_b = ret_b.rolling(window=self.z_window).std()
vol_spread_ratio = vol_b / vol_a
features = pd.DataFrame(index=spread.index)
features['spread'] = spread
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['spread_change_1h'] = spread_change_1h
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
features['vol_diff_ratio'] = vol_spread_ratio
# CQ Merge
if cq_df is not None:
cq_aligned = cq_df.reindex(features.index, method='ffill')
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
cq_aligned['funding_diff'] = cq_aligned['eth_funding'] - cq_aligned['btc_funding']
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
cq_aligned['inflow_ratio'] = cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
features = features.join(cq_aligned)
return features.dropna()
def train_model(self, train_features):
"""
Train Random Forest on training data only.
This method receives ONLY the training subset of features,
ensuring no look-ahead bias. The model learns from historical
patterns and is then applied to unseen test data.
Args:
train_features: DataFrame containing features for training period only
"""
threshold = self.profit_target
horizon = self.horizon
z_thresh = self.z_entry_threshold
# Define targets using ONLY training data
# For Short Spread (Z > threshold): Did spread drop below target within horizon?
future_min = train_features['spread'].rolling(window=horizon).min().shift(-horizon)
target_short = train_features['spread'] * (1 - threshold)
success_short = (train_features['z_score'] > z_thresh) & (future_min < target_short)
# For Long Spread (Z < -threshold): Did spread rise above target within horizon?
future_max = train_features['spread'].rolling(window=horizon).max().shift(-horizon)
target_long = train_features['spread'] * (1 + threshold)
success_long = (train_features['z_score'] < -z_thresh) & (future_max > target_long)
targets = np.select([success_short, success_long], [1, 1], default=0)
# Build model
model = RandomForestClassifier(
n_estimators=300, max_depth=5, min_samples_leaf=30,
class_weight={0: 1, 1: 3}, random_state=42
)
# Exclude non-feature columns
exclude = ['spread']
cols = [c for c in train_features.columns if c not in exclude]
# Clean features
X_train = train_features[cols].fillna(0)
X_train = X_train.replace([np.inf, -np.inf], 0)
# Remove rows with NaN targets (from rolling window at end of training period)
valid_mask = ~np.isnan(targets) & ~np.isinf(targets)
# Also check for rows where future data doesn't exist (shift created NaNs)
valid_mask = valid_mask & (future_min.notna().values) & (future_max.notna().values)
X_train_clean = X_train[valid_mask]
targets_clean = targets[valid_mask]
logger.info(f"Training on {len(X_train_clean)} valid samples (removed {len(X_train) - len(X_train_clean)} with incomplete future data)")
model.fit(X_train_clean, targets_clean)
return model, cols

View File

@@ -0,0 +1,6 @@
"""
Meta Supertrend strategy package.
"""
from .strategy import MetaSupertrendStrategy
__all__ = ['MetaSupertrendStrategy']

View File

@@ -0,0 +1,128 @@
"""
Supertrend indicators and helper functions.
"""
import numpy as np
import vectorbt as vbt
from numba import njit
# --- Numba Compiled Helper Functions ---
@njit(cache=False) # Disable cache to avoid stale compilation issues
def get_tr_nb(high, low, close):
"""Calculate True Range (Numba compiled)."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
tr = np.empty_like(close)
tr[0] = high[0] - low[0]
for i in range(1, len(close)):
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
return tr
@njit(cache=False)
def get_atr_nb(high, low, close, period):
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
# Ensure period is native Python int (critical for Numba array indexing)
n = len(close)
p = int(period)
tr = get_tr_nb(high, low, close)
atr = np.full(n, np.nan, dtype=np.float64)
if n < p:
return atr
# Initial ATR is simple average of TR
sum_tr = 0.0
for i in range(p):
sum_tr += tr[i]
atr[p - 1] = sum_tr / p
# Subsequent ATR is Wilder's smoothed
for i in range(p, n):
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
return atr
@njit(cache=False)
def get_supertrend_nb(high, low, close, period, multiplier):
"""Calculate SuperTrend completely in Numba."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
# Ensure params are native Python types (critical for Numba)
n = len(close)
p = int(period)
m = float(multiplier)
atr = get_atr_nb(high, low, close, p)
final_upper = np.full(n, np.nan, dtype=np.float64)
final_lower = np.full(n, np.nan, dtype=np.float64)
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
# Skip until we have valid ATR
start_idx = p
if start_idx >= n:
return trend
# Init first valid point
hl2 = (high[start_idx] + low[start_idx]) / 2
final_upper[start_idx] = hl2 + m * atr[start_idx]
final_lower[start_idx] = hl2 - m * atr[start_idx]
# Loop
for i in range(start_idx + 1, n):
cur_hl2 = (high[i] + low[i]) / 2
cur_atr = atr[i]
basic_upper = cur_hl2 + m * cur_atr
basic_lower = cur_hl2 - m * cur_atr
# Upper Band Logic
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
final_upper[i] = basic_upper
else:
final_upper[i] = final_upper[i-1]
# Lower Band Logic
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
final_lower[i] = basic_lower
else:
final_lower[i] = final_lower[i-1]
# Trend Logic
if trend[i-1] == 1:
if close[i] < final_lower[i-1]:
trend[i] = -1
else:
trend[i] = 1
else:
if close[i] > final_upper[i-1]:
trend[i] = 1
else:
trend[i] = -1
return trend
# --- VectorBT Indicator Factory ---
SuperTrendIndicator = vbt.IndicatorFactory(
class_name='SuperTrend',
short_name='st',
input_names=['high', 'low', 'close'],
param_names=['period', 'multiplier'],
output_names=['trend']
).from_apply_func(
get_supertrend_nb,
keep_pd=False, # Disable automatic Pandas wrapping of inputs
param_product=True # Enable Cartesian product for list params
)

Some files were not shown because too many files have changed in this diff Show More