Remove deprecated modules and files related to the backtesting framework, including backtest.py, cli.py, config.py, data.py, intrabar.py, logging_utils.py, market_costs.py, metrics.py, trade.py, and supertrend indicators. Introduce a new structure for the backtesting engine with improved organization and functionality, including a CLI handler, data manager, and reporting capabilities. Update dependencies in pyproject.toml to support the new architecture.

This commit is contained in:
2026-01-12 21:11:39 +08:00
parent c4aa965a98
commit 44fac1ed25
37 changed files with 5253 additions and 393 deletions

352
engine/backtester.py Normal file
View File

@@ -0,0 +1,352 @@
"""
Core backtesting engine for running strategy simulations.
Supports multiple market types with realistic trading conditions.
"""
from dataclasses import dataclass
import pandas as pd
import vectorbt as vbt
from engine.data_manager import DataManager
from engine.logging_config import get_logger
from engine.market import MarketType, get_market_config
from engine.optimizer import WalkForwardOptimizer
from engine.portfolio import run_long_only_portfolio, run_long_short_portfolio
from engine.risk import (
LiquidationEvent,
calculate_funding,
calculate_liquidation_adjustment,
inject_liquidation_exits,
)
from strategies.base import BaseStrategy
logger = get_logger(__name__)
@dataclass
class BacktestResult:
"""
Container for backtest results with market-specific metrics.
Attributes:
portfolio: VectorBT Portfolio object
market_type: Market type used for the backtest
leverage: Effective leverage used
total_funding_paid: Total funding fees paid (perpetuals only)
liquidation_count: Number of positions that were liquidated
liquidation_events: Detailed list of liquidation events
total_liquidation_loss: Total margin lost from liquidations
adjusted_return: Return adjusted for liquidation losses (percentage)
"""
portfolio: vbt.Portfolio
market_type: MarketType
leverage: int
total_funding_paid: float = 0.0
liquidation_count: int = 0
liquidation_events: list[LiquidationEvent] | None = None
total_liquidation_loss: float = 0.0
adjusted_return: float | None = None
class Backtester:
"""
Backtester supporting multiple market types with realistic simulation.
Features:
- Spot and Perpetual market support
- Long and short position handling
- Leverage simulation
- Funding rate calculation (perpetuals)
- Liquidation warnings
"""
def __init__(self, data_manager: DataManager):
self.dm = data_manager
def run_strategy(
self,
strategy: BaseStrategy,
exchange_id: str,
symbol: str,
timeframe: str = '1m',
start_date: str | None = None,
end_date: str | None = None,
init_cash: float = 10000,
fees: float | None = None,
slippage: float = 0.001,
sl_stop: float | None = None,
tp_stop: float | None = None,
sl_trail: bool = False,
leverage: int | None = None,
**strategy_params
) -> BacktestResult:
"""
Run a backtest with market-type-aware simulation.
Args:
strategy: Strategy instance to backtest
exchange_id: Exchange identifier (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Data timeframe (e.g., '1m', '1h', '1d')
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
init_cash: Initial capital (margin for leveraged)
fees: Transaction fee override (uses market default if None)
slippage: Slippage percentage
sl_stop: Stop loss percentage
tp_stop: Take profit percentage
sl_trail: Enable trailing stop loss
leverage: Leverage override (uses strategy default if None)
**strategy_params: Additional strategy parameters
Returns:
BacktestResult with portfolio and market-specific metrics
"""
# Get market configuration from strategy
market_type = strategy.default_market_type
market_config = get_market_config(market_type)
# Resolve leverage and fees
effective_leverage = self._resolve_leverage(leverage, strategy, market_type)
effective_fees = fees if fees is not None else market_config.taker_fee
# Load and filter data
df = self._load_data(
exchange_id, symbol, timeframe, market_type, start_date, end_date
)
close_price = df['close']
high_price = df['high']
low_price = df['low']
open_price = df['open']
volume = df['volume']
# Run strategy logic
signals = strategy.run(
close_price,
high=high_price,
low=low_price,
open=open_price,
volume=volume,
**strategy_params
)
# Normalize signals to 4-tuple format
signals = self._normalize_signals(signals, close_price, market_config)
long_entries, long_exits, short_entries, short_exits = signals
# Process liquidations - inject forced exits at liquidation points
liquidation_events: list[LiquidationEvent] = []
if effective_leverage > 1:
long_exits, short_exits, liquidation_events = inject_liquidation_exits(
close_price, high_price, low_price,
long_entries, long_exits,
short_entries, short_exits,
effective_leverage,
market_config.maintenance_margin_rate
)
# Calculate perpetual-specific metrics (after liquidation processing)
total_funding = 0.0
if market_type == MarketType.PERPETUAL:
total_funding = calculate_funding(
close_price,
long_entries, long_exits,
short_entries, short_exits,
market_config,
effective_leverage
)
# Run portfolio simulation with liquidation-aware exits
portfolio = self._run_portfolio(
close_price, market_config,
long_entries, long_exits,
short_entries, short_exits,
init_cash, effective_fees, slippage, timeframe,
sl_stop, tp_stop, sl_trail, effective_leverage
)
# Calculate adjusted returns accounting for liquidation losses
total_liq_loss, liq_adjustment = calculate_liquidation_adjustment(
liquidation_events, init_cash, effective_leverage
)
raw_return = portfolio.total_return().mean() * 100
adjusted_return = raw_return - liq_adjustment
if liquidation_events:
logger.info(
"Liquidation impact: %d events, $%.2f margin lost, %.2f%% adjustment",
len(liquidation_events), total_liq_loss, liq_adjustment
)
logger.info(
"Backtest completed: %s market, %dx leverage, fees=%.4f%%",
market_type.value, effective_leverage, effective_fees * 100
)
return BacktestResult(
portfolio=portfolio,
market_type=market_type,
leverage=effective_leverage,
total_funding_paid=total_funding,
liquidation_count=len(liquidation_events),
liquidation_events=liquidation_events,
total_liquidation_loss=total_liq_loss,
adjusted_return=adjusted_return
)
def _resolve_leverage(
self,
leverage: int | None,
strategy: BaseStrategy,
market_type: MarketType
) -> int:
"""Resolve effective leverage from CLI, strategy default, or market type."""
effective = leverage or strategy.default_leverage
if market_type == MarketType.SPOT:
return 1 # Spot cannot have leverage
return effective
def _load_data(
self,
exchange_id: str,
symbol: str,
timeframe: str,
market_type: MarketType,
start_date: str | None,
end_date: str | None
) -> pd.DataFrame:
"""Load and filter OHLCV data."""
try:
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
except FileNotFoundError:
logger.warning("Data not found locally. Attempting download...")
df = self.dm.download_data(
exchange_id, symbol, timeframe,
start_date, end_date, market_type
)
if start_date:
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
if end_date:
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
return df
def _normalize_signals(
self,
signals: tuple,
close: pd.Series,
market_config
) -> tuple:
"""
Normalize strategy signals to 4-tuple format.
Handles backward compatibility with 2-tuple (long-only) returns.
"""
if len(signals) == 2:
long_entries, long_exits = signals
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits
if len(signals) == 4:
long_entries, long_exits, short_entries, short_exits = signals
# Warn and clear short signals on spot markets
if not market_config.supports_short:
has_shorts = (
short_entries.any().any()
if hasattr(short_entries, 'any')
else short_entries.any()
)
if has_shorts:
logger.warning(
"Short signals detected but market type is SPOT. "
"Short signals will be ignored."
)
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits
raise ValueError(
f"Strategy must return 2 or 4 signal arrays, got {len(signals)}"
)
def _run_portfolio(
self,
close: pd.Series,
market_config,
long_entries, long_exits,
short_entries, short_exits,
init_cash: float,
fees: float,
slippage: float,
freq: str,
sl_stop: float | None,
tp_stop: float | None,
sl_trail: bool,
leverage: int
) -> vbt.Portfolio:
"""Select and run appropriate portfolio simulation."""
has_shorts = (
short_entries.any().any()
if hasattr(short_entries, 'any')
else short_entries.any()
)
if market_config.supports_short and has_shorts:
return run_long_short_portfolio(
close,
long_entries, long_exits,
short_entries, short_exits,
init_cash, fees, slippage, freq,
sl_stop, tp_stop, sl_trail, leverage
)
return run_long_only_portfolio(
close,
long_entries, long_exits,
init_cash, fees, slippage, freq,
sl_stop, tp_stop, sl_trail, leverage
)
def run_wfa(
self,
strategy: BaseStrategy,
exchange_id: str,
symbol: str,
param_grid: dict,
n_windows: int = 10,
timeframe: str = '1m'
):
"""
Execute Walk-Forward Analysis.
Args:
strategy: Strategy instance to optimize
exchange_id: Exchange identifier
symbol: Trading pair symbol
param_grid: Parameter grid for optimization
n_windows: Number of walk-forward windows
timeframe: Data timeframe to load
Returns:
Tuple of (results DataFrame, stitched equity curve)
"""
market_type = strategy.default_market_type
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
wfa = WalkForwardOptimizer(self, strategy, param_grid)
results, stitched_curve = wfa.run(
df['close'],
high=df['high'],
low=df['low'],
n_windows=n_windows
)
return results, stitched_curve

