Compare commits
2 Commits
c4aa965a98
...
e6d69ed04d
| Author | SHA1 | Date | |
|---|---|---|---|
| e6d69ed04d | |||
| 44fac1ed25 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -169,4 +169,5 @@ cython_debug/
|
||||
#.idea/
|
||||
|
||||
./logs/
|
||||
*.csv
|
||||
*.csv
|
||||
research/regime_results.html
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../OHLCVPredictor
|
||||
70
backtest.py
70
backtest.py
@@ -1,70 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from trade import TradeState, enter_long, exit_long, maybe_trailing_stop
|
||||
from indicators import add_supertrends, compute_meta_trend
|
||||
from metrics import compute_metrics
|
||||
from logging_utils import write_trade_log
|
||||
|
||||
DEFAULT_ST_SETTINGS = [(12, 3.0), (10, 1.0), (11, 2.0)]
|
||||
|
||||
def backtest(
|
||||
df: pd.DataFrame,
|
||||
df_1min: pd.DataFrame,
|
||||
timeframe_minutes: int,
|
||||
stop_loss: float,
|
||||
exit_on_bearish_flip: bool,
|
||||
fee_bps: float,
|
||||
slippage_bps: float,
|
||||
log_path: Path | None = None,
|
||||
):
|
||||
df = add_supertrends(df, DEFAULT_ST_SETTINGS)
|
||||
df["meta_bull"] = compute_meta_trend(df, DEFAULT_ST_SETTINGS)
|
||||
|
||||
state = TradeState(stop_loss_frac=stop_loss, fee_bps=fee_bps, slippage_bps=slippage_bps)
|
||||
equity, trades = [], []
|
||||
|
||||
for i, row in df.iterrows():
|
||||
price = float(row["Close"])
|
||||
ts = pd.Timestamp(row["Timestamp"])
|
||||
|
||||
if state.qty <= 0 and row["meta_bull"] == 1:
|
||||
evt = enter_long(state, price)
|
||||
if evt:
|
||||
evt.update({"t": ts.isoformat(), "reason": "bull_flip"})
|
||||
trades.append(evt)
|
||||
|
||||
start = ts
|
||||
end = df["Timestamp"].iat[i + 1] if i + 1 < len(df) else ts + pd.Timedelta(minutes=timeframe_minutes)
|
||||
|
||||
if state.qty > 0:
|
||||
win = df_1min[(df_1min["Timestamp"] >= start) & (df_1min["Timestamp"] < end)]
|
||||
for _, m in win.iterrows():
|
||||
hi = float(m["High"])
|
||||
lo = float(m["Low"])
|
||||
state.max_px = max(state.max_px or hi, hi)
|
||||
trail = state.max_px * (1.0 - state.stop_loss_frac)
|
||||
if lo <= trail:
|
||||
evt = exit_long(state, trail)
|
||||
if evt:
|
||||
prev = trades[-1]
|
||||
pnl = (evt["price"] - (prev.get("price") or evt["price"])) * (prev.get("qty") or 0.0)
|
||||
evt.update({"t": pd.Timestamp(m["Timestamp"]).isoformat(), "reason": "stop", "pnl": pnl})
|
||||
trades.append(evt)
|
||||
break
|
||||
|
||||
if state.qty > 0 and exit_on_bearish_flip and row["meta_bull"] == 0:
|
||||
evt = exit_long(state, price)
|
||||
if evt:
|
||||
prev = trades[-1]
|
||||
pnl = (evt["price"] - (prev.get("price") or evt["price"])) * (prev.get("qty") or 0.0)
|
||||
evt.update({"t": ts.isoformat(), "reason": "bearish_flip", "pnl": pnl})
|
||||
trades.append(evt)
|
||||
|
||||
equity.append(state.cash + state.qty * price)
|
||||
|
||||
equity_curve = pd.Series(equity, index=df["Timestamp"])
|
||||
if log_path:
|
||||
write_trade_log(trades, log_path)
|
||||
perf = compute_metrics(equity_curve, trades)
|
||||
return perf, equity_curve, trades
|
||||
80
cli.py
80
cli.py
@@ -1,80 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
from config import CLIConfig
|
||||
from data import load_data
|
||||
from backtest import backtest
|
||||
|
||||
def parse_args() -> CLIConfig:
|
||||
p = argparse.ArgumentParser(prog="bt", description="Simple supertrend backtester")
|
||||
p.add_argument("start")
|
||||
p.add_argument("end")
|
||||
p.add_argument("--timeframe-minutes", type=int, default=15) # single TF
|
||||
p.add_argument("--timeframes-minutes", nargs="+", type=int) # multi TF: e.g. 5 15 60 240
|
||||
p.add_argument("--stop-loss", dest="stop_losses", type=float, nargs="+", default=[0.02, 0.05])
|
||||
p.add_argument("--exit-on-bearish-flip", action="store_true")
|
||||
p.add_argument("--csv", dest="csv_path", type=Path, required=True)
|
||||
p.add_argument("--out-csv", type=Path, default=Path("summary.csv"))
|
||||
p.add_argument("--log-dir", type=Path, default=Path("./logs"))
|
||||
p.add_argument("--fee-bps", type=float, default=10.0)
|
||||
p.add_argument("--slippage-bps", type=float, default=2.0)
|
||||
a = p.parse_args()
|
||||
|
||||
return CLIConfig(
|
||||
start=a.start,
|
||||
end=a.end,
|
||||
timeframe_minutes=a.timeframe_minutes,
|
||||
timeframes_minutes=a.timeframes_minutes,
|
||||
stop_losses=a.stop_losses,
|
||||
exit_on_bearish_flip=a.exit_on_bearish_flip,
|
||||
csv_path=a.csv_path,
|
||||
out_csv=a.out_csv,
|
||||
log_dir=a.log_dir,
|
||||
fee_bps=a.fee_bps,
|
||||
slippage_bps=a.slippage_bps,
|
||||
)
|
||||
|
||||
def main():
|
||||
cfg = parse_args()
|
||||
frames = cfg.timeframes_minutes or [cfg.timeframe_minutes]
|
||||
|
||||
rows: list[dict] = []
|
||||
for tfm in frames:
|
||||
df_1min, df = load_data(cfg.start, cfg.end, tfm, cfg.csv_path)
|
||||
for sl in cfg.stop_losses:
|
||||
log_path = cfg.log_dir / f"{tfm}m_sl{sl:.2%}.csv"
|
||||
perf, equity, _ = backtest(
|
||||
df=df,
|
||||
df_1min=df_1min,
|
||||
timeframe_minutes=tfm,
|
||||
stop_loss=sl,
|
||||
exit_on_bearish_flip=cfg.exit_on_bearish_flip,
|
||||
fee_bps=cfg.fee_bps,
|
||||
slippage_bps=cfg.slippage_bps,
|
||||
log_path=log_path,
|
||||
)
|
||||
rows.append({
|
||||
"timeframe": f"{tfm}min",
|
||||
"stop_loss": sl,
|
||||
"exit_on_bearish_flip": cfg.exit_on_bearish_flip,
|
||||
"total_return": f"{perf.total_return:.2%}",
|
||||
"max_drawdown": f"{perf.max_drawdown:.2%}",
|
||||
"sharpe_ratio": f"{perf.sharpe_ratio:.2f}",
|
||||
"win_rate": f"{perf.win_rate:.2%}",
|
||||
"num_trades": perf.num_trades,
|
||||
"final_equity": f"${perf.final_equity:.2f}",
|
||||
"initial_equity": f"${perf.initial_equity:.2f}",
|
||||
"num_stop_losses": perf.num_stop_losses,
|
||||
"total_fees": perf.total_fees,
|
||||
"total_slippage_usd": perf.total_slippage_usd,
|
||||
"avg_slippage_bps": perf.avg_slippage_bps,
|
||||
})
|
||||
|
||||
out = pd.DataFrame(rows)
|
||||
out.to_csv(cfg.out_csv, index=False)
|
||||
print(out.to_string(index=False))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
18
config.py
18
config.py
@@ -1,18 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
@dataclass
|
||||
class CLIConfig:
|
||||
start: str
|
||||
end: str
|
||||
timeframe_minutes: int
|
||||
timeframes_minutes: list[int] | None
|
||||
stop_losses: Sequence[float]
|
||||
exit_on_bearish_flip: bool
|
||||
csv_path: Path | None
|
||||
out_csv: Path
|
||||
log_dir: Path
|
||||
fee_bps: float
|
||||
slippage_bps: float
|
||||
24
data.py
24
data.py
@@ -1,24 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
def load_data(start: str, end: str, timeframe_minutes: int, csv_path: Path) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
df_1min = pd.read_csv(csv_path)
|
||||
df_1min["Timestamp"] = pd.to_datetime(df_1min["Timestamp"], unit="s", utc=True)
|
||||
df_1min = df_1min[(df_1min["Timestamp"] >= pd.Timestamp(start, tz="UTC")) &
|
||||
(df_1min["Timestamp"] <= pd.Timestamp(end, tz="UTC"))] \
|
||||
.sort_values("Timestamp").reset_index(drop=True)
|
||||
|
||||
if timeframe_minutes != 1:
|
||||
g = df_1min.set_index("Timestamp").resample(f"{timeframe_minutes}min")
|
||||
df = pd.DataFrame({
|
||||
"Open": g["Open"].first(),
|
||||
"High": g["High"].max(),
|
||||
"Low": g["Low"].min(),
|
||||
"Close": g["Close"].last(),
|
||||
"Volume": g["Volume"].sum(),
|
||||
}).dropna().reset_index()
|
||||
else:
|
||||
df = df_1min.copy()
|
||||
|
||||
return df_1min, df
|
||||
352
engine/backtester.py
Normal file
352
engine/backtester.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Core backtesting engine for running strategy simulations.
|
||||
|
||||
Supports multiple market types with realistic trading conditions.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketType, get_market_config
|
||||
from engine.optimizer import WalkForwardOptimizer
|
||||
from engine.portfolio import run_long_only_portfolio, run_long_short_portfolio
|
||||
from engine.risk import (
|
||||
LiquidationEvent,
|
||||
calculate_funding,
|
||||
calculate_liquidation_adjustment,
|
||||
inject_liquidation_exits,
|
||||
)
|
||||
from strategies.base import BaseStrategy
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BacktestResult:
|
||||
"""
|
||||
Container for backtest results with market-specific metrics.
|
||||
|
||||
Attributes:
|
||||
portfolio: VectorBT Portfolio object
|
||||
market_type: Market type used for the backtest
|
||||
leverage: Effective leverage used
|
||||
total_funding_paid: Total funding fees paid (perpetuals only)
|
||||
liquidation_count: Number of positions that were liquidated
|
||||
liquidation_events: Detailed list of liquidation events
|
||||
total_liquidation_loss: Total margin lost from liquidations
|
||||
adjusted_return: Return adjusted for liquidation losses (percentage)
|
||||
"""
|
||||
portfolio: vbt.Portfolio
|
||||
market_type: MarketType
|
||||
leverage: int
|
||||
total_funding_paid: float = 0.0
|
||||
liquidation_count: int = 0
|
||||
liquidation_events: list[LiquidationEvent] | None = None
|
||||
total_liquidation_loss: float = 0.0
|
||||
adjusted_return: float | None = None
|
||||
|
||||
|
||||
class Backtester:
|
||||
"""
|
||||
Backtester supporting multiple market types with realistic simulation.
|
||||
|
||||
Features:
|
||||
- Spot and Perpetual market support
|
||||
- Long and short position handling
|
||||
- Leverage simulation
|
||||
- Funding rate calculation (perpetuals)
|
||||
- Liquidation warnings
|
||||
"""
|
||||
|
||||
def __init__(self, data_manager: DataManager):
|
||||
self.dm = data_manager
|
||||
|
||||
def run_strategy(
|
||||
self,
|
||||
strategy: BaseStrategy,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str = '1m',
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
init_cash: float = 10000,
|
||||
fees: float | None = None,
|
||||
slippage: float = 0.001,
|
||||
sl_stop: float | None = None,
|
||||
tp_stop: float | None = None,
|
||||
sl_trail: bool = False,
|
||||
leverage: int | None = None,
|
||||
**strategy_params
|
||||
) -> BacktestResult:
|
||||
"""
|
||||
Run a backtest with market-type-aware simulation.
|
||||
|
||||
Args:
|
||||
strategy: Strategy instance to backtest
|
||||
exchange_id: Exchange identifier (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Data timeframe (e.g., '1m', '1h', '1d')
|
||||
start_date: Start date filter (YYYY-MM-DD)
|
||||
end_date: End date filter (YYYY-MM-DD)
|
||||
init_cash: Initial capital (margin for leveraged)
|
||||
fees: Transaction fee override (uses market default if None)
|
||||
slippage: Slippage percentage
|
||||
sl_stop: Stop loss percentage
|
||||
tp_stop: Take profit percentage
|
||||
sl_trail: Enable trailing stop loss
|
||||
leverage: Leverage override (uses strategy default if None)
|
||||
**strategy_params: Additional strategy parameters
|
||||
|
||||
Returns:
|
||||
BacktestResult with portfolio and market-specific metrics
|
||||
"""
|
||||
# Get market configuration from strategy
|
||||
market_type = strategy.default_market_type
|
||||
market_config = get_market_config(market_type)
|
||||
|
||||
# Resolve leverage and fees
|
||||
effective_leverage = self._resolve_leverage(leverage, strategy, market_type)
|
||||
effective_fees = fees if fees is not None else market_config.taker_fee
|
||||
|
||||
# Load and filter data
|
||||
df = self._load_data(
|
||||
exchange_id, symbol, timeframe, market_type, start_date, end_date
|
||||
)
|
||||
|
||||
close_price = df['close']
|
||||
high_price = df['high']
|
||||
low_price = df['low']
|
||||
open_price = df['open']
|
||||
volume = df['volume']
|
||||
|
||||
# Run strategy logic
|
||||
signals = strategy.run(
|
||||
close_price,
|
||||
high=high_price,
|
||||
low=low_price,
|
||||
open=open_price,
|
||||
volume=volume,
|
||||
**strategy_params
|
||||
)
|
||||
|
||||
# Normalize signals to 4-tuple format
|
||||
signals = self._normalize_signals(signals, close_price, market_config)
|
||||
long_entries, long_exits, short_entries, short_exits = signals
|
||||
|
||||
# Process liquidations - inject forced exits at liquidation points
|
||||
liquidation_events: list[LiquidationEvent] = []
|
||||
if effective_leverage > 1:
|
||||
long_exits, short_exits, liquidation_events = inject_liquidation_exits(
|
||||
close_price, high_price, low_price,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
effective_leverage,
|
||||
market_config.maintenance_margin_rate
|
||||
)
|
||||
|
||||
# Calculate perpetual-specific metrics (after liquidation processing)
|
||||
total_funding = 0.0
|
||||
if market_type == MarketType.PERPETUAL:
|
||||
total_funding = calculate_funding(
|
||||
close_price,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
market_config,
|
||||
effective_leverage
|
||||
)
|
||||
|
||||
# Run portfolio simulation with liquidation-aware exits
|
||||
portfolio = self._run_portfolio(
|
||||
close_price, market_config,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
init_cash, effective_fees, slippage, timeframe,
|
||||
sl_stop, tp_stop, sl_trail, effective_leverage
|
||||
)
|
||||
|
||||
# Calculate adjusted returns accounting for liquidation losses
|
||||
total_liq_loss, liq_adjustment = calculate_liquidation_adjustment(
|
||||
liquidation_events, init_cash, effective_leverage
|
||||
)
|
||||
|
||||
raw_return = portfolio.total_return().mean() * 100
|
||||
adjusted_return = raw_return - liq_adjustment
|
||||
|
||||
if liquidation_events:
|
||||
logger.info(
|
||||
"Liquidation impact: %d events, $%.2f margin lost, %.2f%% adjustment",
|
||||
len(liquidation_events), total_liq_loss, liq_adjustment
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Backtest completed: %s market, %dx leverage, fees=%.4f%%",
|
||||
market_type.value, effective_leverage, effective_fees * 100
|
||||
)
|
||||
|
||||
return BacktestResult(
|
||||
portfolio=portfolio,
|
||||
market_type=market_type,
|
||||
leverage=effective_leverage,
|
||||
total_funding_paid=total_funding,
|
||||
liquidation_count=len(liquidation_events),
|
||||
liquidation_events=liquidation_events,
|
||||
total_liquidation_loss=total_liq_loss,
|
||||
adjusted_return=adjusted_return
|
||||
)
|
||||
|
||||
def _resolve_leverage(
|
||||
self,
|
||||
leverage: int | None,
|
||||
strategy: BaseStrategy,
|
||||
market_type: MarketType
|
||||
) -> int:
|
||||
"""Resolve effective leverage from CLI, strategy default, or market type."""
|
||||
effective = leverage or strategy.default_leverage
|
||||
if market_type == MarketType.SPOT:
|
||||
return 1 # Spot cannot have leverage
|
||||
return effective
|
||||
|
||||
def _load_data(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
market_type: MarketType,
|
||||
start_date: str | None,
|
||||
end_date: str | None
|
||||
) -> pd.DataFrame:
|
||||
"""Load and filter OHLCV data."""
|
||||
try:
|
||||
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
|
||||
except FileNotFoundError:
|
||||
logger.warning("Data not found locally. Attempting download...")
|
||||
df = self.dm.download_data(
|
||||
exchange_id, symbol, timeframe,
|
||||
start_date, end_date, market_type
|
||||
)
|
||||
|
||||
if start_date:
|
||||
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
|
||||
if end_date:
|
||||
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
|
||||
|
||||
return df
|
||||
|
||||
def _normalize_signals(
|
||||
self,
|
||||
signals: tuple,
|
||||
close: pd.Series,
|
||||
market_config
|
||||
) -> tuple:
|
||||
"""
|
||||
Normalize strategy signals to 4-tuple format.
|
||||
|
||||
Handles backward compatibility with 2-tuple (long-only) returns.
|
||||
"""
|
||||
if len(signals) == 2:
|
||||
long_entries, long_exits = signals
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
|
||||
if len(signals) == 4:
|
||||
long_entries, long_exits, short_entries, short_exits = signals
|
||||
|
||||
# Warn and clear short signals on spot markets
|
||||
if not market_config.supports_short:
|
||||
has_shorts = (
|
||||
short_entries.any().any()
|
||||
if hasattr(short_entries, 'any')
|
||||
else short_entries.any()
|
||||
)
|
||||
if has_shorts:
|
||||
logger.warning(
|
||||
"Short signals detected but market type is SPOT. "
|
||||
"Short signals will be ignored."
|
||||
)
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
|
||||
raise ValueError(
|
||||
f"Strategy must return 2 or 4 signal arrays, got {len(signals)}"
|
||||
)
|
||||
|
||||
def _run_portfolio(
|
||||
self,
|
||||
close: pd.Series,
|
||||
market_config,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
init_cash: float,
|
||||
fees: float,
|
||||
slippage: float,
|
||||
freq: str,
|
||||
sl_stop: float | None,
|
||||
tp_stop: float | None,
|
||||
sl_trail: bool,
|
||||
leverage: int
|
||||
) -> vbt.Portfolio:
|
||||
"""Select and run appropriate portfolio simulation."""
|
||||
has_shorts = (
|
||||
short_entries.any().any()
|
||||
if hasattr(short_entries, 'any')
|
||||
else short_entries.any()
|
||||
)
|
||||
|
||||
if market_config.supports_short and has_shorts:
|
||||
return run_long_short_portfolio(
|
||||
close,
|
||||
long_entries, long_exits,
|
||||
short_entries, short_exits,
|
||||
init_cash, fees, slippage, freq,
|
||||
sl_stop, tp_stop, sl_trail, leverage
|
||||
)
|
||||
|
||||
return run_long_only_portfolio(
|
||||
close,
|
||||
long_entries, long_exits,
|
||||
init_cash, fees, slippage, freq,
|
||||
sl_stop, tp_stop, sl_trail, leverage
|
||||
)
|
||||
|
||||
def run_wfa(
|
||||
self,
|
||||
strategy: BaseStrategy,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
param_grid: dict,
|
||||
n_windows: int = 10,
|
||||
timeframe: str = '1m'
|
||||
):
|
||||
"""
|
||||
Execute Walk-Forward Analysis.
|
||||
|
||||
Args:
|
||||
strategy: Strategy instance to optimize
|
||||
exchange_id: Exchange identifier
|
||||
symbol: Trading pair symbol
|
||||
param_grid: Parameter grid for optimization
|
||||
n_windows: Number of walk-forward windows
|
||||
timeframe: Data timeframe to load
|
||||
|
||||
Returns:
|
||||
Tuple of (results DataFrame, stitched equity curve)
|
||||
"""
|
||||
market_type = strategy.default_market_type
|
||||
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
|
||||
|
||||
wfa = WalkForwardOptimizer(self, strategy, param_grid)
|
||||
|
||||
results, stitched_curve = wfa.run(
|
||||
df['close'],
|
||||
high=df['high'],
|
||||
low=df['low'],
|
||||
n_windows=n_windows
|
||||
)
|
||||
|
||||
return results, stitched_curve
|
||||
243
engine/cli.py
Normal file
243
engine/cli.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
CLI handler for Lowkey Backtest.
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from engine.backtester import Backtester
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import get_logger, setup_logging
|
||||
from engine.market import MarketType
|
||||
from engine.reporting import Reporter
|
||||
from strategies.factory import get_strategy, get_strategy_names
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_parser() -> argparse.ArgumentParser:
|
||||
"""Create and configure the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Lowkey Backtest CLI (VectorBT Edition)"
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
_add_download_parser(subparsers)
|
||||
_add_backtest_parser(subparsers)
|
||||
_add_wfa_parser(subparsers)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _add_download_parser(subparsers) -> None:
|
||||
"""Add download command parser."""
|
||||
dl_parser = subparsers.add_parser("download", help="Download historical data")
|
||||
dl_parser.add_argument("--exchange", "-e", type=str, default="okx")
|
||||
dl_parser.add_argument("--pair", "-p", type=str, required=True)
|
||||
dl_parser.add_argument("--timeframe", "-t", type=str, default="1m")
|
||||
dl_parser.add_argument("--start", type=str, help="Start Date (YYYY-MM-DD)")
|
||||
dl_parser.add_argument(
|
||||
"--market", "-m",
|
||||
type=str,
|
||||
choices=["spot", "perpetual"],
|
||||
default="spot"
|
||||
)
|
||||
|
||||
|
||||
def _add_backtest_parser(subparsers) -> None:
|
||||
"""Add backtest command parser."""
|
||||
strategy_choices = get_strategy_names()
|
||||
|
||||
bt_parser = subparsers.add_parser("backtest", help="Run a backtest")
|
||||
bt_parser.add_argument(
|
||||
"--strategy", "-s",
|
||||
type=str,
|
||||
choices=strategy_choices,
|
||||
required=True
|
||||
)
|
||||
bt_parser.add_argument("--exchange", "-e", type=str, default="okx")
|
||||
bt_parser.add_argument("--pair", "-p", type=str, required=True)
|
||||
bt_parser.add_argument("--timeframe", "-t", type=str, default="1m")
|
||||
bt_parser.add_argument("--start", type=str)
|
||||
bt_parser.add_argument("--end", type=str)
|
||||
bt_parser.add_argument("--grid", "-g", action="store_true")
|
||||
bt_parser.add_argument("--plot", action="store_true")
|
||||
|
||||
# Risk parameters
|
||||
bt_parser.add_argument("--sl", type=float, help="Stop Loss %%")
|
||||
bt_parser.add_argument("--tp", type=float, help="Take Profit %%")
|
||||
bt_parser.add_argument("--trail", action="store_true")
|
||||
bt_parser.add_argument("--no-bear-exit", action="store_true")
|
||||
|
||||
# Cost parameters
|
||||
bt_parser.add_argument("--fees", type=float, default=None)
|
||||
bt_parser.add_argument("--slippage", type=float, default=0.001)
|
||||
bt_parser.add_argument("--leverage", "-l", type=int, default=None)
|
||||
|
||||
|
||||
def _add_wfa_parser(subparsers) -> None:
|
||||
"""Add walk-forward analysis command parser."""
|
||||
strategy_choices = get_strategy_names()
|
||||
|
||||
wfa_parser = subparsers.add_parser("wfa", help="Run Walk-Forward Analysis")
|
||||
wfa_parser.add_argument(
|
||||
"--strategy", "-s",
|
||||
type=str,
|
||||
choices=strategy_choices,
|
||||
required=True
|
||||
)
|
||||
wfa_parser.add_argument("--pair", "-p", type=str, required=True)
|
||||
wfa_parser.add_argument("--timeframe", "-t", type=str, default="1d")
|
||||
wfa_parser.add_argument("--windows", "-w", type=int, default=10)
|
||||
wfa_parser.add_argument("--plot", action="store_true")
|
||||
|
||||
|
||||
def run_download(args) -> None:
|
||||
"""Execute download command."""
|
||||
dm = DataManager()
|
||||
market_type = MarketType(args.market)
|
||||
dm.download_data(
|
||||
args.exchange,
|
||||
args.pair,
|
||||
args.timeframe,
|
||||
start_date=args.start,
|
||||
market_type=market_type
|
||||
)
|
||||
|
||||
|
||||
def run_backtest(args) -> None:
|
||||
"""Execute backtest command."""
|
||||
dm = DataManager()
|
||||
bt = Backtester(dm)
|
||||
reporter = Reporter()
|
||||
|
||||
strategy, params = get_strategy(args.strategy, args.grid)
|
||||
|
||||
# Apply CLI overrides for meta_st strategy
|
||||
params = _apply_strategy_overrides(args, strategy, params)
|
||||
|
||||
if args.grid and args.strategy == "meta_st":
|
||||
logger.info("Running Grid Search for Meta Supertrend...")
|
||||
|
||||
try:
|
||||
result = bt.run_strategy(
|
||||
strategy,
|
||||
args.exchange,
|
||||
args.pair,
|
||||
timeframe=args.timeframe,
|
||||
start_date=args.start,
|
||||
end_date=args.end,
|
||||
fees=args.fees,
|
||||
slippage=args.slippage,
|
||||
sl_stop=args.sl,
|
||||
tp_stop=args.tp,
|
||||
sl_trail=args.trail,
|
||||
leverage=args.leverage,
|
||||
**params
|
||||
)
|
||||
|
||||
reporter.print_summary(result)
|
||||
reporter.save_reports(result, f"{args.strategy}_{args.pair.replace('/','-')}")
|
||||
|
||||
if args.plot and not args.grid:
|
||||
reporter.plot(result.portfolio)
|
||||
elif args.plot and args.grid:
|
||||
logger.info("Plotting skipped for Grid Search. Check CSV results.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Backtest failed: %s", e, exc_info=True)
|
||||
|
||||
|
||||
def run_wfa(args) -> None:
|
||||
"""Execute walk-forward analysis command."""
|
||||
dm = DataManager()
|
||||
bt = Backtester(dm)
|
||||
reporter = Reporter()
|
||||
|
||||
strategy, params = get_strategy(args.strategy, is_grid=True)
|
||||
|
||||
logger.info(
|
||||
"Running WFA on %s for %s (%s) with %d windows...",
|
||||
args.strategy, args.pair, args.timeframe, args.windows
|
||||
)
|
||||
|
||||
try:
|
||||
results, stitched_curve = bt.run_wfa(
|
||||
strategy,
|
||||
"okx",
|
||||
args.pair,
|
||||
params,
|
||||
n_windows=args.windows,
|
||||
timeframe=args.timeframe
|
||||
)
|
||||
|
||||
_log_wfa_results(results)
|
||||
_save_wfa_results(args, results, stitched_curve, reporter)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("WFA failed: %s", e, exc_info=True)
|
||||
|
||||
|
||||
def _apply_strategy_overrides(args, strategy, params: dict) -> dict:
|
||||
"""Apply CLI argument overrides to strategy parameters."""
|
||||
if args.strategy != "meta_st":
|
||||
return params
|
||||
|
||||
if args.no_bear_exit:
|
||||
params['exit_on_bearish_flip'] = False
|
||||
|
||||
if args.sl is None:
|
||||
args.sl = strategy.default_sl_stop
|
||||
|
||||
if not args.trail:
|
||||
args.trail = strategy.default_sl_trail
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _log_wfa_results(results) -> None:
|
||||
"""Log WFA results summary."""
|
||||
logger.info("Walk-Forward Analysis Results:")
|
||||
|
||||
if results.empty or 'window' not in results.columns:
|
||||
logger.warning("No valid WFA results. All windows may have failed.")
|
||||
return
|
||||
|
||||
columns = ['window', 'train_score', 'test_score', 'test_return']
|
||||
logger.info("\n%s", results[columns].to_string(index=False))
|
||||
|
||||
avg_test_sharpe = results['test_score'].mean()
|
||||
avg_test_return = results['test_return'].mean()
|
||||
logger.info("Average Test Sharpe: %.2f", avg_test_sharpe)
|
||||
logger.info("Average Test Return: %.2f%%", avg_test_return * 100)
|
||||
|
||||
|
||||
def _save_wfa_results(args, results, stitched_curve, reporter) -> None:
|
||||
"""Save WFA results to file and optionally plot."""
|
||||
if results.empty:
|
||||
return
|
||||
|
||||
output_path = f"backtest_logs/wfa_{args.strategy}_{args.pair.replace('/','-')}.csv"
|
||||
results.to_csv(output_path)
|
||||
logger.info("Saved full results to %s", output_path)
|
||||
|
||||
if args.plot:
|
||||
reporter.plot_wfa(results, stitched_curve)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
setup_logging()
|
||||
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
commands = {
|
||||
"download": run_download,
|
||||
"backtest": run_backtest,
|
||||
"wfa": run_wfa,
|
||||
}
|
||||
|
||||
if args.command in commands:
|
||||
commands[args.command](args)
|
||||
else:
|
||||
parser.print_help()
|
||||
156
engine/cryptoquant.py
Normal file
156
engine/cryptoquant.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import requests
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load env vars from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Fix path for direct execution
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class CryptoQuantClient:
|
||||
"""
|
||||
Client for fetching data from CryptoQuant API.
|
||||
"""
|
||||
BASE_URL = "https://api.cryptoquant.com/v1"
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.getenv("CRYPTOQUANT_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("CryptoQuant API Key not found. Set CRYPTOQUANT_API_KEY env var.")
|
||||
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
|
||||
def fetch_metric(
|
||||
self,
|
||||
metric_path: str,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
exchange: str | None = "all_exchange",
|
||||
window: str = "day"
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch a specific metric from CryptoQuant.
|
||||
"""
|
||||
url = f"{self.BASE_URL}/{metric_path}"
|
||||
|
||||
params = {
|
||||
"window": window,
|
||||
"from": start_date,
|
||||
"to": end_date,
|
||||
"limit": 100000
|
||||
}
|
||||
|
||||
if exchange:
|
||||
params["exchange"] = exchange
|
||||
|
||||
logger.info(f"Fetching {metric_path} for {symbol} ({start_date}-{end_date})...")
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=self.headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if 'result' in data and 'data' in data['result']:
|
||||
df = pd.DataFrame(data['result']['data'])
|
||||
if not df.empty:
|
||||
if 'date' in df.columns:
|
||||
df['timestamp'] = pd.to_datetime(df['date'])
|
||||
df.set_index('timestamp', inplace=True)
|
||||
df.sort_index(inplace=True)
|
||||
return df
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CQ data {metric_path}: {e}")
|
||||
if 'response' in locals() and hasattr(response, 'text'):
|
||||
logger.error(f"Response: {response.text}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def fetch_multi_metrics(self, symbols: list[str], metrics: dict, start_date: str, end_date: str):
|
||||
"""
|
||||
Fetch multiple metrics for multiple symbols and combine them.
|
||||
"""
|
||||
combined_df = pd.DataFrame()
|
||||
|
||||
for symbol in symbols:
|
||||
asset = symbol.lower()
|
||||
|
||||
for metric_name, api_path in metrics.items():
|
||||
full_path = f"{asset}/{api_path}"
|
||||
|
||||
# Some metrics (like funding rates) might need specific exchange vs all_exchange
|
||||
# Defaulting to all_exchange is usually safe for flows, but check specific logic if needed
|
||||
exchange_param = "all_exchange"
|
||||
if "funding-rates" in api_path:
|
||||
# For funding rates, 'all_exchange' might not be valid or might be aggregated
|
||||
# Let's try 'binance' as a proxy for market sentiment if all fails,
|
||||
# or keep 'all_exchange' if supported.
|
||||
# Based on testing, 'all_exchange' is standard for flows.
|
||||
pass
|
||||
|
||||
df = self.fetch_metric(full_path, asset, start_date, end_date, exchange=exchange_param)
|
||||
|
||||
if not df.empty:
|
||||
target_col = None
|
||||
# Heuristic to find the value column
|
||||
candidates = ['funding_rate', 'reserve', 'inflow_total', 'outflow_total', 'open_interest', 'ratio', 'value']
|
||||
|
||||
for col in df.columns:
|
||||
if col in candidates:
|
||||
target_col = col
|
||||
break
|
||||
|
||||
if not target_col:
|
||||
# Fallback: take first numeric col that isn't date
|
||||
for col in df.columns:
|
||||
if col not in ['date', 'datetime', 'timestamp_str', 'block_height']:
|
||||
target_col = col
|
||||
break
|
||||
|
||||
if target_col:
|
||||
col_name = f"{asset}_{metric_name}"
|
||||
subset = df[[target_col]].rename(columns={target_col: col_name})
|
||||
|
||||
if combined_df.empty:
|
||||
combined_df = subset
|
||||
else:
|
||||
combined_df = combined_df.join(subset, how='outer')
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
return combined_df
|
||||
|
||||
if __name__ == "__main__":
|
||||
cq = CryptoQuantClient()
|
||||
|
||||
# 3 Months Data (Oct 1 2025 - Dec 31 2025)
|
||||
start = "20251001"
|
||||
end = "20251231"
|
||||
|
||||
metrics = {
|
||||
"reserves": "exchange-flows/exchange-reserve",
|
||||
"inflow": "exchange-flows/inflow",
|
||||
"funding": "market-data/funding-rates"
|
||||
}
|
||||
|
||||
print(f"Fetching training data from {start} to {end}...")
|
||||
df = cq.fetch_multi_metrics(["btc", "eth"], metrics, start, end)
|
||||
|
||||
output_file = "data/cq_training_data.csv"
|
||||
os.makedirs("data", exist_ok=True)
|
||||
df.to_csv(output_file)
|
||||
print(f"\nSaved {len(df)} rows to {output_file}")
|
||||
print(df.head())
|
||||
209
engine/data_manager.py
Normal file
209
engine/data_manager.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Data management for OHLCV data download and storage.
|
||||
|
||||
Handles data retrieval from exchanges and local file management.
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import ccxt
|
||||
import pandas as pd
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketType, get_ccxt_symbol
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataManager:
|
||||
"""
|
||||
Manages OHLCV data download and storage for different market types.
|
||||
|
||||
Data is stored in: data/ccxt/{exchange}/{market_type}/{symbol}/{timeframe}.csv
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str = "data/ccxt"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.exchanges: dict[str, ccxt.Exchange] = {}
|
||||
|
||||
def get_exchange(self, exchange_id: str) -> ccxt.Exchange:
|
||||
"""Get or create a CCXT exchange instance."""
|
||||
if exchange_id not in self.exchanges:
|
||||
exchange_class = getattr(ccxt, exchange_id)
|
||||
self.exchanges[exchange_id] = exchange_class({
|
||||
'enableRateLimit': True,
|
||||
})
|
||||
return self.exchanges[exchange_id]
|
||||
|
||||
def _get_data_path(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
market_type: MarketType
|
||||
) -> Path:
|
||||
"""
|
||||
Get the file path for storing/loading data.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange name (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Candle timeframe (e.g., '1m')
|
||||
market_type: Market type (spot or perpetual)
|
||||
|
||||
Returns:
|
||||
Path to the CSV file
|
||||
"""
|
||||
safe_symbol = symbol.replace('/', '-')
|
||||
return (
|
||||
self.data_dir
|
||||
/ exchange_id
|
||||
/ market_type.value
|
||||
/ safe_symbol
|
||||
/ f"{timeframe}.csv"
|
||||
)
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str = '1m',
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
market_type: MarketType = MarketType.SPOT
|
||||
) -> pd.DataFrame | None:
|
||||
"""
|
||||
Download OHLCV data from exchange and save to CSV.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange name (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Candle timeframe (e.g., '1m')
|
||||
start_date: Start date string (YYYY-MM-DD)
|
||||
end_date: End date string (YYYY-MM-DD)
|
||||
market_type: Market type (spot or perpetual)
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data, or None if download failed
|
||||
"""
|
||||
exchange = self.get_exchange(exchange_id)
|
||||
|
||||
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ccxt_symbol = get_ccxt_symbol(symbol, market_type)
|
||||
|
||||
since, until = self._parse_date_range(exchange, start_date, end_date)
|
||||
|
||||
logger.info(
|
||||
"Downloading %s (%s) from %s...",
|
||||
symbol, market_type.value, exchange_id
|
||||
)
|
||||
|
||||
all_ohlcv = self._fetch_all_candles(exchange, ccxt_symbol, timeframe, since, until)
|
||||
|
||||
if not all_ohlcv:
|
||||
logger.warning("No data downloaded.")
|
||||
return None
|
||||
|
||||
df = self._convert_to_dataframe(all_ohlcv)
|
||||
df.to_csv(file_path)
|
||||
logger.info("Saved %d candles to %s", len(df), file_path)
|
||||
return df
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
exchange_id: str,
|
||||
symbol: str,
|
||||
timeframe: str = '1m',
|
||||
market_type: MarketType = MarketType.SPOT
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Load saved OHLCV data for vectorbt.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange name (e.g., 'okx')
|
||||
symbol: Trading pair (e.g., 'BTC/USDT')
|
||||
timeframe: Candle timeframe (e.g., '1m')
|
||||
market_type: Market type (spot or perpetual)
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data indexed by timestamp
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If data file does not exist
|
||||
"""
|
||||
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Data not found at {file_path}. "
|
||||
f"Run: uv run python main.py download --pair {symbol} "
|
||||
f"--market {market_type.value}"
|
||||
)
|
||||
|
||||
return pd.read_csv(file_path, index_col='timestamp', parse_dates=True)
|
||||
|
||||
def _parse_date_range(
|
||||
self,
|
||||
exchange: ccxt.Exchange,
|
||||
start_date: str | None,
|
||||
end_date: str | None
|
||||
) -> tuple[int, int]:
|
||||
"""Parse date strings into millisecond timestamps."""
|
||||
if start_date:
|
||||
since = exchange.parse8601(f"{start_date}T00:00:00Z")
|
||||
else:
|
||||
since = exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
|
||||
|
||||
if end_date:
|
||||
until = exchange.parse8601(f"{end_date}T23:59:59Z")
|
||||
else:
|
||||
until = exchange.milliseconds()
|
||||
|
||||
return since, until
|
||||
|
||||
def _fetch_all_candles(
|
||||
self,
|
||||
exchange: ccxt.Exchange,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
since: int,
|
||||
until: int
|
||||
) -> list:
|
||||
"""Fetch all candles in the date range."""
|
||||
all_ohlcv = []
|
||||
|
||||
while since < until:
|
||||
try:
|
||||
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since, limit=100)
|
||||
if not ohlcv:
|
||||
break
|
||||
|
||||
all_ohlcv.extend(ohlcv)
|
||||
since = ohlcv[-1][0] + 1
|
||||
|
||||
current_date = datetime.fromtimestamp(
|
||||
since/1000, tz=timezone.utc
|
||||
).strftime('%Y-%m-%d')
|
||||
logger.debug("Fetched up to %s", current_date)
|
||||
|
||||
time.sleep(exchange.rateLimit / 1000)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error fetching data: %s", e)
|
||||
break
|
||||
|
||||
return all_ohlcv
|
||||
|
||||
def _convert_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
|
||||
"""Convert OHLCV list to DataFrame."""
|
||||
df = pd.DataFrame(
|
||||
ohlcv,
|
||||
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
return df
|
||||
124
engine/logging_config.py
Normal file
124
engine/logging_config.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Centralized logging configuration for the backtest engine.
|
||||
|
||||
Provides colored console output and rotating file logs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
class Colors:
|
||||
"""ANSI escape codes for colored terminal output."""
|
||||
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
|
||||
# Log level colors
|
||||
DEBUG = "\033[36m" # Cyan
|
||||
INFO = "\033[32m" # Green
|
||||
WARNING = "\033[33m" # Yellow
|
||||
ERROR = "\033[31m" # Red
|
||||
CRITICAL = "\033[35m" # Magenta
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""
|
||||
Custom formatter that adds colors to log level names in terminal output.
|
||||
"""
|
||||
|
||||
LEVEL_COLORS = {
|
||||
logging.DEBUG: Colors.DEBUG,
|
||||
logging.INFO: Colors.INFO,
|
||||
logging.WARNING: Colors.WARNING,
|
||||
logging.ERROR: Colors.ERROR,
|
||||
logging.CRITICAL: Colors.CRITICAL,
|
||||
}
|
||||
|
||||
def __init__(self, fmt: str = None, datefmt: str = None):
|
||||
super().__init__(fmt, datefmt)
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Save original levelname
|
||||
original_levelname = record.levelname
|
||||
|
||||
# Add color to levelname
|
||||
color = self.LEVEL_COLORS.get(record.levelno, Colors.RESET)
|
||||
record.levelname = f"{color}{record.levelname}{Colors.RESET}"
|
||||
|
||||
# Format the message
|
||||
result = super().format(record)
|
||||
|
||||
# Restore original levelname
|
||||
record.levelname = original_levelname
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def setup_logging(
|
||||
log_dir: str = "logs",
|
||||
log_level: int = logging.INFO,
|
||||
console_level: int = logging.INFO,
|
||||
max_bytes: int = 5 * 1024 * 1024, # 5MB
|
||||
backup_count: int = 3
|
||||
) -> None:
|
||||
"""
|
||||
Configure logging for the application.
|
||||
|
||||
Args:
|
||||
log_dir: Directory for log files
|
||||
log_level: File logging level
|
||||
console_level: Console logging level
|
||||
max_bytes: Max size per log file before rotation
|
||||
backup_count: Number of backup files to keep
|
||||
"""
|
||||
log_path = Path(log_dir)
|
||||
log_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.DEBUG) # Capture all, handlers filter
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Console handler with colors
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(console_level)
|
||||
console_fmt = ColoredFormatter(
|
||||
fmt="[%(asctime)s] %(levelname)s - %(message)s",
|
||||
datefmt="%H:%M:%S"
|
||||
)
|
||||
console_handler.setFormatter(console_fmt)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler with rotation
|
||||
file_handler = RotatingFileHandler(
|
||||
log_path / "backtest.log",
|
||||
maxBytes=max_bytes,
|
||||
backupCount=backup_count,
|
||||
encoding="utf-8"
|
||||
)
|
||||
file_handler.setLevel(log_level)
|
||||
file_fmt = logging.Formatter(
|
||||
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
file_handler.setFormatter(file_fmt)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger instance for the given module name.
|
||||
|
||||
Args:
|
||||
name: Module name (typically __name__)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
156
engine/market.py
Normal file
156
engine/market.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Market type definitions and configuration for backtesting.
|
||||
|
||||
Supports different market types with their specific trading conditions:
|
||||
- SPOT: No leverage, no funding, long-only
|
||||
- PERPETUAL: Leverage, funding rates, long/short
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MarketType(Enum):
|
||||
"""Supported market types for backtesting."""
|
||||
SPOT = "spot"
|
||||
PERPETUAL = "perpetual"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MarketConfig:
|
||||
"""
|
||||
Configuration for a specific market type.
|
||||
|
||||
Attributes:
|
||||
market_type: The market type enum value
|
||||
maker_fee: Maker fee as decimal (e.g., 0.0008 = 0.08%)
|
||||
taker_fee: Taker fee as decimal (e.g., 0.001 = 0.1%)
|
||||
max_leverage: Maximum allowed leverage
|
||||
funding_rate: Funding rate per 8 hours as decimal (perpetuals only)
|
||||
funding_interval_hours: Hours between funding payments
|
||||
maintenance_margin_rate: Rate for liquidation calculation
|
||||
supports_short: Whether short-selling is supported
|
||||
"""
|
||||
market_type: MarketType
|
||||
maker_fee: float
|
||||
taker_fee: float
|
||||
max_leverage: int
|
||||
funding_rate: float
|
||||
funding_interval_hours: int
|
||||
maintenance_margin_rate: float
|
||||
supports_short: bool
|
||||
|
||||
|
||||
# OKX-based default configurations
|
||||
SPOT_CONFIG = MarketConfig(
|
||||
market_type=MarketType.SPOT,
|
||||
maker_fee=0.0008, # 0.08%
|
||||
taker_fee=0.0010, # 0.10%
|
||||
max_leverage=1,
|
||||
funding_rate=0.0,
|
||||
funding_interval_hours=0,
|
||||
maintenance_margin_rate=0.0,
|
||||
supports_short=False,
|
||||
)
|
||||
|
||||
PERPETUAL_CONFIG = MarketConfig(
|
||||
market_type=MarketType.PERPETUAL,
|
||||
maker_fee=0.0002, # 0.02%
|
||||
taker_fee=0.0005, # 0.05%
|
||||
max_leverage=125,
|
||||
funding_rate=0.0001, # 0.01% per 8 hours (simplified average)
|
||||
funding_interval_hours=8,
|
||||
maintenance_margin_rate=0.004, # 0.4% for BTC on OKX
|
||||
supports_short=True,
|
||||
)
|
||||
|
||||
|
||||
def get_market_config(market_type: MarketType) -> MarketConfig:
|
||||
"""
|
||||
Get the configuration for a specific market type.
|
||||
|
||||
Args:
|
||||
market_type: The market type to get configuration for
|
||||
|
||||
Returns:
|
||||
MarketConfig with default values for that market type
|
||||
"""
|
||||
configs = {
|
||||
MarketType.SPOT: SPOT_CONFIG,
|
||||
MarketType.PERPETUAL: PERPETUAL_CONFIG,
|
||||
}
|
||||
return configs[market_type]
|
||||
|
||||
|
||||
def get_ccxt_symbol(symbol: str, market_type: MarketType) -> str:
|
||||
"""
|
||||
Convert a standard symbol to CCXT format for the given market type.
|
||||
|
||||
Args:
|
||||
symbol: Standard symbol (e.g., 'BTC/USDT')
|
||||
market_type: The market type
|
||||
|
||||
Returns:
|
||||
CCXT-formatted symbol (e.g., 'BTC/USDT:USDT' for perpetuals)
|
||||
"""
|
||||
if market_type == MarketType.PERPETUAL:
|
||||
# OKX perpetual format: BTC/USDT:USDT
|
||||
quote = symbol.split('/')[1] if '/' in symbol else 'USDT'
|
||||
return f"{symbol}:{quote}"
|
||||
return symbol
|
||||
|
||||
|
||||
def calculate_leverage_stop_loss(
|
||||
leverage: int,
|
||||
maintenance_margin_rate: float = 0.004
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the implicit stop-loss percentage from leverage.
|
||||
|
||||
At a given leverage, liquidation occurs when the position loses
|
||||
approximately (1/leverage - maintenance_margin_rate) of its value.
|
||||
|
||||
Args:
|
||||
leverage: Position leverage multiplier
|
||||
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
|
||||
|
||||
Returns:
|
||||
Stop-loss percentage as decimal (e.g., 0.196 for 19.6%)
|
||||
"""
|
||||
if leverage <= 1:
|
||||
return 1.0 # No forced stop for spot
|
||||
|
||||
return (1 / leverage) - maintenance_margin_rate
|
||||
|
||||
|
||||
def calculate_liquidation_price(
|
||||
entry_price: float,
|
||||
leverage: float,
|
||||
is_long: bool,
|
||||
maintenance_margin_rate: float = 0.004
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the liquidation price for a leveraged position.
|
||||
|
||||
Args:
|
||||
entry_price: Position entry price
|
||||
leverage: Position leverage
|
||||
is_long: True for long positions, False for short
|
||||
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
|
||||
|
||||
Returns:
|
||||
Liquidation price
|
||||
"""
|
||||
if leverage <= 1:
|
||||
return 0.0 if is_long else float('inf')
|
||||
|
||||
# Simplified liquidation formula
|
||||
# Long: liq_price = entry * (1 - 1/leverage + maintenance_margin_rate)
|
||||
# Short: liq_price = entry * (1 + 1/leverage - maintenance_margin_rate)
|
||||
margin_ratio = 1 / leverage
|
||||
|
||||
if is_long:
|
||||
liq_price = entry_price * (1 - margin_ratio + maintenance_margin_rate)
|
||||
else:
|
||||
liq_price = entry_price * (1 + margin_ratio - maintenance_margin_rate)
|
||||
|
||||
return liq_price
|
||||
245
engine/optimizer.py
Normal file
245
engine/optimizer.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Walk-Forward Analysis optimizer for strategy parameter optimization.
|
||||
|
||||
Implements expanding window walk-forward analysis with train/test splits.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_rolling_windows(
|
||||
index: pd.Index,
|
||||
n_windows: int,
|
||||
train_split: float = 0.7
|
||||
):
|
||||
"""
|
||||
Create rolling train/test split indices using expanding window approach.
|
||||
|
||||
Args:
|
||||
index: DataFrame index to split
|
||||
n_windows: Number of walk-forward windows
|
||||
train_split: Unused, kept for API compatibility
|
||||
|
||||
Yields:
|
||||
Tuples of (train_idx, test_idx) numpy arrays
|
||||
"""
|
||||
chunks = np.array_split(index, n_windows + 1)
|
||||
|
||||
for i in range(n_windows):
|
||||
train_idx = np.concatenate([c for c in chunks[:i+1]])
|
||||
test_idx = chunks[i+1]
|
||||
yield train_idx, test_idx
|
||||
|
||||
|
||||
class WalkForwardOptimizer:
|
||||
"""
|
||||
Walk-Forward Analysis optimizer for strategy backtesting.
|
||||
|
||||
Optimizes strategy parameters on training windows and validates
|
||||
on out-of-sample test windows.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backtester,
|
||||
strategy,
|
||||
param_grid: dict,
|
||||
metric: str = 'Sharpe Ratio',
|
||||
fees: float = 0.001,
|
||||
freq: str = '1m'
|
||||
):
|
||||
"""
|
||||
Initialize the optimizer.
|
||||
|
||||
Args:
|
||||
backtester: Backtester instance
|
||||
strategy: Strategy instance to optimize
|
||||
param_grid: Parameter grid for optimization
|
||||
metric: Performance metric to optimize
|
||||
fees: Transaction fees for simulation
|
||||
freq: Data frequency for portfolio simulation
|
||||
"""
|
||||
self.bt = backtester
|
||||
self.strategy = strategy
|
||||
self.param_grid = param_grid
|
||||
self.metric = metric
|
||||
self.fees = fees
|
||||
self.freq = freq
|
||||
|
||||
# Separate grid params (lists) from fixed params (scalars)
|
||||
self.grid_keys = []
|
||||
self.fixed_params = {}
|
||||
for k, v in param_grid.items():
|
||||
if isinstance(v, (list, np.ndarray)):
|
||||
self.grid_keys.append(k)
|
||||
else:
|
||||
self.fixed_params[k] = v
|
||||
|
||||
def run(
|
||||
self,
|
||||
close_price: pd.Series,
|
||||
high: pd.Series | None = None,
|
||||
low: pd.Series | None = None,
|
||||
n_windows: int = 10
|
||||
) -> tuple[pd.DataFrame, pd.Series | None]:
|
||||
"""
|
||||
Execute walk-forward analysis.
|
||||
|
||||
Args:
|
||||
close_price: Close price series
|
||||
high: High price series (optional)
|
||||
low: Low price series (optional)
|
||||
n_windows: Number of walk-forward windows
|
||||
|
||||
Returns:
|
||||
Tuple of (results DataFrame, stitched equity curve)
|
||||
"""
|
||||
results = []
|
||||
equity_curves = []
|
||||
|
||||
logger.info(
|
||||
"Starting Walk-Forward Analysis with %d windows (Expanding Train)...",
|
||||
n_windows
|
||||
)
|
||||
|
||||
splitter = create_rolling_windows(close_price.index, n_windows)
|
||||
|
||||
for i, (train_idx, test_idx) in enumerate(splitter):
|
||||
logger.info("Processing Window %d/%d...", i + 1, n_windows)
|
||||
|
||||
window_result = self._process_window(
|
||||
i, train_idx, test_idx, close_price, high, low
|
||||
)
|
||||
|
||||
if window_result is not None:
|
||||
result_dict, eq_curve = window_result
|
||||
results.append(result_dict)
|
||||
equity_curves.append(eq_curve)
|
||||
|
||||
stitched_series = self._stitch_equity_curves(equity_curves)
|
||||
return pd.DataFrame(results), stitched_series
|
||||
|
||||
def _process_window(
|
||||
self,
|
||||
window_idx: int,
|
||||
train_idx: np.ndarray,
|
||||
test_idx: np.ndarray,
|
||||
close_price: pd.Series,
|
||||
high: pd.Series | None,
|
||||
low: pd.Series | None
|
||||
) -> tuple[dict, pd.Series] | None:
|
||||
"""Process a single WFA window."""
|
||||
try:
|
||||
# Slice data for train/test
|
||||
train_close = close_price.loc[train_idx]
|
||||
train_high = high.loc[train_idx] if high is not None else None
|
||||
train_low = low.loc[train_idx] if low is not None else None
|
||||
|
||||
# Train phase: find best parameters
|
||||
best_params, best_score = self._optimize_train(
|
||||
train_close, train_high, train_low
|
||||
)
|
||||
|
||||
# Test phase: validate with best params
|
||||
test_close = close_price.loc[test_idx]
|
||||
test_high = high.loc[test_idx] if high is not None else None
|
||||
test_low = low.loc[test_idx] if low is not None else None
|
||||
|
||||
test_params = {**self.fixed_params, **best_params}
|
||||
test_score, test_return, eq_curve = self._run_test(
|
||||
test_close, test_high, test_low, test_params
|
||||
)
|
||||
|
||||
return {
|
||||
'window': window_idx + 1,
|
||||
'train_start': train_idx[0],
|
||||
'train_end': train_idx[-1],
|
||||
'test_start': test_idx[0],
|
||||
'test_end': test_idx[-1],
|
||||
'best_params': best_params,
|
||||
'train_score': best_score,
|
||||
'test_score': test_score,
|
||||
'test_return': test_return
|
||||
}, eq_curve
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in window %d: %s", window_idx + 1, e, exc_info=True)
|
||||
return None
|
||||
|
||||
def _optimize_train(
|
||||
self,
|
||||
close: pd.Series,
|
||||
high: pd.Series | None,
|
||||
low: pd.Series | None
|
||||
) -> tuple[dict, float]:
|
||||
"""Run grid search on training data to find best parameters."""
|
||||
entries, exits = self.strategy.run(
|
||||
close, high=high, low=low, **self.param_grid
|
||||
)
|
||||
|
||||
pf_train = vbt.Portfolio.from_signals(
|
||||
close, entries, exits,
|
||||
fees=self.fees,
|
||||
freq=self.freq
|
||||
)
|
||||
|
||||
perf_stats = pf_train.sharpe_ratio()
|
||||
perf_stats = perf_stats.fillna(-999)
|
||||
|
||||
best_idx = perf_stats.idxmax()
|
||||
best_score = perf_stats.max()
|
||||
|
||||
# Extract best params from grid search
|
||||
if len(self.grid_keys) == 1:
|
||||
best_params = {self.grid_keys[0]: best_idx}
|
||||
elif len(self.grid_keys) > 1:
|
||||
best_params = dict(zip(self.grid_keys, best_idx))
|
||||
else:
|
||||
best_params = {}
|
||||
|
||||
return best_params, best_score
|
||||
|
||||
def _run_test(
|
||||
self,
|
||||
close: pd.Series,
|
||||
high: pd.Series | None,
|
||||
low: pd.Series | None,
|
||||
params: dict
|
||||
) -> tuple[float, float, pd.Series]:
|
||||
"""Run test phase with given parameters."""
|
||||
entries, exits = self.strategy.run(
|
||||
close, high=high, low=low, **params
|
||||
)
|
||||
|
||||
pf_test = vbt.Portfolio.from_signals(
|
||||
close, entries, exits,
|
||||
fees=self.fees,
|
||||
freq=self.freq
|
||||
)
|
||||
|
||||
return pf_test.sharpe_ratio(), pf_test.total_return(), pf_test.value()
|
||||
|
||||
def _stitch_equity_curves(
|
||||
self,
|
||||
equity_curves: list[pd.Series]
|
||||
) -> pd.Series | None:
|
||||
"""Stitch multiple equity curves into a continuous series."""
|
||||
if not equity_curves:
|
||||
return None
|
||||
|
||||
stitched = [equity_curves[0]]
|
||||
for j in range(1, len(equity_curves)):
|
||||
prev_end_val = stitched[-1].iloc[-1]
|
||||
curr_curve = equity_curves[j]
|
||||
init_cash = curr_curve.iloc[0]
|
||||
|
||||
# Scale curve to continue from previous end value
|
||||
scaled_curve = (curr_curve / init_cash) * prev_end_val
|
||||
stitched.append(scaled_curve)
|
||||
|
||||
return pd.concat(stitched)
|
||||
148
engine/portfolio.py
Normal file
148
engine/portfolio.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Portfolio simulation utilities for backtesting.
|
||||
|
||||
Handles long-only and long/short portfolio creation using VectorBT.
|
||||
"""
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_long_only_portfolio(
|
||||
close: pd.Series,
|
||||
entries: pd.DataFrame,
|
||||
exits: pd.DataFrame,
|
||||
init_cash: float,
|
||||
fees: float,
|
||||
slippage: float,
|
||||
freq: str,
|
||||
sl_stop: float | None,
|
||||
tp_stop: float | None,
|
||||
sl_trail: bool,
|
||||
leverage: int
|
||||
) -> vbt.Portfolio:
|
||||
"""
|
||||
Run a long-only portfolio simulation.
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
entries: Entry signals
|
||||
exits: Exit signals
|
||||
init_cash: Initial capital
|
||||
fees: Transaction fee percentage
|
||||
slippage: Slippage percentage
|
||||
freq: Data frequency string
|
||||
sl_stop: Stop loss percentage
|
||||
tp_stop: Take profit percentage
|
||||
sl_trail: Enable trailing stop loss
|
||||
leverage: Leverage multiplier
|
||||
|
||||
Returns:
|
||||
VectorBT Portfolio object
|
||||
"""
|
||||
effective_cash = init_cash * leverage
|
||||
|
||||
return vbt.Portfolio.from_signals(
|
||||
close=close,
|
||||
entries=entries,
|
||||
exits=exits,
|
||||
init_cash=effective_cash,
|
||||
fees=fees,
|
||||
slippage=slippage,
|
||||
freq=freq,
|
||||
sl_stop=sl_stop,
|
||||
tp_stop=tp_stop,
|
||||
sl_trail=sl_trail,
|
||||
size=1.0,
|
||||
size_type='percent',
|
||||
)
|
||||
|
||||
|
||||
def run_long_short_portfolio(
|
||||
close: pd.Series,
|
||||
long_entries: pd.DataFrame,
|
||||
long_exits: pd.DataFrame,
|
||||
short_entries: pd.DataFrame,
|
||||
short_exits: pd.DataFrame,
|
||||
init_cash: float,
|
||||
fees: float,
|
||||
slippage: float,
|
||||
freq: str,
|
||||
sl_stop: float | None,
|
||||
tp_stop: float | None,
|
||||
sl_trail: bool,
|
||||
leverage: int
|
||||
) -> vbt.Portfolio:
|
||||
"""
|
||||
Run a portfolio supporting both long and short positions.
|
||||
|
||||
Runs two separate portfolios (long and short) and combines results.
|
||||
Each gets half the capital.
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
init_cash: Initial capital
|
||||
fees: Transaction fee percentage
|
||||
slippage: Slippage percentage
|
||||
freq: Data frequency string
|
||||
sl_stop: Stop loss percentage
|
||||
tp_stop: Take profit percentage
|
||||
sl_trail: Enable trailing stop loss
|
||||
leverage: Leverage multiplier
|
||||
|
||||
Returns:
|
||||
VectorBT Portfolio object (long portfolio, short stats logged)
|
||||
"""
|
||||
effective_cash = init_cash * leverage
|
||||
half_cash = effective_cash / 2
|
||||
|
||||
# Run long-only portfolio
|
||||
long_pf = vbt.Portfolio.from_signals(
|
||||
close=close,
|
||||
entries=long_entries,
|
||||
exits=long_exits,
|
||||
direction='longonly',
|
||||
init_cash=half_cash,
|
||||
fees=fees,
|
||||
slippage=slippage,
|
||||
freq=freq,
|
||||
sl_stop=sl_stop,
|
||||
tp_stop=tp_stop,
|
||||
sl_trail=sl_trail,
|
||||
size=1.0,
|
||||
size_type='percent',
|
||||
)
|
||||
|
||||
# Run short-only portfolio
|
||||
short_pf = vbt.Portfolio.from_signals(
|
||||
close=close,
|
||||
entries=short_entries,
|
||||
exits=short_exits,
|
||||
direction='shortonly',
|
||||
init_cash=half_cash,
|
||||
fees=fees,
|
||||
slippage=slippage,
|
||||
freq=freq,
|
||||
sl_stop=sl_stop,
|
||||
tp_stop=tp_stop,
|
||||
sl_trail=sl_trail,
|
||||
size=1.0,
|
||||
size_type='percent',
|
||||
)
|
||||
|
||||
# Log both portfolio stats
|
||||
# TODO: Implement proper portfolio combination
|
||||
logger.info(
|
||||
"Long portfolio: %.2f%% return, Short portfolio: %.2f%% return",
|
||||
long_pf.total_return().mean() * 100,
|
||||
short_pf.total_return().mean() * 100
|
||||
)
|
||||
|
||||
return long_pf
|
||||
228
engine/reporting.py
Normal file
228
engine/reporting.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Reporting module for backtest results.
|
||||
|
||||
Handles summary printing, CSV exports, and plotting.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import vectorbt as vbt
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Reporter:
|
||||
"""Reporter for backtest results with market-specific metrics."""
|
||||
|
||||
def __init__(self, output_dir: str = "backtest_logs"):
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def print_summary(self, result) -> None:
|
||||
"""
|
||||
Print backtest summary to console via logger.
|
||||
|
||||
Args:
|
||||
result: BacktestResult or vbt.Portfolio object
|
||||
"""
|
||||
(portfolio, market_type, leverage, funding_paid,
|
||||
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
|
||||
|
||||
# Extract period info
|
||||
idx = portfolio.wrapper.index
|
||||
start_date = idx[0].strftime("%Y-%m-%d")
|
||||
end_date = idx[-1].strftime("%Y-%m-%d")
|
||||
|
||||
# Extract price info
|
||||
close = portfolio.close
|
||||
start_price = close.iloc[0].mean() if hasattr(close.iloc[0], 'mean') else close.iloc[0]
|
||||
end_price = close.iloc[-1].mean() if hasattr(close.iloc[-1], 'mean') else close.iloc[-1]
|
||||
price_change = ((end_price - start_price) / start_price) * 100
|
||||
|
||||
# Extract fees
|
||||
stats = portfolio.stats()
|
||||
total_fees = stats.get('Total Fees Paid', 0)
|
||||
|
||||
raw_return = portfolio.total_return().mean() * 100
|
||||
|
||||
# Build summary
|
||||
summary_lines = [
|
||||
"",
|
||||
"=" * 50,
|
||||
"BACKTEST RESULTS",
|
||||
"=" * 50,
|
||||
f"Market Type: [{market_type.upper()}]",
|
||||
f"Leverage: [{leverage}x]",
|
||||
f"Period: [{start_date} to {end_date}]",
|
||||
f"Price: [{start_price:,.2f} -> {end_price:,.2f} ({price_change:+.2f}%)]",
|
||||
]
|
||||
|
||||
# Show adjusted return if liquidations occurred
|
||||
if liq_count > 0 and adjusted_return is not None:
|
||||
summary_lines.append(f"Raw Return: [%{raw_return:.2f}] (before liq adjustment)")
|
||||
summary_lines.append(f"Adj Return: [%{adjusted_return:.2f}] (after liq losses)")
|
||||
else:
|
||||
summary_lines.append(f"Total Return: [%{raw_return:.2f}]")
|
||||
|
||||
summary_lines.extend([
|
||||
f"Sharpe Ratio: [{portfolio.sharpe_ratio().mean():.2f}]",
|
||||
f"Max Drawdown: [%{portfolio.max_drawdown().mean() * 100:.2f}]",
|
||||
f"Total Trades: [{portfolio.trades.count().mean():.0f}]",
|
||||
f"Win Rate: [%{portfolio.trades.win_rate().mean() * 100:.2f}]",
|
||||
f"Total Fees: [{total_fees:,.2f}]",
|
||||
])
|
||||
|
||||
if funding_paid != 0:
|
||||
summary_lines.append(f"Funding Paid: [{funding_paid:,.2f}]")
|
||||
if liq_count > 0:
|
||||
summary_lines.append(f"Liquidations: [{liq_count}] (${liq_loss:,.2f} margin lost)")
|
||||
|
||||
summary_lines.append("=" * 50)
|
||||
logger.info("\n".join(summary_lines))
|
||||
|
||||
def save_reports(self, result, filename_prefix: str) -> None:
|
||||
"""
|
||||
Save trade log, stats, and liquidation events to CSV files.
|
||||
|
||||
Args:
|
||||
result: BacktestResult or vbt.Portfolio object
|
||||
filename_prefix: Prefix for output filenames
|
||||
"""
|
||||
(portfolio, market_type, leverage, funding_paid,
|
||||
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
|
||||
|
||||
# Save trades
|
||||
self._save_csv(
|
||||
data=portfolio.trades.records_readable,
|
||||
path=self.output_dir / f"{filename_prefix}_trades.csv",
|
||||
description="trade log"
|
||||
)
|
||||
|
||||
# Save stats with market-specific additions
|
||||
stats = portfolio.stats()
|
||||
stats['Market Type'] = market_type
|
||||
stats['Leverage'] = leverage
|
||||
stats['Total Funding Paid'] = funding_paid
|
||||
stats['Liquidations'] = liq_count
|
||||
stats['Liquidation Loss'] = liq_loss
|
||||
if adjusted_return is not None:
|
||||
stats['Adjusted Return'] = adjusted_return
|
||||
|
||||
self._save_csv(
|
||||
data=stats,
|
||||
path=self.output_dir / f"{filename_prefix}_stats.csv",
|
||||
description="stats"
|
||||
)
|
||||
|
||||
# Save liquidation events if any
|
||||
if hasattr(result, 'liquidation_events') and result.liquidation_events:
|
||||
liq_df = pd.DataFrame([
|
||||
{
|
||||
'entry_time': e.entry_time,
|
||||
'entry_price': e.entry_price,
|
||||
'liquidation_time': e.liquidation_time,
|
||||
'liquidation_price': e.liquidation_price,
|
||||
'actual_price': e.actual_price,
|
||||
'direction': e.direction,
|
||||
'margin_lost_pct': e.margin_lost_pct
|
||||
}
|
||||
for e in result.liquidation_events
|
||||
])
|
||||
self._save_csv(
|
||||
data=liq_df,
|
||||
path=self.output_dir / f"{filename_prefix}_liquidations.csv",
|
||||
description="liquidation events"
|
||||
)
|
||||
|
||||
def plot(self, portfolio: vbt.Portfolio, show: bool = True) -> None:
|
||||
"""Display portfolio plot."""
|
||||
if show:
|
||||
portfolio.plot().show()
|
||||
|
||||
def plot_wfa(
|
||||
self,
|
||||
wfa_results: pd.DataFrame,
|
||||
stitched_curve: pd.Series | None = None,
|
||||
show: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Plot Walk-Forward Analysis results.
|
||||
|
||||
Args:
|
||||
wfa_results: DataFrame with WFA window results
|
||||
stitched_curve: Stitched out-of-sample equity curve
|
||||
show: Whether to display the plot
|
||||
"""
|
||||
fig = make_subplots(
|
||||
rows=2, cols=1,
|
||||
shared_xaxes=False,
|
||||
vertical_spacing=0.1,
|
||||
subplot_titles=(
|
||||
"Walk-Forward Test Scores (Sharpe)",
|
||||
"Stitched Out-of-Sample Equity"
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=wfa_results['window'],
|
||||
y=wfa_results['test_score'],
|
||||
name="Test Sharpe"
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
if stitched_curve is not None:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=stitched_curve.index,
|
||||
y=stitched_curve.values,
|
||||
name="Equity",
|
||||
mode='lines'
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
fig.update_layout(height=800, title_text="Walk-Forward Analysis Report")
|
||||
|
||||
if show:
|
||||
fig.show()
|
||||
|
||||
def _extract_result_data(self, result) -> tuple:
|
||||
"""
|
||||
Extract data from BacktestResult or raw Portfolio.
|
||||
|
||||
Returns:
|
||||
Tuple of (portfolio, market_type, leverage, funding_paid, liq_count,
|
||||
liq_loss, adjusted_return)
|
||||
"""
|
||||
if hasattr(result, 'portfolio'):
|
||||
return (
|
||||
result.portfolio,
|
||||
result.market_type.value,
|
||||
result.leverage,
|
||||
result.total_funding_paid,
|
||||
result.liquidation_count,
|
||||
getattr(result, 'total_liquidation_loss', 0.0),
|
||||
getattr(result, 'adjusted_return', None)
|
||||
)
|
||||
return (result, "unknown", 1, 0.0, 0, 0.0, None)
|
||||
|
||||
def _save_csv(self, data, path: Path, description: str) -> None:
|
||||
"""
|
||||
Save data to CSV with consistent error handling.
|
||||
|
||||
Args:
|
||||
data: DataFrame or Series to save
|
||||
path: Output file path
|
||||
description: Human-readable description for logging
|
||||
"""
|
||||
try:
|
||||
data.to_csv(path)
|
||||
logger.info("Saved %s to %s", description, path)
|
||||
except Exception as e:
|
||||
logger.error("Could not save %s: %s", description, e)
|
||||
395
engine/risk.py
Normal file
395
engine/risk.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Risk management utilities for backtesting.
|
||||
|
||||
Handles funding rate calculations and liquidation detection.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from engine.market import MarketConfig, calculate_liquidation_price
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiquidationEvent:
|
||||
"""
|
||||
Record of a liquidation event during backtesting.
|
||||
|
||||
Attributes:
|
||||
entry_time: Timestamp when position was opened
|
||||
entry_price: Price at position entry
|
||||
liquidation_time: Timestamp when liquidation occurred
|
||||
liquidation_price: Calculated liquidation price
|
||||
actual_price: Actual price that triggered liquidation (high/low)
|
||||
direction: 'long' or 'short'
|
||||
margin_lost_pct: Percentage of margin lost (typically 100%)
|
||||
"""
|
||||
entry_time: pd.Timestamp
|
||||
entry_price: float
|
||||
liquidation_time: pd.Timestamp
|
||||
liquidation_price: float
|
||||
actual_price: float
|
||||
direction: str
|
||||
margin_lost_pct: float = 1.0
|
||||
|
||||
|
||||
def calculate_funding(
|
||||
close: pd.Series,
|
||||
long_entries: pd.DataFrame,
|
||||
long_exits: pd.DataFrame,
|
||||
short_entries: pd.DataFrame,
|
||||
short_exits: pd.DataFrame,
|
||||
market_config: MarketConfig,
|
||||
leverage: int
|
||||
) -> float:
|
||||
"""
|
||||
Calculate total funding paid/received for perpetual positions.
|
||||
|
||||
Simplified model: applies funding rate every 8 hours to open positions.
|
||||
Positive rate means longs pay shorts.
|
||||
|
||||
Args:
|
||||
close: Price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
market_config: Market configuration with funding parameters
|
||||
leverage: Position leverage
|
||||
|
||||
Returns:
|
||||
Total funding paid (positive) or received (negative)
|
||||
"""
|
||||
if market_config.funding_interval_hours == 0:
|
||||
return 0.0
|
||||
|
||||
funding_rate = market_config.funding_rate
|
||||
interval_hours = market_config.funding_interval_hours
|
||||
|
||||
# Determine position state at each bar
|
||||
long_position = long_entries.cumsum() - long_exits.cumsum()
|
||||
short_position = short_entries.cumsum() - short_exits.cumsum()
|
||||
|
||||
# Clamp to 0/1 (either in position or not)
|
||||
long_position = (long_position > 0).astype(int)
|
||||
short_position = (short_position > 0).astype(int)
|
||||
|
||||
# Find funding timestamps (every 8 hours: 00:00, 08:00, 16:00 UTC)
|
||||
funding_times = close.index[close.index.hour % interval_hours == 0]
|
||||
|
||||
total_funding = 0.0
|
||||
for ts in funding_times:
|
||||
if ts not in close.index:
|
||||
continue
|
||||
price = close.loc[ts]
|
||||
|
||||
# Long pays funding, short receives (when rate > 0)
|
||||
if isinstance(long_position, pd.DataFrame):
|
||||
long_open = long_position.loc[ts].any()
|
||||
short_open = short_position.loc[ts].any()
|
||||
else:
|
||||
long_open = long_position.loc[ts] > 0
|
||||
short_open = short_position.loc[ts] > 0
|
||||
|
||||
position_value = price * leverage
|
||||
if long_open:
|
||||
total_funding += position_value * funding_rate
|
||||
if short_open:
|
||||
total_funding -= position_value * funding_rate
|
||||
|
||||
return total_funding
|
||||
|
||||
|
||||
def inject_liquidation_exits(
|
||||
close: pd.Series,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
long_entries: pd.DataFrame | pd.Series,
|
||||
long_exits: pd.DataFrame | pd.Series,
|
||||
short_entries: pd.DataFrame | pd.Series,
|
||||
short_exits: pd.DataFrame | pd.Series,
|
||||
leverage: int,
|
||||
maintenance_margin_rate: float
|
||||
) -> tuple[pd.DataFrame | pd.Series, pd.DataFrame | pd.Series, list[LiquidationEvent]]:
|
||||
"""
|
||||
Modify exit signals to force position closure at liquidation points.
|
||||
|
||||
This function simulates realistic liquidation behavior by:
|
||||
1. Finding positions that would be liquidated before their normal exit
|
||||
2. Injecting forced exit signals at the liquidation bar
|
||||
3. Recording all liquidation events
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
high: High price series
|
||||
low: Low price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
leverage: Position leverage
|
||||
maintenance_margin_rate: Maintenance margin rate for liquidation
|
||||
|
||||
Returns:
|
||||
Tuple of (modified_long_exits, modified_short_exits, liquidation_events)
|
||||
"""
|
||||
if leverage <= 1:
|
||||
return long_exits, short_exits, []
|
||||
|
||||
liquidation_events: list[LiquidationEvent] = []
|
||||
|
||||
# Convert to DataFrame if Series for consistent handling
|
||||
is_series = isinstance(long_entries, pd.Series)
|
||||
if is_series:
|
||||
long_entries_df = long_entries.to_frame()
|
||||
long_exits_df = long_exits.to_frame()
|
||||
short_entries_df = short_entries.to_frame()
|
||||
short_exits_df = short_exits.to_frame()
|
||||
else:
|
||||
long_entries_df = long_entries
|
||||
long_exits_df = long_exits.copy()
|
||||
short_entries_df = short_entries
|
||||
short_exits_df = short_exits.copy()
|
||||
|
||||
modified_long_exits = long_exits_df.copy()
|
||||
modified_short_exits = short_exits_df.copy()
|
||||
|
||||
# Process long positions
|
||||
long_mask = long_entries_df.any(axis=1)
|
||||
for entry_idx in close.index[long_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=True,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
# Find the normal exit for this entry
|
||||
subsequent_exits = long_exits_df.loc[entry_idx:].any(axis=1)
|
||||
exit_indices = subsequent_exits[subsequent_exits].index
|
||||
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
|
||||
|
||||
# Check if liquidation occurs before normal exit
|
||||
price_range = low.loc[entry_idx:normal_exit_idx]
|
||||
if (price_range < liq_price).any():
|
||||
liq_bar = price_range[price_range < liq_price].index[0]
|
||||
|
||||
# Inject forced exit at liquidation bar
|
||||
for col in modified_long_exits.columns:
|
||||
modified_long_exits.loc[liq_bar, col] = True
|
||||
|
||||
# Record the liquidation event
|
||||
liquidation_events.append(LiquidationEvent(
|
||||
entry_time=entry_idx,
|
||||
entry_price=entry_price,
|
||||
liquidation_time=liq_bar,
|
||||
liquidation_price=liq_price,
|
||||
actual_price=low.loc[liq_bar],
|
||||
direction='long',
|
||||
margin_lost_pct=1.0
|
||||
))
|
||||
|
||||
logger.warning(
|
||||
"LIQUIDATION (Long): Entry %s ($%.2f) -> Liquidated %s "
|
||||
"(liq=$%.2f, low=$%.2f)",
|
||||
entry_idx.strftime('%Y-%m-%d'), entry_price,
|
||||
liq_bar.strftime('%Y-%m-%d'), liq_price, low.loc[liq_bar]
|
||||
)
|
||||
|
||||
# Process short positions
|
||||
short_mask = short_entries_df.any(axis=1)
|
||||
for entry_idx in close.index[short_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=False,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
# Find the normal exit for this entry
|
||||
subsequent_exits = short_exits_df.loc[entry_idx:].any(axis=1)
|
||||
exit_indices = subsequent_exits[subsequent_exits].index
|
||||
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
|
||||
|
||||
# Check if liquidation occurs before normal exit
|
||||
price_range = high.loc[entry_idx:normal_exit_idx]
|
||||
if (price_range > liq_price).any():
|
||||
liq_bar = price_range[price_range > liq_price].index[0]
|
||||
|
||||
# Inject forced exit at liquidation bar
|
||||
for col in modified_short_exits.columns:
|
||||
modified_short_exits.loc[liq_bar, col] = True
|
||||
|
||||
# Record the liquidation event
|
||||
liquidation_events.append(LiquidationEvent(
|
||||
entry_time=entry_idx,
|
||||
entry_price=entry_price,
|
||||
liquidation_time=liq_bar,
|
||||
liquidation_price=liq_price,
|
||||
actual_price=high.loc[liq_bar],
|
||||
direction='short',
|
||||
margin_lost_pct=1.0
|
||||
))
|
||||
|
||||
logger.warning(
|
||||
"LIQUIDATION (Short): Entry %s ($%.2f) -> Liquidated %s "
|
||||
"(liq=$%.2f, high=$%.2f)",
|
||||
entry_idx.strftime('%Y-%m-%d'), entry_price,
|
||||
liq_bar.strftime('%Y-%m-%d'), liq_price, high.loc[liq_bar]
|
||||
)
|
||||
|
||||
# Convert back to Series if input was Series
|
||||
if is_series:
|
||||
modified_long_exits = modified_long_exits.iloc[:, 0]
|
||||
modified_short_exits = modified_short_exits.iloc[:, 0]
|
||||
|
||||
return modified_long_exits, modified_short_exits, liquidation_events
|
||||
|
||||
|
||||
def calculate_liquidation_adjustment(
|
||||
liquidation_events: list[LiquidationEvent],
|
||||
init_cash: float,
|
||||
leverage: int
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate the return adjustment for liquidated positions.
|
||||
|
||||
VectorBT calculates trade P&L using close price at exit bar.
|
||||
For liquidations, the actual loss is 100% of the position margin.
|
||||
This function calculates the difference between what VectorBT
|
||||
recorded and what actually would have happened.
|
||||
|
||||
In our portfolio setup:
|
||||
- Long/short each get half the capital (init_cash * leverage / 2)
|
||||
- Each trade uses 100% of that allocation (size=1.0, percent)
|
||||
- On liquidation, the margin for that trade is lost entirely
|
||||
|
||||
The adjustment is the DIFFERENCE between:
|
||||
- VectorBT's calculated P&L (exit at close price)
|
||||
- Actual liquidation P&L (100% margin loss)
|
||||
|
||||
Args:
|
||||
liquidation_events: List of liquidation events
|
||||
init_cash: Initial portfolio cash (before leverage)
|
||||
leverage: Position leverage used
|
||||
|
||||
Returns:
|
||||
Tuple of (total_margin_lost, adjustment_pct)
|
||||
- total_margin_lost: Estimated total margin lost from liquidations
|
||||
- adjustment_pct: Percentage adjustment to apply to returns
|
||||
"""
|
||||
if not liquidation_events:
|
||||
return 0.0, 0.0
|
||||
|
||||
# In our setup, each side (long/short) gets half the capital
|
||||
# Margin per side = init_cash / 2
|
||||
margin_per_side = init_cash / 2
|
||||
|
||||
# For each liquidation, VectorBT recorded some P&L based on close price
|
||||
# The actual P&L should be -100% of the margin used for that trade
|
||||
#
|
||||
# We estimate the adjustment as:
|
||||
# - Each liquidation should have resulted in ~-20% loss (at 5x leverage)
|
||||
# - VectorBT may have recorded a different value
|
||||
# - The margin loss is (1/leverage) per trade that gets liquidated
|
||||
|
||||
# Calculate liquidation loss rate based on leverage
|
||||
# At 5x leverage, liquidation = ~19.6% adverse move = 100% margin loss
|
||||
liq_loss_rate = 1.0 / leverage # Approximate loss per trade as % of position
|
||||
|
||||
# Count liquidations
|
||||
n_liquidations = len(liquidation_events)
|
||||
|
||||
# Estimate total margin lost:
|
||||
# Each liquidation on average loses the margin for that trade
|
||||
# Since VectorBT uses half capital per side, and we trade 100% size,
|
||||
# each liquidation loses approximately margin_per_side
|
||||
# But we cap at available capital
|
||||
total_margin_lost = min(n_liquidations * margin_per_side * liq_loss_rate, init_cash)
|
||||
|
||||
# Calculate as percentage of initial capital
|
||||
adjustment_pct = (total_margin_lost / init_cash) * 100
|
||||
|
||||
return total_margin_lost, adjustment_pct
|
||||
|
||||
|
||||
def check_liquidations(
|
||||
close: pd.Series,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
long_entries: pd.DataFrame,
|
||||
long_exits: pd.DataFrame,
|
||||
short_entries: pd.DataFrame,
|
||||
short_exits: pd.DataFrame,
|
||||
leverage: int,
|
||||
maintenance_margin_rate: float
|
||||
) -> int:
|
||||
"""
|
||||
Check for liquidation events and log warnings.
|
||||
|
||||
Args:
|
||||
close: Close price series
|
||||
high: High price series
|
||||
low: Low price series
|
||||
long_entries: Long entry signals
|
||||
long_exits: Long exit signals
|
||||
short_entries: Short entry signals
|
||||
short_exits: Short exit signals
|
||||
leverage: Position leverage
|
||||
maintenance_margin_rate: Maintenance margin rate for liquidation
|
||||
|
||||
Returns:
|
||||
Count of liquidation warnings
|
||||
"""
|
||||
warnings = 0
|
||||
|
||||
# For long positions
|
||||
long_mask = (
|
||||
long_entries.any(axis=1)
|
||||
if isinstance(long_entries, pd.DataFrame)
|
||||
else long_entries
|
||||
)
|
||||
|
||||
for entry_idx in close.index[long_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=True,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
subsequent = low.loc[entry_idx:]
|
||||
if (subsequent < liq_price).any():
|
||||
liq_bar = subsequent[subsequent < liq_price].index[0]
|
||||
logger.warning(
|
||||
"LIQUIDATION WARNING (Long): Entry at %s ($%.2f), "
|
||||
"would liquidate at %s (liq_price=$%.2f, low=$%.2f)",
|
||||
entry_idx, entry_price, liq_bar, liq_price, low.loc[liq_bar]
|
||||
)
|
||||
warnings += 1
|
||||
|
||||
# For short positions
|
||||
short_mask = (
|
||||
short_entries.any(axis=1)
|
||||
if isinstance(short_entries, pd.DataFrame)
|
||||
else short_entries
|
||||
)
|
||||
|
||||
for entry_idx in close.index[short_mask]:
|
||||
entry_price = close.loc[entry_idx]
|
||||
liq_price = calculate_liquidation_price(
|
||||
entry_price, leverage, is_long=False,
|
||||
maintenance_margin_rate=maintenance_margin_rate
|
||||
)
|
||||
|
||||
subsequent = high.loc[entry_idx:]
|
||||
if (subsequent > liq_price).any():
|
||||
liq_bar = subsequent[subsequent > liq_price].index[0]
|
||||
logger.warning(
|
||||
"LIQUIDATION WARNING (Short): Entry at %s ($%.2f), "
|
||||
"would liquidate at %s (liq_price=$%.2f, high=$%.2f)",
|
||||
entry_idx, entry_price, liq_bar, liq_price, high.loc[liq_bar]
|
||||
)
|
||||
warnings += 1
|
||||
|
||||
return warnings
|
||||
@@ -1,3 +0,0 @@
|
||||
from .supertrend import add_supertrends, compute_meta_trend
|
||||
|
||||
__all__ = ["add_supertrends", "compute_meta_trend"]
|
||||
@@ -1,58 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _atr(high: pd.Series, low: pd.Series, close: pd.Series, period: int) -> pd.Series:
|
||||
hl = (high - low).abs()
|
||||
hc = (high - close.shift()).abs()
|
||||
lc = (low - close.shift()).abs()
|
||||
tr = pd.concat([hl, hc, lc], axis=1).max(axis=1)
|
||||
return tr.rolling(period, min_periods=period).mean()
|
||||
|
||||
|
||||
def supertrend_series(df: pd.DataFrame, length: int, multiplier: float) -> pd.Series:
|
||||
atr = _atr(df["High"], df["Low"], df["Close"], length)
|
||||
hl2 = (df["High"] + df["Low"]) / 2
|
||||
upper = hl2 + multiplier * atr
|
||||
lower = hl2 - multiplier * atr
|
||||
|
||||
trend = pd.Series(index=df.index, dtype=float)
|
||||
dir_up = True
|
||||
prev_upper = np.nan
|
||||
prev_lower = np.nan
|
||||
|
||||
for i in range(len(df)):
|
||||
if i == 0 or pd.isna(atr.iat[i]):
|
||||
trend.iat[i] = np.nan
|
||||
prev_upper = upper.iat[i]
|
||||
prev_lower = lower.iat[i]
|
||||
continue
|
||||
|
||||
cu = min(upper.iat[i], prev_upper) if dir_up else upper.iat[i]
|
||||
cl = max(lower.iat[i], prev_lower) if not dir_up else lower.iat[i]
|
||||
|
||||
if df["Close"].iat[i] > cu:
|
||||
dir_up = True
|
||||
elif df["Close"].iat[i] < cl:
|
||||
dir_up = False
|
||||
|
||||
prev_upper = cu if dir_up else upper.iat[i]
|
||||
prev_lower = lower.iat[i] if dir_up else cl
|
||||
trend.iat[i] = cl if dir_up else cu
|
||||
|
||||
return trend
|
||||
|
||||
|
||||
def add_supertrends(df: pd.DataFrame, settings: list[tuple[int, float]]) -> pd.DataFrame:
|
||||
out = df.copy()
|
||||
for length, mult in settings:
|
||||
col = f"supertrend_{length}_{mult}"
|
||||
out[col] = supertrend_series(out, length, mult)
|
||||
out[f"bull_{length}_{mult}"] = (out["Close"] >= out[col]).astype(int)
|
||||
return out
|
||||
|
||||
|
||||
def compute_meta_trend(df: pd.DataFrame, settings: list[tuple[int, float]]) -> pd.Series:
|
||||
bull_cols = [f"bull_{l}_{m}" for l, m in settings]
|
||||
return (df[bull_cols].sum(axis=1) == len(bull_cols)).astype(int)
|
||||
10
intrabar.py
10
intrabar.py
@@ -1,10 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def precompute_slices(df: pd.DataFrame) -> pd.DataFrame:
|
||||
return df # hook for future use
|
||||
|
||||
|
||||
def entry_slippage_row(price: float, qty: float, slippage_bps: float) -> float:
|
||||
return price + price * (slippage_bps / 1e4)
|
||||
@@ -1,11 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def write_trade_log(trades: list[dict], path: Path) -> None:
|
||||
if not trades:
|
||||
return
|
||||
df = pd.DataFrame(trades)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_csv(path, index=False)
|
||||
10
main.py
Normal file
10
main.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Lowkey Backtest CLI - VectorBT Edition
|
||||
|
||||
A backtesting framework supporting multiple market types (spot, perpetual)
|
||||
with realistic trading simulation including leverage, funding, and shorts.
|
||||
"""
|
||||
from engine.cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,11 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
TAKER_FEE_BPS_DEFAULT = 10.0 # 0.10%
|
||||
|
||||
|
||||
def okx_fee(fee_bps: float, notional_usd: float) -> float:
|
||||
return notional_usd * (fee_bps / 1e4)
|
||||
|
||||
|
||||
def estimate_slippage_rate(slippage_bps: float, notional_usd: float) -> float:
|
||||
return notional_usd * (slippage_bps / 1e4)
|
||||
54
metrics.py
54
metrics.py
@@ -1,54 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@dataclass
|
||||
class Perf:
|
||||
total_return: float
|
||||
max_drawdown: float
|
||||
sharpe_ratio: float
|
||||
win_rate: float
|
||||
num_trades: int
|
||||
final_equity: float
|
||||
initial_equity: float
|
||||
num_stop_losses: int
|
||||
total_fees: float
|
||||
total_slippage_usd: float
|
||||
avg_slippage_bps: float
|
||||
|
||||
|
||||
def compute_metrics(equity_curve: pd.Series, trades: list[dict]) -> Perf:
|
||||
ret = equity_curve.pct_change().fillna(0.0)
|
||||
total_return = equity_curve.iat[-1] / equity_curve.iat[0] - 1.0
|
||||
cummax = equity_curve.cummax()
|
||||
dd = (equity_curve / cummax - 1.0).min()
|
||||
max_drawdown = dd
|
||||
|
||||
if ret.std(ddof=0) > 0:
|
||||
sharpe = (ret.mean() / ret.std(ddof=0)) * np.sqrt(252 * 24 * 60) # minute bars -> annualized
|
||||
else:
|
||||
sharpe = 0.0
|
||||
|
||||
closes = [t for t in trades if t.get("side") == "SELL"]
|
||||
wins = [t for t in closes if t.get("pnl", 0.0) > 0]
|
||||
win_rate = (len(wins) / len(closes)) if closes else 0.0
|
||||
|
||||
fees = sum(t.get("fee", 0.0) for t in trades)
|
||||
slip = sum(t.get("slippage", 0.0) for t in trades)
|
||||
slippage_bps = [t.get("slippage_bps", 0.0) for t in trades if "slippage_bps" in t]
|
||||
|
||||
return Perf(
|
||||
total_return=total_return,
|
||||
max_drawdown=max_drawdown,
|
||||
sharpe_ratio=sharpe,
|
||||
win_rate=win_rate,
|
||||
num_trades=len(closes),
|
||||
final_equity=float(equity_curve.iat[-1]),
|
||||
initial_equity=float(equity_curve.iat[0]),
|
||||
num_stop_losses=sum(1 for t in closes if t.get("reason") == "stop"),
|
||||
total_fees=fees,
|
||||
total_slippage_usd=slip,
|
||||
avg_slippage_bps=float(np.mean(slippage_bps)) if slippage_bps else 0.0,
|
||||
)
|
||||
@@ -5,5 +5,25 @@ description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"ccxt>=4.5.32",
|
||||
"numpy>=2.3.2",
|
||||
"pandas>=2.3.1",
|
||||
"ta>=0.11.0",
|
||||
"vectorbt>=0.28.2",
|
||||
"scikit-learn>=1.6.0",
|
||||
"matplotlib>=3.10.0",
|
||||
"plotly>=5.24.0",
|
||||
"requests>=2.32.5",
|
||||
"python-dotenv>=1.2.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["."]
|
||||
markers = [
|
||||
"network: marks tests as requiring network access",
|
||||
]
|
||||
|
||||
384
research/regime_detection.py
Normal file
384
research/regime_detection.py
Normal file
@@ -0,0 +1,384 @@
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import ta
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
from engine.data_manager import DataManager
|
||||
from engine.market import MarketType
|
||||
|
||||
def prepare_data(symbol_a="BTC-USDT", symbol_b="ETH-USDT", timeframe="1h", limit=None, start_date=None, end_date=None):
|
||||
"""
|
||||
Load and align data for two assets to create a pair.
|
||||
"""
|
||||
dm = DataManager()
|
||||
|
||||
print(f"Loading data for {symbol_a} and {symbol_b}...")
|
||||
|
||||
# Helper to load or download
|
||||
def get_df(symbol):
|
||||
try:
|
||||
# Try load first
|
||||
df = dm.load_data("okx", symbol, timeframe, MarketType.SPOT)
|
||||
except Exception:
|
||||
df = dm.download_data("okx", symbol, timeframe, market_type=MarketType.SPOT)
|
||||
|
||||
# If we have start/end dates, ensure we have enough data or re-download
|
||||
if start_date:
|
||||
mask_start = pd.Timestamp(start_date, tz='UTC')
|
||||
if df.index.min() > mask_start:
|
||||
print(f"Local data starts {df.index.min()}, need {mask_start}. Downloading...")
|
||||
df = dm.download_data("okx", symbol, timeframe, start_date=start_date, end_date=end_date, market_type=MarketType.SPOT)
|
||||
return df
|
||||
|
||||
df_a = get_df(symbol_a)
|
||||
df_b = get_df(symbol_b)
|
||||
|
||||
# Filter by date if provided (to match CQ data range)
|
||||
if start_date:
|
||||
df_a = df_a[df_a.index >= pd.Timestamp(start_date, tz='UTC')]
|
||||
df_b = df_b[df_b.index >= pd.Timestamp(start_date, tz='UTC')]
|
||||
|
||||
if end_date:
|
||||
df_a = df_a[df_a.index <= pd.Timestamp(end_date, tz='UTC')]
|
||||
df_b = df_b[df_b.index <= pd.Timestamp(end_date, tz='UTC')]
|
||||
|
||||
# Align DataFrames
|
||||
print("Aligning data...")
|
||||
common_index = df_a.index.intersection(df_b.index)
|
||||
df_a = df_a.loc[common_index].copy()
|
||||
df_b = df_b.loc[common_index].copy()
|
||||
|
||||
if limit:
|
||||
df_a = df_a.tail(limit)
|
||||
df_b = df_b.tail(limit)
|
||||
|
||||
return df_a, df_b
|
||||
|
||||
def load_cryptoquant_data(file_path: str) -> pd.DataFrame | None:
|
||||
"""
|
||||
Load CryptoQuant data and prepare it for merging.
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"Warning: CQ data file {file_path} not found.")
|
||||
return None
|
||||
|
||||
print(f"Loading CryptoQuant data from {file_path}...")
|
||||
df = pd.read_csv(file_path, index_col='timestamp', parse_dates=True)
|
||||
|
||||
# CQ data is usually daily (UTC 00:00).
|
||||
# Ensure index is timezone aware to match market data
|
||||
if df.index.tz is None:
|
||||
df.index = df.index.tz_localize('UTC')
|
||||
|
||||
return df
|
||||
|
||||
def calculate_features(df_a, df_b, cq_df=None, window=24):
|
||||
"""
|
||||
Calculate spread, z-score, and advanced regime features including CQ data.
|
||||
"""
|
||||
# 1. Price Ratio (Spread)
|
||||
spread = df_b['close'] / df_a['close']
|
||||
|
||||
# 2. Rolling Statistics for Z-Score
|
||||
rolling_mean = spread.rolling(window=window).mean()
|
||||
rolling_std = spread.rolling(window=window).std()
|
||||
z_score = (spread - rolling_mean) / rolling_std
|
||||
|
||||
# 3. Spread Momentum / Technicals
|
||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||
spread_roc = spread.pct_change(periods=5) * 100
|
||||
|
||||
# 4. Volume Dynamics
|
||||
vol_ratio = df_b['volume'] / df_a['volume']
|
||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||
|
||||
# 5. Volatility Regime
|
||||
ret_a = df_a['close'].pct_change()
|
||||
ret_b = df_b['close'].pct_change()
|
||||
vol_a = ret_a.rolling(window=window).std()
|
||||
vol_b = ret_b.rolling(window=window).std()
|
||||
vol_spread_ratio = vol_b / vol_a
|
||||
|
||||
# Create feature DataFrame
|
||||
features = pd.DataFrame(index=spread.index)
|
||||
features['spread'] = spread
|
||||
features['z_score'] = z_score
|
||||
features['spread_rsi'] = spread_rsi
|
||||
features['spread_roc'] = spread_roc
|
||||
features['vol_ratio'] = vol_ratio
|
||||
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
|
||||
features['vol_diff_ratio'] = vol_spread_ratio
|
||||
|
||||
# 6. Merge CryptoQuant Data
|
||||
if cq_df is not None:
|
||||
print("Merging CryptoQuant features...")
|
||||
# Forward fill daily data to hourly timestamps
|
||||
# reindex features to match cq_df range or join
|
||||
|
||||
# Resample CQ to hourly (ffill)
|
||||
# But easier: join features with cq_df using asof or reindex
|
||||
cq_aligned = cq_df.reindex(features.index, method='ffill')
|
||||
|
||||
# Add derived CQ features
|
||||
# Funding Diff: If ETH funding > BTC funding => ETH overheated
|
||||
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
|
||||
cq_aligned['funding_diff'] = cq_aligned['eth_funding'] - cq_aligned['btc_funding']
|
||||
|
||||
# Inflow Ratio: If ETH inflow >> BTC inflow => ETH dump incoming?
|
||||
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
|
||||
# Add small epsilon to avoid div by zero
|
||||
cq_aligned['inflow_ratio'] = cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
|
||||
|
||||
features = features.join(cq_aligned)
|
||||
|
||||
# --- Refined Target Definition (Anytime Profit) ---
|
||||
horizon = 6
|
||||
threshold = 0.005 # 0.5% profit target
|
||||
z_threshold = 1.0
|
||||
|
||||
# For Short Spread (Z > 1): Did it drop below target?
|
||||
# We look for the MINIMUM spread in the next 'horizon' periods
|
||||
future_min = features['spread'].rolling(window=horizon).min().shift(-horizon)
|
||||
target_short = features['spread'] * (1 - threshold)
|
||||
success_short = (features['z_score'] > z_threshold) & (future_min < target_short)
|
||||
|
||||
# For Long Spread (Z < -1): Did it rise above target?
|
||||
# We look for the MAXIMUM spread in the next 'horizon' periods
|
||||
future_max = features['spread'].rolling(window=horizon).max().shift(-horizon)
|
||||
target_long = features['spread'] * (1 + threshold)
|
||||
success_long = (features['z_score'] < -z_threshold) & (future_max > target_long)
|
||||
|
||||
conditions = [success_short, success_long]
|
||||
|
||||
features['target'] = np.select(conditions, [1, 1], default=0)
|
||||
|
||||
return features.dropna()
|
||||
|
||||
def train_regime_model(features):
|
||||
"""
|
||||
Train a Random Forest to predict mean reversion success.
|
||||
"""
|
||||
# Define excluded columns (targets, raw prices, intermediates)
|
||||
exclude_cols = ['spread', 'horizon_ret', 'target', 'rolling_mean', 'rolling_std']
|
||||
|
||||
# Auto-select all other numeric columns as features
|
||||
feature_cols = [c for c in features.columns if c not in exclude_cols]
|
||||
|
||||
# Handle NaN/Inf if any slipped through
|
||||
X = features[feature_cols].replace([np.inf, -np.inf], np.nan).fillna(0)
|
||||
y = features['target']
|
||||
|
||||
# Split Data
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=False)
|
||||
|
||||
print(f"\nTraining on {len(X_train)} samples, Testing on {len(X_test)} samples...")
|
||||
print(f"Features used: {feature_cols}")
|
||||
print(f"Class Balance (Target=1): {y.mean():.2%}")
|
||||
|
||||
# Model
|
||||
model = RandomForestClassifier(
|
||||
n_estimators=200,
|
||||
max_depth=6,
|
||||
min_samples_leaf=20,
|
||||
class_weight='balanced_subsample',
|
||||
random_state=42
|
||||
)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# Evaluation
|
||||
y_pred = model.predict(X_test)
|
||||
y_prob = model.predict_proba(X_test)[:, 1]
|
||||
|
||||
print("\n--- Model Evaluation ---")
|
||||
print(classification_report(y_test, y_pred))
|
||||
|
||||
# Feature Importance
|
||||
importances = pd.Series(model.feature_importances_, index=feature_cols).sort_values(ascending=False)
|
||||
print("\n--- Feature Importance ---")
|
||||
print(importances)
|
||||
|
||||
return model, X_test, y_test, y_pred, y_prob
|
||||
|
||||
def plot_interactive_results(features, y_test, y_pred, y_prob):
|
||||
"""
|
||||
Create an interactive HTML plot using Plotly.
|
||||
"""
|
||||
print("\nGenerating interactive plot...")
|
||||
|
||||
test_idx = y_test.index
|
||||
test_data = features.loc[test_idx].copy()
|
||||
test_data['prob'] = y_prob
|
||||
test_data['prediction'] = y_pred
|
||||
test_data['actual'] = y_test
|
||||
|
||||
# Create Subplots
|
||||
fig = make_subplots(
|
||||
rows=3, cols=1,
|
||||
shared_xaxes=True,
|
||||
vertical_spacing=0.05,
|
||||
row_heights=[0.5, 0.25, 0.25],
|
||||
subplot_titles=('Spread & Signals', 'Exchange Inflows', 'Z-Score & Probability')
|
||||
)
|
||||
|
||||
# Top: Spread
|
||||
fig.add_trace(
|
||||
go.Scatter(x=test_data.index, y=test_data['spread'], mode='lines', name='Spread', line=dict(color='gray')),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Signals
|
||||
# Separate Long and Short signals for clarity
|
||||
# Logic: If Z-Score was High (>1), we were betting on a SHORT Spread (Reversion Down)
|
||||
# If Z-Score was Low (< -1), we were betting on a LONG Spread (Reversion Up)
|
||||
|
||||
# Correct Short Signals (Green Triangle Down)
|
||||
tp_short = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 1) & (test_data['z_score'] > 0)]
|
||||
fig.add_trace(
|
||||
go.Scatter(x=tp_short.index, y=tp_short['spread'], mode='markers', name='Win: Short Spread',
|
||||
marker=dict(symbol='triangle-down', size=12, color='green')),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Correct Long Signals (Green Triangle Up)
|
||||
tp_long = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 1) & (test_data['z_score'] < 0)]
|
||||
fig.add_trace(
|
||||
go.Scatter(x=tp_long.index, y=tp_long['spread'], mode='markers', name='Win: Long Spread',
|
||||
marker=dict(symbol='triangle-up', size=12, color='green')),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# False Short Signals (Red Triangle Down)
|
||||
fp_short = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 0) & (test_data['z_score'] > 0)]
|
||||
fig.add_trace(
|
||||
go.Scatter(x=fp_short.index, y=fp_short['spread'], mode='markers', name='Loss: Short Spread',
|
||||
marker=dict(symbol='triangle-down', size=10, color='red')),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# False Long Signals (Red Triangle Up)
|
||||
fp_long = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 0) & (test_data['z_score'] < 0)]
|
||||
fig.add_trace(
|
||||
go.Scatter(x=fp_long.index, y=fp_long['spread'], mode='markers', name='Loss: Long Spread',
|
||||
marker=dict(symbol='triangle-up', size=10, color='red')),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Middle: Inflows (BTC vs ETH)
|
||||
if 'btc_inflow' in test_data.columns:
|
||||
fig.add_trace(
|
||||
go.Bar(x=test_data.index, y=test_data['btc_inflow'], name='BTC Inflow', marker_color='orange', opacity=0.6),
|
||||
row=2, col=1
|
||||
)
|
||||
if 'eth_inflow' in test_data.columns:
|
||||
fig.add_trace(
|
||||
go.Bar(x=test_data.index, y=test_data['eth_inflow'], name='ETH Inflow', marker_color='purple', opacity=0.6),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
# Bottom: Z-Score
|
||||
fig.add_trace(
|
||||
go.Scatter(x=test_data.index, y=test_data['z_score'], mode='lines', name='Z-Score', line=dict(color='blue'), opacity=0.5),
|
||||
row=3, col=1
|
||||
)
|
||||
fig.add_hline(y=2, line_dash="dash", line_color="red", row=3, col=1)
|
||||
fig.add_hline(y=-2, line_dash="dash", line_color="green", row=3, col=1)
|
||||
|
||||
# Probability (Secondary Y for Row 3)
|
||||
fig.add_trace(
|
||||
go.Scatter(x=test_data.index, y=test_data['prob'], mode='lines', name='Prob', line=dict(color='cyan', width=1.5), yaxis='y4'),
|
||||
row=3, col=1
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title='Regime Detection Analysis (with CryptoQuant)',
|
||||
autosize=True,
|
||||
height=None,
|
||||
hovermode='x unified',
|
||||
yaxis4=dict(title='Probability', overlaying='y3', side='right', range=[0, 1], showgrid=False),
|
||||
template="plotly_dark",
|
||||
margin=dict(l=10, r=10, t=40, b=10),
|
||||
barmode='group'
|
||||
)
|
||||
|
||||
# Update all x-axes to ensure spikes are visible everywhere
|
||||
fig.update_xaxes(
|
||||
showspikes=True,
|
||||
spikemode='across',
|
||||
spikesnap='cursor',
|
||||
showline=False,
|
||||
showgrid=True,
|
||||
spikedash='dot',
|
||||
spikecolor='white', # Make it bright to see
|
||||
spikethickness=1,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title='Regime Detection Analysis (with CryptoQuant)',
|
||||
autosize=True,
|
||||
height=None,
|
||||
hovermode='x unified', # Keep unified hover for data reading
|
||||
yaxis4=dict(title='Probability', overlaying='y3', side='right', range=[0, 1], showgrid=False),
|
||||
template="plotly_dark",
|
||||
margin=dict(l=10, r=10, t=40, b=10),
|
||||
barmode='group'
|
||||
)
|
||||
|
||||
output_path = "research/regime_results.html"
|
||||
fig.write_html(
|
||||
output_path,
|
||||
config={'responsive': True, 'scrollZoom': True},
|
||||
include_plotlyjs='cdn',
|
||||
full_html=True,
|
||||
default_height='100vh',
|
||||
default_width='100%'
|
||||
)
|
||||
print(f"Interactive plot saved to {output_path}")
|
||||
|
||||
def main():
|
||||
# 1. Load CQ Data first to determine valid date range
|
||||
cq_path = "data/cq_training_data.csv"
|
||||
cq_df = load_cryptoquant_data(cq_path)
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if cq_df is not None and not cq_df.empty:
|
||||
start_date = cq_df.index.min().strftime('%Y-%m-%d')
|
||||
end_date = cq_df.index.max().strftime('%Y-%m-%d')
|
||||
print(f"CryptoQuant Data Range: {start_date} to {end_date}")
|
||||
|
||||
# 2. Get Market Data (Aligned to CQ range)
|
||||
df_btc, df_eth = prepare_data(
|
||||
"BTC-USDT", "ETH-USDT",
|
||||
timeframe="1h",
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# 3. Calculate Features
|
||||
print("Calculating advanced regime features...")
|
||||
data = calculate_features(df_btc, df_eth, cq_df=cq_df, window=24)
|
||||
|
||||
if data.empty:
|
||||
print("Error: No overlapping data found between Price and CryptoQuant data.")
|
||||
return
|
||||
|
||||
# 4. Train & Evaluate
|
||||
model, X_test, y_test, y_pred, y_prob = train_regime_model(data)
|
||||
|
||||
# 5. Plot
|
||||
plot_interactive_results(data, y_test, y_pred, y_prob)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
research/regime_results.png
Normal file
BIN
research/regime_results.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 289 KiB |
80
strategies/base.py
Normal file
80
strategies/base.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Base strategy class for all trading strategies.
|
||||
|
||||
Strategies should inherit from BaseStrategy and implement the run() method.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from engine.market import MarketType
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""
|
||||
Abstract base class for trading strategies.
|
||||
|
||||
Class Attributes:
|
||||
default_market_type: Default market type for this strategy
|
||||
default_leverage: Default leverage (only applies to perpetuals)
|
||||
default_sl_stop: Default stop-loss percentage
|
||||
default_tp_stop: Default take-profit percentage
|
||||
default_sl_trail: Whether stop-loss is trailing by default
|
||||
"""
|
||||
# Market configuration defaults
|
||||
default_market_type: MarketType = MarketType.SPOT
|
||||
default_leverage: int = 1
|
||||
|
||||
# Risk management defaults (can be overridden per strategy)
|
||||
default_sl_stop: float | None = None
|
||||
default_tp_stop: float | None = None
|
||||
default_sl_trail: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.params = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Run the strategy logic.
|
||||
|
||||
Args:
|
||||
close: Price series (can be multiple columns for grid search)
|
||||
**kwargs: Additional data (high, low, open, volume) and parameters
|
||||
|
||||
Returns:
|
||||
Tuple of 4 DataFrames/Series:
|
||||
- long_entries: Boolean signals to open long positions
|
||||
- long_exits: Boolean signals to close long positions
|
||||
- short_entries: Boolean signals to open short positions
|
||||
- short_exits: Boolean signals to close short positions
|
||||
|
||||
Note:
|
||||
For spot markets, short signals will be ignored.
|
||||
For backward compatibility, strategies can return 2-tuple (entries, exits)
|
||||
which will be interpreted as long-only signals.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_indicator(self, ind_cls, *args, **kwargs):
|
||||
"""Helper to run a vectorbt indicator."""
|
||||
return ind_cls.run(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_empty_signals(reference: pd.Series | pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Create an empty (all False) signal DataFrame matching the reference shape.
|
||||
|
||||
Args:
|
||||
reference: Series or DataFrame to match shape/index
|
||||
|
||||
Returns:
|
||||
DataFrame of False values with same shape as reference
|
||||
"""
|
||||
if isinstance(reference, pd.DataFrame):
|
||||
return pd.DataFrame(False, index=reference.index, columns=reference.columns)
|
||||
return pd.Series(False, index=reference.index)
|
||||
97
strategies/examples.py
Normal file
97
strategies/examples.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Example trading strategies for backtesting.
|
||||
|
||||
These are simple strategies demonstrating the framework usage.
|
||||
"""
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.market import MarketType
|
||||
from strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class RsiStrategy(BaseStrategy):
|
||||
"""
|
||||
RSI mean-reversion strategy.
|
||||
|
||||
Long entry when RSI crosses below oversold level.
|
||||
Long exit when RSI crosses above overbought level.
|
||||
"""
|
||||
default_market_type = MarketType.SPOT
|
||||
default_leverage = 1
|
||||
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
period: int = 14,
|
||||
rsi_lower: int = 30,
|
||||
rsi_upper: int = 70,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Generate RSI-based trading signals.
|
||||
|
||||
Args:
|
||||
close: Price series
|
||||
period: RSI calculation period
|
||||
rsi_lower: Oversold threshold (buy signal)
|
||||
rsi_upper: Overbought threshold (sell signal)
|
||||
|
||||
Returns:
|
||||
4-tuple of (long_entries, long_exits, short_entries, short_exits)
|
||||
"""
|
||||
# Calculate RSI
|
||||
rsi = vbt.RSI.run(close, window=period)
|
||||
|
||||
# Long signals: buy oversold, sell overbought
|
||||
long_entries = rsi.rsi_crossed_below(rsi_lower)
|
||||
long_exits = rsi.rsi_crossed_above(rsi_upper)
|
||||
|
||||
# No short signals for this strategy (spot-focused)
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
|
||||
|
||||
class MaCrossStrategy(BaseStrategy):
|
||||
"""
|
||||
Moving Average crossover strategy.
|
||||
|
||||
Long entry when fast MA crosses above slow MA.
|
||||
Long exit when fast MA crosses below slow MA.
|
||||
"""
|
||||
default_market_type = MarketType.SPOT
|
||||
default_leverage = 1
|
||||
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
fast_window: int = 10,
|
||||
slow_window: int = 20,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Generate MA crossover trading signals.
|
||||
|
||||
Args:
|
||||
close: Price series
|
||||
fast_window: Fast MA period
|
||||
slow_window: Slow MA period
|
||||
|
||||
Returns:
|
||||
4-tuple of (long_entries, long_exits, short_entries, short_exits)
|
||||
"""
|
||||
# Calculate Moving Averages
|
||||
fast_ma = vbt.MA.run(close, window=fast_window)
|
||||
slow_ma = vbt.MA.run(close, window=slow_window)
|
||||
|
||||
# Long signals
|
||||
long_entries = fast_ma.ma_crossed_above(slow_ma)
|
||||
long_exits = fast_ma.ma_crossed_below(slow_ma)
|
||||
|
||||
# No short signals for this strategy
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
128
strategies/factory.py
Normal file
128
strategies/factory.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Strategy factory for creating strategy instances with their parameters.
|
||||
|
||||
Centralizes strategy creation and parameter configuration.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from strategies.base import BaseStrategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyConfig:
|
||||
"""
|
||||
Configuration for a strategy including default and grid parameters.
|
||||
|
||||
Attributes:
|
||||
strategy_class: The strategy class to instantiate
|
||||
default_params: Parameters for single backtest runs
|
||||
grid_params: Parameters for grid search optimization
|
||||
"""
|
||||
strategy_class: type[BaseStrategy]
|
||||
default_params: dict[str, Any] = field(default_factory=dict)
|
||||
grid_params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _build_registry() -> dict[str, StrategyConfig]:
|
||||
"""
|
||||
Build the strategy registry lazily to avoid circular imports.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping strategy names to their configurations
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from strategies.examples import MaCrossStrategy, RsiStrategy
|
||||
from strategies.supertrend import MetaSupertrendStrategy
|
||||
|
||||
return {
|
||||
"rsi": StrategyConfig(
|
||||
strategy_class=RsiStrategy,
|
||||
default_params={
|
||||
'period': 14,
|
||||
'rsi_lower': 30,
|
||||
'rsi_upper': 70
|
||||
},
|
||||
grid_params={
|
||||
'period': np.arange(10, 25, 2),
|
||||
'rsi_lower': [20, 30, 40],
|
||||
'rsi_upper': [60, 70, 80]
|
||||
}
|
||||
),
|
||||
"macross": StrategyConfig(
|
||||
strategy_class=MaCrossStrategy,
|
||||
default_params={
|
||||
'fast_window': 10,
|
||||
'slow_window': 20
|
||||
},
|
||||
grid_params={
|
||||
'fast_window': np.arange(5, 20, 5),
|
||||
'slow_window': np.arange(20, 60, 10)
|
||||
}
|
||||
),
|
||||
"meta_st": StrategyConfig(
|
||||
strategy_class=MetaSupertrendStrategy,
|
||||
default_params={
|
||||
'period1': 12, 'multiplier1': 3.0,
|
||||
'period2': 10, 'multiplier2': 1.0,
|
||||
'period3': 11, 'multiplier3': 2.0
|
||||
},
|
||||
grid_params={
|
||||
'multiplier1': [2.0, 3.0, 4.0],
|
||||
'period1': [10, 12, 14],
|
||||
'period2': 11, 'multiplier2': 2.0,
|
||||
'period3': 12, 'multiplier3': 1.0
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Module-level cache for the registry
|
||||
_REGISTRY_CACHE: dict[str, StrategyConfig] | None = None
|
||||
|
||||
|
||||
def get_registry() -> dict[str, StrategyConfig]:
|
||||
"""Get the strategy registry, building it on first access."""
|
||||
global _REGISTRY_CACHE
|
||||
if _REGISTRY_CACHE is None:
|
||||
_REGISTRY_CACHE = _build_registry()
|
||||
return _REGISTRY_CACHE
|
||||
|
||||
|
||||
def get_strategy_names() -> list[str]:
|
||||
"""
|
||||
Get list of available strategy names.
|
||||
|
||||
Returns:
|
||||
List of strategy name strings
|
||||
"""
|
||||
return list(get_registry().keys())
|
||||
|
||||
|
||||
def get_strategy(name: str, is_grid: bool = False) -> tuple[BaseStrategy, dict[str, Any]]:
|
||||
"""
|
||||
Create a strategy instance with appropriate parameters.
|
||||
|
||||
Args:
|
||||
name: Strategy identifier (e.g., 'rsi', 'macross', 'meta_st')
|
||||
is_grid: If True, return grid search parameters
|
||||
|
||||
Returns:
|
||||
Tuple of (strategy instance, parameters dict)
|
||||
|
||||
Raises:
|
||||
KeyError: If strategy name is not found in registry
|
||||
"""
|
||||
registry = get_registry()
|
||||
|
||||
if name not in registry:
|
||||
available = ", ".join(registry.keys())
|
||||
raise KeyError(f"Unknown strategy '{name}'. Available: {available}")
|
||||
|
||||
config = registry[name]
|
||||
strategy = config.strategy_class()
|
||||
params = config.grid_params if is_grid else config.default_params
|
||||
|
||||
return strategy, params.copy()
|
||||
6
strategies/supertrend/__init__.py
Normal file
6
strategies/supertrend/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Meta Supertrend strategy package.
|
||||
"""
|
||||
from .strategy import MetaSupertrendStrategy
|
||||
|
||||
__all__ = ['MetaSupertrendStrategy']
|
||||
128
strategies/supertrend/indicators.py
Normal file
128
strategies/supertrend/indicators.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Supertrend indicators and helper functions.
|
||||
"""
|
||||
import numpy as np
|
||||
import vectorbt as vbt
|
||||
from numba import njit
|
||||
|
||||
# --- Numba Compiled Helper Functions ---
|
||||
|
||||
@njit(cache=False) # Disable cache to avoid stale compilation issues
|
||||
def get_tr_nb(high, low, close):
|
||||
"""Calculate True Range (Numba compiled)."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
tr = np.empty_like(close)
|
||||
tr[0] = high[0] - low[0]
|
||||
for i in range(1, len(close)):
|
||||
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
|
||||
return tr
|
||||
|
||||
@njit(cache=False)
|
||||
def get_atr_nb(high, low, close, period):
|
||||
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
# Ensure period is native Python int (critical for Numba array indexing)
|
||||
n = len(close)
|
||||
p = int(period)
|
||||
|
||||
tr = get_tr_nb(high, low, close)
|
||||
atr = np.full(n, np.nan, dtype=np.float64)
|
||||
|
||||
if n < p:
|
||||
return atr
|
||||
|
||||
# Initial ATR is simple average of TR
|
||||
sum_tr = 0.0
|
||||
for i in range(p):
|
||||
sum_tr += tr[i]
|
||||
atr[p - 1] = sum_tr / p
|
||||
|
||||
# Subsequent ATR is Wilder's smoothed
|
||||
for i in range(p, n):
|
||||
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
|
||||
|
||||
return atr
|
||||
|
||||
@njit(cache=False)
|
||||
def get_supertrend_nb(high, low, close, period, multiplier):
|
||||
"""Calculate SuperTrend completely in Numba."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
# Ensure params are native Python types (critical for Numba)
|
||||
n = len(close)
|
||||
p = int(period)
|
||||
m = float(multiplier)
|
||||
|
||||
atr = get_atr_nb(high, low, close, p)
|
||||
|
||||
final_upper = np.full(n, np.nan, dtype=np.float64)
|
||||
final_lower = np.full(n, np.nan, dtype=np.float64)
|
||||
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
|
||||
|
||||
# Skip until we have valid ATR
|
||||
start_idx = p
|
||||
if start_idx >= n:
|
||||
return trend
|
||||
|
||||
# Init first valid point
|
||||
hl2 = (high[start_idx] + low[start_idx]) / 2
|
||||
final_upper[start_idx] = hl2 + m * atr[start_idx]
|
||||
final_lower[start_idx] = hl2 - m * atr[start_idx]
|
||||
|
||||
# Loop
|
||||
for i in range(start_idx + 1, n):
|
||||
cur_hl2 = (high[i] + low[i]) / 2
|
||||
cur_atr = atr[i]
|
||||
basic_upper = cur_hl2 + m * cur_atr
|
||||
basic_lower = cur_hl2 - m * cur_atr
|
||||
|
||||
# Upper Band Logic
|
||||
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
|
||||
final_upper[i] = basic_upper
|
||||
else:
|
||||
final_upper[i] = final_upper[i-1]
|
||||
|
||||
# Lower Band Logic
|
||||
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
|
||||
final_lower[i] = basic_lower
|
||||
else:
|
||||
final_lower[i] = final_lower[i-1]
|
||||
|
||||
# Trend Logic
|
||||
if trend[i-1] == 1:
|
||||
if close[i] < final_lower[i-1]:
|
||||
trend[i] = -1
|
||||
else:
|
||||
trend[i] = 1
|
||||
else:
|
||||
if close[i] > final_upper[i-1]:
|
||||
trend[i] = 1
|
||||
else:
|
||||
trend[i] = -1
|
||||
|
||||
return trend
|
||||
|
||||
# --- VectorBT Indicator Factory ---
|
||||
|
||||
SuperTrendIndicator = vbt.IndicatorFactory(
|
||||
class_name='SuperTrend',
|
||||
short_name='st',
|
||||
input_names=['high', 'low', 'close'],
|
||||
param_names=['period', 'multiplier'],
|
||||
output_names=['trend']
|
||||
).from_apply_func(
|
||||
get_supertrend_nb,
|
||||
keep_pd=False, # Disable automatic Pandas wrapping of inputs
|
||||
param_product=True # Enable Cartesian product for list params
|
||||
)
|
||||
142
strategies/supertrend/strategy.py
Normal file
142
strategies/supertrend/strategy.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Meta Supertrend strategy implementation.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from engine.market import MarketType
|
||||
from strategies.base import BaseStrategy
|
||||
from .indicators import SuperTrendIndicator
|
||||
|
||||
|
||||
class MetaSupertrendStrategy(BaseStrategy):
|
||||
"""
|
||||
Meta Supertrend Strategy using 3 Supertrend indicators.
|
||||
|
||||
Enters long when all 3 Supertrends are bullish.
|
||||
Enters short when all 3 Supertrends are bearish.
|
||||
|
||||
Designed for perpetual futures with leverage and short-selling support.
|
||||
"""
|
||||
# Market configuration
|
||||
default_market_type = MarketType.PERPETUAL
|
||||
default_leverage = 5
|
||||
|
||||
# Risk management parameters
|
||||
default_sl_stop = 0.02 # 2% stop loss
|
||||
default_sl_trail = True # Trailing stop enabled
|
||||
default_exit_on_bearish_flip = False # Rely on SL/TP, not bearish flip
|
||||
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
high: pd.Series = None,
|
||||
low: pd.Series = None,
|
||||
period1: int = 10,
|
||||
multiplier1: float = 3.0,
|
||||
period2: int = 11,
|
||||
multiplier2: float = 2.0,
|
||||
period3: int = 12,
|
||||
multiplier3: float = 1.0,
|
||||
exit_on_bearish_flip: bool = None,
|
||||
enable_short: bool = True,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
|
||||
# 1. Validation & Setup
|
||||
if exit_on_bearish_flip is None:
|
||||
exit_on_bearish_flip = self.default_exit_on_bearish_flip
|
||||
|
||||
if high is None or low is None:
|
||||
raise ValueError("MetaSupertrendStrategy requires High and Low prices.")
|
||||
|
||||
# 2. Calculate Supertrends
|
||||
t1, t2, t3 = self._calculate_supertrends(
|
||||
high, low, close,
|
||||
period1, multiplier1,
|
||||
period2, multiplier2,
|
||||
period3, multiplier3
|
||||
)
|
||||
|
||||
# 3. Meta Signals
|
||||
bullish, bearish = self._calculate_meta_signals(t1, t2, t3, close)
|
||||
|
||||
# 4. Generate Entry/Exit Signals
|
||||
return self._generate_signals(bullish, bearish, exit_on_bearish_flip, enable_short)
|
||||
|
||||
def _calculate_supertrends(
|
||||
self, high, low, close, p1, m1, p2, m2, p3, m3
|
||||
):
|
||||
"""Run the 3 Supertrend indicators."""
|
||||
# Pass NumPy arrays explicitly to avoid Numba typing errors
|
||||
h_vals = high.values
|
||||
l_vals = low.values
|
||||
c_vals = close.values
|
||||
|
||||
def run_st(p, m):
|
||||
st = SuperTrendIndicator.run(h_vals, l_vals, c_vals, period=p, multiplier=m)
|
||||
trend = st.trend
|
||||
if isinstance(trend, pd.DataFrame):
|
||||
trend.index = close.index
|
||||
if trend.shape[1] == 1:
|
||||
trend = trend.iloc[:, 0]
|
||||
elif isinstance(trend, pd.Series):
|
||||
trend.index = close.index
|
||||
return trend
|
||||
|
||||
t1 = run_st(p1, m1)
|
||||
t2 = run_st(p2, m2)
|
||||
t3 = run_st(p3, m3)
|
||||
return t1, t2, t3
|
||||
|
||||
def _calculate_meta_signals(self, t1, t2, t3, close_series):
|
||||
"""Combine 3 Supertrends into boolean Bullish/Bearish signals."""
|
||||
# Use NumPy broadcasting
|
||||
t1_vals = t1.values if isinstance(t1, pd.DataFrame) else t1.values.reshape(-1, 1)
|
||||
# Force column vectors for broadcasting if scalar result
|
||||
t2_vals = t2.values.reshape(-1, 1)
|
||||
t3_vals = t3.values.reshape(-1, 1)
|
||||
|
||||
# Boolean logic on numpy arrays (1 = Bull, -1 = Bear)
|
||||
bullish_vals = (t1_vals == 1) & (t2_vals == 1) & (t3_vals == 1)
|
||||
bearish_vals = (t1_vals == -1) & (t2_vals == -1) & (t3_vals == -1)
|
||||
|
||||
# Reconstruct Pandas objects
|
||||
if isinstance(t1, pd.DataFrame):
|
||||
bullish = pd.DataFrame(bullish_vals, index=t1.index, columns=t1.columns)
|
||||
bearish = pd.DataFrame(bearish_vals, index=t1.index, columns=t1.columns)
|
||||
else:
|
||||
bullish = pd.Series(bullish_vals.flatten(), index=t1.index)
|
||||
bearish = pd.Series(bearish_vals.flatten(), index=t1.index)
|
||||
|
||||
return bullish, bearish
|
||||
|
||||
def _generate_signals(
|
||||
self, bullish, bearish, exit_on_bearish_flip, enable_short
|
||||
):
|
||||
"""Generate long/short entry/exit signals based on meta trend."""
|
||||
# Long Entries: Change from Not Bullish to Bullish
|
||||
prev_bullish = bullish.shift(1).fillna(False)
|
||||
long_entries = bullish & (~prev_bullish)
|
||||
|
||||
# Long Exits
|
||||
if exit_on_bearish_flip:
|
||||
prev_bearish = bearish.shift(1).fillna(False)
|
||||
long_exits = bearish & (~prev_bearish)
|
||||
else:
|
||||
long_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
# Short signals
|
||||
if enable_short:
|
||||
prev_bearish = bearish.shift(1).fillna(False)
|
||||
short_entries = bearish & (~prev_bearish)
|
||||
|
||||
if exit_on_bearish_flip:
|
||||
short_exits = bullish & (~prev_bullish)
|
||||
else:
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
else:
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
6
strategies/supertrend_pkg/__init__.py
Normal file
6
strategies/supertrend_pkg/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Meta Supertrend strategy package.
|
||||
"""
|
||||
from .strategy import MetaSupertrendStrategy
|
||||
|
||||
__all__ = ['MetaSupertrendStrategy']
|
||||
128
strategies/supertrend_pkg/indicators.py
Normal file
128
strategies/supertrend_pkg/indicators.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Supertrend indicators and helper functions.
|
||||
"""
|
||||
import numpy as np
|
||||
import vectorbt as vbt
|
||||
from numba import njit
|
||||
|
||||
# --- Numba Compiled Helper Functions ---
|
||||
|
||||
@njit(cache=False) # Disable cache to avoid stale compilation issues
|
||||
def get_tr_nb(high, low, close):
|
||||
"""Calculate True Range (Numba compiled)."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
tr = np.empty_like(close)
|
||||
tr[0] = high[0] - low[0]
|
||||
for i in range(1, len(close)):
|
||||
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
|
||||
return tr
|
||||
|
||||
@njit(cache=False)
|
||||
def get_atr_nb(high, low, close, period):
|
||||
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
# Ensure period is native Python int (critical for Numba array indexing)
|
||||
n = len(close)
|
||||
p = int(period)
|
||||
|
||||
tr = get_tr_nb(high, low, close)
|
||||
atr = np.full(n, np.nan, dtype=np.float64)
|
||||
|
||||
if n < p:
|
||||
return atr
|
||||
|
||||
# Initial ATR is simple average of TR
|
||||
sum_tr = 0.0
|
||||
for i in range(p):
|
||||
sum_tr += tr[i]
|
||||
atr[p - 1] = sum_tr / p
|
||||
|
||||
# Subsequent ATR is Wilder's smoothed
|
||||
for i in range(p, n):
|
||||
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
|
||||
|
||||
return atr
|
||||
|
||||
@njit(cache=False)
|
||||
def get_supertrend_nb(high, low, close, period, multiplier):
|
||||
"""Calculate SuperTrend completely in Numba."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
# Ensure params are native Python types (critical for Numba)
|
||||
n = len(close)
|
||||
p = int(period)
|
||||
m = float(multiplier)
|
||||
|
||||
atr = get_atr_nb(high, low, close, p)
|
||||
|
||||
final_upper = np.full(n, np.nan, dtype=np.float64)
|
||||
final_lower = np.full(n, np.nan, dtype=np.float64)
|
||||
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
|
||||
|
||||
# Skip until we have valid ATR
|
||||
start_idx = p
|
||||
if start_idx >= n:
|
||||
return trend
|
||||
|
||||
# Init first valid point
|
||||
hl2 = (high[start_idx] + low[start_idx]) / 2
|
||||
final_upper[start_idx] = hl2 + m * atr[start_idx]
|
||||
final_lower[start_idx] = hl2 - m * atr[start_idx]
|
||||
|
||||
# Loop
|
||||
for i in range(start_idx + 1, n):
|
||||
cur_hl2 = (high[i] + low[i]) / 2
|
||||
cur_atr = atr[i]
|
||||
basic_upper = cur_hl2 + m * cur_atr
|
||||
basic_lower = cur_hl2 - m * cur_atr
|
||||
|
||||
# Upper Band Logic
|
||||
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
|
||||
final_upper[i] = basic_upper
|
||||
else:
|
||||
final_upper[i] = final_upper[i-1]
|
||||
|
||||
# Lower Band Logic
|
||||
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
|
||||
final_lower[i] = basic_lower
|
||||
else:
|
||||
final_lower[i] = final_lower[i-1]
|
||||
|
||||
# Trend Logic
|
||||
if trend[i-1] == 1:
|
||||
if close[i] < final_lower[i-1]:
|
||||
trend[i] = -1
|
||||
else:
|
||||
trend[i] = 1
|
||||
else:
|
||||
if close[i] > final_upper[i-1]:
|
||||
trend[i] = 1
|
||||
else:
|
||||
trend[i] = -1
|
||||
|
||||
return trend
|
||||
|
||||
# --- VectorBT Indicator Factory ---
|
||||
|
||||
SuperTrendIndicator = vbt.IndicatorFactory(
|
||||
class_name='SuperTrend',
|
||||
short_name='st',
|
||||
input_names=['high', 'low', 'close'],
|
||||
param_names=['period', 'multiplier'],
|
||||
output_names=['trend']
|
||||
).from_apply_func(
|
||||
get_supertrend_nb,
|
||||
keep_pd=False, # Disable automatic Pandas wrapping of inputs
|
||||
param_product=True # Enable Cartesian product for list params
|
||||
)
|
||||
142
strategies/supertrend_pkg/strategy.py
Normal file
142
strategies/supertrend_pkg/strategy.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Meta Supertrend strategy implementation.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from engine.market import MarketType
|
||||
from strategies.base import BaseStrategy
|
||||
from .indicators import SuperTrendIndicator
|
||||
|
||||
|
||||
class MetaSupertrendStrategy(BaseStrategy):
|
||||
"""
|
||||
Meta Supertrend Strategy using 3 Supertrend indicators.
|
||||
|
||||
Enters long when all 3 Supertrends are bullish.
|
||||
Enters short when all 3 Supertrends are bearish.
|
||||
|
||||
Designed for perpetual futures with leverage and short-selling support.
|
||||
"""
|
||||
# Market configuration
|
||||
default_market_type = MarketType.PERPETUAL
|
||||
default_leverage = 5
|
||||
|
||||
# Risk management parameters
|
||||
default_sl_stop = 0.02 # 2% stop loss
|
||||
default_sl_trail = True # Trailing stop enabled
|
||||
default_exit_on_bearish_flip = False # Rely on SL/TP, not bearish flip
|
||||
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
high: pd.Series = None,
|
||||
low: pd.Series = None,
|
||||
period1: int = 10,
|
||||
multiplier1: float = 3.0,
|
||||
period2: int = 11,
|
||||
multiplier2: float = 2.0,
|
||||
period3: int = 12,
|
||||
multiplier3: float = 1.0,
|
||||
exit_on_bearish_flip: bool = None,
|
||||
enable_short: bool = True,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
|
||||
# 1. Validation & Setup
|
||||
if exit_on_bearish_flip is None:
|
||||
exit_on_bearish_flip = self.default_exit_on_bearish_flip
|
||||
|
||||
if high is None or low is None:
|
||||
raise ValueError("MetaSupertrendStrategy requires High and Low prices.")
|
||||
|
||||
# 2. Calculate Supertrends
|
||||
t1, t2, t3 = self._calculate_supertrends(
|
||||
high, low, close,
|
||||
period1, multiplier1,
|
||||
period2, multiplier2,
|
||||
period3, multiplier3
|
||||
)
|
||||
|
||||
# 3. Meta Signals
|
||||
bullish, bearish = self._calculate_meta_signals(t1, t2, t3, close)
|
||||
|
||||
# 4. Generate Entry/Exit Signals
|
||||
return self._generate_signals(bullish, bearish, exit_on_bearish_flip, enable_short)
|
||||
|
||||
def _calculate_supertrends(
|
||||
self, high, low, close, p1, m1, p2, m2, p3, m3
|
||||
):
|
||||
"""Run the 3 Supertrend indicators."""
|
||||
# Pass NumPy arrays explicitly to avoid Numba typing errors
|
||||
h_vals = high.values
|
||||
l_vals = low.values
|
||||
c_vals = close.values
|
||||
|
||||
def run_st(p, m):
|
||||
st = SuperTrendIndicator.run(h_vals, l_vals, c_vals, period=p, multiplier=m)
|
||||
trend = st.trend
|
||||
if isinstance(trend, pd.DataFrame):
|
||||
trend.index = close.index
|
||||
if trend.shape[1] == 1:
|
||||
trend = trend.iloc[:, 0]
|
||||
elif isinstance(trend, pd.Series):
|
||||
trend.index = close.index
|
||||
return trend
|
||||
|
||||
t1 = run_st(p1, m1)
|
||||
t2 = run_st(p2, m2)
|
||||
t3 = run_st(p3, m3)
|
||||
return t1, t2, t3
|
||||
|
||||
def _calculate_meta_signals(self, t1, t2, t3, close_series):
|
||||
"""Combine 3 Supertrends into boolean Bullish/Bearish signals."""
|
||||
# Use NumPy broadcasting
|
||||
t1_vals = t1.values if isinstance(t1, pd.DataFrame) else t1.values.reshape(-1, 1)
|
||||
# Force column vectors for broadcasting if scalar result
|
||||
t2_vals = t2.values.reshape(-1, 1)
|
||||
t3_vals = t3.values.reshape(-1, 1)
|
||||
|
||||
# Boolean logic on numpy arrays (1 = Bull, -1 = Bear)
|
||||
bullish_vals = (t1_vals == 1) & (t2_vals == 1) & (t3_vals == 1)
|
||||
bearish_vals = (t1_vals == -1) & (t2_vals == -1) & (t3_vals == -1)
|
||||
|
||||
# Reconstruct Pandas objects
|
||||
if isinstance(t1, pd.DataFrame):
|
||||
bullish = pd.DataFrame(bullish_vals, index=t1.index, columns=t1.columns)
|
||||
bearish = pd.DataFrame(bearish_vals, index=t1.index, columns=t1.columns)
|
||||
else:
|
||||
bullish = pd.Series(bullish_vals.flatten(), index=t1.index)
|
||||
bearish = pd.Series(bearish_vals.flatten(), index=t1.index)
|
||||
|
||||
return bullish, bearish
|
||||
|
||||
def _generate_signals(
|
||||
self, bullish, bearish, exit_on_bearish_flip, enable_short
|
||||
):
|
||||
"""Generate long/short entry/exit signals based on meta trend."""
|
||||
# Long Entries: Change from Not Bullish to Bullish
|
||||
prev_bullish = bullish.shift(1).fillna(False)
|
||||
long_entries = bullish & (~prev_bullish)
|
||||
|
||||
# Long Exits
|
||||
if exit_on_bearish_flip:
|
||||
prev_bearish = bearish.shift(1).fillna(False)
|
||||
long_exits = bearish & (~prev_bearish)
|
||||
else:
|
||||
long_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
# Short signals
|
||||
if enable_short:
|
||||
prev_bearish = bearish.shift(1).fillna(False)
|
||||
short_entries = bearish & (~prev_bearish)
|
||||
|
||||
if exit_on_bearish_flip:
|
||||
short_exits = bullish & (~prev_bullish)
|
||||
else:
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
else:
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
295
tasks/prd-market-type-selection.md
Normal file
295
tasks/prd-market-type-selection.md
Normal file
@@ -0,0 +1,295 @@
|
||||
# PRD: Market Type Selection for Backtesting
|
||||
|
||||
## Introduction/Overview
|
||||
|
||||
Currently, the backtesting system operates with a single, implicit market type assumption. This PRD defines the implementation of **market type selection** (Spot vs. USDT-M Perpetual Futures) to enable realistic simulation of different trading conditions.
|
||||
|
||||
**Problem Statement:**
|
||||
- Strategies cannot be backtested against different market mechanics (leverage, funding, short-selling)
|
||||
- Fee structures are uniform regardless of market type
|
||||
- No support for short-selling strategies
|
||||
- Data fetching doesn't distinguish between spot and futures markets
|
||||
|
||||
**Goal:**
|
||||
Enable users to backtest strategies against specific market types (Spot or USDT-M Perpetual) with realistic trading conditions matching OKX's live environment.
|
||||
|
||||
---
|
||||
|
||||
## Goals
|
||||
|
||||
1. **Support two market types:** Spot and USDT-M Perpetual Futures
|
||||
2. **Realistic fee simulation:** Match OKX's fee structure per market type
|
||||
3. **Leverage support:** Per-strategy configurable leverage (perpetuals only)
|
||||
4. **Funding rate simulation:** Simplified funding rate model for perpetuals
|
||||
5. **Short-selling support:** Enable strategies to generate short signals
|
||||
6. **Liquidation awareness:** Warn when positions would be liquidated (no full simulation)
|
||||
7. **Separate data storage:** Download and store data per market type
|
||||
8. **Grid search integration:** Allow leverage optimization in parameter grids
|
||||
|
||||
---
|
||||
|
||||
## User Stories
|
||||
|
||||
1. **As a trader**, I want to backtest my strategy on perpetual futures so that I can simulate leveraged trading with funding costs.
|
||||
|
||||
2. **As a trader**, I want to backtest on spot markets so that I can compare performance without leverage or funding overhead.
|
||||
|
||||
3. **As a strategy developer**, I want to define a default market type for my strategy so that it runs with appropriate settings by default.
|
||||
|
||||
4. **As a trader**, I want to test different leverage levels so that I can find the optimal risk/reward balance.
|
||||
|
||||
5. **As a trader**, I want to see warnings when my position would have been liquidated so that I can adjust my risk parameters.
|
||||
|
||||
6. **As a strategy developer**, I want to create strategies that can go short so that I can profit from downward price movements.
|
||||
|
||||
---
|
||||
|
||||
## Functional Requirements
|
||||
|
||||
### FR1: Market Type Enum and Configuration
|
||||
|
||||
1.1. Create a `MarketType` enum with values: `SPOT`, `PERPETUAL`
|
||||
|
||||
1.2. Each strategy class must have a `default_market_type` class attribute
|
||||
|
||||
1.3. Market type can be overridden via CLI (optional, for testing)
|
||||
|
||||
### FR2: Data Management
|
||||
|
||||
2.1. Modify `DataManager` to support market type in data paths:
|
||||
- Spot: `data/ccxt/{exchange}/spot/{symbol}/{timeframe}.csv`
|
||||
- Perpetual: `data/ccxt/{exchange}/perpetual/{symbol}/{timeframe}.csv`
|
||||
|
||||
2.2. Update `download` command to accept `--market` flag:
|
||||
```bash
|
||||
uv run python main.py download --pair BTC/USDT --market perpetual
|
||||
```
|
||||
|
||||
2.3. Use CCXT's market type parameter when fetching data:
|
||||
- Spot: `exchange.fetch_ohlcv(symbol, timeframe, ...)`
|
||||
- Perpetual: `exchange.fetch_ohlcv(symbol + ':USDT', timeframe, ...)`
|
||||
|
||||
### FR3: Fee Structure
|
||||
|
||||
3.1. Define default fees per market type (matching OKX):
|
||||
|
||||
| Market Type | Maker Fee | Taker Fee | Notes |
|
||||
|-------------|-----------|-----------|-------|
|
||||
| Spot | 0.08% | 0.10% | No funding |
|
||||
| Perpetual | 0.02% | 0.05% | + funding |
|
||||
|
||||
3.2. Allow fee override via CLI (existing `--fees` flag)
|
||||
|
||||
### FR4: Leverage Support (Perpetual Only)
|
||||
|
||||
4.1. Add `default_leverage` class attribute to strategies (default: 1 for spot, configurable for perpetual)
|
||||
|
||||
4.2. Add `--leverage` CLI flag for backtest command
|
||||
|
||||
4.3. Leverage affects:
|
||||
- Position sizing (notional = cash * leverage)
|
||||
- PnL calculation (multiplied by leverage)
|
||||
- Liquidation threshold calculation
|
||||
|
||||
4.4. Support leverage in grid search parameter grids
|
||||
|
||||
### FR5: Funding Rate Simulation (Perpetual Only)
|
||||
|
||||
5.1. Implement simplified funding rate model:
|
||||
- Default rate: 0.01% per 8 hours (configurable)
|
||||
- Applied every 8 hours to open positions
|
||||
- Positive rate: Longs pay shorts
|
||||
- Negative rate: Shorts pay longs
|
||||
|
||||
5.2. Add `--funding-rate` CLI flag to override default
|
||||
|
||||
5.3. Track cumulative funding paid/received in backtest stats
|
||||
|
||||
### FR6: Short-Selling Support
|
||||
|
||||
6.1. Modify `BaseStrategy.run()` signature to return 4 signal arrays:
|
||||
```python
|
||||
def run(self, close, **kwargs) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Returns:
|
||||
long_entries: Boolean signals to open long positions
|
||||
long_exits: Boolean signals to close long positions
|
||||
short_entries: Boolean signals to open short positions
|
||||
short_exits: Boolean signals to close short positions
|
||||
"""
|
||||
```
|
||||
|
||||
6.2. Update `Backtester` to use VectorBT's `direction` parameter or dual portfolio simulation
|
||||
|
||||
6.3. For spot market: Ignore short signals (log warning if present)
|
||||
|
||||
6.4. For perpetual market: Process both long and short signals
|
||||
|
||||
### FR7: Liquidation Warning
|
||||
|
||||
7.1. Calculate liquidation price based on:
|
||||
- Entry price
|
||||
- Leverage
|
||||
- Maintenance margin rate (OKX: ~0.4% for BTC)
|
||||
|
||||
7.2. During backtest, check if price crosses liquidation threshold
|
||||
|
||||
7.3. Log warning with details:
|
||||
```
|
||||
WARNING: Position would be liquidated at bar 1234 (price: $45,000, liq_price: $44,820)
|
||||
```
|
||||
|
||||
7.4. Include liquidation event count in backtest summary stats
|
||||
|
||||
### FR8: Backtester Integration
|
||||
|
||||
8.1. Modify `Backtester.run_strategy()` to accept market type from strategy
|
||||
|
||||
8.2. Apply market-specific simulation parameters:
|
||||
- Fees (if not overridden)
|
||||
- Leverage
|
||||
- Funding rate calculation
|
||||
- Short-selling capability
|
||||
|
||||
8.3. Update portfolio simulation to handle leveraged positions
|
||||
|
||||
### FR9: Reporting Updates
|
||||
|
||||
9.1. Add market type to backtest summary output
|
||||
|
||||
9.2. Add new stats for perpetual backtests:
|
||||
- Total funding paid/received
|
||||
- Number of liquidation warnings
|
||||
- Effective leverage used
|
||||
|
||||
9.3. Update CSV exports to include market-specific columns
|
||||
|
||||
---
|
||||
|
||||
## Non-Goals (Out of Scope)
|
||||
|
||||
- **Coin-M (Inverse) Perpetuals:** Not included in v1
|
||||
- **Spot Margin Trading:** Not included in v1
|
||||
- **Expiry Futures:** Not included in v1
|
||||
- **Full Liquidation Simulation:** Only warnings, no automatic position closure
|
||||
- **Real Funding Rate Data:** Use simplified model; historical funding API integration is future work
|
||||
- **Cross-Margin Mode:** Assume isolated margin for simplicity
|
||||
- **Partial Liquidation:** Assume full liquidation threshold only
|
||||
|
||||
---
|
||||
|
||||
## Design Considerations
|
||||
|
||||
### Data Directory Structure (New)
|
||||
|
||||
```
|
||||
data/ccxt/
|
||||
okx/
|
||||
spot/
|
||||
BTC-USDT/
|
||||
1m.csv
|
||||
1d.csv
|
||||
perpetual/
|
||||
BTC-USDT/
|
||||
1m.csv
|
||||
1d.csv
|
||||
```
|
||||
|
||||
### Strategy Class Example
|
||||
|
||||
```python
|
||||
class MetaSupertrendStrategy(BaseStrategy):
|
||||
default_market_type = MarketType.PERPETUAL
|
||||
default_leverage = 5
|
||||
default_sl_stop = 0.02
|
||||
|
||||
def run(self, close, **kwargs):
|
||||
# ... indicator logic ...
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
```
|
||||
|
||||
### CLI Usage Examples
|
||||
|
||||
```bash
|
||||
# Download perpetual data
|
||||
uv run python main.py download --pair BTC/USDT --market perpetual
|
||||
|
||||
# Backtest with strategy defaults (uses strategy's default_market_type)
|
||||
uv run python main.py backtest --strategy meta_st --pair BTC/USDT
|
||||
|
||||
# Override leverage
|
||||
uv run python main.py backtest --strategy meta_st --pair BTC/USDT --leverage 10
|
||||
|
||||
# Grid search including leverage
|
||||
uv run python main.py backtest --strategy meta_st --pair BTC/USDT --grid
|
||||
# (leverage can be part of param grid in strategy factory)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Technical Considerations
|
||||
|
||||
1. **VectorBT Compatibility:**
|
||||
- VectorBT's `Portfolio.from_signals()` supports `direction` parameter for long/short
|
||||
- Alternatively, run two portfolios (long-only, short-only) and combine
|
||||
- Leverage can be simulated via `size` parameter or post-processing returns
|
||||
|
||||
2. **CCXT Market Type Handling:**
|
||||
- OKX perpetual symbols use format: `BTC/USDT:USDT`
|
||||
- Need to handle symbol conversion in DataManager
|
||||
|
||||
3. **Funding Rate Timing:**
|
||||
- OKX funding at 00:00, 08:00, 16:00 UTC
|
||||
- Need to identify these timestamps in the data and apply funding
|
||||
|
||||
4. **Backward Compatibility:**
|
||||
- Existing strategies should work with minimal changes
|
||||
- Default to `MarketType.SPOT` if not specified
|
||||
- Existing 2-tuple return from `run()` should be interpreted as long-only
|
||||
|
||||
---
|
||||
|
||||
## Success Metrics
|
||||
|
||||
1. **Functional:** All existing backtests produce same results when run with `MarketType.SPOT`
|
||||
2. **Functional:** Perpetual backtests correctly apply funding every 8 hours
|
||||
3. **Functional:** Leverage multiplies both gains and losses correctly
|
||||
4. **Functional:** Short signals are processed for perpetual, ignored for spot
|
||||
5. **Usability:** Users can switch market types with minimal configuration
|
||||
6. **Accuracy:** Fee structures match OKX's published rates
|
||||
|
||||
---
|
||||
|
||||
## Open Questions
|
||||
|
||||
1. **Position Sizing with Leverage:**
|
||||
- Should leverage affect `init_cash` interpretation (notional value) or position size directly?
|
||||
- Recommendation: Affect position size; `init_cash` remains the actual margin deposited.
|
||||
|
||||
2. **Multiple Positions:**
|
||||
- Can strategies hold both long and short simultaneously (hedging)?
|
||||
- Recommendation: No for v1; only one direction at a time.
|
||||
|
||||
3. **Funding Rate Sign:**
|
||||
- When funding is positive, longs pay shorts. Should we assume the user is always the "taker" of funding?
|
||||
- Recommendation: Yes, apply funding based on position direction.
|
||||
|
||||
4. **Migration Path:**
|
||||
- Should we migrate existing data to new directory structure?
|
||||
- Recommendation: No auto-migration; users re-download with `--market` flag.
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priority
|
||||
|
||||
| Priority | Component | Complexity |
|
||||
|----------|-----------|------------|
|
||||
| 1 | MarketType enum + strategy defaults | Low |
|
||||
| 2 | DataManager market type support | Medium |
|
||||
| 3 | Fee structure per market type | Low |
|
||||
| 4 | Short-selling signal support | Medium |
|
||||
| 5 | Leverage simulation | Medium |
|
||||
| 6 | Funding rate simulation | Medium |
|
||||
| 7 | Liquidation warnings | Low |
|
||||
| 8 | Reporting updates | Low |
|
||||
| 9 | Grid search leverage support | Low |
|
||||
76
tasks/prd-vectorbt-migration.md
Normal file
76
tasks/prd-vectorbt-migration.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# PRD: VectorBT Migration & CCXT Integration
|
||||
|
||||
## 1. Introduction
|
||||
The goal of this project is to refactor the current backtesting infrastructure to a professional-grade stack using **VectorBT** for high-performance backtesting and **CCXT** for robust historical data acquisition. The system will support rapid prototyping of "many simple strategies," parameter optimization (Grid Search), and stability testing (Walk-Forward Analysis).
|
||||
|
||||
## 2. Goals
|
||||
- **Replace Custom Backtester:** Retire the existing loop-based backtesting logic in favor of vectorized operations using `vectorbt`.
|
||||
- **Automate Data Collection:** Implement a `ccxt` based downloader to fetch and cache OHLCV data from OKX (and other exchanges) automatically.
|
||||
- **Enable Optimization:** Built-in support for Grid Search to find optimal strategy parameters.
|
||||
- **Validation:** Implement Walk-Forward Analysis (WFA) to validate strategy robustness and prevent overfitting.
|
||||
- **Standardized Reporting:** Generate consistent outputs: Console summaries, CSV logs, and VectorBT interactive plots.
|
||||
|
||||
## 3. User Stories
|
||||
- **Data Acquisition:** "As a user, I want to run a command `download_data --pair BTC/USDT --exchange okx` and have the system fetch historical 1-minute candles and save them to `data/ccxt/okx/BTC-USDT/1m.csv`."
|
||||
- **Strategy Dev:** "As a researcher, I want to define a new strategy by simply writing a class/function that defines entry/exit signals, without worrying about the backtesting loop."
|
||||
- **Optimization:** "As a researcher, I want to say 'Optimize RSI period between 10 and 20' and get a heatmap of results."
|
||||
- **Validation:** "As a researcher, I want to verify if my 'best' parameters work on unseen data using Walk-Forward Analysis."
|
||||
- **Analysis:** "As a user, I want to see an equity curve and key metrics (Sharpe, Drawdown) immediately after a test run."
|
||||
|
||||
## 4. Functional Requirements
|
||||
|
||||
### 4.1 Data Module (`data_manager`)
|
||||
- **Exchange Interface:** Use `ccxt` to connect to exchanges (initially OKX).
|
||||
- **Fetching Logic:** Fetch OHLCV data in chunks to handle rate limits and long histories.
|
||||
- **Storage:** Save data to standardized paths: `data/ccxt/{exchange}/{pair}_{timeframe}.csv`.
|
||||
- **Loading:** Utility to load saved CSVs into a Pandas DataFrame compatible with `vectorbt`.
|
||||
|
||||
### 4.2 Strategy Interface (`strategies/`)
|
||||
- **Base Protocol:** Define a standard structure for strategies. A strategy should return/define:
|
||||
- Indicator calculations (Vectorized).
|
||||
- Entry signals (Boolean Series).
|
||||
- Exit signals (Boolean Series).
|
||||
- **Parameterization:** Strategies must accept dynamic parameters to support Grid Search.
|
||||
|
||||
### 4.3 Backtest Engine (`engine.py`)
|
||||
- **Simulation:** Use `vectorbt.Portfolio.from_signals` (or similar) for fast simulation.
|
||||
- **Cost Model:** Support configurable fees (maker/taker) and slippage estimates.
|
||||
- **Grid Search:** Utilize `vectorbt`'s parameter broadcasting to run many variations simultaneously.
|
||||
- **Walk-Forward Analysis:**
|
||||
- Implement a splitting mechanism (e.g., `vectorbt.Splitter`) to divide data into In-Sample (Train) and Out-of-Sample (Test) sets.
|
||||
- Execute optimization on Train, validate on Test.
|
||||
|
||||
### 4.4 Reporting (`reporting.py`)
|
||||
- **Console:** Print key metrics: Total Return, Sharpe Ratio, Max Drawdown, Win Rate, Count of Trades.
|
||||
- **Files:** Save detailed trade logs and metrics summaries to `backtest_logs/`.
|
||||
- **Visuals:** Generate and save/show `vectorbt` plots (Equity curve, Drawdowns).
|
||||
|
||||
## 5. Non-Goals
|
||||
- Real-time live trading execution (this is strictly for research/backtesting).
|
||||
- Complex Machine Learning models (initially focusing on indicator-based logic).
|
||||
- High-frequency tick-level backtesting (1-minute granularity is the target).
|
||||
|
||||
## 6. Technical Architecture Proposal
|
||||
```text
|
||||
project_root/
|
||||
├── data/
|
||||
│ └── ccxt/ # New data storage structure
|
||||
├── strategies/ # Strategy definitions
|
||||
│ ├── __init__.py
|
||||
│ ├── base.py # Abstract Base Class
|
||||
│ └── ma_cross.py # Example strategy
|
||||
├── engine/
|
||||
│ ├── data_loader.py # CCXT wrapper
|
||||
│ ├── backtester.py # VBT runner
|
||||
│ └── optimizer.py # Grid Search & WFA logic
|
||||
├── main.py # CLI entry point
|
||||
└── pyproject.toml
|
||||
```
|
||||
|
||||
## 7. Success Metrics
|
||||
- Can download 1 year of 1m BTC/USDT data from OKX in < 2 minutes.
|
||||
- Can run a 100-parameter grid search on 1 year of 1m data in < 10 seconds.
|
||||
- Walk-forward analysis produces a clear "Robustness Score" or visual comparison of Train vs Test performance.
|
||||
|
||||
## 8. Open Questions
|
||||
- Do we need to handle funding rates for perp futures in the PnL calculation immediately? (Assumed NO for V1, stick to spot/simple futures price action).
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test suite for lowkey_backtest."""
|
||||
69
tests/test_data_manager.py
Normal file
69
tests/test_data_manager.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Tests for DataManager functionality."""
|
||||
import pytest
|
||||
|
||||
from engine.data_manager import DataManager
|
||||
from engine.market import MarketType
|
||||
|
||||
|
||||
class TestDataManager:
|
||||
"""Test suite for DataManager class."""
|
||||
|
||||
def test_init_creates_data_dir(self, tmp_path):
|
||||
"""Test that DataManager creates data directory on init."""
|
||||
data_dir = tmp_path / "test_data"
|
||||
dm = DataManager(data_dir=str(data_dir))
|
||||
assert data_dir.exists()
|
||||
|
||||
def test_get_data_path_spot(self, tmp_path):
|
||||
"""Test data path generation for spot market."""
|
||||
dm = DataManager(data_dir=str(tmp_path))
|
||||
path = dm._get_data_path("okx", "BTC/USDT", "1m", MarketType.SPOT)
|
||||
|
||||
expected = tmp_path / "okx" / "spot" / "BTC-USDT" / "1m.csv"
|
||||
assert path == expected
|
||||
|
||||
def test_get_data_path_perpetual(self, tmp_path):
|
||||
"""Test data path generation for perpetual market."""
|
||||
dm = DataManager(data_dir=str(tmp_path))
|
||||
path = dm._get_data_path("okx", "BTC/USDT", "1h", MarketType.PERPETUAL)
|
||||
|
||||
expected = tmp_path / "okx" / "perpetual" / "BTC-USDT" / "1h.csv"
|
||||
assert path == expected
|
||||
|
||||
def test_load_data_file_not_found(self, tmp_path):
|
||||
"""Test that load_data raises FileNotFoundError for missing data."""
|
||||
dm = DataManager(data_dir=str(tmp_path))
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
dm.load_data("okx", "BTC/USDT", "1m", MarketType.SPOT)
|
||||
|
||||
|
||||
@pytest.mark.network
|
||||
class TestDataManagerDownload:
|
||||
"""Tests requiring network access (marked for selective running)."""
|
||||
|
||||
def test_download_spot_data(self):
|
||||
"""Test downloading spot market data."""
|
||||
dm = DataManager()
|
||||
df = dm.download_data(
|
||||
"okx", "BTC/USDT", "1m",
|
||||
start_date="2025-01-01",
|
||||
market_type=MarketType.SPOT
|
||||
)
|
||||
|
||||
assert df is not None
|
||||
assert 'open' in df.columns
|
||||
assert 'close' in df.columns
|
||||
|
||||
def test_download_perpetual_data(self):
|
||||
"""Test downloading perpetual market data."""
|
||||
dm = DataManager()
|
||||
df = dm.download_data(
|
||||
"okx", "BTC/USDT", "1m",
|
||||
start_date="2025-01-01",
|
||||
market_type=MarketType.PERPETUAL
|
||||
)
|
||||
|
||||
assert df is not None
|
||||
assert 'open' in df.columns
|
||||
assert 'close' in df.columns
|
||||
52
trade.py
52
trade.py
@@ -1,52 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import pandas as pd
|
||||
from market_costs import okx_fee, estimate_slippage_rate
|
||||
from intrabar import entry_slippage_row
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradeState:
|
||||
cash: float = 1000.0
|
||||
qty: float = 0.0
|
||||
entry_px: float | None = None
|
||||
max_px: float | None = None
|
||||
stop_loss_frac: float = 0.02
|
||||
fee_bps: float = 10.0
|
||||
slippage_bps: float = 2.0
|
||||
|
||||
|
||||
def enter_long(state: TradeState, price: float) -> dict:
|
||||
if state.qty > 0:
|
||||
return {}
|
||||
px = entry_slippage_row(price, 0.0, state.slippage_bps)
|
||||
qty = state.cash / px
|
||||
fee = okx_fee(state.fee_bps, state.cash)
|
||||
state.qty = max(qty - fee / px, 0.0)
|
||||
state.cash = 0.0
|
||||
state.entry_px = px
|
||||
state.max_px = px
|
||||
return {"side": "BUY", "price": px, "qty": state.qty, "fee": fee}
|
||||
|
||||
|
||||
def maybe_trailing_stop(state: TradeState, price: float) -> float:
|
||||
if state.qty <= 0:
|
||||
return float("inf")
|
||||
state.max_px = max(state.max_px or price, price)
|
||||
trail_px = state.max_px * (1.0 - state.stop_loss_frac)
|
||||
return trail_px
|
||||
|
||||
|
||||
def exit_long(state: TradeState, price: float) -> dict:
|
||||
if state.qty <= 0:
|
||||
return {}
|
||||
notional = state.qty * price
|
||||
slip = estimate_slippage_rate(state.slippage_bps, notional)
|
||||
fee = okx_fee(state.fee_bps, notional)
|
||||
cash_back = notional - slip - fee
|
||||
event = {"side": "SELL", "price": price, "qty": state.qty, "fee": fee, "slippage": slip}
|
||||
state.cash = cash_back
|
||||
state.qty = 0.0
|
||||
state.entry_px = None
|
||||
state.max_px = None
|
||||
return event
|
||||
Reference in New Issue
Block a user