352 lines
12 KiB
Python
352 lines
12 KiB
Python
"""
|
|
Core backtesting engine for running strategy simulations.
|
|
|
|
Supports multiple market types with realistic trading conditions.
|
|
"""
|
|
from dataclasses import dataclass
|
|
|
|
import pandas as pd
|
|
import vectorbt as vbt
|
|
|
|
from engine.data_manager import DataManager
|
|
from engine.logging_config import get_logger
|
|
from engine.market import MarketType, get_market_config
|
|
from engine.optimizer import WalkForwardOptimizer
|
|
from engine.portfolio import run_long_only_portfolio, run_long_short_portfolio
|
|
from engine.risk import (
|
|
LiquidationEvent,
|
|
calculate_funding,
|
|
calculate_liquidation_adjustment,
|
|
inject_liquidation_exits,
|
|
)
|
|
from strategies.base import BaseStrategy
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BacktestResult:
|
|
"""
|
|
Container for backtest results with market-specific metrics.
|
|
|
|
Attributes:
|
|
portfolio: VectorBT Portfolio object
|
|
market_type: Market type used for the backtest
|
|
leverage: Effective leverage used
|
|
total_funding_paid: Total funding fees paid (perpetuals only)
|
|
liquidation_count: Number of positions that were liquidated
|
|
liquidation_events: Detailed list of liquidation events
|
|
total_liquidation_loss: Total margin lost from liquidations
|
|
adjusted_return: Return adjusted for liquidation losses (percentage)
|
|
"""
|
|
portfolio: vbt.Portfolio
|
|
market_type: MarketType
|
|
leverage: int
|
|
total_funding_paid: float = 0.0
|
|
liquidation_count: int = 0
|
|
liquidation_events: list[LiquidationEvent] | None = None
|
|
total_liquidation_loss: float = 0.0
|
|
adjusted_return: float | None = None
|
|
|
|
|
|
class Backtester:
|
|
"""
|
|
Backtester supporting multiple market types with realistic simulation.
|
|
|
|
Features:
|
|
- Spot and Perpetual market support
|
|
- Long and short position handling
|
|
- Leverage simulation
|
|
- Funding rate calculation (perpetuals)
|
|
- Liquidation warnings
|
|
"""
|
|
|
|
def __init__(self, data_manager: DataManager):
|
|
self.dm = data_manager
|
|
|
|
def run_strategy(
|
|
self,
|
|
strategy: BaseStrategy,
|
|
exchange_id: str,
|
|
symbol: str,
|
|
timeframe: str = '1m',
|
|
start_date: str | None = None,
|
|
end_date: str | None = None,
|
|
init_cash: float = 10000,
|
|
fees: float | None = None,
|
|
slippage: float = 0.001,
|
|
sl_stop: float | None = None,
|
|
tp_stop: float | None = None,
|
|
sl_trail: bool = False,
|
|
leverage: int | None = None,
|
|
**strategy_params
|
|
) -> BacktestResult:
|
|
"""
|
|
Run a backtest with market-type-aware simulation.
|
|
|
|
Args:
|
|
strategy: Strategy instance to backtest
|
|
exchange_id: Exchange identifier (e.g., 'okx')
|
|
symbol: Trading pair (e.g., 'BTC/USDT')
|
|
timeframe: Data timeframe (e.g., '1m', '1h', '1d')
|
|
start_date: Start date filter (YYYY-MM-DD)
|
|
end_date: End date filter (YYYY-MM-DD)
|
|
init_cash: Initial capital (margin for leveraged)
|
|
fees: Transaction fee override (uses market default if None)
|
|
slippage: Slippage percentage
|
|
sl_stop: Stop loss percentage
|
|
tp_stop: Take profit percentage
|
|
sl_trail: Enable trailing stop loss
|
|
leverage: Leverage override (uses strategy default if None)
|
|
**strategy_params: Additional strategy parameters
|
|
|
|
Returns:
|
|
BacktestResult with portfolio and market-specific metrics
|
|
"""
|
|
# Get market configuration from strategy
|
|
market_type = strategy.default_market_type
|
|
market_config = get_market_config(market_type)
|
|
|
|
# Resolve leverage and fees
|
|
effective_leverage = self._resolve_leverage(leverage, strategy, market_type)
|
|
effective_fees = fees if fees is not None else market_config.taker_fee
|
|
|
|
# Load and filter data
|
|
df = self._load_data(
|
|
exchange_id, symbol, timeframe, market_type, start_date, end_date
|
|
)
|
|
|
|
close_price = df['close']
|
|
high_price = df['high']
|
|
low_price = df['low']
|
|
open_price = df['open']
|
|
volume = df['volume']
|
|
|
|
# Run strategy logic
|
|
signals = strategy.run(
|
|
close_price,
|
|
high=high_price,
|
|
low=low_price,
|
|
open=open_price,
|
|
volume=volume,
|
|
**strategy_params
|
|
)
|
|
|
|
# Normalize signals to 4-tuple format
|
|
signals = self._normalize_signals(signals, close_price, market_config)
|
|
long_entries, long_exits, short_entries, short_exits = signals
|
|
|
|
# Process liquidations - inject forced exits at liquidation points
|
|
liquidation_events: list[LiquidationEvent] = []
|
|
if effective_leverage > 1:
|
|
long_exits, short_exits, liquidation_events = inject_liquidation_exits(
|
|
close_price, high_price, low_price,
|
|
long_entries, long_exits,
|
|
short_entries, short_exits,
|
|
effective_leverage,
|
|
market_config.maintenance_margin_rate
|
|
)
|
|
|
|
# Calculate perpetual-specific metrics (after liquidation processing)
|
|
total_funding = 0.0
|
|
if market_type == MarketType.PERPETUAL:
|
|
total_funding = calculate_funding(
|
|
close_price,
|
|
long_entries, long_exits,
|
|
short_entries, short_exits,
|
|
market_config,
|
|
effective_leverage
|
|
)
|
|
|
|
# Run portfolio simulation with liquidation-aware exits
|
|
portfolio = self._run_portfolio(
|
|
close_price, market_config,
|
|
long_entries, long_exits,
|
|
short_entries, short_exits,
|
|
init_cash, effective_fees, slippage, timeframe,
|
|
sl_stop, tp_stop, sl_trail, effective_leverage
|
|
)
|
|
|
|
# Calculate adjusted returns accounting for liquidation losses
|
|
total_liq_loss, liq_adjustment = calculate_liquidation_adjustment(
|
|
liquidation_events, init_cash, effective_leverage
|
|
)
|
|
|
|
raw_return = portfolio.total_return().mean() * 100
|
|
adjusted_return = raw_return - liq_adjustment
|
|
|
|
if liquidation_events:
|
|
logger.info(
|
|
"Liquidation impact: %d events, $%.2f margin lost, %.2f%% adjustment",
|
|
len(liquidation_events), total_liq_loss, liq_adjustment
|
|
)
|
|
|
|
logger.info(
|
|
"Backtest completed: %s market, %dx leverage, fees=%.4f%%",
|
|
market_type.value, effective_leverage, effective_fees * 100
|
|
)
|
|
|
|
return BacktestResult(
|
|
portfolio=portfolio,
|
|
market_type=market_type,
|
|
leverage=effective_leverage,
|
|
total_funding_paid=total_funding,
|
|
liquidation_count=len(liquidation_events),
|
|
liquidation_events=liquidation_events,
|
|
total_liquidation_loss=total_liq_loss,
|
|
adjusted_return=adjusted_return
|
|
)
|
|
|
|
def _resolve_leverage(
|
|
self,
|
|
leverage: int | None,
|
|
strategy: BaseStrategy,
|
|
market_type: MarketType
|
|
) -> int:
|
|
"""Resolve effective leverage from CLI, strategy default, or market type."""
|
|
effective = leverage or strategy.default_leverage
|
|
if market_type == MarketType.SPOT:
|
|
return 1 # Spot cannot have leverage
|
|
return effective
|
|
|
|
def _load_data(
|
|
self,
|
|
exchange_id: str,
|
|
symbol: str,
|
|
timeframe: str,
|
|
market_type: MarketType,
|
|
start_date: str | None,
|
|
end_date: str | None
|
|
) -> pd.DataFrame:
|
|
"""Load and filter OHLCV data."""
|
|
try:
|
|
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
|
|
except FileNotFoundError:
|
|
logger.warning("Data not found locally. Attempting download...")
|
|
df = self.dm.download_data(
|
|
exchange_id, symbol, timeframe,
|
|
start_date, end_date, market_type
|
|
)
|
|
|
|
if start_date:
|
|
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
|
|
if end_date:
|
|
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
|
|
|
|
return df
|
|
|
|
def _normalize_signals(
|
|
self,
|
|
signals: tuple,
|
|
close: pd.Series,
|
|
market_config
|
|
) -> tuple:
|
|
"""
|
|
Normalize strategy signals to 4-tuple format.
|
|
|
|
Handles backward compatibility with 2-tuple (long-only) returns.
|
|
"""
|
|
if len(signals) == 2:
|
|
long_entries, long_exits = signals
|
|
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
|
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
|
return long_entries, long_exits, short_entries, short_exits
|
|
|
|
if len(signals) == 4:
|
|
long_entries, long_exits, short_entries, short_exits = signals
|
|
|
|
# Warn and clear short signals on spot markets
|
|
if not market_config.supports_short:
|
|
has_shorts = (
|
|
short_entries.any().any()
|
|
if hasattr(short_entries, 'any')
|
|
else short_entries.any()
|
|
)
|
|
if has_shorts:
|
|
logger.warning(
|
|
"Short signals detected but market type is SPOT. "
|
|
"Short signals will be ignored."
|
|
)
|
|
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
|
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
|
|
|
return long_entries, long_exits, short_entries, short_exits
|
|
|
|
raise ValueError(
|
|
f"Strategy must return 2 or 4 signal arrays, got {len(signals)}"
|
|
)
|
|
|
|
def _run_portfolio(
|
|
self,
|
|
close: pd.Series,
|
|
market_config,
|
|
long_entries, long_exits,
|
|
short_entries, short_exits,
|
|
init_cash: float,
|
|
fees: float,
|
|
slippage: float,
|
|
freq: str,
|
|
sl_stop: float | None,
|
|
tp_stop: float | None,
|
|
sl_trail: bool,
|
|
leverage: int
|
|
) -> vbt.Portfolio:
|
|
"""Select and run appropriate portfolio simulation."""
|
|
has_shorts = (
|
|
short_entries.any().any()
|
|
if hasattr(short_entries, 'any')
|
|
else short_entries.any()
|
|
)
|
|
|
|
if market_config.supports_short and has_shorts:
|
|
return run_long_short_portfolio(
|
|
close,
|
|
long_entries, long_exits,
|
|
short_entries, short_exits,
|
|
init_cash, fees, slippage, freq,
|
|
sl_stop, tp_stop, sl_trail, leverage
|
|
)
|
|
|
|
return run_long_only_portfolio(
|
|
close,
|
|
long_entries, long_exits,
|
|
init_cash, fees, slippage, freq,
|
|
sl_stop, tp_stop, sl_trail, leverage
|
|
)
|
|
|
|
def run_wfa(
|
|
self,
|
|
strategy: BaseStrategy,
|
|
exchange_id: str,
|
|
symbol: str,
|
|
param_grid: dict,
|
|
n_windows: int = 10,
|
|
timeframe: str = '1m'
|
|
):
|
|
"""
|
|
Execute Walk-Forward Analysis.
|
|
|
|
Args:
|
|
strategy: Strategy instance to optimize
|
|
exchange_id: Exchange identifier
|
|
symbol: Trading pair symbol
|
|
param_grid: Parameter grid for optimization
|
|
n_windows: Number of walk-forward windows
|
|
timeframe: Data timeframe to load
|
|
|
|
Returns:
|
|
Tuple of (results DataFrame, stitched equity curve)
|
|
"""
|
|
market_type = strategy.default_market_type
|
|
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
|
|
|
|
wfa = WalkForwardOptimizer(self, strategy, param_grid)
|
|
|
|
results, stitched_curve = wfa.run(
|
|
df['close'],
|
|
high=df['high'],
|
|
low=df['low'],
|
|
n_windows=n_windows
|
|
)
|
|
|
|
return results, stitched_curve |