243
engine/cli.py Normal file
View File

@@ -0,0 +1,243 @@
"""
CLI handler for Lowkey Backtest.
"""
import argparse
import sys
from engine.backtester import Backtester
from engine.data_manager import DataManager
from engine.logging_config import get_logger, setup_logging
from engine.market import MarketType
from engine.reporting import Reporter
from strategies.factory import get_strategy, get_strategy_names
logger = get_logger(__name__)
def create_parser() -> argparse.ArgumentParser:
"""Create and configure the argument parser."""
parser = argparse.ArgumentParser(
description="Lowkey Backtest CLI (VectorBT Edition)"
)
subparsers = parser.add_subparsers(dest="command", help="Command to run")
_add_download_parser(subparsers)
_add_backtest_parser(subparsers)
_add_wfa_parser(subparsers)
return parser
def _add_download_parser(subparsers) -> None:
"""Add download command parser."""
dl_parser = subparsers.add_parser("download", help="Download historical data")
dl_parser.add_argument("--exchange", "-e", type=str, default="okx")
dl_parser.add_argument("--pair", "-p", type=str, required=True)
dl_parser.add_argument("--timeframe", "-t", type=str, default="1m")
dl_parser.add_argument("--start", type=str, help="Start Date (YYYY-MM-DD)")
dl_parser.add_argument(
"--market", "-m",
type=str,
choices=["spot", "perpetual"],
default="spot"
)
def _add_backtest_parser(subparsers) -> None:
"""Add backtest command parser."""
strategy_choices = get_strategy_names()
bt_parser = subparsers.add_parser("backtest", help="Run a backtest")
bt_parser.add_argument(
"--strategy", "-s",
type=str,
choices=strategy_choices,
required=True
)
bt_parser.add_argument("--exchange", "-e", type=str, default="okx")
bt_parser.add_argument("--pair", "-p", type=str, required=True)
bt_parser.add_argument("--timeframe", "-t", type=str, default="1m")
bt_parser.add_argument("--start", type=str)
bt_parser.add_argument("--end", type=str)
bt_parser.add_argument("--grid", "-g", action="store_true")
bt_parser.add_argument("--plot", action="store_true")
# Risk parameters
bt_parser.add_argument("--sl", type=float, help="Stop Loss %%")
bt_parser.add_argument("--tp", type=float, help="Take Profit %%")
bt_parser.add_argument("--trail", action="store_true")
bt_parser.add_argument("--no-bear-exit", action="store_true")
# Cost parameters
bt_parser.add_argument("--fees", type=float, default=None)
bt_parser.add_argument("--slippage", type=float, default=0.001)
bt_parser.add_argument("--leverage", "-l", type=int, default=None)
def _add_wfa_parser(subparsers) -> None:
"""Add walk-forward analysis command parser."""
strategy_choices = get_strategy_names()
wfa_parser = subparsers.add_parser("wfa", help="Run Walk-Forward Analysis")
wfa_parser.add_argument(
"--strategy", "-s",
type=str,
choices=strategy_choices,
required=True
)
wfa_parser.add_argument("--pair", "-p", type=str, required=True)
wfa_parser.add_argument("--timeframe", "-t", type=str, default="1d")
wfa_parser.add_argument("--windows", "-w", type=int, default=10)
wfa_parser.add_argument("--plot", action="store_true")
def run_download(args) -> None:
"""Execute download command."""
dm = DataManager()
market_type = MarketType(args.market)
dm.download_data(
args.exchange,
args.pair,
args.timeframe,
start_date=args.start,
market_type=market_type
)
def run_backtest(args) -> None:
"""Execute backtest command."""
dm = DataManager()
bt = Backtester(dm)
reporter = Reporter()
strategy, params = get_strategy(args.strategy, args.grid)
# Apply CLI overrides for meta_st strategy
params = _apply_strategy_overrides(args, strategy, params)
if args.grid and args.strategy == "meta_st":
logger.info("Running Grid Search for Meta Supertrend...")
try:
result = bt.run_strategy(
strategy,
args.exchange,
args.pair,
timeframe=args.timeframe,
start_date=args.start,
end_date=args.end,
fees=args.fees,
slippage=args.slippage,
sl_stop=args.sl,
tp_stop=args.tp,
sl_trail=args.trail,
leverage=args.leverage,
**params
)
reporter.print_summary(result)
reporter.save_reports(result, f"{args.strategy}_{args.pair.replace('/','-')}")
if args.plot and not args.grid:
reporter.plot(result.portfolio)
elif args.plot and args.grid:
logger.info("Plotting skipped for Grid Search. Check CSV results.")
except Exception as e:
logger.error("Backtest failed: %s", e, exc_info=True)
def run_wfa(args) -> None:
"""Execute walk-forward analysis command."""
dm = DataManager()
bt = Backtester(dm)
reporter = Reporter()
strategy, params = get_strategy(args.strategy, is_grid=True)
logger.info(
"Running WFA on %s for %s (%s) with %d windows...",
args.strategy, args.pair, args.timeframe, args.windows
)
try:
results, stitched_curve = bt.run_wfa(
strategy,
"okx",
args.pair,
params,
n_windows=args.windows,
timeframe=args.timeframe
)
_log_wfa_results(results)
_save_wfa_results(args, results, stitched_curve, reporter)
except Exception as e:
logger.error("WFA failed: %s", e, exc_info=True)
def _apply_strategy_overrides(args, strategy, params: dict) -> dict:
"""Apply CLI argument overrides to strategy parameters."""
if args.strategy != "meta_st":
return params
if args.no_bear_exit:
params['exit_on_bearish_flip'] = False
if args.sl is None:
args.sl = strategy.default_sl_stop
if not args.trail:
args.trail = strategy.default_sl_trail
return params
def _log_wfa_results(results) -> None:
"""Log WFA results summary."""
logger.info("Walk-Forward Analysis Results:")
if results.empty or 'window' not in results.columns:
logger.warning("No valid WFA results. All windows may have failed.")
return
columns = ['window', 'train_score', 'test_score', 'test_return']
logger.info("\n%s", results[columns].to_string(index=False))
avg_test_sharpe = results['test_score'].mean()
avg_test_return = results['test_return'].mean()
logger.info("Average Test Sharpe: %.2f", avg_test_sharpe)
logger.info("Average Test Return: %.2f%%", avg_test_return * 100)
def _save_wfa_results(args, results, stitched_curve, reporter) -> None:
"""Save WFA results to file and optionally plot."""
if results.empty:
return
output_path = f"backtest_logs/wfa_{args.strategy}_{args.pair.replace('/','-')}.csv"
results.to_csv(output_path)
logger.info("Saved full results to %s", output_path)
if args.plot:
reporter.plot_wfa(results, stitched_curve)
def main():
"""Main entry point."""
setup_logging()
parser = create_parser()
args = parser.parse_args()
commands = {
"download": run_download,
"backtest": run_backtest,
"wfa": run_wfa,
}
if args.command in commands:
commands[args.command](args)
else:
parser.print_help()

