Remove deprecated modules and files related to the backtesting framework, including backtest.py, cli.py, config.py, data.py, intrabar.py, logging_utils.py, market_costs.py, metrics.py, trade.py, and supertrend indicators. Introduce a new structure for the backtesting engine with improved organization and functionality, including a CLI handler, data manager, and reporting capabilities. Update dependencies in pyproject.toml to support the new architecture.
This commit is contained in:
352
engine/backtester.py
Normal file
352
engine/backtester.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Core backtesting engine for running strategy simulations.
|
||||
|
||||
Supports multiple market types with realistic trading conditions.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketType, get_market_config
|
||||
from engine.optimizer import WalkForwardOptimizer
|
||||
from engine.portfolio import run_long_only_portfolio, run_long_short_portfolio
|
||||
from engine.risk import (
|
||||
LiquidationEvent,
|
||||
calculate_funding,
|
||||
calculate_liquidation_adjustment,
|
||||
inject_liquidation_exits,
|
||||
)
|
||||
from strategies.base import BaseStrategy
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BacktestResult:
|
||||
"""
|
||||
Container for backtest results with market-specific metrics.
|
||||
|
||||
Attributes:
|
||||
portfolio: VectorBT Portfolio object
|
||||
market_type: Market type used for the backtest
|
||||
leverage: Effective leverage used
|
||||
total_funding_paid: Total funding fees paid (perpetuals only)
|
||||
liquidation_count: Number of positions that were liquidated
|
||||
liquidation_events: Detailed list of liquidation events
|
||||
total_liquidation_loss: Total margin lost from liquidations
|
||||
adjusted_return: Return adjusted for liquidation losses (percentage)
|
||||
"""
|
||||
portfolio: vbt.Portfolio
|
||||
market_type: MarketType
|
||||
leverage: int
|
||||
total_funding_paid: float = 0.0
|
||||
liquidation_count: int = 0
|
||||
liquidation_events: list[LiquidationEvent] | None = None
|
||||
total_liquidation_loss: float = 0.0
|
||||
adjusted_return: float | None = None
|
||||
|
||||
|
||||
class Backtester:
|
||||
"""
|
||||
Backtester supporting multiple market types with realistic simulation.
|
||||
|
||||
Features:
|
||||
- Spot and Perpetual market support
|
||||
- Long and short position handling
|
||||
- Leverage simulation
|
||||
- Funding rate calculation (perpetuals)
|
||||
- Liquidation warnings
|
||||
"""
|
||||
|
||||
def __init__(self, data_manager: DataManager):
|
||||
self.dm = data_manager
|
||||
|
||||
def run_strategy(
|
||||
self,
|
||||
strategy: BaseStrategy,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str = '1m',
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
init_cash: float = 10000,
|
||||
fees: float | None = None,
|
||||
slippage: float = 0.001,
|
||||
sl_stop: float | None = None,
|
||||
tp_stop: float | None = None,
|
||||
sl_trail: bool = False,
|
||||
leverage: int | None = None,
|
||||
**strategy_params
|
||||
) -> BacktestResult:
|
||||
"""
|
||||
Run a backtest with market-type-aware simulation.
|
||||
|
||||
Args:
|
||||
strategy: Strategy instance to backtest
|
||||
exchange_id: Exchange identifier (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Data timeframe (e.g., '1m', '1h', '1d')
|
||||
start_date: Start date filter (YYYY-MM-DD)
|
||||
end_date: End date filter (YYYY-MM-DD)
|
||||
init_cash: Initial capital (margin for leveraged)
|
||||
fees: Transaction fee override (uses market default if None)
|
||||
slippage: Slippage percentage
|
||||
sl_stop: Stop loss percentage
|
||||
tp_stop: Take profit percentage
|
||||
sl_trail: Enable trailing stop loss
|
||||
leverage: Leverage override (uses strategy default if None)
|
||||
**strategy_params: Additional strategy parameters
|
||||
|
||||
Returns:
|
||||
BacktestResult with portfolio and market-specific metrics
|
||||
"""
|
||||
# Get market configuration from strategy
|
||||
market_type = strategy.default_market_type
|
||||
market_config = get_market_config(market_type)
|
||||
|
||||
# Resolve leverage and fees
|
||||
effective_leverage = self._resolve_leverage(leverage, strategy, market_type)
|
||||
effective_fees = fees if fees is not None else market_config.taker_fee
|
||||
|
||||
# Load and filter data
|
||||
df = self._load_data(
|
||||
exchange_id, symbol, timeframe, market_type, start_date, end_date
|
||||
)
|
||||
|
||||
close_price = df['close']
|
||||
high_price = df['high']
|
||||
low_price = df['low']
|
||||
open_price = df['open']
|
||||
volume = df['volume']
|
||||
|
||||
# Run strategy logic
|
||||
signals = strategy.run(
|
||||
close_price,
|
||||
high=high_price,
|
||||
low=low_price,
|
||||
open=open_price,
|
||||
volume=volume,
|
||||
**strategy_params
|
||||
)
|
||||
|
||||
# Normalize signals to 4-tuple format
|
||||
signals = self._normalize_signals(signals, close_price, market_config)
|
||||
long_entries, long_exits, short_entries, short_exits = signals
|
||||
|
||||
# Process liquidations - inject forced exits at liquidation points
|
||||
liquidation_events: list[LiquidationEvent] = []
|
||||
if effective_leverage > 1:
|
||||
long_exits, short_exits, liquidation_events = inject_liquidation_exits(
|
||||
close_price, high_price, low_price,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
effective_leverage,
|
||||
market_config.maintenance_margin_rate
|
||||
)
|
||||
|
||||
# Calculate perpetual-specific metrics (after liquidation processing)
|
||||
total_funding = 0.0
|
||||
if market_type == MarketType.PERPETUAL:
|
||||
total_funding = calculate_funding(
|
||||
close_price,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
market_config,
|
||||
effective_leverage
|
||||
)
|
||||
|
||||
# Run portfolio simulation with liquidation-aware exits
|
||||
portfolio = self._run_portfolio(
|
||||
close_price, market_config,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
init_cash, effective_fees, slippage, timeframe,
|
||||
sl_stop, tp_stop, sl_trail, effective_leverage
|
||||
)
|
||||
|
||||
# Calculate adjusted returns accounting for liquidation losses
|
||||
total_liq_loss, liq_adjustment = calculate_liquidation_adjustment(
|
||||
liquidation_events, init_cash, effective_leverage
|
||||
)
|
||||
|
||||
raw_return = portfolio.total_return().mean() * 100
|
||||
adjusted_return = raw_return - liq_adjustment
|
||||
|
||||
if liquidation_events:
|
||||
logger.info(
|
||||
"Liquidation impact: %d events, $%.2f margin lost, %.2f%% adjustment",
|
||||
len(liquidation_events), total_liq_loss, liq_adjustment
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Backtest completed: %s market, %dx leverage, fees=%.4f%%",
|
||||
market_type.value, effective_leverage, effective_fees * 100
|
||||
)
|
||||
|
||||
return BacktestResult(
|
||||
portfolio=portfolio,
|
||||
market_type=market_type,
|
||||
leverage=effective_leverage,
|
||||
total_funding_paid=total_funding,
|
||||
liquidation_count=len(liquidation_events),
|
||||
liquidation_events=liquidation_events,
|
||||
total_liquidation_loss=total_liq_loss,
|
||||
adjusted_return=adjusted_return
|
||||
)
|
||||
|
||||
def _resolve_leverage(
|
||||
self,
|
||||
leverage: int | None,
|
||||
strategy: BaseStrategy,
|
||||
market_type: MarketType
|
||||
) -> int:
|
||||
"""Resolve effective leverage from CLI, strategy default, or market type."""
|
||||
effective = leverage or strategy.default_leverage
|
||||
if market_type == MarketType.SPOT:
|
||||
return 1 # Spot cannot have leverage
|
||||
return effective
|
||||
|
||||
def _load_data(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
market_type: MarketType,
|
||||
start_date: str | None,
|
||||
end_date: str | None
|
||||
) -> pd.DataFrame:
|
||||
"""Load and filter OHLCV data."""
|
||||
try:
|
||||
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
|
||||
except FileNotFoundError:
|
||||
logger.warning("Data not found locally. Attempting download...")
|
||||
df = self.dm.download_data(
|
||||
exchange_id, symbol, timeframe,
|
||||
start_date, end_date, market_type
|
||||
)
|
||||
|
||||
if start_date:
|
||||
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
|
||||
if end_date:
|
||||
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
|
||||
|
||||
return df
|
||||
|
||||
def _normalize_signals(
|
||||
self,
|
||||
signals: tuple,
|
||||
close: pd.Series,
|
||||
market_config
|
||||
) -> tuple:
|
||||
"""
|
||||
Normalize strategy signals to 4-tuple format.
|
||||
|
||||
Handles backward compatibility with 2-tuple (long-only) returns.
|
||||
"""
|
||||
if len(signals) == 2:
|
||||
long_entries, long_exits = signals
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
|
||||
if len(signals) == 4:
|
||||
long_entries, long_exits, short_entries, short_exits = signals
|
||||
|
||||
# Warn and clear short signals on spot markets
|
||||
if not market_config.supports_short:
|
||||
has_shorts = (
|
||||
short_entries.any().any()
|
||||
if hasattr(short_entries, 'any')
|
||||
else short_entries.any()
|
||||
)
|
||||
if has_shorts:
|
||||
logger.warning(
|
||||
"Short signals detected but market type is SPOT. "
|
||||
"Short signals will be ignored."
|
||||
)
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
|
||||
raise ValueError(
|
||||
f"Strategy must return 2 or 4 signal arrays, got {len(signals)}"
|
||||
)
|
||||
|
||||
def _run_portfolio(
|
||||
self,
|
||||
close: pd.Series,
|
||||
market_config,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
init_cash: float,
|
||||
fees: float,
|
||||
slippage: float,
|
||||
freq: str,
|
||||
sl_stop: float | None,
|
||||
tp_stop: float | None,
|
||||
sl_trail: bool,
|
||||
leverage: int
|
||||
) -> vbt.Portfolio:
|
||||
"""Select and run appropriate portfolio simulation."""
|
||||
has_shorts = (
|
||||
short_entries.any().any()
|
||||
if hasattr(short_entries, 'any')
|
||||
else short_entries.any()
|
||||
)
|
||||
|
||||
if market_config.supports_short and has_shorts:
|
||||
return run_long_short_portfolio(
|
||||
close,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
init_cash, fees, slippage, freq,
|
||||
sl_stop, tp_stop, sl_trail, leverage
|
||||
)
|
||||
|
||||
return run_long_only_portfolio(
|
||||
close,
|
||||
long_entries, long_exits,
|
||||
init_cash, fees, slippage, freq,
|
||||
sl_stop, tp_stop, sl_trail, leverage
|
||||
)
|
||||
|
||||
def run_wfa(
|
||||
self,
|
||||
strategy: BaseStrategy,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
param_grid: dict,
|
||||
n_windows: int = 10,
|
||||
timeframe: str = '1m'
|
||||
):
|
||||
"""
|
||||
Execute Walk-Forward Analysis.
|
||||
|
||||
Args:
|
||||
strategy: Strategy instance to optimize
|
||||
exchange_id: Exchange identifier
|
||||
symbol: Trading pair symbol
|
||||
param_grid: Parameter grid for optimization
|
||||
n_windows: Number of walk-forward windows
|
||||
timeframe: Data timeframe to load
|
||||
|
||||
Returns:
|
||||
Tuple of (results DataFrame, stitched equity curve)
|
||||
"""
|
||||
market_type = strategy.default_market_type
|
||||
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
|
||||
|
||||
wfa = WalkForwardOptimizer(self, strategy, param_grid)
|
||||
|
||||
results, stitched_curve = wfa.run(
|
||||
df['close'],
|
||||
high=df['high'],
|
||||
low=df['low'],
|
||||
n_windows=n_windows
|
||||
)
|
||||
|
||||
return results, stitched_curve
|
||||
243
engine/cli.py
Normal file
243
engine/cli.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
CLI handler for Lowkey Backtest.
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from engine.backtester import Backtester
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import get_logger, setup_logging
|
||||
from engine.market import MarketType
|
||||
from engine.reporting import Reporter
|
||||
from strategies.factory import get_strategy, get_strategy_names
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_parser() -> argparse.ArgumentParser:
|
||||
"""Create and configure the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Lowkey Backtest CLI (VectorBT Edition)"
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
_add_download_parser(subparsers)
|
||||
_add_backtest_parser(subparsers)
|
||||
_add_wfa_parser(subparsers)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _add_download_parser(subparsers) -> None:
|
||||
"""Add download command parser."""
|
||||
dl_parser = subparsers.add_parser("download", help="Download historical data")
|
||||
dl_parser.add_argument("--exchange", "-e", type=str, default="okx")
|
||||
dl_parser.add_argument("--pair", "-p", type=str, required=True)
|
||||
dl_parser.add_argument("--timeframe", "-t", type=str, default="1m")
|
||||
dl_parser.add_argument("--start", type=str, help="Start Date (YYYY-MM-DD)")
|
||||
dl_parser.add_argument(
|
||||
"--market", "-m",
|
||||
type=str,
|
||||
choices=["spot", "perpetual"],
|
||||
default="spot"
|
||||
)
|
||||
|
||||
|
||||
def _add_backtest_parser(subparsers) -> None:
|
||||
"""Add backtest command parser."""
|
||||
strategy_choices = get_strategy_names()
|
||||
|
||||
bt_parser = subparsers.add_parser("backtest", help="Run a backtest")
|
||||
bt_parser.add_argument(
|
||||
"--strategy", "-s",
|
||||
type=str,
|
||||
choices=strategy_choices,
|
||||
required=True
|
||||
)
|
||||
bt_parser.add_argument("--exchange", "-e", type=str, default="okx")
|
||||
bt_parser.add_argument("--pair", "-p", type=str, required=True)
|
||||
bt_parser.add_argument("--timeframe", "-t", type=str, default="1m")
|
||||
bt_parser.add_argument("--start", type=str)
|
||||
bt_parser.add_argument("--end", type=str)
|
||||
bt_parser.add_argument("--grid", "-g", action="store_true")
|
||||
bt_parser.add_argument("--plot", action="store_true")
|
||||
|
||||
# Risk parameters
|
||||
bt_parser.add_argument("--sl", type=float, help="Stop Loss %%")
|
||||
bt_parser.add_argument("--tp", type=float, help="Take Profit %%")
|
||||
bt_parser.add_argument("--trail", action="store_true")
|
||||
bt_parser.add_argument("--no-bear-exit", action="store_true")
|
||||
|
||||
# Cost parameters
|
||||
bt_parser.add_argument("--fees", type=float, default=None)
|
||||
bt_parser.add_argument("--slippage", type=float, default=0.001)
|
||||
bt_parser.add_argument("--leverage", "-l", type=int, default=None)
|
||||
|
||||
|
||||
def _add_wfa_parser(subparsers) -> None:
|
||||
"""Add walk-forward analysis command parser."""
|
||||
strategy_choices = get_strategy_names()
|
||||
|
||||
wfa_parser = subparsers.add_parser("wfa", help="Run Walk-Forward Analysis")
|
||||
wfa_parser.add_argument(
|
||||
"--strategy", "-s",
|
||||
type=str,
|
||||
choices=strategy_choices,
|
||||
required=True
|
||||
)
|
||||
wfa_parser.add_argument("--pair", "-p", type=str, required=True)
|
||||
wfa_parser.add_argument("--timeframe", "-t", type=str, default="1d")
|
||||
wfa_parser.add_argument("--windows", "-w", type=int, default=10)
|
||||
wfa_parser.add_argument("--plot", action="store_true")
|
||||
|
||||
|
||||
def run_download(args) -> None:
|
||||
"""Execute download command."""
|
||||
dm = DataManager()
|
||||
market_type = MarketType(args.market)
|
||||
dm.download_data(
|
||||
args.exchange,
|
||||
args.pair,
|
||||
args.timeframe,
|
||||
start_date=args.start,
|
||||
market_type=market_type
|
||||
)
|
||||
|
||||
|
||||
def run_backtest(args) -> None:
|
||||
"""Execute backtest command."""
|
||||
dm = DataManager()
|
||||
bt = Backtester(dm)
|
||||
reporter = Reporter()
|
||||
|
||||
strategy, params = get_strategy(args.strategy, args.grid)
|
||||
|
||||
# Apply CLI overrides for meta_st strategy
|
||||
params = _apply_strategy_overrides(args, strategy, params)
|
||||
|
||||
if args.grid and args.strategy == "meta_st":
|
||||
logger.info("Running Grid Search for Meta Supertrend...")
|
||||
|
||||
try:
|
||||
result = bt.run_strategy(
|
||||
strategy,
|
||||
args.exchange,
|
||||
args.pair,
|
||||
timeframe=args.timeframe,
|
||||
start_date=args.start,
|
||||
end_date=args.end,
|
||||
fees=args.fees,
|
||||
slippage=args.slippage,
|
||||
sl_stop=args.sl,
|
||||
tp_stop=args.tp,
|
||||
sl_trail=args.trail,
|
||||
leverage=args.leverage,
|
||||
**params
|
||||
)
|
||||
|
||||
reporter.print_summary(result)
|
||||
reporter.save_reports(result, f"{args.strategy}_{args.pair.replace('/','-')}")
|
||||
|
||||
if args.plot and not args.grid:
|
||||
reporter.plot(result.portfolio)
|
||||
elif args.plot and args.grid:
|
||||
logger.info("Plotting skipped for Grid Search. Check CSV results.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Backtest failed: %s", e, exc_info=True)
|
||||
|
||||
|
||||
def run_wfa(args) -> None:
|
||||
"""Execute walk-forward analysis command."""
|
||||
dm = DataManager()
|
||||
bt = Backtester(dm)
|
||||
reporter = Reporter()
|
||||
|
||||
strategy, params = get_strategy(args.strategy, is_grid=True)
|
||||
|
||||
logger.info(
|
||||
"Running WFA on %s for %s (%s) with %d windows...",
|
||||
args.strategy, args.pair, args.timeframe, args.windows
|
||||
)
|
||||
|
||||
try:
|
||||
results, stitched_curve = bt.run_wfa(
|
||||
strategy,
|
||||
"okx",
|
||||
args.pair,
|
||||
params,
|
||||
n_windows=args.windows,
|
||||
timeframe=args.timeframe
|
||||
)
|
||||
|
||||
_log_wfa_results(results)
|
||||
_save_wfa_results(args, results, stitched_curve, reporter)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("WFA failed: %s", e, exc_info=True)
|
||||
|
||||
|
||||
def _apply_strategy_overrides(args, strategy, params: dict) -> dict:
|
||||
"""Apply CLI argument overrides to strategy parameters."""
|
||||
if args.strategy != "meta_st":
|
||||
return params
|
||||
|
||||
if args.no_bear_exit:
|
||||
params['exit_on_bearish_flip'] = False
|
||||
|
||||
if args.sl is None:
|
||||
args.sl = strategy.default_sl_stop
|
||||
|
||||
if not args.trail:
|
||||
args.trail = strategy.default_sl_trail
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _log_wfa_results(results) -> None:
|
||||
"""Log WFA results summary."""
|
||||
logger.info("Walk-Forward Analysis Results:")
|
||||
|
||||
if results.empty or 'window' not in results.columns:
|
||||
logger.warning("No valid WFA results. All windows may have failed.")
|
||||
return
|
||||
|
||||
columns = ['window', 'train_score', 'test_score', 'test_return']
|
||||
logger.info("\n%s", results[columns].to_string(index=False))
|
||||
|
||||
avg_test_sharpe = results['test_score'].mean()
|
||||
avg_test_return = results['test_return'].mean()
|
||||
logger.info("Average Test Sharpe: %.2f", avg_test_sharpe)
|
||||
logger.info("Average Test Return: %.2f%%", avg_test_return * 100)
|
||||
|
||||
|
||||
def _save_wfa_results(args, results, stitched_curve, reporter) -> None:
|
||||
"""Save WFA results to file and optionally plot."""
|
||||
if results.empty:
|
||||
return
|
||||
|
||||
output_path = f"backtest_logs/wfa_{args.strategy}_{args.pair.replace('/','-')}.csv"
|
||||
results.to_csv(output_path)
|
||||
logger.info("Saved full results to %s", output_path)
|
||||
|
||||
if args.plot:
|
||||
reporter.plot_wfa(results, stitched_curve)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
setup_logging()
|
||||
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
commands = {
|
||||
"download": run_download,
|
||||
"backtest": run_backtest,
|
||||
"wfa": run_wfa,
|
||||
}
|
||||
|
||||
if args.command in commands:
|
||||
commands[args.command](args)
|
||||
else:
|
||||
parser.print_help()
|
||||
209
engine/data_manager.py
Normal file
209
engine/data_manager.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Data management for OHLCV data download and storage.
|
||||
|
||||
Handles data retrieval from exchanges and local file management.
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import ccxt
|
||||
import pandas as pd
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketType, get_ccxt_symbol
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataManager:
|
||||
"""
|
||||
Manages OHLCV data download and storage for different market types.
|
||||
|
||||
Data is stored in: data/ccxt/{exchange}/{market_type}/{symbol}/{timeframe}.csv
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str = "data/ccxt"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.exchanges: dict[str, ccxt.Exchange] = {}
|
||||
|
||||
def get_exchange(self, exchange_id: str) -> ccxt.Exchange:
|
||||
"""Get or create a CCXT exchange instance."""
|
||||
if exchange_id not in self.exchanges:
|
||||
exchange_class = getattr(ccxt, exchange_id)
|
||||
self.exchanges[exchange_id] = exchange_class({
|
||||
'enableRateLimit': True,
|
||||
})
|
||||
return self.exchanges[exchange_id]
|
||||
|
||||
def _get_data_path(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
market_type: MarketType
|
||||
) -> Path:
|
||||
"""
|
||||
Get the file path for storing/loading data.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange name (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Candle timeframe (e.g., '1m')
|
||||
market_type: Market type (spot or perpetual)
|
||||
|
||||
Returns:
|
||||
Path to the CSV file
|
||||
"""
|
||||
safe_symbol = symbol.replace('/', '-')
|
||||
return (
|
||||
self.data_dir
|
||||
/ exchange_id
|
||||
/ market_type.value
|
||||
/ safe_symbol
|
||||
/ f"{timeframe}.csv"
|
||||
)
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str = '1m',
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
market_type: MarketType = MarketType.SPOT
|
||||
) -> pd.DataFrame | None:
|
||||
"""
|
||||
Download OHLCV data from exchange and save to CSV.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange name (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Candle timeframe (e.g., '1m')
|
||||
start_date: Start date string (YYYY-MM-DD)
|
||||
end_date: End date string (YYYY-MM-DD)
|
||||
market_type: Market type (spot or perpetual)
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data, or None if download failed
|
||||
"""
|
||||
exchange = self.get_exchange(exchange_id)
|
||||
|
||||
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ccxt_symbol = get_ccxt_symbol(symbol, market_type)
|
||||
|
||||
since, until = self._parse_date_range(exchange, start_date, end_date)
|
||||
|
||||
logger.info(
|
||||
"Downloading %s (%s) from %s...",
|
||||
symbol, market_type.value, exchange_id
|
||||
)
|
||||
|
||||
all_ohlcv = self._fetch_all_candles(exchange, ccxt_symbol, timeframe, since, until)
|
||||
|
||||
if not all_ohlcv:
|
||||
logger.warning("No data downloaded.")
|
||||
return None
|
||||
|
||||
df = self._convert_to_dataframe(all_ohlcv)
|
||||
df.to_csv(file_path)
|
||||
logger.info("Saved %d candles to %s", len(df), file_path)
|
||||
return df
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str = '1m',
|
||||
market_type: MarketType = MarketType.SPOT
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Load saved OHLCV data for vectorbt.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange name (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Candle timeframe (e.g., '1m')
|
||||
market_type: Market type (spot or perpetual)
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data indexed by timestamp
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If data file does not exist
|
||||
"""
|
||||
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Data not found at {file_path}. "
|
||||
f"Run: uv run python main.py download --pair {symbol} "
|
||||
f"--market {market_type.value}"
|
||||
)
|
||||
|
||||
return pd.read_csv(file_path, index_col='timestamp', parse_dates=True)
|
||||
|
||||
def _parse_date_range(
|
||||
self,
|
||||
exchange: ccxt.Exchange,
|
||||
start_date: str | None,
|
||||
end_date: str | None
|
||||
) -> tuple[int, int]:
|
||||
"""Parse date strings into millisecond timestamps."""
|
||||
if start_date:
|
||||
since = exchange.parse8601(f"{start_date}T00:00:00Z")
|
||||
else:
|
||||
since = exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
|
||||
|
||||
if end_date:
|
||||
until = exchange.parse8601(f"{end_date}T23:59:59Z")
|
||||
else:
|
||||
until = exchange.milliseconds()
|
||||
|
||||
return since, until
|
||||
|
||||
def _fetch_all_candles(
|
||||
self,
|
||||
exchange: ccxt.Exchange,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
since: int,
|
||||
until: int
|
||||
) -> list:
|
||||
"""Fetch all candles in the date range."""
|
||||
all_ohlcv = []
|
||||
|
||||
while since < until:
|
||||
try:
|
||||
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since, limit=100)
|
||||
if not ohlcv:
|
||||
break
|
||||
|
||||
all_ohlcv.extend(ohlcv)
|
||||
since = ohlcv[-1][0] + 1
|
||||
|
||||
current_date = datetime.fromtimestamp(
|
||||
since/1000, tz=timezone.utc
|
||||
).strftime('%Y-%m-%d')
|
||||
logger.debug("Fetched up to %s", current_date)
|
||||
|
||||
time.sleep(exchange.rateLimit / 1000)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error fetching data: %s", e)
|
||||
break
|
||||
|
||||
return all_ohlcv
|
||||
|
||||
def _convert_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
|
||||
"""Convert OHLCV list to DataFrame."""
|
||||
df = pd.DataFrame(
|
||||
ohlcv,
|
||||
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
return df
|
||||
124
engine/logging_config.py
Normal file
124
engine/logging_config.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Centralized logging configuration for the backtest engine.
|
||||
|
||||
Provides colored console output and rotating file logs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
class Colors:
|
||||
"""ANSI escape codes for colored terminal output."""
|
||||
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
|
||||
# Log level colors
|
||||
DEBUG = "\033[36m" # Cyan
|
||||
INFO = "\033[32m" # Green
|
||||
WARNING = "\033[33m" # Yellow
|
||||
ERROR = "\033[31m" # Red
|
||||
CRITICAL = "\033[35m" # Magenta
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""
|
||||
Custom formatter that adds colors to log level names in terminal output.
|
||||
"""
|
||||
|
||||
LEVEL_COLORS = {
|
||||
logging.DEBUG: Colors.DEBUG,
|
||||
logging.INFO: Colors.INFO,
|
||||
logging.WARNING: Colors.WARNING,
|
||||
logging.ERROR: Colors.ERROR,
|
||||
logging.CRITICAL: Colors.CRITICAL,
|
||||
}
|
||||
|
||||
def __init__(self, fmt: str = None, datefmt: str = None):
|
||||
super().__init__(fmt, datefmt)
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Save original levelname
|
||||
original_levelname = record.levelname
|
||||
|
||||
# Add color to levelname
|
||||
color = self.LEVEL_COLORS.get(record.levelno, Colors.RESET)
|
||||
record.levelname = f"{color}{record.levelname}{Colors.RESET}"
|
||||
|
||||
# Format the message
|
||||
result = super().format(record)
|
||||
|
||||
# Restore original levelname
|
||||
record.levelname = original_levelname
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def setup_logging(
|
||||
log_dir: str = "logs",
|
||||
log_level: int = logging.INFO,
|
||||
console_level: int = logging.INFO,
|
||||
max_bytes: int = 5 * 1024 * 1024, # 5MB
|
||||
backup_count: int = 3
|
||||
) -> None:
|
||||
"""
|
||||
Configure logging for the application.
|
||||
|
||||
Args:
|
||||
log_dir: Directory for log files
|
||||
log_level: File logging level
|
||||
console_level: Console logging level
|
||||
max_bytes: Max size per log file before rotation
|
||||
backup_count: Number of backup files to keep
|
||||
"""
|
||||
log_path = Path(log_dir)
|
||||
log_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.DEBUG) # Capture all, handlers filter
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Console handler with colors
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(console_level)
|
||||
console_fmt = ColoredFormatter(
|
||||
fmt="[%(asctime)s] %(levelname)s - %(message)s",
|
||||
datefmt="%H:%M:%S"
|
||||
)
|
||||
console_handler.setFormatter(console_fmt)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler with rotation
|
||||
file_handler = RotatingFileHandler(
|
||||
log_path / "backtest.log",
|
||||
maxBytes=max_bytes,
|
||||
backupCount=backup_count,
|
||||
encoding="utf-8"
|
||||
)
|
||||
file_handler.setLevel(log_level)
|
||||
file_fmt = logging.Formatter(
|
||||
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
file_handler.setFormatter(file_fmt)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger instance for the given module name.
|
||||
|
||||
Args:
|
||||
name: Module name (typically __name__)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
156
engine/market.py
Normal file
156
engine/market.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Market type definitions and configuration for backtesting.
|
||||
|
||||
Supports different market types with their specific trading conditions:
|
||||
- SPOT: No leverage, no funding, long-only
|
||||
- PERPETUAL: Leverage, funding rates, long/short
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MarketType(Enum):
|
||||
"""Supported market types for backtesting."""
|
||||
SPOT = "spot"
|
||||
PERPETUAL = "perpetual"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MarketConfig:
|
||||
"""
|
||||
Configuration for a specific market type.
|
||||
|
||||
Attributes:
|
||||
market_type: The market type enum value
|
||||
maker_fee: Maker fee as decimal (e.g., 0.0008 = 0.08%)
|
||||
taker_fee: Taker fee as decimal (e.g., 0.001 = 0.1%)
|
||||
max_leverage: Maximum allowed leverage
|
||||
funding_rate: Funding rate per 8 hours as decimal (perpetuals only)
|
||||
funding_interval_hours: Hours between funding payments
|
||||
maintenance_margin_rate: Rate for liquidation calculation
|
||||
supports_short: Whether short-selling is supported
|
||||
"""
|
||||
market_type: MarketType
|
||||
maker_fee: float
|
||||
taker_fee: float
|
||||
max_leverage: int
|
||||
funding_rate: float
|
||||
funding_interval_hours: int
|
||||
maintenance_margin_rate: float
|
||||
supports_short: bool
|
||||
|
||||
|
||||
# OKX-based default configurations
|
||||
SPOT_CONFIG = MarketConfig(
|
||||
market_type=MarketType.SPOT,
|
||||
maker_fee=0.0008, # 0.08%
|
||||
taker_fee=0.0010, # 0.10%
|
||||
max_leverage=1,
|
||||
funding_rate=0.0,
|
||||
funding_interval_hours=0,
|
||||
maintenance_margin_rate=0.0,
|
||||
supports_short=False,
|
||||
)
|
||||
|
||||
PERPETUAL_CONFIG = MarketConfig(
|
||||
market_type=MarketType.PERPETUAL,
|
||||
maker_fee=0.0002, # 0.02%
|
||||
taker_fee=0.0005, # 0.05%
|
||||
max_leverage=125,
|
||||
funding_rate=0.0001, # 0.01% per 8 hours (simplified average)
|
||||
funding_interval_hours=8,
|
||||
maintenance_margin_rate=0.004, # 0.4% for BTC on OKX
|
||||
supports_short=True,
|
||||
)
|
||||
|
||||
|
||||
def get_market_config(market_type: MarketType) -> MarketConfig:
|
||||
"""
|
||||
Get the configuration for a specific market type.
|
||||
|
||||
Args:
|
||||
market_type: The market type to get configuration for
|
||||
|
||||
Returns:
|
||||
MarketConfig with default values for that market type
|
||||
"""
|
||||
configs = {
|
||||
MarketType.SPOT: SPOT_CONFIG,
|
||||
MarketType.PERPETUAL: PERPETUAL_CONFIG,
|
||||
}
|
||||
return configs[market_type]
|
||||
|
||||
|
||||
def get_ccxt_symbol(symbol: str, market_type: MarketType) -> str:
|
||||
"""
|
||||
Convert a standard symbol to CCXT format for the given market type.
|
||||
|
||||
Args:
|
||||
symbol: Standard symbol (e.g., 'BTC/USDT')
|
||||
market_type: The market type
|
||||
|
||||
Returns:
|
||||
CCXT-formatted symbol (e.g., 'BTC/USDT:USDT' for perpetuals)
|
||||
"""
|
||||
if market_type == MarketType.PERPETUAL:
|
||||
# OKX perpetual format: BTC/USDT:USDT
|
||||
quote = symbol.split('/')[1] if '/' in symbol else 'USDT'
|
||||
return f"{symbol}:{quote}"
|
||||
return symbol
|
||||
|
||||
|
||||
def calculate_leverage_stop_loss(
|
||||
leverage: int,
|
||||
maintenance_margin_rate: float = 0.004
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the implicit stop-loss percentage from leverage.
|
||||
|
||||
At a given leverage, liquidation occurs when the position loses
|
||||
approximately (1/leverage - maintenance_margin_rate) of its value.
|
||||
|
||||
Args:
|
||||
leverage: Position leverage multiplier
|
||||
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
|
||||
|
||||
Returns:
|
||||
Stop-loss percentage as decimal (e.g., 0.196 for 19.6%)
|
||||
"""
|
||||
if leverage <= 1:
|
||||
return 1.0 # No forced stop for spot
|
||||
|
||||
return (1 / leverage) - maintenance_margin_rate
|
||||
|
||||
|
||||
def calculate_liquidation_price(
|
||||
entry_price: float,
|
||||
leverage: float,
|
||||
is_long: bool,
|
||||
maintenance_margin_rate: float = 0.004
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the liquidation price for a leveraged position.
|
||||
|
||||
Args:
|
||||
entry_price: Position entry price
|
||||
leverage: Position leverage
|
||||
is_long: True for long positions, False for short
|
||||
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
|
||||
|
||||
Returns:
|
||||
Liquidation price
|
||||
"""
|
||||
if leverage <= 1:
|
||||
return 0.0 if is_long else float('inf')
|
||||
|
||||
# Simplified liquidation formula
|
||||
# Long: liq_price = entry * (1 - 1/leverage + maintenance_margin_rate)
|
||||
# Short: liq_price = entry * (1 + 1/leverage - maintenance_margin_rate)
|
||||
margin_ratio = 1 / leverage
|
||||
|
||||
if is_long:
|
||||
liq_price = entry_price * (1 - margin_ratio + maintenance_margin_rate)
|
||||
else:
|
||||
liq_price = entry_price * (1 + margin_ratio - maintenance_margin_rate)
|
||||
|
||||
return liq_price
|
||||
245
engine/optimizer.py
Normal file
245
engine/optimizer.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Walk-Forward Analysis optimizer for strategy parameter optimization.
|
||||
|
||||
Implements expanding window walk-forward analysis with train/test splits.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_rolling_windows(
|
||||
index: pd.Index,
|
||||
n_windows: int,
|
||||
train_split: float = 0.7
|
||||
):
|
||||
"""
|
||||
Create rolling train/test split indices using expanding window approach.
|
||||
|
||||
Args:
|
||||
index: DataFrame index to split
|
||||
n_windows: Number of walk-forward windows
|
||||
train_split: Unused, kept for API compatibility
|
||||
|
||||
Yields:
|
||||
Tuples of (train_idx, test_idx) numpy arrays
|
||||
"""
|
||||
chunks = np.array_split(index, n_windows + 1)
|
||||
|
||||
for i in range(n_windows):
|
||||
train_idx = np.concatenate([c for c in chunks[:i+1]])
|
||||
test_idx = chunks[i+1]
|
||||
yield train_idx, test_idx
|
||||
|
||||
|
||||
class WalkForwardOptimizer:
|
||||
"""
|
||||
Walk-Forward Analysis optimizer for strategy backtesting.
|
||||
|
||||
Optimizes strategy parameters on training windows and validates
|
||||
on out-of-sample test windows.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backtester,
|
||||
strategy,
|
||||
param_grid: dict,
|
||||
metric: str = 'Sharpe Ratio',
|
||||
fees: float = 0.001,
|
||||
freq: str = '1m'
|
||||
):
|
||||
"""
|
||||
Initialize the optimizer.
|
||||
|
||||
Args:
|
||||
backtester: Backtester instance
|
||||
strategy: Strategy instance to optimize
|
||||
param_grid: Parameter grid for optimization
|
||||
metric: Performance metric to optimize
|
||||
fees: Transaction fees for simulation
|
||||
freq: Data frequency for portfolio simulation
|
||||
"""
|
||||
self.bt = backtester
|
||||
self.strategy = strategy
|
||||
self.param_grid = param_grid
|
||||
self.metric = metric
|
||||
self.fees = fees
|
||||
self.freq = freq
|
||||
|
||||
# Separate grid params (lists) from fixed params (scalars)
|
||||
self.grid_keys = []
|
||||
self.fixed_params = {}
|
||||
for k, v in param_grid.items():
|
||||
if isinstance(v, (list, np.ndarray)):
|
||||
self.grid_keys.append(k)
|
||||
else:
|
||||
self.fixed_params[k] = v
|
||||
|
||||
def run(
|
||||
self,
|
||||
close_price: pd.Series,
|
||||
high: pd.Series | None = None,
|
||||
low: pd.Series | None = None,
|
||||
n_windows: int = 10
|
||||
) -> tuple[pd.DataFrame, pd.Series | None]:
|
||||
"""
|
||||
Execute walk-forward analysis.
|
||||
|
||||
Args:
|
||||
close_price: Close price series
|
||||
high: High price series (optional)
|
||||
low: Low price series (optional)
|
||||
n_windows: Number of walk-forward windows
|
||||
|
||||
Returns:
|
||||
Tuple of (results DataFrame, stitched equity curve)
|
||||
"""
|
||||
results = []
|
||||
equity_curves = []
|
||||
|
||||
logger.info(
|
||||
"Starting Walk-Forward Analysis with %d windows (Expanding Train)...",
|
||||
n_windows
|
||||
)
|
||||
|
||||
splitter = create_rolling_windows(close_price.index, n_windows)
|
||||
|
||||
for i, (train_idx, test_idx) in enumerate(splitter):
|
||||
logger.info("Processing Window %d/%d...", i + 1, n_windows)
|
||||
|
||||
window_result = self._process_window(
|
||||
i, train_idx, test_idx, close_price, high, low
|
||||
)
|
||||
|
||||
if window_result is not None:
|
||||
result_dict, eq_curve = window_result
|
||||
results.append(result_dict)
|
||||
equity_curves.append(eq_curve)
|
||||
|
||||
stitched_series = self._stitch_equity_curves(equity_curves)
|
||||
return pd.DataFrame(results), stitched_series
|
||||
|
||||
def _process_window(
|
||||
self,
|
||||
window_idx: int,
|
||||
train_idx: np.ndarray,
|
||||
test_idx: np.ndarray,
|
||||
close_price: pd.Series,
|
||||
high: pd.Series | None,
|
||||
low: pd.Series | None
|
||||
) -> tuple[dict, pd.Series] | None:
|
||||
"""Process a single WFA window."""
|
||||
try:
|
||||
# Slice data for train/test
|
||||
train_close = close_price.loc[train_idx]
|
||||
train_high = high.loc[train_idx] if high is not None else None
|
||||
train_low = low.loc[train_idx] if low is not None else None
|
||||
|
||||
# Train phase: find best parameters
|
||||
best_params, best_score = self._optimize_train(
|
||||
train_close, train_high, train_low
|
||||
)
|
||||
|
||||
# Test phase: validate with best params
|
||||
test_close = close_price.loc[test_idx]
|
||||
test_high = high.loc[test_idx] if high is not None else None
|
||||
test_low = low.loc[test_idx] if low is not None else None
|
||||
|
||||
test_params = {**self.fixed_params, **best_params}
|
||||
test_score, test_return, eq_curve = self._run_test(
|
||||
test_close, test_high, test_low, test_params
|
||||
)
|
||||
|
||||
return {
|
||||
'window': window_idx + 1,
|
||||
'train_start': train_idx[0],
|
||||
'train_end': train_idx[-1],
|
||||
'test_start': test_idx[0],
|
||||
'test_end': test_idx[-1],
|
||||
'best_params': best_params,
|
||||
'train_score': best_score,
|
||||
'test_score': test_score,
|
||||
'test_return': test_return
|
||||
}, eq_curve
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in window %d: %s", window_idx + 1, e, exc_info=True)
|
||||
return None
|
||||
|
||||
def _optimize_train(
|
||||
self,
|
||||
close: pd.Series,
|
||||
high: pd.Series | None,
|
||||
low: pd.Series | None
|
||||
) -> tuple[dict, float]:
|
||||
"""Run grid search on training data to find best parameters."""
|
||||
entries, exits = self.strategy.run(
|
||||
close, high=high, low=low, **self.param_grid
|
||||
)
|
||||
|
||||
pf_train = vbt.Portfolio.from_signals(
|
||||
close, entries, exits,
|
||||
fees=self.fees,
|
||||
freq=self.freq
|
||||
)
|
||||
|
||||
perf_stats = pf_train.sharpe_ratio()
|
||||
perf_stats = perf_stats.fillna(-999)
|
||||
|
||||
best_idx = perf_stats.idxmax()
|
||||
best_score = perf_stats.max()
|
||||
|
||||
# Extract best params from grid search
|
||||
if len(self.grid_keys) == 1:
|
||||
best_params = {self.grid_keys[0]: best_idx}
|
||||
elif len(self.grid_keys) > 1:
|
||||
best_params = dict(zip(self.grid_keys, best_idx))
|
||||
else:
|
||||
best_params = {}
|
||||
|
||||
return best_params, best_score
|
||||
|
||||
def _run_test(
|
||||
self,
|
||||
close: pd.Series,
|
||||
high: pd.Series | None,
|
||||
low: pd.Series | None,
|
||||
params: dict
|
||||
) -> tuple[float, float, pd.Series]:
|
||||
"""Run test phase with given parameters."""
|
||||
entries, exits = self.strategy.run(
|
||||
close, high=high, low=low, **params
|
||||
)
|
||||
|
||||
pf_test = vbt.Portfolio.from_signals(
|
||||
close, entries, exits,
|
||||
fees=self.fees,
|
||||
freq=self.freq
|
||||
)
|
||||
|
||||
return pf_test.sharpe_ratio(), pf_test.total_return(), pf_test.value()
|
||||
|
||||
def _stitch_equity_curves(
|
||||
self,
|
||||
equity_curves: list[pd.Series]
|
||||
) -> pd.Series | None:
|
||||
"""Stitch multiple equity curves into a continuous series."""
|
||||
if not equity_curves:
|
||||
return None
|
||||
|
||||
stitched = [equity_curves[0]]
|
||||
for j in range(1, len(equity_curves)):
|
||||
prev_end_val = stitched[-1].iloc[-1]
|
||||
curr_curve = equity_curves[j]
|
||||
init_cash = curr_curve.iloc[0]
|
||||
|
||||
# Scale curve to continue from previous end value
|
||||
scaled_curve = (curr_curve / init_cash) * prev_end_val
|
||||
stitched.append(scaled_curve)
|
||||
|
||||
return pd.concat(stitched)
|
||||
148
engine/portfolio.py
Normal file
148
engine/portfolio.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Portfolio simulation utilities for backtesting.
|
||||
|
||||
Handles long-only and long/short portfolio creation using VectorBT.
|
||||
"""
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_long_only_portfolio(
|
||||
close: pd.Series,
|
||||
entries: pd.DataFrame,
|
||||
exits: pd.DataFrame,
|
||||
init_cash: float,
|
||||
fees: float,
|
||||
slippage: float,
|
||||
freq: str,
|
||||
sl_stop: float | None,
|
||||
tp_stop: float | None,
|
||||
sl_trail: bool,
|
||||
leverage: int
|
||||
) -> vbt.Portfolio:
|
||||
"""
|
||||
Run a long-only portfolio simulation.
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
entries: Entry signals
|
||||
exits: Exit signals
|
||||
init_cash: Initial capital
|
||||
fees: Transaction fee percentage
|
||||
slippage: Slippage percentage
|
||||
freq: Data frequency string
|
||||
sl_stop: Stop loss percentage
|
||||
tp_stop: Take profit percentage
|
||||
sl_trail: Enable trailing stop loss
|
||||
leverage: Leverage multiplier
|
||||
|
||||
Returns:
|
||||
VectorBT Portfolio object
|
||||
"""
|
||||
effective_cash = init_cash * leverage
|
||||
|
||||
return vbt.Portfolio.from_signals(
|
||||
close=close,
|
||||
entries=entries,
|
||||
exits=exits,
|
||||
init_cash=effective_cash,
|
||||
fees=fees,
|
||||
slippage=slippage,
|
||||
freq=freq,
|
||||
sl_stop=sl_stop,
|
||||
tp_stop=tp_stop,
|
||||
sl_trail=sl_trail,
|
||||
size=1.0,
|
||||
size_type='percent',
|
||||
)
|
||||
|
||||
|
||||
def run_long_short_portfolio(
|
||||
close: pd.Series,
|
||||
long_entries: pd.DataFrame,
|
||||
long_exits: pd.DataFrame,
|
||||
short_entries: pd.DataFrame,
|
||||
short_exits: pd.DataFrame,
|
||||
init_cash: float,
|
||||
fees: float,
|
||||
slippage: float,
|
||||
freq: str,
|
||||
sl_stop: float | None,
|
||||
tp_stop: float | None,
|
||||
sl_trail: bool,
|
||||
leverage: int
|
||||
) -> vbt.Portfolio:
|
||||
"""
|
||||
Run a portfolio supporting both long and short positions.
|
||||
|
||||
Runs two separate portfolios (long and short) and combines results.
|
||||
Each gets half the capital.
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
init_cash: Initial capital
|
||||
fees: Transaction fee percentage
|
||||
slippage: Slippage percentage
|
||||
freq: Data frequency string
|
||||
sl_stop: Stop loss percentage
|
||||
tp_stop: Take profit percentage
|
||||
sl_trail: Enable trailing stop loss
|
||||
leverage: Leverage multiplier
|
||||
|
||||
Returns:
|
||||
VectorBT Portfolio object (long portfolio, short stats logged)
|
||||
"""
|
||||
effective_cash = init_cash * leverage
|
||||
half_cash = effective_cash / 2
|
||||
|
||||
# Run long-only portfolio
|
||||
long_pf = vbt.Portfolio.from_signals(
|
||||
close=close,
|
||||
entries=long_entries,
|
||||
exits=long_exits,
|
||||
direction='longonly',
|
||||
init_cash=half_cash,
|
||||
fees=fees,
|
||||
slippage=slippage,
|
||||
freq=freq,
|
||||
sl_stop=sl_stop,
|
||||
tp_stop=tp_stop,
|
||||
sl_trail=sl_trail,
|
||||
size=1.0,
|
||||
size_type='percent',
|
||||
)
|
||||
|
||||
# Run short-only portfolio
|
||||
short_pf = vbt.Portfolio.from_signals(
|
||||
close=close,
|
||||
entries=short_entries,
|
||||
exits=short_exits,
|
||||
direction='shortonly',
|
||||
init_cash=half_cash,
|
||||
fees=fees,
|
||||
slippage=slippage,
|
||||
freq=freq,
|
||||
sl_stop=sl_stop,
|
||||
tp_stop=tp_stop,
|
||||
sl_trail=sl_trail,
|
||||
size=1.0,
|
||||
size_type='percent',
|
||||
)
|
||||
|
||||
# Log both portfolio stats
|
||||
# TODO: Implement proper portfolio combination
|
||||
logger.info(
|
||||
"Long portfolio: %.2f%% return, Short portfolio: %.2f%% return",
|
||||
long_pf.total_return().mean() * 100,
|
||||
short_pf.total_return().mean() * 100
|
||||
)
|
||||
|
||||
return long_pf
|
||||
228
engine/reporting.py
Normal file
228
engine/reporting.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Reporting module for backtest results.
|
||||
|
||||
Handles summary printing, CSV exports, and plotting.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import vectorbt as vbt
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Reporter:
|
||||
"""Reporter for backtest results with market-specific metrics."""
|
||||
|
||||
def __init__(self, output_dir: str = "backtest_logs"):
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def print_summary(self, result) -> None:
|
||||
"""
|
||||
Print backtest summary to console via logger.
|
||||
|
||||
Args:
|
||||
result: BacktestResult or vbt.Portfolio object
|
||||
"""
|
||||
(portfolio, market_type, leverage, funding_paid,
|
||||
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
|
||||
|
||||
# Extract period info
|
||||
idx = portfolio.wrapper.index
|
||||
start_date = idx[0].strftime("%Y-%m-%d")
|
||||
end_date = idx[-1].strftime("%Y-%m-%d")
|
||||
|
||||
# Extract price info
|
||||
close = portfolio.close
|
||||
start_price = close.iloc[0].mean() if hasattr(close.iloc[0], 'mean') else close.iloc[0]
|
||||
end_price = close.iloc[-1].mean() if hasattr(close.iloc[-1], 'mean') else close.iloc[-1]
|
||||
price_change = ((end_price - start_price) / start_price) * 100
|
||||
|
||||
# Extract fees
|
||||
stats = portfolio.stats()
|
||||
total_fees = stats.get('Total Fees Paid', 0)
|
||||
|
||||
raw_return = portfolio.total_return().mean() * 100
|
||||
|
||||
# Build summary
|
||||
summary_lines = [
|
||||
"",
|
||||
"=" * 50,
|
||||
"BACKTEST RESULTS",
|
||||
"=" * 50,
|
||||
f"Market Type: [{market_type.upper()}]",
|
||||
f"Leverage: [{leverage}x]",
|
||||
f"Period: [{start_date} to {end_date}]",
|
||||
f"Price: [{start_price:,.2f} -> {end_price:,.2f} ({price_change:+.2f}%)]",
|
||||
]
|
||||
|
||||
# Show adjusted return if liquidations occurred
|
||||
if liq_count > 0 and adjusted_return is not None:
|
||||
summary_lines.append(f"Raw Return: [%{raw_return:.2f}] (before liq adjustment)")
|
||||
summary_lines.append(f"Adj Return: [%{adjusted_return:.2f}] (after liq losses)")
|
||||
else:
|
||||
summary_lines.append(f"Total Return: [%{raw_return:.2f}]")
|
||||
|
||||
summary_lines.extend([
|
||||
f"Sharpe Ratio: [{portfolio.sharpe_ratio().mean():.2f}]",
|
||||
f"Max Drawdown: [%{portfolio.max_drawdown().mean() * 100:.2f}]",
|
||||
f"Total Trades: [{portfolio.trades.count().mean():.0f}]",
|
||||
f"Win Rate: [%{portfolio.trades.win_rate().mean() * 100:.2f}]",
|
||||
f"Total Fees: [{total_fees:,.2f}]",
|
||||
])
|
||||
|
||||
if funding_paid != 0:
|
||||
summary_lines.append(f"Funding Paid: [{funding_paid:,.2f}]")
|
||||
if liq_count > 0:
|
||||
summary_lines.append(f"Liquidations: [{liq_count}] (${liq_loss:,.2f} margin lost)")
|
||||
|
||||
summary_lines.append("=" * 50)
|
||||
logger.info("\n".join(summary_lines))
|
||||
|
||||
def save_reports(self, result, filename_prefix: str) -> None:
|
||||
"""
|
||||
Save trade log, stats, and liquidation events to CSV files.
|
||||
|
||||
Args:
|
||||
result: BacktestResult or vbt.Portfolio object
|
||||
filename_prefix: Prefix for output filenames
|
||||
"""
|
||||
(portfolio, market_type, leverage, funding_paid,
|
||||
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
|
||||
|
||||
# Save trades
|
||||
self._save_csv(
|
||||
data=portfolio.trades.records_readable,
|
||||
path=self.output_dir / f"{filename_prefix}_trades.csv",
|
||||
description="trade log"
|
||||
)
|
||||
|
||||
# Save stats with market-specific additions
|
||||
stats = portfolio.stats()
|
||||
stats['Market Type'] = market_type
|
||||
stats['Leverage'] = leverage
|
||||
stats['Total Funding Paid'] = funding_paid
|
||||
stats['Liquidations'] = liq_count
|
||||
stats['Liquidation Loss'] = liq_loss
|
||||
if adjusted_return is not None:
|
||||
stats['Adjusted Return'] = adjusted_return
|
||||
|
||||
self._save_csv(
|
||||
data=stats,
|
||||
path=self.output_dir / f"{filename_prefix}_stats.csv",
|
||||
description="stats"
|
||||
)
|
||||
|
||||
# Save liquidation events if any
|
||||
if hasattr(result, 'liquidation_events') and result.liquidation_events:
|
||||
liq_df = pd.DataFrame([
|
||||
{
|
||||
'entry_time': e.entry_time,
|
||||
'entry_price': e.entry_price,
|
||||
'liquidation_time': e.liquidation_time,
|
||||
'liquidation_price': e.liquidation_price,
|
||||
'actual_price': e.actual_price,
|
||||
'direction': e.direction,
|
||||
'margin_lost_pct': e.margin_lost_pct
|
||||
}
|
||||
for e in result.liquidation_events
|
||||
])
|
||||
self._save_csv(
|
||||
data=liq_df,
|
||||
path=self.output_dir / f"{filename_prefix}_liquidations.csv",
|
||||
description="liquidation events"
|
||||
)
|
||||
|
||||
def plot(self, portfolio: vbt.Portfolio, show: bool = True) -> None:
|
||||
"""Display portfolio plot."""
|
||||
if show:
|
||||
portfolio.plot().show()
|
||||
|
||||
def plot_wfa(
|
||||
self,
|
||||
wfa_results: pd.DataFrame,
|
||||
stitched_curve: pd.Series | None = None,
|
||||
show: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Plot Walk-Forward Analysis results.
|
||||
|
||||
Args:
|
||||
wfa_results: DataFrame with WFA window results
|
||||
stitched_curve: Stitched out-of-sample equity curve
|
||||
show: Whether to display the plot
|
||||
"""
|
||||
fig = make_subplots(
|
||||
rows=2, cols=1,
|
||||
shared_xaxes=False,
|
||||
vertical_spacing=0.1,
|
||||
subplot_titles=(
|
||||
"Walk-Forward Test Scores (Sharpe)",
|
||||
"Stitched Out-of-Sample Equity"
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=wfa_results['window'],
|
||||
y=wfa_results['test_score'],
|
||||
name="Test Sharpe"
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
if stitched_curve is not None:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=stitched_curve.index,
|
||||
y=stitched_curve.values,
|
||||
name="Equity",
|
||||
mode='lines'
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
fig.update_layout(height=800, title_text="Walk-Forward Analysis Report")
|
||||
|
||||
if show:
|
||||
fig.show()
|
||||
|
||||
def _extract_result_data(self, result) -> tuple:
|
||||
"""
|
||||
Extract data from BacktestResult or raw Portfolio.
|
||||
|
||||
Returns:
|
||||
Tuple of (portfolio, market_type, leverage, funding_paid, liq_count,
|
||||
liq_loss, adjusted_return)
|
||||
"""
|
||||
if hasattr(result, 'portfolio'):
|
||||
return (
|
||||
result.portfolio,
|
||||
result.market_type.value,
|
||||
result.leverage,
|
||||
result.total_funding_paid,
|
||||
result.liquidation_count,
|
||||
getattr(result, 'total_liquidation_loss', 0.0),
|
||||
getattr(result, 'adjusted_return', None)
|
||||
)
|
||||
return (result, "unknown", 1, 0.0, 0, 0.0, None)
|
||||
|
||||
def _save_csv(self, data, path: Path, description: str) -> None:
|
||||
"""
|
||||
Save data to CSV with consistent error handling.
|
||||
|
||||
Args:
|
||||
data: DataFrame or Series to save
|
||||
path: Output file path
|
||||
description: Human-readable description for logging
|
||||
"""
|
||||
try:
|
||||
data.to_csv(path)
|
||||
logger.info("Saved %s to %s", description, path)
|
||||
except Exception as e:
|
||||
logger.error("Could not save %s: %s", description, e)
|
||||
395
engine/risk.py
Normal file
395
engine/risk.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Risk management utilities for backtesting.
|
||||
|
||||
Handles funding rate calculations and liquidation detection.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketConfig, calculate_liquidation_price
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiquidationEvent:
|
||||
"""
|
||||
Record of a liquidation event during backtesting.
|
||||
|
||||
Attributes:
|
||||
entry_time: Timestamp when position was opened
|
||||
entry_price: Price at position entry
|
||||
liquidation_time: Timestamp when liquidation occurred
|
||||
liquidation_price: Calculated liquidation price
|
||||
actual_price: Actual price that triggered liquidation (high/low)
|
||||
direction: 'long' or 'short'
|
||||
margin_lost_pct: Percentage of margin lost (typically 100%)
|
||||
"""
|
||||
entry_time: pd.Timestamp
|
||||
entry_price: float
|
||||
liquidation_time: pd.Timestamp
|
||||
liquidation_price: float
|
||||
actual_price: float
|
||||
direction: str
|
||||
margin_lost_pct: float = 1.0
|
||||
|
||||
|
||||
def calculate_funding(
|
||||
close: pd.Series,
|
||||
long_entries: pd.DataFrame,
|
||||
long_exits: pd.DataFrame,
|
||||
short_entries: pd.DataFrame,
|
||||
short_exits: pd.DataFrame,
|
||||
market_config: MarketConfig,
|
||||
leverage: int
|
||||
) -> float:
|
||||
"""
|
||||
Calculate total funding paid/received for perpetual positions.
|
||||
|
||||
Simplified model: applies funding rate every 8 hours to open positions.
|
||||
Positive rate means longs pay shorts.
|
||||
|
||||
Args:
|
||||
close: Price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
market_config: Market configuration with funding parameters
|
||||
leverage: Position leverage
|
||||
|
||||
Returns:
|
||||
Total funding paid (positive) or received (negative)
|
||||
"""
|
||||
if market_config.funding_interval_hours == 0:
|
||||
return 0.0
|
||||
|
||||
funding_rate = market_config.funding_rate
|
||||
interval_hours = market_config.funding_interval_hours
|
||||
|
||||
# Determine position state at each bar
|
||||
long_position = long_entries.cumsum() - long_exits.cumsum()
|
||||
short_position = short_entries.cumsum() - short_exits.cumsum()
|
||||
|
||||
# Clamp to 0/1 (either in position or not)
|
||||
long_position = (long_position > 0).astype(int)
|
||||
short_position = (short_position > 0).astype(int)
|
||||
|
||||
# Find funding timestamps (every 8 hours: 00:00, 08:00, 16:00 UTC)
|
||||
funding_times = close.index[close.index.hour % interval_hours == 0]
|
||||
|
||||
total_funding = 0.0
|
||||
for ts in funding_times:
|
||||
if ts not in close.index:
|
||||
continue
|
||||
price = close.loc[ts]
|
||||
|
||||
# Long pays funding, short receives (when rate > 0)
|
||||
if isinstance(long_position, pd.DataFrame):
|
||||
long_open = long_position.loc[ts].any()
|
||||
short_open = short_position.loc[ts].any()
|
||||
else:
|
||||
long_open = long_position.loc[ts] > 0
|
||||
short_open = short_position.loc[ts] > 0
|
||||
|
||||
position_value = price * leverage
|
||||
if long_open:
|
||||
total_funding += position_value * funding_rate
|
||||
if short_open:
|
||||
total_funding -= position_value * funding_rate
|
||||
|
||||
return total_funding
|
||||
|
||||
|
||||
def inject_liquidation_exits(
|
||||
close: pd.Series,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
long_entries: pd.DataFrame | pd.Series,
|
||||
long_exits: pd.DataFrame | pd.Series,
|
||||
short_entries: pd.DataFrame | pd.Series,
|
||||
short_exits: pd.DataFrame | pd.Series,
|
||||
leverage: int,
|
||||
maintenance_margin_rate: float
|
||||
) -> tuple[pd.DataFrame | pd.Series, pd.DataFrame | pd.Series, list[LiquidationEvent]]:
|
||||
"""
|
||||
Modify exit signals to force position closure at liquidation points.
|
||||
|
||||
This function simulates realistic liquidation behavior by:
|
||||
1. Finding positions that would be liquidated before their normal exit
|
||||
2. Injecting forced exit signals at the liquidation bar
|
||||
3. Recording all liquidation events
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
high: High price series
|
||||
low: Low price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
leverage: Position leverage
|
||||
maintenance_margin_rate: Maintenance margin rate for liquidation
|
||||
|
||||
Returns:
|
||||
Tuple of (modified_long_exits, modified_short_exits, liquidation_events)
|
||||
"""
|
||||
if leverage <= 1:
|
||||
return long_exits, short_exits, []
|
||||
|
||||
liquidation_events: list[LiquidationEvent] = []
|
||||
|
||||
# Convert to DataFrame if Series for consistent handling
|
||||
is_series = isinstance(long_entries, pd.Series)
|
||||
if is_series:
|
||||
long_entries_df = long_entries.to_frame()
|
||||
long_exits_df = long_exits.to_frame()
|
||||
short_entries_df = short_entries.to_frame()
|
||||
short_exits_df = short_exits.to_frame()
|
||||
else:
|
||||
long_entries_df = long_entries
|
||||
long_exits_df = long_exits.copy()
|
||||
short_entries_df = short_entries
|
||||
short_exits_df = short_exits.copy()
|
||||
|
||||
modified_long_exits = long_exits_df.copy()
|
||||
modified_short_exits = short_exits_df.copy()
|
||||
|
||||
# Process long positions
|
||||
long_mask = long_entries_df.any(axis=1)
|
||||
for entry_idx in close.index[long_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=True,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
# Find the normal exit for this entry
|
||||
subsequent_exits = long_exits_df.loc[entry_idx:].any(axis=1)
|
||||
exit_indices = subsequent_exits[subsequent_exits].index
|
||||
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
|
||||
|
||||
# Check if liquidation occurs before normal exit
|
||||
price_range = low.loc[entry_idx:normal_exit_idx]
|
||||
if (price_range < liq_price).any():
|
||||
liq_bar = price_range[price_range < liq_price].index[0]
|
||||
|
||||
# Inject forced exit at liquidation bar
|
||||
for col in modified_long_exits.columns:
|
||||
modified_long_exits.loc[liq_bar, col] = True
|
||||
|
||||
# Record the liquidation event
|
||||
liquidation_events.append(LiquidationEvent(
|
||||
entry_time=entry_idx,
|
||||
entry_price=entry_price,
|
||||
liquidation_time=liq_bar,
|
||||
liquidation_price=liq_price,
|
||||
actual_price=low.loc[liq_bar],
|
||||
direction='long',
|
||||
margin_lost_pct=1.0
|
||||
))
|
||||
|
||||
logger.warning(
|
||||
"LIQUIDATION (Long): Entry %s ($%.2f) -> Liquidated %s "
|
||||
"(liq=$%.2f, low=$%.2f)",
|
||||
entry_idx.strftime('%Y-%m-%d'), entry_price,
|
||||
liq_bar.strftime('%Y-%m-%d'), liq_price, low.loc[liq_bar]
|
||||
)
|
||||
|
||||
# Process short positions
|
||||
short_mask = short_entries_df.any(axis=1)
|
||||
for entry_idx in close.index[short_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=False,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
# Find the normal exit for this entry
|
||||
subsequent_exits = short_exits_df.loc[entry_idx:].any(axis=1)
|
||||
exit_indices = subsequent_exits[subsequent_exits].index
|
||||
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
|
||||
|
||||
# Check if liquidation occurs before normal exit
|
||||
price_range = high.loc[entry_idx:normal_exit_idx]
|
||||
if (price_range > liq_price).any():
|
||||
liq_bar = price_range[price_range > liq_price].index[0]
|
||||
|
||||
# Inject forced exit at liquidation bar
|
||||
for col in modified_short_exits.columns:
|
||||
modified_short_exits.loc[liq_bar, col] = True
|
||||
|
||||
# Record the liquidation event
|
||||
liquidation_events.append(LiquidationEvent(
|
||||
entry_time=entry_idx,
|
||||
entry_price=entry_price,
|
||||
liquidation_time=liq_bar,
|
||||
liquidation_price=liq_price,
|
||||
actual_price=high.loc[liq_bar],
|
||||
direction='short',
|
||||
margin_lost_pct=1.0
|
||||
))
|
||||
|
||||
logger.warning(
|
||||
"LIQUIDATION (Short): Entry %s ($%.2f) -> Liquidated %s "
|
||||
"(liq=$%.2f, high=$%.2f)",
|
||||
entry_idx.strftime('%Y-%m-%d'), entry_price,
|
||||
liq_bar.strftime('%Y-%m-%d'), liq_price, high.loc[liq_bar]
|
||||
)
|
||||
|
||||
# Convert back to Series if input was Series
|
||||
if is_series:
|
||||
modified_long_exits = modified_long_exits.iloc[:, 0]
|
||||
modified_short_exits = modified_short_exits.iloc[:, 0]
|
||||
|
||||
return modified_long_exits, modified_short_exits, liquidation_events
|
||||
|
||||
|
||||
def calculate_liquidation_adjustment(
|
||||
liquidation_events: list[LiquidationEvent],
|
||||
init_cash: float,
|
||||
leverage: int
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate the return adjustment for liquidated positions.
|
||||
|
||||
VectorBT calculates trade P&L using close price at exit bar.
|
||||
For liquidations, the actual loss is 100% of the position margin.
|
||||
This function calculates the difference between what VectorBT
|
||||
recorded and what actually would have happened.
|
||||
|
||||
In our portfolio setup:
|
||||
- Long/short each get half the capital (init_cash * leverage / 2)
|
||||
- Each trade uses 100% of that allocation (size=1.0, percent)
|
||||
- On liquidation, the margin for that trade is lost entirely
|
||||
|
||||
The adjustment is the DIFFERENCE between:
|
||||
- VectorBT's calculated P&L (exit at close price)
|
||||
- Actual liquidation P&L (100% margin loss)
|
||||
|
||||
Args:
|
||||
liquidation_events: List of liquidation events
|
||||
init_cash: Initial portfolio cash (before leverage)
|
||||
leverage: Position leverage used
|
||||
|
||||
Returns:
|
||||
Tuple of (total_margin_lost, adjustment_pct)
|
||||
- total_margin_lost: Estimated total margin lost from liquidations
|
||||
- adjustment_pct: Percentage adjustment to apply to returns
|
||||
"""
|
||||
if not liquidation_events:
|
||||
return 0.0, 0.0
|
||||
|
||||
# In our setup, each side (long/short) gets half the capital
|
||||
# Margin per side = init_cash / 2
|
||||
margin_per_side = init_cash / 2
|
||||
|
||||
# For each liquidation, VectorBT recorded some P&L based on close price
|
||||
# The actual P&L should be -100% of the margin used for that trade
|
||||
#
|
||||
# We estimate the adjustment as:
|
||||
# - Each liquidation should have resulted in ~-20% loss (at 5x leverage)
|
||||
# - VectorBT may have recorded a different value
|
||||
# - The margin loss is (1/leverage) per trade that gets liquidated
|
||||
|
||||
# Calculate liquidation loss rate based on leverage
|
||||
# At 5x leverage, liquidation = ~19.6% adverse move = 100% margin loss
|
||||
liq_loss_rate = 1.0 / leverage # Approximate loss per trade as % of position
|
||||
|
||||
# Count liquidations
|
||||
n_liquidations = len(liquidation_events)
|
||||
|
||||
# Estimate total margin lost:
|
||||
# Each liquidation on average loses the margin for that trade
|
||||
# Since VectorBT uses half capital per side, and we trade 100% size,
|
||||
# each liquidation loses approximately margin_per_side
|
||||
# But we cap at available capital
|
||||
total_margin_lost = min(n_liquidations * margin_per_side * liq_loss_rate, init_cash)
|
||||
|
||||
# Calculate as percentage of initial capital
|
||||
adjustment_pct = (total_margin_lost / init_cash) * 100
|
||||
|
||||
return total_margin_lost, adjustment_pct
|
||||
|
||||
|
||||
def check_liquidations(
|
||||
close: pd.Series,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
long_entries: pd.DataFrame,
|
||||
long_exits: pd.DataFrame,
|
||||
short_entries: pd.DataFrame,
|
||||
short_exits: pd.DataFrame,
|
||||
leverage: int,
|
||||
maintenance_margin_rate: float
|
||||
) -> int:
|
||||
"""
|
||||
Check for liquidation events and log warnings.
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
high: High price series
|
||||
low: Low price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
leverage: Position leverage
|
||||
maintenance_margin_rate: Maintenance margin rate for liquidation
|
||||
|
||||
Returns:
|
||||
Count of liquidation warnings
|
||||
"""
|
||||
warnings = 0
|
||||
|
||||
# For long positions
|
||||
long_mask = (
|
||||
long_entries.any(axis=1)
|
||||
if isinstance(long_entries, pd.DataFrame)
|
||||
else long_entries
|
||||
)
|
||||
|
||||
for entry_idx in close.index[long_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=True,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
subsequent = low.loc[entry_idx:]
|
||||
if (subsequent < liq_price).any():
|
||||
liq_bar = subsequent[subsequent < liq_price].index[0]
|
||||
logger.warning(
|
||||
"LIQUIDATION WARNING (Long): Entry at %s ($%.2f), "
|
||||
"would liquidate at %s (liq_price=$%.2f, low=$%.2f)",
|
||||
entry_idx, entry_price, liq_bar, liq_price, low.loc[liq_bar]
|
||||
)
|
||||
warnings += 1
|
||||
|
||||
# For short positions
|
||||
short_mask = (
|
||||
short_entries.any(axis=1)
|
||||
if isinstance(short_entries, pd.DataFrame)
|
||||
else short_entries
|
||||
)
|
||||
|
||||
for entry_idx in close.index[short_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=False,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
subsequent = high.loc[entry_idx:]
|
||||
if (subsequent > liq_price).any():
|
||||
liq_bar = subsequent[subsequent > liq_price].index[0]
|
||||
logger.warning(
|
||||
"LIQUIDATION WARNING (Short): Entry at %s ($%.2f), "
|
||||
"would liquidate at %s (liq_price=$%.2f, high=$%.2f)",
|
||||
entry_idx, entry_price, liq_bar, liq_price, high.loc[liq_bar]
|
||||
)
|
||||
warnings += 1
|
||||
|
||||
return warnings
|
||||
Reference in New Issue
Block a user