209
engine/data_manager.py Normal file
View File

@@ -0,0 +1,209 @@
"""
Data management for OHLCV data download and storage.
Handles data retrieval from exchanges and local file management.
"""
import time
from datetime import datetime, timezone
from pathlib import Path
import ccxt
import pandas as pd
from engine.logging_config import get_logger
from engine.market import MarketType, get_ccxt_symbol
logger = get_logger(__name__)
class DataManager:
"""
Manages OHLCV data download and storage for different market types.
Data is stored in: data/ccxt/{exchange}/{market_type}/{symbol}/{timeframe}.csv
"""
def __init__(self, data_dir: str = "data/ccxt"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.exchanges: dict[str, ccxt.Exchange] = {}
def get_exchange(self, exchange_id: str) -> ccxt.Exchange:
"""Get or create a CCXT exchange instance."""
if exchange_id not in self.exchanges:
exchange_class = getattr(ccxt, exchange_id)
self.exchanges[exchange_id] = exchange_class({
'enableRateLimit': True,
})
return self.exchanges[exchange_id]
def _get_data_path(
self,
exchange_id: str,
symbol: str,
timeframe: str,
market_type: MarketType
) -> Path:
"""
Get the file path for storing/loading data.
Args:
exchange_id: Exchange name (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Candle timeframe (e.g., '1m')
market_type: Market type (spot or perpetual)
Returns:
Path to the CSV file
"""
safe_symbol = symbol.replace('/', '-')
return (
self.data_dir
/ exchange_id
/ market_type.value
/ safe_symbol
/ f"{timeframe}.csv"
)
def download_data(
self,
exchange_id: str,
symbol: str,
timeframe: str = '1m',
start_date: str | None = None,
end_date: str | None = None,
market_type: MarketType = MarketType.SPOT
) -> pd.DataFrame | None:
"""
Download OHLCV data from exchange and save to CSV.
Args:
exchange_id: Exchange name (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Candle timeframe (e.g., '1m')
start_date: Start date string (YYYY-MM-DD)
end_date: End date string (YYYY-MM-DD)
market_type: Market type (spot or perpetual)
Returns:
DataFrame with OHLCV data, or None if download failed
"""
exchange = self.get_exchange(exchange_id)
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
file_path.parent.mkdir(parents=True, exist_ok=True)
ccxt_symbol = get_ccxt_symbol(symbol, market_type)
since, until = self._parse_date_range(exchange, start_date, end_date)
logger.info(
"Downloading %s (%s) from %s...",
symbol, market_type.value, exchange_id
)
all_ohlcv = self._fetch_all_candles(exchange, ccxt_symbol, timeframe, since, until)
if not all_ohlcv:
logger.warning("No data downloaded.")
return None
df = self._convert_to_dataframe(all_ohlcv)
df.to_csv(file_path)
logger.info("Saved %d candles to %s", len(df), file_path)
return df
def load_data(
self,
exchange_id: str,
symbol: str,
timeframe: str = '1m',
market_type: MarketType = MarketType.SPOT
) -> pd.DataFrame:
"""
Load saved OHLCV data for vectorbt.
Args:
exchange_id: Exchange name (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Candle timeframe (e.g., '1m')
market_type: Market type (spot or perpetual)
Returns:
DataFrame with OHLCV data indexed by timestamp
Raises:
FileNotFoundError: If data file does not exist
"""
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
if not file_path.exists():
raise FileNotFoundError(
f"Data not found at {file_path}. "
f"Run: uv run python main.py download --pair {symbol} "
f"--market {market_type.value}"
)
return pd.read_csv(file_path, index_col='timestamp', parse_dates=True)
def _parse_date_range(
self,
exchange: ccxt.Exchange,
start_date: str | None,
end_date: str | None
) -> tuple[int, int]:
"""Parse date strings into millisecond timestamps."""
if start_date:
since = exchange.parse8601(f"{start_date}T00:00:00Z")
else:
since = exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
if end_date:
until = exchange.parse8601(f"{end_date}T23:59:59Z")
else:
until = exchange.milliseconds()
return since, until
def _fetch_all_candles(
self,
exchange: ccxt.Exchange,
symbol: str,
timeframe: str,
since: int,
until: int
) -> list:
"""Fetch all candles in the date range."""
all_ohlcv = []
while since < until:
try:
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since, limit=100)
if not ohlcv:
break
all_ohlcv.extend(ohlcv)
since = ohlcv[-1][0] + 1
current_date = datetime.fromtimestamp(
since/1000, tz=timezone.utc
).strftime('%Y-%m-%d')
logger.debug("Fetched up to %s", current_date)
time.sleep(exchange.rateLimit / 1000)
except Exception as e:
logger.error("Error fetching data: %s", e)
break
return all_ohlcv
def _convert_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
"""Convert OHLCV list to DataFrame."""
df = pd.DataFrame(
ohlcv,
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df.set_index('timestamp', inplace=True)
return df

124
engine/logging_config.py Normal file
View File

@@ -0,0 +1,124 @@
"""
Centralized logging configuration for the backtest engine.
Provides colored console output and rotating file logs.
"""
import logging
import sys
from logging.handlers import RotatingFileHandler
from pathlib import Path
# ANSI color codes for terminal output
class Colors:
"""ANSI escape codes for colored terminal output."""
RESET = "\033[0m"
BOLD = "\033[1m"
# Log level colors
DEBUG = "\033[36m" # Cyan
INFO = "\033[32m" # Green
WARNING = "\033[33m" # Yellow
ERROR = "\033[31m" # Red
CRITICAL = "\033[35m" # Magenta
class ColoredFormatter(logging.Formatter):
"""
Custom formatter that adds colors to log level names in terminal output.
"""
LEVEL_COLORS = {
logging.DEBUG: Colors.DEBUG,
logging.INFO: Colors.INFO,
logging.WARNING: Colors.WARNING,
logging.ERROR: Colors.ERROR,
logging.CRITICAL: Colors.CRITICAL,
}
def __init__(self, fmt: str = None, datefmt: str = None):
super().__init__(fmt, datefmt)
def format(self, record: logging.LogRecord) -> str:
# Save original levelname
original_levelname = record.levelname
# Add color to levelname
color = self.LEVEL_COLORS.get(record.levelno, Colors.RESET)
record.levelname = f"{color}{record.levelname}{Colors.RESET}"
# Format the message
result = super().format(record)
# Restore original levelname
record.levelname = original_levelname
return result
def setup_logging(
log_dir: str = "logs",
log_level: int = logging.INFO,
console_level: int = logging.INFO,
max_bytes: int = 5 * 1024 * 1024, # 5MB
backup_count: int = 3
) -> None:
"""
Configure logging for the application.
Args:
log_dir: Directory for log files
log_level: File logging level
console_level: Console logging level
max_bytes: Max size per log file before rotation
backup_count: Number of backup files to keep
"""
log_path = Path(log_dir)
log_path.mkdir(parents=True, exist_ok=True)
# Get root logger
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG) # Capture all, handlers filter
# Clear existing handlers
root_logger.handlers.clear()
# Console handler with colors
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(console_level)
console_fmt = ColoredFormatter(
fmt="[%(asctime)s] %(levelname)s - %(message)s",
datefmt="%H:%M:%S"
)
console_handler.setFormatter(console_fmt)
root_logger.addHandler(console_handler)
# File handler with rotation
file_handler = RotatingFileHandler(
log_path / "backtest.log",
maxBytes=max_bytes,
backupCount=backup_count,
encoding="utf-8"
)
file_handler.setLevel(log_level)
file_fmt = logging.Formatter(
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
file_handler.setFormatter(file_fmt)
root_logger.addHandler(file_handler)
def get_logger(name: str) -> logging.Logger:
"""
Get a logger instance for the given module name.
Args:
name: Module name (typically __name__)
Returns:
Configured logger instance
"""
return logging.getLogger(name)

156
engine/market.py Normal file
View File

@@ -0,0 +1,156 @@
"""
Market type definitions and configuration for backtesting.
Supports different market types with their specific trading conditions:
- SPOT: No leverage, no funding, long-only
- PERPETUAL: Leverage, funding rates, long/short
"""
from dataclasses import dataclass
from enum import Enum
class MarketType(Enum):
"""Supported market types for backtesting."""
SPOT = "spot"
PERPETUAL = "perpetual"
@dataclass(frozen=True)
class MarketConfig:
"""
Configuration for a specific market type.
Attributes:
market_type: The market type enum value
maker_fee: Maker fee as decimal (e.g., 0.0008 = 0.08%)
taker_fee: Taker fee as decimal (e.g., 0.001 = 0.1%)
max_leverage: Maximum allowed leverage
funding_rate: Funding rate per 8 hours as decimal (perpetuals only)
funding_interval_hours: Hours between funding payments
maintenance_margin_rate: Rate for liquidation calculation
supports_short: Whether short-selling is supported
"""
market_type: MarketType
maker_fee: float
taker_fee: float
max_leverage: int
funding_rate: float
funding_interval_hours: int
maintenance_margin_rate: float
supports_short: bool
# OKX-based default configurations
SPOT_CONFIG = MarketConfig(
market_type=MarketType.SPOT,
maker_fee=0.0008, # 0.08%
taker_fee=0.0010, # 0.10%
max_leverage=1,
funding_rate=0.0,
funding_interval_hours=0,
maintenance_margin_rate=0.0,
supports_short=False,
)
PERPETUAL_CONFIG = MarketConfig(
market_type=MarketType.PERPETUAL,
maker_fee=0.0002, # 0.02%
taker_fee=0.0005, # 0.05%
max_leverage=125,
funding_rate=0.0001, # 0.01% per 8 hours (simplified average)
funding_interval_hours=8,
maintenance_margin_rate=0.004, # 0.4% for BTC on OKX
supports_short=True,
)
def get_market_config(market_type: MarketType) -> MarketConfig:
"""
Get the configuration for a specific market type.
Args:
market_type: The market type to get configuration for
Returns:
MarketConfig with default values for that market type
"""
configs = {
MarketType.SPOT: SPOT_CONFIG,
MarketType.PERPETUAL: PERPETUAL_CONFIG,
}
return configs[market_type]
def get_ccxt_symbol(symbol: str, market_type: MarketType) -> str:
"""
Convert a standard symbol to CCXT format for the given market type.
Args:
symbol: Standard symbol (e.g., 'BTC/USDT')
market_type: The market type
Returns:
CCXT-formatted symbol (e.g., 'BTC/USDT:USDT' for perpetuals)
"""
if market_type == MarketType.PERPETUAL:
# OKX perpetual format: BTC/USDT:USDT
quote = symbol.split('/')[1] if '/' in symbol else 'USDT'
return f"{symbol}:{quote}"
return symbol
def calculate_leverage_stop_loss(
leverage: int,
maintenance_margin_rate: float = 0.004
) -> float:
"""
Calculate the implicit stop-loss percentage from leverage.
At a given leverage, liquidation occurs when the position loses
approximately (1/leverage - maintenance_margin_rate) of its value.
Args:
leverage: Position leverage multiplier
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
Returns:
Stop-loss percentage as decimal (e.g., 0.196 for 19.6%)
"""
if leverage <= 1:
return 1.0 # No forced stop for spot
return (1 / leverage) - maintenance_margin_rate
def calculate_liquidation_price(
entry_price: float,
leverage: float,
is_long: bool,
maintenance_margin_rate: float = 0.004
) -> float:
"""
Calculate the liquidation price for a leveraged position.
Args:
entry_price: Position entry price
leverage: Position leverage
is_long: True for long positions, False for short
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
Returns:
Liquidation price
"""
if leverage <= 1:
return 0.0 if is_long else float('inf')
# Simplified liquidation formula
# Long: liq_price = entry * (1 - 1/leverage + maintenance_margin_rate)
# Short: liq_price = entry * (1 + 1/leverage - maintenance_margin_rate)
margin_ratio = 1 / leverage
if is_long:
liq_price = entry_price * (1 - margin_ratio + maintenance_margin_rate)
else:
liq_price = entry_price * (1 + margin_ratio - maintenance_margin_rate)
return liq_price

245
engine/optimizer.py Normal file
View File

@@ -0,0 +1,245 @@
"""
Walk-Forward Analysis optimizer for strategy parameter optimization.
Implements expanding window walk-forward analysis with train/test splits.
"""
import numpy as np
import pandas as pd
import vectorbt as vbt
from engine.logging_config import get_logger
logger = get_logger(__name__)
def create_rolling_windows(
index: pd.Index,
n_windows: int,
train_split: float = 0.7
):
"""
Create rolling train/test split indices using expanding window approach.
Args:
index: DataFrame index to split
n_windows: Number of walk-forward windows
train_split: Unused, kept for API compatibility
Yields:
Tuples of (train_idx, test_idx) numpy arrays
"""
chunks = np.array_split(index, n_windows + 1)
for i in range(n_windows):
train_idx = np.concatenate([c for c in chunks[:i+1]])
test_idx = chunks[i+1]
yield train_idx, test_idx
class WalkForwardOptimizer:
"""
Walk-Forward Analysis optimizer for strategy backtesting.
Optimizes strategy parameters on training windows and validates
on out-of-sample test windows.
"""
def __init__(
self,
backtester,
strategy,
param_grid: dict,
metric: str = 'Sharpe Ratio',
fees: float = 0.001,
freq: str = '1m'
):
"""
Initialize the optimizer.
Args:
backtester: Backtester instance
strategy: Strategy instance to optimize
param_grid: Parameter grid for optimization
metric: Performance metric to optimize
fees: Transaction fees for simulation
freq: Data frequency for portfolio simulation
"""
self.bt = backtester
self.strategy = strategy
self.param_grid = param_grid
self.metric = metric
self.fees = fees
self.freq = freq
# Separate grid params (lists) from fixed params (scalars)
self.grid_keys = []
self.fixed_params = {}
for k, v in param_grid.items():
if isinstance(v, (list, np.ndarray)):
self.grid_keys.append(k)
else:
self.fixed_params[k] = v
def run(
self,
close_price: pd.Series,
high: pd.Series | None = None,
low: pd.Series | None = None,
n_windows: int = 10
) -> tuple[pd.DataFrame, pd.Series | None]:
"""
Execute walk-forward analysis.
Args:
close_price: Close price series
high: High price series (optional)
low: Low price series (optional)
n_windows: Number of walk-forward windows
Returns:
Tuple of (results DataFrame, stitched equity curve)
"""
results = []
equity_curves = []
logger.info(
"Starting Walk-Forward Analysis with %d windows (Expanding Train)...",
n_windows
)
splitter = create_rolling_windows(close_price.index, n_windows)
for i, (train_idx, test_idx) in enumerate(splitter):
logger.info("Processing Window %d/%d...", i + 1, n_windows)
window_result = self._process_window(
i, train_idx, test_idx, close_price, high, low
)
if window_result is not None:
result_dict, eq_curve = window_result
results.append(result_dict)
equity_curves.append(eq_curve)
stitched_series = self._stitch_equity_curves(equity_curves)
return pd.DataFrame(results), stitched_series
def _process_window(
self,
window_idx: int,
train_idx: np.ndarray,
test_idx: np.ndarray,
close_price: pd.Series,
high: pd.Series | None,
low: pd.Series | None
) -> tuple[dict, pd.Series] | None:
"""Process a single WFA window."""
try:
# Slice data for train/test
train_close = close_price.loc[train_idx]
train_high = high.loc[train_idx] if high is not None else None
train_low = low.loc[train_idx] if low is not None else None
# Train phase: find best parameters
best_params, best_score = self._optimize_train(
train_close, train_high, train_low
)
# Test phase: validate with best params
test_close = close_price.loc[test_idx]
test_high = high.loc[test_idx] if high is not None else None
test_low = low.loc[test_idx] if low is not None else None
test_params = {**self.fixed_params, **best_params}
test_score, test_return, eq_curve = self._run_test(
test_close, test_high, test_low, test_params
)
return {
'window': window_idx + 1,
'train_start': train_idx[0],
'train_end': train_idx[-1],
'test_start': test_idx[0],
'test_end': test_idx[-1],
'best_params': best_params,
'train_score': best_score,
'test_score': test_score,
'test_return': test_return
}, eq_curve
except Exception as e:
logger.error("Error in window %d: %s", window_idx + 1, e, exc_info=True)
return None
def _optimize_train(
self,
close: pd.Series,
high: pd.Series | None,
low: pd.Series | None
) -> tuple[dict, float]:
"""Run grid search on training data to find best parameters."""
entries, exits = self.strategy.run(
close, high=high, low=low, **self.param_grid
)
pf_train = vbt.Portfolio.from_signals(
close, entries, exits,
fees=self.fees,
freq=self.freq
)
perf_stats = pf_train.sharpe_ratio()
perf_stats = perf_stats.fillna(-999)
best_idx = perf_stats.idxmax()
best_score = perf_stats.max()
# Extract best params from grid search
if len(self.grid_keys) == 1:
best_params = {self.grid_keys[0]: best_idx}
elif len(self.grid_keys) > 1:
best_params = dict(zip(self.grid_keys, best_idx))
else:
best_params = {}
return best_params, best_score
def _run_test(
self,
close: pd.Series,
high: pd.Series | None,
low: pd.Series | None,
params: dict
) -> tuple[float, float, pd.Series]:
"""Run test phase with given parameters."""
entries, exits = self.strategy.run(
close, high=high, low=low, **params
)
pf_test = vbt.Portfolio.from_signals(
close, entries, exits,
fees=self.fees,
freq=self.freq
)
return pf_test.sharpe_ratio(), pf_test.total_return(), pf_test.value()
def _stitch_equity_curves(
self,
equity_curves: list[pd.Series]
) -> pd.Series | None:
"""Stitch multiple equity curves into a continuous series."""
if not equity_curves:
return None
stitched = [equity_curves[0]]
for j in range(1, len(equity_curves)):
prev_end_val = stitched[-1].iloc[-1]
curr_curve = equity_curves[j]
init_cash = curr_curve.iloc[0]
# Scale curve to continue from previous end value
scaled_curve = (curr_curve / init_cash) * prev_end_val
stitched.append(scaled_curve)
return pd.concat(stitched)

148
engine/portfolio.py Normal file
View File

@@ -0,0 +1,148 @@
"""
Portfolio simulation utilities for backtesting.
Handles long-only and long/short portfolio creation using VectorBT.
"""
import pandas as pd
import vectorbt as vbt
from engine.logging_config import get_logger
logger = get_logger(__name__)
def run_long_only_portfolio(
close: pd.Series,
entries: pd.DataFrame,
exits: pd.DataFrame,
init_cash: float,
fees: float,
slippage: float,
freq: str,
sl_stop: float | None,
tp_stop: float | None,
sl_trail: bool,
leverage: int
) -> vbt.Portfolio:
"""
Run a long-only portfolio simulation.
Args:
close: Close price series
entries: Entry signals
exits: Exit signals
init_cash: Initial capital
fees: Transaction fee percentage
slippage: Slippage percentage
freq: Data frequency string
sl_stop: Stop loss percentage
tp_stop: Take profit percentage
sl_trail: Enable trailing stop loss
leverage: Leverage multiplier
Returns:
VectorBT Portfolio object
"""
effective_cash = init_cash * leverage
return vbt.Portfolio.from_signals(
close=close,
entries=entries,
exits=exits,
init_cash=effective_cash,
fees=fees,
slippage=slippage,
freq=freq,
sl_stop=sl_stop,
tp_stop=tp_stop,
sl_trail=sl_trail,
size=1.0,
size_type='percent',
)
def run_long_short_portfolio(
close: pd.Series,
long_entries: pd.DataFrame,
long_exits: pd.DataFrame,
short_entries: pd.DataFrame,
short_exits: pd.DataFrame,
init_cash: float,
fees: float,
slippage: float,
freq: str,
sl_stop: float | None,
tp_stop: float | None,
sl_trail: bool,
leverage: int
) -> vbt.Portfolio:
"""
Run a portfolio supporting both long and short positions.
Runs two separate portfolios (long and short) and combines results.
Each gets half the capital.
Args:
close: Close price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
init_cash: Initial capital
fees: Transaction fee percentage
slippage: Slippage percentage
freq: Data frequency string
sl_stop: Stop loss percentage
tp_stop: Take profit percentage
sl_trail: Enable trailing stop loss
leverage: Leverage multiplier
Returns:
VectorBT Portfolio object (long portfolio, short stats logged)
"""
effective_cash = init_cash * leverage
half_cash = effective_cash / 2
# Run long-only portfolio
long_pf = vbt.Portfolio.from_signals(
close=close,
entries=long_entries,
exits=long_exits,
direction='longonly',
init_cash=half_cash,
fees=fees,
slippage=slippage,
freq=freq,
sl_stop=sl_stop,
tp_stop=tp_stop,
sl_trail=sl_trail,
size=1.0,
size_type='percent',
)
# Run short-only portfolio
short_pf = vbt.Portfolio.from_signals(
close=close,
entries=short_entries,
exits=short_exits,
direction='shortonly',
init_cash=half_cash,
fees=fees,
slippage=slippage,
freq=freq,
sl_stop=sl_stop,
tp_stop=tp_stop,
sl_trail=sl_trail,
size=1.0,
size_type='percent',
)
# Log both portfolio stats
# TODO: Implement proper portfolio combination
logger.info(
"Long portfolio: %.2f%% return, Short portfolio: %.2f%% return",
long_pf.total_return().mean() * 100,
short_pf.total_return().mean() * 100
)
return long_pf

228
engine/reporting.py Normal file
View File

@@ -0,0 +1,228 @@
"""
Reporting module for backtest results.
Handles summary printing, CSV exports, and plotting.
"""
from pathlib import Path
import pandas as pd
import plotly.graph_objects as go
import vectorbt as vbt
from plotly.subplots import make_subplots
from engine.logging_config import get_logger
logger = get_logger(__name__)
class Reporter:
"""Reporter for backtest results with market-specific metrics."""
def __init__(self, output_dir: str = "backtest_logs"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
def print_summary(self, result) -> None:
"""
Print backtest summary to console via logger.
Args:
result: BacktestResult or vbt.Portfolio object
"""
(portfolio, market_type, leverage, funding_paid,
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
# Extract period info
idx = portfolio.wrapper.index
start_date = idx[0].strftime("%Y-%m-%d")
end_date = idx[-1].strftime("%Y-%m-%d")
# Extract price info
close = portfolio.close
start_price = close.iloc[0].mean() if hasattr(close.iloc[0], 'mean') else close.iloc[0]
end_price = close.iloc[-1].mean() if hasattr(close.iloc[-1], 'mean') else close.iloc[-1]
price_change = ((end_price - start_price) / start_price) * 100
# Extract fees
stats = portfolio.stats()
total_fees = stats.get('Total Fees Paid', 0)
raw_return = portfolio.total_return().mean() * 100
# Build summary
summary_lines = [
"",
"=" * 50,
"BACKTEST RESULTS",
"=" * 50,
f"Market Type: [{market_type.upper()}]",
f"Leverage: [{leverage}x]",
f"Period: [{start_date} to {end_date}]",
f"Price: [{start_price:,.2f} -> {end_price:,.2f} ({price_change:+.2f}%)]",
]
# Show adjusted return if liquidations occurred
if liq_count > 0 and adjusted_return is not None:
summary_lines.append(f"Raw Return: [%{raw_return:.2f}] (before liq adjustment)")
summary_lines.append(f"Adj Return: [%{adjusted_return:.2f}] (after liq losses)")
else:
summary_lines.append(f"Total Return: [%{raw_return:.2f}]")
summary_lines.extend([
f"Sharpe Ratio: [{portfolio.sharpe_ratio().mean():.2f}]",
f"Max Drawdown: [%{portfolio.max_drawdown().mean() * 100:.2f}]",
f"Total Trades: [{portfolio.trades.count().mean():.0f}]",
f"Win Rate: [%{portfolio.trades.win_rate().mean() * 100:.2f}]",
f"Total Fees: [{total_fees:,.2f}]",
])
if funding_paid != 0:
summary_lines.append(f"Funding Paid: [{funding_paid:,.2f}]")
if liq_count > 0:
summary_lines.append(f"Liquidations: [{liq_count}] (${liq_loss:,.2f} margin lost)")
summary_lines.append("=" * 50)
logger.info("\n".join(summary_lines))
def save_reports(self, result, filename_prefix: str) -> None:
"""
Save trade log, stats, and liquidation events to CSV files.
Args:
result: BacktestResult or vbt.Portfolio object
filename_prefix: Prefix for output filenames
"""
(portfolio, market_type, leverage, funding_paid,
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
# Save trades
self._save_csv(
data=portfolio.trades.records_readable,
path=self.output_dir / f"{filename_prefix}_trades.csv",
description="trade log"
)
# Save stats with market-specific additions
stats = portfolio.stats()
stats['Market Type'] = market_type
stats['Leverage'] = leverage
stats['Total Funding Paid'] = funding_paid
stats['Liquidations'] = liq_count
stats['Liquidation Loss'] = liq_loss
if adjusted_return is not None:
stats['Adjusted Return'] = adjusted_return
self._save_csv(
data=stats,
path=self.output_dir / f"{filename_prefix}_stats.csv",
description="stats"
)
# Save liquidation events if any
if hasattr(result, 'liquidation_events') and result.liquidation_events:
liq_df = pd.DataFrame([
{
'entry_time': e.entry_time,
'entry_price': e.entry_price,
'liquidation_time': e.liquidation_time,
'liquidation_price': e.liquidation_price,
'actual_price': e.actual_price,
'direction': e.direction,
'margin_lost_pct': e.margin_lost_pct
}
for e in result.liquidation_events
])
self._save_csv(
data=liq_df,
path=self.output_dir / f"{filename_prefix}_liquidations.csv",
description="liquidation events"
)
def plot(self, portfolio: vbt.Portfolio, show: bool = True) -> None:
"""Display portfolio plot."""
if show:
portfolio.plot().show()
def plot_wfa(
self,
wfa_results: pd.DataFrame,
stitched_curve: pd.Series | None = None,
show: bool = True
) -> None:
"""
Plot Walk-Forward Analysis results.
Args:
wfa_results: DataFrame with WFA window results
stitched_curve: Stitched out-of-sample equity curve
show: Whether to display the plot
"""
fig = make_subplots(
rows=2, cols=1,
shared_xaxes=False,
vertical_spacing=0.1,
subplot_titles=(
"Walk-Forward Test Scores (Sharpe)",
"Stitched Out-of-Sample Equity"
)
)
fig.add_trace(
go.Bar(
x=wfa_results['window'],
y=wfa_results['test_score'],
name="Test Sharpe"
),
row=1, col=1
)
if stitched_curve is not None:
fig.add_trace(
go.Scatter(
x=stitched_curve.index,
y=stitched_curve.values,
name="Equity",
mode='lines'
),
row=2, col=1
)
fig.update_layout(height=800, title_text="Walk-Forward Analysis Report")
if show:
fig.show()
def _extract_result_data(self, result) -> tuple:
"""
Extract data from BacktestResult or raw Portfolio.
Returns:
Tuple of (portfolio, market_type, leverage, funding_paid, liq_count,
liq_loss, adjusted_return)
"""
if hasattr(result, 'portfolio'):
return (
result.portfolio,
result.market_type.value,
result.leverage,
result.total_funding_paid,
result.liquidation_count,
getattr(result, 'total_liquidation_loss', 0.0),
getattr(result, 'adjusted_return', None)
)
return (result, "unknown", 1, 0.0, 0, 0.0, None)
def _save_csv(self, data, path: Path, description: str) -> None:
"""
Save data to CSV with consistent error handling.
Args:
data: DataFrame or Series to save
path: Output file path
description: Human-readable description for logging
"""
try:
data.to_csv(path)
logger.info("Saved %s to %s", description, path)
except Exception as e:
logger.error("Could not save %s: %s", description, e)

395
engine/risk.py Normal file
View File

@@ -0,0 +1,395 @@
"""
Risk management utilities for backtesting.
Handles funding rate calculations and liquidation detection.
"""
from dataclasses import dataclass
import pandas as pd
from engine.logging_config import get_logger
from engine.market import MarketConfig, calculate_liquidation_price
logger = get_logger(__name__)
@dataclass
class LiquidationEvent:
"""
Record of a liquidation event during backtesting.
Attributes:
entry_time: Timestamp when position was opened
entry_price: Price at position entry
liquidation_time: Timestamp when liquidation occurred
liquidation_price: Calculated liquidation price
actual_price: Actual price that triggered liquidation (high/low)
direction: 'long' or 'short'
margin_lost_pct: Percentage of margin lost (typically 100%)
"""
entry_time: pd.Timestamp
entry_price: float
liquidation_time: pd.Timestamp
liquidation_price: float
actual_price: float
direction: str
margin_lost_pct: float = 1.0
def calculate_funding(
close: pd.Series,
long_entries: pd.DataFrame,
long_exits: pd.DataFrame,
short_entries: pd.DataFrame,
short_exits: pd.DataFrame,
market_config: MarketConfig,
leverage: int
) -> float:
"""
Calculate total funding paid/received for perpetual positions.
Simplified model: applies funding rate every 8 hours to open positions.
Positive rate means longs pay shorts.
Args:
close: Price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
market_config: Market configuration with funding parameters
leverage: Position leverage
Returns:
Total funding paid (positive) or received (negative)
"""
if market_config.funding_interval_hours == 0:
return 0.0
funding_rate = market_config.funding_rate
interval_hours = market_config.funding_interval_hours
# Determine position state at each bar
long_position = long_entries.cumsum() - long_exits.cumsum()
short_position = short_entries.cumsum() - short_exits.cumsum()
# Clamp to 0/1 (either in position or not)
long_position = (long_position > 0).astype(int)
short_position = (short_position > 0).astype(int)
# Find funding timestamps (every 8 hours: 00:00, 08:00, 16:00 UTC)
funding_times = close.index[close.index.hour % interval_hours == 0]
total_funding = 0.0
for ts in funding_times:
if ts not in close.index:
continue
price = close.loc[ts]
# Long pays funding, short receives (when rate > 0)
if isinstance(long_position, pd.DataFrame):
long_open = long_position.loc[ts].any()
short_open = short_position.loc[ts].any()
else:
long_open = long_position.loc[ts] > 0
short_open = short_position.loc[ts] > 0
position_value = price * leverage
if long_open:
total_funding += position_value * funding_rate
if short_open:
total_funding -= position_value * funding_rate
return total_funding
def inject_liquidation_exits(
close: pd.Series,
high: pd.Series,
low: pd.Series,
long_entries: pd.DataFrame | pd.Series,
long_exits: pd.DataFrame | pd.Series,
short_entries: pd.DataFrame | pd.Series,
short_exits: pd.DataFrame | pd.Series,
leverage: int,
maintenance_margin_rate: float
) -> tuple[pd.DataFrame | pd.Series, pd.DataFrame | pd.Series, list[LiquidationEvent]]:
"""
Modify exit signals to force position closure at liquidation points.
This function simulates realistic liquidation behavior by:
1. Finding positions that would be liquidated before their normal exit
2. Injecting forced exit signals at the liquidation bar
3. Recording all liquidation events
Args:
close: Close price series
high: High price series
low: Low price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
leverage: Position leverage
maintenance_margin_rate: Maintenance margin rate for liquidation
Returns:
Tuple of (modified_long_exits, modified_short_exits, liquidation_events)
"""
if leverage <= 1:
return long_exits, short_exits, []
liquidation_events: list[LiquidationEvent] = []
# Convert to DataFrame if Series for consistent handling
is_series = isinstance(long_entries, pd.Series)
if is_series:
long_entries_df = long_entries.to_frame()
long_exits_df = long_exits.to_frame()
short_entries_df = short_entries.to_frame()
short_exits_df = short_exits.to_frame()
else:
long_entries_df = long_entries
long_exits_df = long_exits.copy()
short_entries_df = short_entries
short_exits_df = short_exits.copy()
modified_long_exits = long_exits_df.copy()
modified_short_exits = short_exits_df.copy()
# Process long positions
long_mask = long_entries_df.any(axis=1)
for entry_idx in close.index[long_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=True,
maintenance_margin_rate=maintenance_margin_rate
)
# Find the normal exit for this entry
subsequent_exits = long_exits_df.loc[entry_idx:].any(axis=1)
exit_indices = subsequent_exits[subsequent_exits].index
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
# Check if liquidation occurs before normal exit
price_range = low.loc[entry_idx:normal_exit_idx]
if (price_range < liq_price).any():
liq_bar = price_range[price_range < liq_price].index[0]
# Inject forced exit at liquidation bar
for col in modified_long_exits.columns:
modified_long_exits.loc[liq_bar, col] = True
# Record the liquidation event
liquidation_events.append(LiquidationEvent(
entry_time=entry_idx,
entry_price=entry_price,
liquidation_time=liq_bar,
liquidation_price=liq_price,
actual_price=low.loc[liq_bar],
direction='long',
margin_lost_pct=1.0
))
logger.warning(
"LIQUIDATION (Long): Entry %s ($%.2f) -> Liquidated %s "
"(liq=$%.2f, low=$%.2f)",
entry_idx.strftime('%Y-%m-%d'), entry_price,
liq_bar.strftime('%Y-%m-%d'), liq_price, low.loc[liq_bar]
)
# Process short positions
short_mask = short_entries_df.any(axis=1)
for entry_idx in close.index[short_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=False,
maintenance_margin_rate=maintenance_margin_rate
)
# Find the normal exit for this entry
subsequent_exits = short_exits_df.loc[entry_idx:].any(axis=1)
exit_indices = subsequent_exits[subsequent_exits].index
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
# Check if liquidation occurs before normal exit
price_range = high.loc[entry_idx:normal_exit_idx]
if (price_range > liq_price).any():
liq_bar = price_range[price_range > liq_price].index[0]
# Inject forced exit at liquidation bar
for col in modified_short_exits.columns:
modified_short_exits.loc[liq_bar, col] = True
# Record the liquidation event
liquidation_events.append(LiquidationEvent(
entry_time=entry_idx,
entry_price=entry_price,
liquidation_time=liq_bar,
liquidation_price=liq_price,
actual_price=high.loc[liq_bar],
direction='short',
margin_lost_pct=1.0
))
logger.warning(
"LIQUIDATION (Short): Entry %s ($%.2f) -> Liquidated %s "
"(liq=$%.2f, high=$%.2f)",
entry_idx.strftime('%Y-%m-%d'), entry_price,
liq_bar.strftime('%Y-%m-%d'), liq_price, high.loc[liq_bar]
)
# Convert back to Series if input was Series
if is_series:
modified_long_exits = modified_long_exits.iloc[:, 0]
modified_short_exits = modified_short_exits.iloc[:, 0]
return modified_long_exits, modified_short_exits, liquidation_events
def calculate_liquidation_adjustment(
liquidation_events: list[LiquidationEvent],
init_cash: float,
leverage: int
) -> tuple[float, float]:
"""
Calculate the return adjustment for liquidated positions.
VectorBT calculates trade P&L using close price at exit bar.
For liquidations, the actual loss is 100% of the position margin.
This function calculates the difference between what VectorBT
recorded and what actually would have happened.
In our portfolio setup:
- Long/short each get half the capital (init_cash * leverage / 2)
- Each trade uses 100% of that allocation (size=1.0, percent)
- On liquidation, the margin for that trade is lost entirely
The adjustment is the DIFFERENCE between:
- VectorBT's calculated P&L (exit at close price)
- Actual liquidation P&L (100% margin loss)
Args:
liquidation_events: List of liquidation events
init_cash: Initial portfolio cash (before leverage)
leverage: Position leverage used
Returns:
Tuple of (total_margin_lost, adjustment_pct)
- total_margin_lost: Estimated total margin lost from liquidations
- adjustment_pct: Percentage adjustment to apply to returns
"""
if not liquidation_events:
return 0.0, 0.0
# In our setup, each side (long/short) gets half the capital
# Margin per side = init_cash / 2
margin_per_side = init_cash / 2
# For each liquidation, VectorBT recorded some P&L based on close price
# The actual P&L should be -100% of the margin used for that trade
#
# We estimate the adjustment as:
# - Each liquidation should have resulted in ~-20% loss (at 5x leverage)
# - VectorBT may have recorded a different value
# - The margin loss is (1/leverage) per trade that gets liquidated
# Calculate liquidation loss rate based on leverage
# At 5x leverage, liquidation = ~19.6% adverse move = 100% margin loss
liq_loss_rate = 1.0 / leverage # Approximate loss per trade as % of position
# Count liquidations
n_liquidations = len(liquidation_events)
# Estimate total margin lost:
# Each liquidation on average loses the margin for that trade
# Since VectorBT uses half capital per side, and we trade 100% size,
# each liquidation loses approximately margin_per_side
# But we cap at available capital
total_margin_lost = min(n_liquidations * margin_per_side * liq_loss_rate, init_cash)
# Calculate as percentage of initial capital
adjustment_pct = (total_margin_lost / init_cash) * 100
return total_margin_lost, adjustment_pct
def check_liquidations(
close: pd.Series,
high: pd.Series,
low: pd.Series,
long_entries: pd.DataFrame,
long_exits: pd.DataFrame,
short_entries: pd.DataFrame,
short_exits: pd.DataFrame,
leverage: int,
maintenance_margin_rate: float
) -> int:
"""
Check for liquidation events and log warnings.
Args:
close: Close price series
high: High price series
low: Low price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
leverage: Position leverage
maintenance_margin_rate: Maintenance margin rate for liquidation
Returns:
Count of liquidation warnings
"""
warnings = 0
# For long positions
long_mask = (
long_entries.any(axis=1)
if isinstance(long_entries, pd.DataFrame)
else long_entries
)
for entry_idx in close.index[long_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=True,
maintenance_margin_rate=maintenance_margin_rate
)
subsequent = low.loc[entry_idx:]
if (subsequent < liq_price).any():
liq_bar = subsequent[subsequent < liq_price].index[0]
logger.warning(
"LIQUIDATION WARNING (Long): Entry at %s ($%.2f), "
"would liquidate at %s (liq_price=$%.2f, low=$%.2f)",
entry_idx, entry_price, liq_bar, liq_price, low.loc[liq_bar]
)
warnings += 1
# For short positions
short_mask = (
short_entries.any(axis=1)
if isinstance(short_entries, pd.DataFrame)
else short_entries
)
for entry_idx in close.index[short_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=False,
maintenance_margin_rate=maintenance_margin_rate
)
subsequent = high.loc[entry_idx:]
if (subsequent > liq_price).any():
liq_bar = subsequent[subsequent > liq_price].index[0]
logger.warning(
"LIQUIDATION WARNING (Short): Entry at %s ($%.2f), "
"would liquidate at %s (liq_price=$%.2f, high=$%.2f)",
entry_idx, entry_price, liq_bar, liq_price, high.loc[liq_bar]
)
warnings += 1
return warnings