Remove deprecated modules and files related to the backtesting framework, including backtest.py, cli.py, config.py, data.py, intrabar.py, logging_utils.py, market_costs.py, metrics.py, trade.py, and supertrend indicators. Introduce a new structure for the backtesting engine with improved organization and functionality, including a CLI handler, data manager, and reporting capabilities. Update dependencies in pyproject.toml to support the new architecture.

This commit is contained in:
2026-01-12 21:11:39 +08:00
parent c4aa965a98
commit 44fac1ed25
37 changed files with 5253 additions and 393 deletions

View File

@@ -1 +0,0 @@
../OHLCVPredictor

View File

@@ -1,70 +0,0 @@
from __future__ import annotations
import pandas as pd
from pathlib import Path
from trade import TradeState, enter_long, exit_long, maybe_trailing_stop
from indicators import add_supertrends, compute_meta_trend
from metrics import compute_metrics
from logging_utils import write_trade_log
DEFAULT_ST_SETTINGS = [(12, 3.0), (10, 1.0), (11, 2.0)]
def backtest(
df: pd.DataFrame,
df_1min: pd.DataFrame,
timeframe_minutes: int,
stop_loss: float,
exit_on_bearish_flip: bool,
fee_bps: float,
slippage_bps: float,
log_path: Path | None = None,
):
df = add_supertrends(df, DEFAULT_ST_SETTINGS)
df["meta_bull"] = compute_meta_trend(df, DEFAULT_ST_SETTINGS)
state = TradeState(stop_loss_frac=stop_loss, fee_bps=fee_bps, slippage_bps=slippage_bps)
equity, trades = [], []
for i, row in df.iterrows():
price = float(row["Close"])
ts = pd.Timestamp(row["Timestamp"])
if state.qty <= 0 and row["meta_bull"] == 1:
evt = enter_long(state, price)
if evt:
evt.update({"t": ts.isoformat(), "reason": "bull_flip"})
trades.append(evt)
start = ts
end = df["Timestamp"].iat[i + 1] if i + 1 < len(df) else ts + pd.Timedelta(minutes=timeframe_minutes)
if state.qty > 0:
win = df_1min[(df_1min["Timestamp"] >= start) & (df_1min["Timestamp"] < end)]
for _, m in win.iterrows():
hi = float(m["High"])
lo = float(m["Low"])
state.max_px = max(state.max_px or hi, hi)
trail = state.max_px * (1.0 - state.stop_loss_frac)
if lo <= trail:
evt = exit_long(state, trail)
if evt:
prev = trades[-1]
pnl = (evt["price"] - (prev.get("price") or evt["price"])) * (prev.get("qty") or 0.0)
evt.update({"t": pd.Timestamp(m["Timestamp"]).isoformat(), "reason": "stop", "pnl": pnl})
trades.append(evt)
break
if state.qty > 0 and exit_on_bearish_flip and row["meta_bull"] == 0:
evt = exit_long(state, price)
if evt:
prev = trades[-1]
pnl = (evt["price"] - (prev.get("price") or evt["price"])) * (prev.get("qty") or 0.0)
evt.update({"t": ts.isoformat(), "reason": "bearish_flip", "pnl": pnl})
trades.append(evt)
equity.append(state.cash + state.qty * price)
equity_curve = pd.Series(equity, index=df["Timestamp"])
if log_path:
write_trade_log(trades, log_path)
perf = compute_metrics(equity_curve, trades)
return perf, equity_curve, trades

80
cli.py
View File

@@ -1,80 +0,0 @@
from __future__ import annotations
import argparse
from pathlib import Path
import pandas as pd
from config import CLIConfig
from data import load_data
from backtest import backtest
def parse_args() -> CLIConfig:
p = argparse.ArgumentParser(prog="bt", description="Simple supertrend backtester")
p.add_argument("start")
p.add_argument("end")
p.add_argument("--timeframe-minutes", type=int, default=15) # single TF
p.add_argument("--timeframes-minutes", nargs="+", type=int) # multi TF: e.g. 5 15 60 240
p.add_argument("--stop-loss", dest="stop_losses", type=float, nargs="+", default=[0.02, 0.05])
p.add_argument("--exit-on-bearish-flip", action="store_true")
p.add_argument("--csv", dest="csv_path", type=Path, required=True)
p.add_argument("--out-csv", type=Path, default=Path("summary.csv"))
p.add_argument("--log-dir", type=Path, default=Path("./logs"))
p.add_argument("--fee-bps", type=float, default=10.0)
p.add_argument("--slippage-bps", type=float, default=2.0)
a = p.parse_args()
return CLIConfig(
start=a.start,
end=a.end,
timeframe_minutes=a.timeframe_minutes,
timeframes_minutes=a.timeframes_minutes,
stop_losses=a.stop_losses,
exit_on_bearish_flip=a.exit_on_bearish_flip,
csv_path=a.csv_path,
out_csv=a.out_csv,
log_dir=a.log_dir,
fee_bps=a.fee_bps,
slippage_bps=a.slippage_bps,
)
def main():
cfg = parse_args()
frames = cfg.timeframes_minutes or [cfg.timeframe_minutes]
rows: list[dict] = []
for tfm in frames:
df_1min, df = load_data(cfg.start, cfg.end, tfm, cfg.csv_path)
for sl in cfg.stop_losses:
log_path = cfg.log_dir / f"{tfm}m_sl{sl:.2%}.csv"
perf, equity, _ = backtest(
df=df,
df_1min=df_1min,
timeframe_minutes=tfm,
stop_loss=sl,
exit_on_bearish_flip=cfg.exit_on_bearish_flip,
fee_bps=cfg.fee_bps,
slippage_bps=cfg.slippage_bps,
log_path=log_path,
)
rows.append({
"timeframe": f"{tfm}min",
"stop_loss": sl,
"exit_on_bearish_flip": cfg.exit_on_bearish_flip,
"total_return": f"{perf.total_return:.2%}",
"max_drawdown": f"{perf.max_drawdown:.2%}",
"sharpe_ratio": f"{perf.sharpe_ratio:.2f}",
"win_rate": f"{perf.win_rate:.2%}",
"num_trades": perf.num_trades,
"final_equity": f"${perf.final_equity:.2f}",
"initial_equity": f"${perf.initial_equity:.2f}",
"num_stop_losses": perf.num_stop_losses,
"total_fees": perf.total_fees,
"total_slippage_usd": perf.total_slippage_usd,
"avg_slippage_bps": perf.avg_slippage_bps,
})
out = pd.DataFrame(rows)
out.to_csv(cfg.out_csv, index=False)
print(out.to_string(index=False))
if __name__ == "__main__":
main()

View File

@@ -1,18 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence
@dataclass
class CLIConfig:
start: str
end: str
timeframe_minutes: int
timeframes_minutes: list[int] | None
stop_losses: Sequence[float]
exit_on_bearish_flip: bool
csv_path: Path | None
out_csv: Path
log_dir: Path
fee_bps: float
slippage_bps: float

24
data.py
View File

@@ -1,24 +0,0 @@
from __future__ import annotations
import pandas as pd
from pathlib import Path
def load_data(start: str, end: str, timeframe_minutes: int, csv_path: Path) -> tuple[pd.DataFrame, pd.DataFrame]:
df_1min = pd.read_csv(csv_path)
df_1min["Timestamp"] = pd.to_datetime(df_1min["Timestamp"], unit="s", utc=True)
df_1min = df_1min[(df_1min["Timestamp"] >= pd.Timestamp(start, tz="UTC")) &
(df_1min["Timestamp"] <= pd.Timestamp(end, tz="UTC"))] \
.sort_values("Timestamp").reset_index(drop=True)
if timeframe_minutes != 1:
g = df_1min.set_index("Timestamp").resample(f"{timeframe_minutes}min")
df = pd.DataFrame({
"Open": g["Open"].first(),
"High": g["High"].max(),
"Low": g["Low"].min(),
"Close": g["Close"].last(),
"Volume": g["Volume"].sum(),
}).dropna().reset_index()
else:
df = df_1min.copy()
return df_1min, df

352
engine/backtester.py Normal file
View File

@@ -0,0 +1,352 @@
"""
Core backtesting engine for running strategy simulations.
Supports multiple market types with realistic trading conditions.
"""
from dataclasses import dataclass
import pandas as pd
import vectorbt as vbt
from engine.data_manager import DataManager
from engine.logging_config import get_logger
from engine.market import MarketType, get_market_config
from engine.optimizer import WalkForwardOptimizer
from engine.portfolio import run_long_only_portfolio, run_long_short_portfolio
from engine.risk import (
LiquidationEvent,
calculate_funding,
calculate_liquidation_adjustment,
inject_liquidation_exits,
)
from strategies.base import BaseStrategy
logger = get_logger(__name__)
@dataclass
class BacktestResult:
"""
Container for backtest results with market-specific metrics.
Attributes:
portfolio: VectorBT Portfolio object
market_type: Market type used for the backtest
leverage: Effective leverage used
total_funding_paid: Total funding fees paid (perpetuals only)
liquidation_count: Number of positions that were liquidated
liquidation_events: Detailed list of liquidation events
total_liquidation_loss: Total margin lost from liquidations
adjusted_return: Return adjusted for liquidation losses (percentage)
"""
portfolio: vbt.Portfolio
market_type: MarketType
leverage: int
total_funding_paid: float = 0.0
liquidation_count: int = 0
liquidation_events: list[LiquidationEvent] | None = None
total_liquidation_loss: float = 0.0
adjusted_return: float | None = None
class Backtester:
"""
Backtester supporting multiple market types with realistic simulation.
Features:
- Spot and Perpetual market support
- Long and short position handling
- Leverage simulation
- Funding rate calculation (perpetuals)
- Liquidation warnings
"""
def __init__(self, data_manager: DataManager):
self.dm = data_manager
def run_strategy(
self,
strategy: BaseStrategy,
exchange_id: str,
symbol: str,
timeframe: str = '1m',
start_date: str | None = None,
end_date: str | None = None,
init_cash: float = 10000,
fees: float | None = None,
slippage: float = 0.001,
sl_stop: float | None = None,
tp_stop: float | None = None,
sl_trail: bool = False,
leverage: int | None = None,
**strategy_params
) -> BacktestResult:
"""
Run a backtest with market-type-aware simulation.
Args:
strategy: Strategy instance to backtest
exchange_id: Exchange identifier (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Data timeframe (e.g., '1m', '1h', '1d')
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
init_cash: Initial capital (margin for leveraged)
fees: Transaction fee override (uses market default if None)
slippage: Slippage percentage
sl_stop: Stop loss percentage
tp_stop: Take profit percentage
sl_trail: Enable trailing stop loss
leverage: Leverage override (uses strategy default if None)
**strategy_params: Additional strategy parameters
Returns:
BacktestResult with portfolio and market-specific metrics
"""
# Get market configuration from strategy
market_type = strategy.default_market_type
market_config = get_market_config(market_type)
# Resolve leverage and fees
effective_leverage = self._resolve_leverage(leverage, strategy, market_type)
effective_fees = fees if fees is not None else market_config.taker_fee
# Load and filter data
df = self._load_data(
exchange_id, symbol, timeframe, market_type, start_date, end_date
)
close_price = df['close']
high_price = df['high']
low_price = df['low']
open_price = df['open']
volume = df['volume']
# Run strategy logic
signals = strategy.run(
close_price,
high=high_price,
low=low_price,
open=open_price,
volume=volume,
**strategy_params
)
# Normalize signals to 4-tuple format
signals = self._normalize_signals(signals, close_price, market_config)
long_entries, long_exits, short_entries, short_exits = signals
# Process liquidations - inject forced exits at liquidation points
liquidation_events: list[LiquidationEvent] = []
if effective_leverage > 1:
long_exits, short_exits, liquidation_events = inject_liquidation_exits(
close_price, high_price, low_price,
long_entries, long_exits,
short_entries, short_exits,
effective_leverage,
market_config.maintenance_margin_rate
)
# Calculate perpetual-specific metrics (after liquidation processing)
total_funding = 0.0
if market_type == MarketType.PERPETUAL:
total_funding = calculate_funding(
close_price,
long_entries, long_exits,
short_entries, short_exits,
market_config,
effective_leverage
)
# Run portfolio simulation with liquidation-aware exits
portfolio = self._run_portfolio(
close_price, market_config,
long_entries, long_exits,
short_entries, short_exits,
init_cash, effective_fees, slippage, timeframe,
sl_stop, tp_stop, sl_trail, effective_leverage
)
# Calculate adjusted returns accounting for liquidation losses
total_liq_loss, liq_adjustment = calculate_liquidation_adjustment(
liquidation_events, init_cash, effective_leverage
)
raw_return = portfolio.total_return().mean() * 100
adjusted_return = raw_return - liq_adjustment
if liquidation_events:
logger.info(
"Liquidation impact: %d events, $%.2f margin lost, %.2f%% adjustment",
len(liquidation_events), total_liq_loss, liq_adjustment
)
logger.info(
"Backtest completed: %s market, %dx leverage, fees=%.4f%%",
market_type.value, effective_leverage, effective_fees * 100
)
return BacktestResult(
portfolio=portfolio,
market_type=market_type,
leverage=effective_leverage,
total_funding_paid=total_funding,
liquidation_count=len(liquidation_events),
liquidation_events=liquidation_events,
total_liquidation_loss=total_liq_loss,
adjusted_return=adjusted_return
)
def _resolve_leverage(
self,
leverage: int | None,
strategy: BaseStrategy,
market_type: MarketType
) -> int:
"""Resolve effective leverage from CLI, strategy default, or market type."""
effective = leverage or strategy.default_leverage
if market_type == MarketType.SPOT:
return 1 # Spot cannot have leverage
return effective
def _load_data(
self,
exchange_id: str,
symbol: str,
timeframe: str,
market_type: MarketType,
start_date: str | None,
end_date: str | None
) -> pd.DataFrame:
"""Load and filter OHLCV data."""
try:
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
except FileNotFoundError:
logger.warning("Data not found locally. Attempting download...")
df = self.dm.download_data(
exchange_id, symbol, timeframe,
start_date, end_date, market_type
)
if start_date:
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
if end_date:
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
return df
def _normalize_signals(
self,
signals: tuple,
close: pd.Series,
market_config
) -> tuple:
"""
Normalize strategy signals to 4-tuple format.
Handles backward compatibility with 2-tuple (long-only) returns.
"""
if len(signals) == 2:
long_entries, long_exits = signals
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits
if len(signals) == 4:
long_entries, long_exits, short_entries, short_exits = signals
# Warn and clear short signals on spot markets
if not market_config.supports_short:
has_shorts = (
short_entries.any().any()
if hasattr(short_entries, 'any')
else short_entries.any()
)
if has_shorts:
logger.warning(
"Short signals detected but market type is SPOT. "
"Short signals will be ignored."
)
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits
raise ValueError(
f"Strategy must return 2 or 4 signal arrays, got {len(signals)}"
)
def _run_portfolio(
self,
close: pd.Series,
market_config,
long_entries, long_exits,
short_entries, short_exits,
init_cash: float,
fees: float,
slippage: float,
freq: str,
sl_stop: float | None,
tp_stop: float | None,
sl_trail: bool,
leverage: int
) -> vbt.Portfolio:
"""Select and run appropriate portfolio simulation."""
has_shorts = (
short_entries.any().any()
if hasattr(short_entries, 'any')
else short_entries.any()
)
if market_config.supports_short and has_shorts:
return run_long_short_portfolio(
close,
long_entries, long_exits,
short_entries, short_exits,
init_cash, fees, slippage, freq,
sl_stop, tp_stop, sl_trail, leverage
)
return run_long_only_portfolio(
close,
long_entries, long_exits,
init_cash, fees, slippage, freq,
sl_stop, tp_stop, sl_trail, leverage
)
def run_wfa(
self,
strategy: BaseStrategy,
exchange_id: str,
symbol: str,
param_grid: dict,
n_windows: int = 10,
timeframe: str = '1m'
):
"""
Execute Walk-Forward Analysis.
Args:
strategy: Strategy instance to optimize
exchange_id: Exchange identifier
symbol: Trading pair symbol
param_grid: Parameter grid for optimization
n_windows: Number of walk-forward windows
timeframe: Data timeframe to load
Returns:
Tuple of (results DataFrame, stitched equity curve)
"""
market_type = strategy.default_market_type
df = self.dm.load_data(exchange_id, symbol, timeframe, market_type)
wfa = WalkForwardOptimizer(self, strategy, param_grid)
results, stitched_curve = wfa.run(
df['close'],
high=df['high'],
low=df['low'],
n_windows=n_windows
)
return results, stitched_curve

243
engine/cli.py Normal file
View File

@@ -0,0 +1,243 @@
"""
CLI handler for Lowkey Backtest.
"""
import argparse
import sys
from engine.backtester import Backtester
from engine.data_manager import DataManager
from engine.logging_config import get_logger, setup_logging
from engine.market import MarketType
from engine.reporting import Reporter
from strategies.factory import get_strategy, get_strategy_names
logger = get_logger(__name__)
def create_parser() -> argparse.ArgumentParser:
"""Create and configure the argument parser."""
parser = argparse.ArgumentParser(
description="Lowkey Backtest CLI (VectorBT Edition)"
)
subparsers = parser.add_subparsers(dest="command", help="Command to run")
_add_download_parser(subparsers)
_add_backtest_parser(subparsers)
_add_wfa_parser(subparsers)
return parser
def _add_download_parser(subparsers) -> None:
"""Add download command parser."""
dl_parser = subparsers.add_parser("download", help="Download historical data")
dl_parser.add_argument("--exchange", "-e", type=str, default="okx")
dl_parser.add_argument("--pair", "-p", type=str, required=True)
dl_parser.add_argument("--timeframe", "-t", type=str, default="1m")
dl_parser.add_argument("--start", type=str, help="Start Date (YYYY-MM-DD)")
dl_parser.add_argument(
"--market", "-m",
type=str,
choices=["spot", "perpetual"],
default="spot"
)
def _add_backtest_parser(subparsers) -> None:
"""Add backtest command parser."""
strategy_choices = get_strategy_names()
bt_parser = subparsers.add_parser("backtest", help="Run a backtest")
bt_parser.add_argument(
"--strategy", "-s",
type=str,
choices=strategy_choices,
required=True
)
bt_parser.add_argument("--exchange", "-e", type=str, default="okx")
bt_parser.add_argument("--pair", "-p", type=str, required=True)
bt_parser.add_argument("--timeframe", "-t", type=str, default="1m")
bt_parser.add_argument("--start", type=str)
bt_parser.add_argument("--end", type=str)
bt_parser.add_argument("--grid", "-g", action="store_true")
bt_parser.add_argument("--plot", action="store_true")
# Risk parameters
bt_parser.add_argument("--sl", type=float, help="Stop Loss %%")
bt_parser.add_argument("--tp", type=float, help="Take Profit %%")
bt_parser.add_argument("--trail", action="store_true")
bt_parser.add_argument("--no-bear-exit", action="store_true")
# Cost parameters
bt_parser.add_argument("--fees", type=float, default=None)
bt_parser.add_argument("--slippage", type=float, default=0.001)
bt_parser.add_argument("--leverage", "-l", type=int, default=None)
def _add_wfa_parser(subparsers) -> None:
"""Add walk-forward analysis command parser."""
strategy_choices = get_strategy_names()
wfa_parser = subparsers.add_parser("wfa", help="Run Walk-Forward Analysis")
wfa_parser.add_argument(
"--strategy", "-s",
type=str,
choices=strategy_choices,
required=True
)
wfa_parser.add_argument("--pair", "-p", type=str, required=True)
wfa_parser.add_argument("--timeframe", "-t", type=str, default="1d")
wfa_parser.add_argument("--windows", "-w", type=int, default=10)
wfa_parser.add_argument("--plot", action="store_true")
def run_download(args) -> None:
"""Execute download command."""
dm = DataManager()
market_type = MarketType(args.market)
dm.download_data(
args.exchange,
args.pair,
args.timeframe,
start_date=args.start,
market_type=market_type
)
def run_backtest(args) -> None:
"""Execute backtest command."""
dm = DataManager()
bt = Backtester(dm)
reporter = Reporter()
strategy, params = get_strategy(args.strategy, args.grid)
# Apply CLI overrides for meta_st strategy
params = _apply_strategy_overrides(args, strategy, params)
if args.grid and args.strategy == "meta_st":
logger.info("Running Grid Search for Meta Supertrend...")
try:
result = bt.run_strategy(
strategy,
args.exchange,
args.pair,
timeframe=args.timeframe,
start_date=args.start,
end_date=args.end,
fees=args.fees,
slippage=args.slippage,
sl_stop=args.sl,
tp_stop=args.tp,
sl_trail=args.trail,
leverage=args.leverage,
**params
)
reporter.print_summary(result)
reporter.save_reports(result, f"{args.strategy}_{args.pair.replace('/','-')}")
if args.plot and not args.grid:
reporter.plot(result.portfolio)
elif args.plot and args.grid:
logger.info("Plotting skipped for Grid Search. Check CSV results.")
except Exception as e:
logger.error("Backtest failed: %s", e, exc_info=True)
def run_wfa(args) -> None:
"""Execute walk-forward analysis command."""
dm = DataManager()
bt = Backtester(dm)
reporter = Reporter()
strategy, params = get_strategy(args.strategy, is_grid=True)
logger.info(
"Running WFA on %s for %s (%s) with %d windows...",
args.strategy, args.pair, args.timeframe, args.windows
)
try:
results, stitched_curve = bt.run_wfa(
strategy,
"okx",
args.pair,
params,
n_windows=args.windows,
timeframe=args.timeframe
)
_log_wfa_results(results)
_save_wfa_results(args, results, stitched_curve, reporter)
except Exception as e:
logger.error("WFA failed: %s", e, exc_info=True)
def _apply_strategy_overrides(args, strategy, params: dict) -> dict:
"""Apply CLI argument overrides to strategy parameters."""
if args.strategy != "meta_st":
return params
if args.no_bear_exit:
params['exit_on_bearish_flip'] = False
if args.sl is None:
args.sl = strategy.default_sl_stop
if not args.trail:
args.trail = strategy.default_sl_trail
return params
def _log_wfa_results(results) -> None:
"""Log WFA results summary."""
logger.info("Walk-Forward Analysis Results:")
if results.empty or 'window' not in results.columns:
logger.warning("No valid WFA results. All windows may have failed.")
return
columns = ['window', 'train_score', 'test_score', 'test_return']
logger.info("\n%s", results[columns].to_string(index=False))
avg_test_sharpe = results['test_score'].mean()
avg_test_return = results['test_return'].mean()
logger.info("Average Test Sharpe: %.2f", avg_test_sharpe)
logger.info("Average Test Return: %.2f%%", avg_test_return * 100)
def _save_wfa_results(args, results, stitched_curve, reporter) -> None:
"""Save WFA results to file and optionally plot."""
if results.empty:
return
output_path = f"backtest_logs/wfa_{args.strategy}_{args.pair.replace('/','-')}.csv"
results.to_csv(output_path)
logger.info("Saved full results to %s", output_path)
if args.plot:
reporter.plot_wfa(results, stitched_curve)
def main():
"""Main entry point."""
setup_logging()
parser = create_parser()
args = parser.parse_args()
commands = {
"download": run_download,
"backtest": run_backtest,
"wfa": run_wfa,
}
if args.command in commands:
commands[args.command](args)
else:
parser.print_help()

209
engine/data_manager.py Normal file
View File

@@ -0,0 +1,209 @@
"""
Data management for OHLCV data download and storage.
Handles data retrieval from exchanges and local file management.
"""
import time
from datetime import datetime, timezone
from pathlib import Path
import ccxt
import pandas as pd
from engine.logging_config import get_logger
from engine.market import MarketType, get_ccxt_symbol
logger = get_logger(__name__)
class DataManager:
"""
Manages OHLCV data download and storage for different market types.
Data is stored in: data/ccxt/{exchange}/{market_type}/{symbol}/{timeframe}.csv
"""
def __init__(self, data_dir: str = "data/ccxt"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.exchanges: dict[str, ccxt.Exchange] = {}
def get_exchange(self, exchange_id: str) -> ccxt.Exchange:
"""Get or create a CCXT exchange instance."""
if exchange_id not in self.exchanges:
exchange_class = getattr(ccxt, exchange_id)
self.exchanges[exchange_id] = exchange_class({
'enableRateLimit': True,
})
return self.exchanges[exchange_id]
def _get_data_path(
self,
exchange_id: str,
symbol: str,
timeframe: str,
market_type: MarketType
) -> Path:
"""
Get the file path for storing/loading data.
Args:
exchange_id: Exchange name (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Candle timeframe (e.g., '1m')
market_type: Market type (spot or perpetual)
Returns:
Path to the CSV file
"""
safe_symbol = symbol.replace('/', '-')
return (
self.data_dir
/ exchange_id
/ market_type.value
/ safe_symbol
/ f"{timeframe}.csv"
)
def download_data(
self,
exchange_id: str,
symbol: str,
timeframe: str = '1m',
start_date: str | None = None,
end_date: str | None = None,
market_type: MarketType = MarketType.SPOT
) -> pd.DataFrame | None:
"""
Download OHLCV data from exchange and save to CSV.
Args:
exchange_id: Exchange name (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Candle timeframe (e.g., '1m')
start_date: Start date string (YYYY-MM-DD)
end_date: End date string (YYYY-MM-DD)
market_type: Market type (spot or perpetual)
Returns:
DataFrame with OHLCV data, or None if download failed
"""
exchange = self.get_exchange(exchange_id)
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
file_path.parent.mkdir(parents=True, exist_ok=True)
ccxt_symbol = get_ccxt_symbol(symbol, market_type)
since, until = self._parse_date_range(exchange, start_date, end_date)
logger.info(
"Downloading %s (%s) from %s...",
symbol, market_type.value, exchange_id
)
all_ohlcv = self._fetch_all_candles(exchange, ccxt_symbol, timeframe, since, until)
if not all_ohlcv:
logger.warning("No data downloaded.")
return None
df = self._convert_to_dataframe(all_ohlcv)
df.to_csv(file_path)
logger.info("Saved %d candles to %s", len(df), file_path)
return df
def load_data(
self,
exchange_id: str,
symbol: str,
timeframe: str = '1m',
market_type: MarketType = MarketType.SPOT
) -> pd.DataFrame:
"""
Load saved OHLCV data for vectorbt.
Args:
exchange_id: Exchange name (e.g., 'okx')
symbol: Trading pair (e.g., 'BTC/USDT')
timeframe: Candle timeframe (e.g., '1m')
market_type: Market type (spot or perpetual)
Returns:
DataFrame with OHLCV data indexed by timestamp
Raises:
FileNotFoundError: If data file does not exist
"""
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
if not file_path.exists():
raise FileNotFoundError(
f"Data not found at {file_path}. "
f"Run: uv run python main.py download --pair {symbol} "
f"--market {market_type.value}"
)
return pd.read_csv(file_path, index_col='timestamp', parse_dates=True)
def _parse_date_range(
self,
exchange: ccxt.Exchange,
start_date: str | None,
end_date: str | None
) -> tuple[int, int]:
"""Parse date strings into millisecond timestamps."""
if start_date:
since = exchange.parse8601(f"{start_date}T00:00:00Z")
else:
since = exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
if end_date:
until = exchange.parse8601(f"{end_date}T23:59:59Z")
else:
until = exchange.milliseconds()
return since, until
def _fetch_all_candles(
self,
exchange: ccxt.Exchange,
symbol: str,
timeframe: str,
since: int,
until: int
) -> list:
"""Fetch all candles in the date range."""
all_ohlcv = []
while since < until:
try:
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since, limit=100)
if not ohlcv:
break
all_ohlcv.extend(ohlcv)
since = ohlcv[-1][0] + 1
current_date = datetime.fromtimestamp(
since/1000, tz=timezone.utc
).strftime('%Y-%m-%d')
logger.debug("Fetched up to %s", current_date)
time.sleep(exchange.rateLimit / 1000)
except Exception as e:
logger.error("Error fetching data: %s", e)
break
return all_ohlcv
def _convert_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
"""Convert OHLCV list to DataFrame."""
df = pd.DataFrame(
ohlcv,
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df.set_index('timestamp', inplace=True)
return df

124
engine/logging_config.py Normal file
View File

@@ -0,0 +1,124 @@
"""
Centralized logging configuration for the backtest engine.
Provides colored console output and rotating file logs.
"""
import logging
import sys
from logging.handlers import RotatingFileHandler
from pathlib import Path
# ANSI color codes for terminal output
class Colors:
"""ANSI escape codes for colored terminal output."""
RESET = "\033[0m"
BOLD = "\033[1m"
# Log level colors
DEBUG = "\033[36m" # Cyan
INFO = "\033[32m" # Green
WARNING = "\033[33m" # Yellow
ERROR = "\033[31m" # Red
CRITICAL = "\033[35m" # Magenta
class ColoredFormatter(logging.Formatter):
"""
Custom formatter that adds colors to log level names in terminal output.
"""
LEVEL_COLORS = {
logging.DEBUG: Colors.DEBUG,
logging.INFO: Colors.INFO,
logging.WARNING: Colors.WARNING,
logging.ERROR: Colors.ERROR,
logging.CRITICAL: Colors.CRITICAL,
}
def __init__(self, fmt: str = None, datefmt: str = None):
super().__init__(fmt, datefmt)
def format(self, record: logging.LogRecord) -> str:
# Save original levelname
original_levelname = record.levelname
# Add color to levelname
color = self.LEVEL_COLORS.get(record.levelno, Colors.RESET)
record.levelname = f"{color}{record.levelname}{Colors.RESET}"
# Format the message
result = super().format(record)
# Restore original levelname
record.levelname = original_levelname
return result
def setup_logging(
log_dir: str = "logs",
log_level: int = logging.INFO,
console_level: int = logging.INFO,
max_bytes: int = 5 * 1024 * 1024, # 5MB
backup_count: int = 3
) -> None:
"""
Configure logging for the application.
Args:
log_dir: Directory for log files
log_level: File logging level
console_level: Console logging level
max_bytes: Max size per log file before rotation
backup_count: Number of backup files to keep
"""
log_path = Path(log_dir)
log_path.mkdir(parents=True, exist_ok=True)
# Get root logger
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG) # Capture all, handlers filter
# Clear existing handlers
root_logger.handlers.clear()
# Console handler with colors
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(console_level)
console_fmt = ColoredFormatter(
fmt="[%(asctime)s] %(levelname)s - %(message)s",
datefmt="%H:%M:%S"
)
console_handler.setFormatter(console_fmt)
root_logger.addHandler(console_handler)
# File handler with rotation
file_handler = RotatingFileHandler(
log_path / "backtest.log",
maxBytes=max_bytes,
backupCount=backup_count,
encoding="utf-8"
)
file_handler.setLevel(log_level)
file_fmt = logging.Formatter(
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
file_handler.setFormatter(file_fmt)
root_logger.addHandler(file_handler)
def get_logger(name: str) -> logging.Logger:
"""
Get a logger instance for the given module name.
Args:
name: Module name (typically __name__)
Returns:
Configured logger instance
"""
return logging.getLogger(name)

156
engine/market.py Normal file
View File

@@ -0,0 +1,156 @@
"""
Market type definitions and configuration for backtesting.
Supports different market types with their specific trading conditions:
- SPOT: No leverage, no funding, long-only
- PERPETUAL: Leverage, funding rates, long/short
"""
from dataclasses import dataclass
from enum import Enum
class MarketType(Enum):
"""Supported market types for backtesting."""
SPOT = "spot"
PERPETUAL = "perpetual"
@dataclass(frozen=True)
class MarketConfig:
"""
Configuration for a specific market type.
Attributes:
market_type: The market type enum value
maker_fee: Maker fee as decimal (e.g., 0.0008 = 0.08%)
taker_fee: Taker fee as decimal (e.g., 0.001 = 0.1%)
max_leverage: Maximum allowed leverage
funding_rate: Funding rate per 8 hours as decimal (perpetuals only)
funding_interval_hours: Hours between funding payments
maintenance_margin_rate: Rate for liquidation calculation
supports_short: Whether short-selling is supported
"""
market_type: MarketType
maker_fee: float
taker_fee: float
max_leverage: int
funding_rate: float
funding_interval_hours: int
maintenance_margin_rate: float
supports_short: bool
# OKX-based default configurations
SPOT_CONFIG = MarketConfig(
market_type=MarketType.SPOT,
maker_fee=0.0008, # 0.08%
taker_fee=0.0010, # 0.10%
max_leverage=1,
funding_rate=0.0,
funding_interval_hours=0,
maintenance_margin_rate=0.0,
supports_short=False,
)
PERPETUAL_CONFIG = MarketConfig(
market_type=MarketType.PERPETUAL,
maker_fee=0.0002, # 0.02%
taker_fee=0.0005, # 0.05%
max_leverage=125,
funding_rate=0.0001, # 0.01% per 8 hours (simplified average)
funding_interval_hours=8,
maintenance_margin_rate=0.004, # 0.4% for BTC on OKX
supports_short=True,
)
def get_market_config(market_type: MarketType) -> MarketConfig:
"""
Get the configuration for a specific market type.
Args:
market_type: The market type to get configuration for
Returns:
MarketConfig with default values for that market type
"""
configs = {
MarketType.SPOT: SPOT_CONFIG,
MarketType.PERPETUAL: PERPETUAL_CONFIG,
}
return configs[market_type]
def get_ccxt_symbol(symbol: str, market_type: MarketType) -> str:
"""
Convert a standard symbol to CCXT format for the given market type.
Args:
symbol: Standard symbol (e.g., 'BTC/USDT')
market_type: The market type
Returns:
CCXT-formatted symbol (e.g., 'BTC/USDT:USDT' for perpetuals)
"""
if market_type == MarketType.PERPETUAL:
# OKX perpetual format: BTC/USDT:USDT
quote = symbol.split('/')[1] if '/' in symbol else 'USDT'
return f"{symbol}:{quote}"
return symbol
def calculate_leverage_stop_loss(
leverage: int,
maintenance_margin_rate: float = 0.004
) -> float:
"""
Calculate the implicit stop-loss percentage from leverage.
At a given leverage, liquidation occurs when the position loses
approximately (1/leverage - maintenance_margin_rate) of its value.
Args:
leverage: Position leverage multiplier
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
Returns:
Stop-loss percentage as decimal (e.g., 0.196 for 19.6%)
"""
if leverage <= 1:
return 1.0 # No forced stop for spot
return (1 / leverage) - maintenance_margin_rate
def calculate_liquidation_price(
entry_price: float,
leverage: float,
is_long: bool,
maintenance_margin_rate: float = 0.004
) -> float:
"""
Calculate the liquidation price for a leveraged position.
Args:
entry_price: Position entry price
leverage: Position leverage
is_long: True for long positions, False for short
maintenance_margin_rate: Maintenance margin rate (default OKX BTC: 0.4%)
Returns:
Liquidation price
"""
if leverage <= 1:
return 0.0 if is_long else float('inf')
# Simplified liquidation formula
# Long: liq_price = entry * (1 - 1/leverage + maintenance_margin_rate)
# Short: liq_price = entry * (1 + 1/leverage - maintenance_margin_rate)
margin_ratio = 1 / leverage
if is_long:
liq_price = entry_price * (1 - margin_ratio + maintenance_margin_rate)
else:
liq_price = entry_price * (1 + margin_ratio - maintenance_margin_rate)
return liq_price

245
engine/optimizer.py Normal file
View File

@@ -0,0 +1,245 @@
"""
Walk-Forward Analysis optimizer for strategy parameter optimization.
Implements expanding window walk-forward analysis with train/test splits.
"""
import numpy as np
import pandas as pd
import vectorbt as vbt
from engine.logging_config import get_logger
logger = get_logger(__name__)
def create_rolling_windows(
index: pd.Index,
n_windows: int,
train_split: float = 0.7
):
"""
Create rolling train/test split indices using expanding window approach.
Args:
index: DataFrame index to split
n_windows: Number of walk-forward windows
train_split: Unused, kept for API compatibility
Yields:
Tuples of (train_idx, test_idx) numpy arrays
"""
chunks = np.array_split(index, n_windows + 1)
for i in range(n_windows):
train_idx = np.concatenate([c for c in chunks[:i+1]])
test_idx = chunks[i+1]
yield train_idx, test_idx
class WalkForwardOptimizer:
"""
Walk-Forward Analysis optimizer for strategy backtesting.
Optimizes strategy parameters on training windows and validates
on out-of-sample test windows.
"""
def __init__(
self,
backtester,
strategy,
param_grid: dict,
metric: str = 'Sharpe Ratio',
fees: float = 0.001,
freq: str = '1m'
):
"""
Initialize the optimizer.
Args:
backtester: Backtester instance
strategy: Strategy instance to optimize
param_grid: Parameter grid for optimization
metric: Performance metric to optimize
fees: Transaction fees for simulation
freq: Data frequency for portfolio simulation
"""
self.bt = backtester
self.strategy = strategy
self.param_grid = param_grid
self.metric = metric
self.fees = fees
self.freq = freq
# Separate grid params (lists) from fixed params (scalars)
self.grid_keys = []
self.fixed_params = {}
for k, v in param_grid.items():
if isinstance(v, (list, np.ndarray)):
self.grid_keys.append(k)
else:
self.fixed_params[k] = v
def run(
self,
close_price: pd.Series,
high: pd.Series | None = None,
low: pd.Series | None = None,
n_windows: int = 10
) -> tuple[pd.DataFrame, pd.Series | None]:
"""
Execute walk-forward analysis.
Args:
close_price: Close price series
high: High price series (optional)
low: Low price series (optional)
n_windows: Number of walk-forward windows
Returns:
Tuple of (results DataFrame, stitched equity curve)
"""
results = []
equity_curves = []
logger.info(
"Starting Walk-Forward Analysis with %d windows (Expanding Train)...",
n_windows
)
splitter = create_rolling_windows(close_price.index, n_windows)
for i, (train_idx, test_idx) in enumerate(splitter):
logger.info("Processing Window %d/%d...", i + 1, n_windows)
window_result = self._process_window(
i, train_idx, test_idx, close_price, high, low
)
if window_result is not None:
result_dict, eq_curve = window_result
results.append(result_dict)
equity_curves.append(eq_curve)
stitched_series = self._stitch_equity_curves(equity_curves)
return pd.DataFrame(results), stitched_series
def _process_window(
self,
window_idx: int,
train_idx: np.ndarray,
test_idx: np.ndarray,
close_price: pd.Series,
high: pd.Series | None,
low: pd.Series | None
) -> tuple[dict, pd.Series] | None:
"""Process a single WFA window."""
try:
# Slice data for train/test
train_close = close_price.loc[train_idx]
train_high = high.loc[train_idx] if high is not None else None
train_low = low.loc[train_idx] if low is not None else None
# Train phase: find best parameters
best_params, best_score = self._optimize_train(
train_close, train_high, train_low
)
# Test phase: validate with best params
test_close = close_price.loc[test_idx]
test_high = high.loc[test_idx] if high is not None else None
test_low = low.loc[test_idx] if low is not None else None
test_params = {**self.fixed_params, **best_params}
test_score, test_return, eq_curve = self._run_test(
test_close, test_high, test_low, test_params
)
return {
'window': window_idx + 1,
'train_start': train_idx[0],
'train_end': train_idx[-1],
'test_start': test_idx[0],
'test_end': test_idx[-1],
'best_params': best_params,
'train_score': best_score,
'test_score': test_score,
'test_return': test_return
}, eq_curve
except Exception as e:
logger.error("Error in window %d: %s", window_idx + 1, e, exc_info=True)
return None
def _optimize_train(
self,
close: pd.Series,
high: pd.Series | None,
low: pd.Series | None
) -> tuple[dict, float]:
"""Run grid search on training data to find best parameters."""
entries, exits = self.strategy.run(
close, high=high, low=low, **self.param_grid
)
pf_train = vbt.Portfolio.from_signals(
close, entries, exits,
fees=self.fees,
freq=self.freq
)
perf_stats = pf_train.sharpe_ratio()
perf_stats = perf_stats.fillna(-999)
best_idx = perf_stats.idxmax()
best_score = perf_stats.max()
# Extract best params from grid search
if len(self.grid_keys) == 1:
best_params = {self.grid_keys[0]: best_idx}
elif len(self.grid_keys) > 1:
best_params = dict(zip(self.grid_keys, best_idx))
else:
best_params = {}
return best_params, best_score
def _run_test(
self,
close: pd.Series,
high: pd.Series | None,
low: pd.Series | None,
params: dict
) -> tuple[float, float, pd.Series]:
"""Run test phase with given parameters."""
entries, exits = self.strategy.run(
close, high=high, low=low, **params
)
pf_test = vbt.Portfolio.from_signals(
close, entries, exits,
fees=self.fees,
freq=self.freq
)
return pf_test.sharpe_ratio(), pf_test.total_return(), pf_test.value()
def _stitch_equity_curves(
self,
equity_curves: list[pd.Series]
) -> pd.Series | None:
"""Stitch multiple equity curves into a continuous series."""
if not equity_curves:
return None
stitched = [equity_curves[0]]
for j in range(1, len(equity_curves)):
prev_end_val = stitched[-1].iloc[-1]
curr_curve = equity_curves[j]
init_cash = curr_curve.iloc[0]
# Scale curve to continue from previous end value
scaled_curve = (curr_curve / init_cash) * prev_end_val
stitched.append(scaled_curve)
return pd.concat(stitched)

148
engine/portfolio.py Normal file
View File

@@ -0,0 +1,148 @@
"""
Portfolio simulation utilities for backtesting.
Handles long-only and long/short portfolio creation using VectorBT.
"""
import pandas as pd
import vectorbt as vbt
from engine.logging_config import get_logger
logger = get_logger(__name__)
def run_long_only_portfolio(
close: pd.Series,
entries: pd.DataFrame,
exits: pd.DataFrame,
init_cash: float,
fees: float,
slippage: float,
freq: str,
sl_stop: float | None,
tp_stop: float | None,
sl_trail: bool,
leverage: int
) -> vbt.Portfolio:
"""
Run a long-only portfolio simulation.
Args:
close: Close price series
entries: Entry signals
exits: Exit signals
init_cash: Initial capital
fees: Transaction fee percentage
slippage: Slippage percentage
freq: Data frequency string
sl_stop: Stop loss percentage
tp_stop: Take profit percentage
sl_trail: Enable trailing stop loss
leverage: Leverage multiplier
Returns:
VectorBT Portfolio object
"""
effective_cash = init_cash * leverage
return vbt.Portfolio.from_signals(
close=close,
entries=entries,
exits=exits,
init_cash=effective_cash,
fees=fees,
slippage=slippage,
freq=freq,
sl_stop=sl_stop,
tp_stop=tp_stop,
sl_trail=sl_trail,
size=1.0,
size_type='percent',
)
def run_long_short_portfolio(
close: pd.Series,
long_entries: pd.DataFrame,
long_exits: pd.DataFrame,
short_entries: pd.DataFrame,
short_exits: pd.DataFrame,
init_cash: float,
fees: float,
slippage: float,
freq: str,
sl_stop: float | None,
tp_stop: float | None,
sl_trail: bool,
leverage: int
) -> vbt.Portfolio:
"""
Run a portfolio supporting both long and short positions.
Runs two separate portfolios (long and short) and combines results.
Each gets half the capital.
Args:
close: Close price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
init_cash: Initial capital
fees: Transaction fee percentage
slippage: Slippage percentage
freq: Data frequency string
sl_stop: Stop loss percentage
tp_stop: Take profit percentage
sl_trail: Enable trailing stop loss
leverage: Leverage multiplier
Returns:
VectorBT Portfolio object (long portfolio, short stats logged)
"""
effective_cash = init_cash * leverage
half_cash = effective_cash / 2
# Run long-only portfolio
long_pf = vbt.Portfolio.from_signals(
close=close,
entries=long_entries,
exits=long_exits,
direction='longonly',
init_cash=half_cash,
fees=fees,
slippage=slippage,
freq=freq,
sl_stop=sl_stop,
tp_stop=tp_stop,
sl_trail=sl_trail,
size=1.0,
size_type='percent',
)
# Run short-only portfolio
short_pf = vbt.Portfolio.from_signals(
close=close,
entries=short_entries,
exits=short_exits,
direction='shortonly',
init_cash=half_cash,
fees=fees,
slippage=slippage,
freq=freq,
sl_stop=sl_stop,
tp_stop=tp_stop,
sl_trail=sl_trail,
size=1.0,
size_type='percent',
)
# Log both portfolio stats
# TODO: Implement proper portfolio combination
logger.info(
"Long portfolio: %.2f%% return, Short portfolio: %.2f%% return",
long_pf.total_return().mean() * 100,
short_pf.total_return().mean() * 100
)
return long_pf

228
engine/reporting.py Normal file
View File

@@ -0,0 +1,228 @@
"""
Reporting module for backtest results.
Handles summary printing, CSV exports, and plotting.
"""
from pathlib import Path
import pandas as pd
import plotly.graph_objects as go
import vectorbt as vbt
from plotly.subplots import make_subplots
from engine.logging_config import get_logger
logger = get_logger(__name__)
class Reporter:
"""Reporter for backtest results with market-specific metrics."""
def __init__(self, output_dir: str = "backtest_logs"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
def print_summary(self, result) -> None:
"""
Print backtest summary to console via logger.
Args:
result: BacktestResult or vbt.Portfolio object
"""
(portfolio, market_type, leverage, funding_paid,
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
# Extract period info
idx = portfolio.wrapper.index
start_date = idx[0].strftime("%Y-%m-%d")
end_date = idx[-1].strftime("%Y-%m-%d")
# Extract price info
close = portfolio.close
start_price = close.iloc[0].mean() if hasattr(close.iloc[0], 'mean') else close.iloc[0]
end_price = close.iloc[-1].mean() if hasattr(close.iloc[-1], 'mean') else close.iloc[-1]
price_change = ((end_price - start_price) / start_price) * 100
# Extract fees
stats = portfolio.stats()
total_fees = stats.get('Total Fees Paid', 0)
raw_return = portfolio.total_return().mean() * 100
# Build summary
summary_lines = [
"",
"=" * 50,
"BACKTEST RESULTS",
"=" * 50,
f"Market Type: [{market_type.upper()}]",
f"Leverage: [{leverage}x]",
f"Period: [{start_date} to {end_date}]",
f"Price: [{start_price:,.2f} -> {end_price:,.2f} ({price_change:+.2f}%)]",
]
# Show adjusted return if liquidations occurred
if liq_count > 0 and adjusted_return is not None:
summary_lines.append(f"Raw Return: [%{raw_return:.2f}] (before liq adjustment)")
summary_lines.append(f"Adj Return: [%{adjusted_return:.2f}] (after liq losses)")
else:
summary_lines.append(f"Total Return: [%{raw_return:.2f}]")
summary_lines.extend([
f"Sharpe Ratio: [{portfolio.sharpe_ratio().mean():.2f}]",
f"Max Drawdown: [%{portfolio.max_drawdown().mean() * 100:.2f}]",
f"Total Trades: [{portfolio.trades.count().mean():.0f}]",
f"Win Rate: [%{portfolio.trades.win_rate().mean() * 100:.2f}]",
f"Total Fees: [{total_fees:,.2f}]",
])
if funding_paid != 0:
summary_lines.append(f"Funding Paid: [{funding_paid:,.2f}]")
if liq_count > 0:
summary_lines.append(f"Liquidations: [{liq_count}] (${liq_loss:,.2f} margin lost)")
summary_lines.append("=" * 50)
logger.info("\n".join(summary_lines))
def save_reports(self, result, filename_prefix: str) -> None:
"""
Save trade log, stats, and liquidation events to CSV files.
Args:
result: BacktestResult or vbt.Portfolio object
filename_prefix: Prefix for output filenames
"""
(portfolio, market_type, leverage, funding_paid,
liq_count, liq_loss, adjusted_return) = self._extract_result_data(result)
# Save trades
self._save_csv(
data=portfolio.trades.records_readable,
path=self.output_dir / f"{filename_prefix}_trades.csv",
description="trade log"
)
# Save stats with market-specific additions
stats = portfolio.stats()
stats['Market Type'] = market_type
stats['Leverage'] = leverage
stats['Total Funding Paid'] = funding_paid
stats['Liquidations'] = liq_count
stats['Liquidation Loss'] = liq_loss
if adjusted_return is not None:
stats['Adjusted Return'] = adjusted_return
self._save_csv(
data=stats,
path=self.output_dir / f"{filename_prefix}_stats.csv",
description="stats"
)
# Save liquidation events if any
if hasattr(result, 'liquidation_events') and result.liquidation_events:
liq_df = pd.DataFrame([
{
'entry_time': e.entry_time,
'entry_price': e.entry_price,
'liquidation_time': e.liquidation_time,
'liquidation_price': e.liquidation_price,
'actual_price': e.actual_price,
'direction': e.direction,
'margin_lost_pct': e.margin_lost_pct
}
for e in result.liquidation_events
])
self._save_csv(
data=liq_df,
path=self.output_dir / f"{filename_prefix}_liquidations.csv",
description="liquidation events"
)
def plot(self, portfolio: vbt.Portfolio, show: bool = True) -> None:
"""Display portfolio plot."""
if show:
portfolio.plot().show()
def plot_wfa(
self,
wfa_results: pd.DataFrame,
stitched_curve: pd.Series | None = None,
show: bool = True
) -> None:
"""
Plot Walk-Forward Analysis results.
Args:
wfa_results: DataFrame with WFA window results
stitched_curve: Stitched out-of-sample equity curve
show: Whether to display the plot
"""
fig = make_subplots(
rows=2, cols=1,
shared_xaxes=False,
vertical_spacing=0.1,
subplot_titles=(
"Walk-Forward Test Scores (Sharpe)",
"Stitched Out-of-Sample Equity"
)
)
fig.add_trace(
go.Bar(
x=wfa_results['window'],
y=wfa_results['test_score'],
name="Test Sharpe"
),
row=1, col=1
)
if stitched_curve is not None:
fig.add_trace(
go.Scatter(
x=stitched_curve.index,
y=stitched_curve.values,
name="Equity",
mode='lines'
),
row=2, col=1
)
fig.update_layout(height=800, title_text="Walk-Forward Analysis Report")
if show:
fig.show()
def _extract_result_data(self, result) -> tuple:
"""
Extract data from BacktestResult or raw Portfolio.
Returns:
Tuple of (portfolio, market_type, leverage, funding_paid, liq_count,
liq_loss, adjusted_return)
"""
if hasattr(result, 'portfolio'):
return (
result.portfolio,
result.market_type.value,
result.leverage,
result.total_funding_paid,
result.liquidation_count,
getattr(result, 'total_liquidation_loss', 0.0),
getattr(result, 'adjusted_return', None)
)
return (result, "unknown", 1, 0.0, 0, 0.0, None)
def _save_csv(self, data, path: Path, description: str) -> None:
"""
Save data to CSV with consistent error handling.
Args:
data: DataFrame or Series to save
path: Output file path
description: Human-readable description for logging
"""
try:
data.to_csv(path)
logger.info("Saved %s to %s", description, path)
except Exception as e:
logger.error("Could not save %s: %s", description, e)

395
engine/risk.py Normal file
View File

@@ -0,0 +1,395 @@
"""
Risk management utilities for backtesting.
Handles funding rate calculations and liquidation detection.
"""
from dataclasses import dataclass
import pandas as pd
from engine.logging_config import get_logger
from engine.market import MarketConfig, calculate_liquidation_price
logger = get_logger(__name__)
@dataclass
class LiquidationEvent:
"""
Record of a liquidation event during backtesting.
Attributes:
entry_time: Timestamp when position was opened
entry_price: Price at position entry
liquidation_time: Timestamp when liquidation occurred
liquidation_price: Calculated liquidation price
actual_price: Actual price that triggered liquidation (high/low)
direction: 'long' or 'short'
margin_lost_pct: Percentage of margin lost (typically 100%)
"""
entry_time: pd.Timestamp
entry_price: float
liquidation_time: pd.Timestamp
liquidation_price: float
actual_price: float
direction: str
margin_lost_pct: float = 1.0
def calculate_funding(
close: pd.Series,
long_entries: pd.DataFrame,
long_exits: pd.DataFrame,
short_entries: pd.DataFrame,
short_exits: pd.DataFrame,
market_config: MarketConfig,
leverage: int
) -> float:
"""
Calculate total funding paid/received for perpetual positions.
Simplified model: applies funding rate every 8 hours to open positions.
Positive rate means longs pay shorts.
Args:
close: Price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
market_config: Market configuration with funding parameters
leverage: Position leverage
Returns:
Total funding paid (positive) or received (negative)
"""
if market_config.funding_interval_hours == 0:
return 0.0
funding_rate = market_config.funding_rate
interval_hours = market_config.funding_interval_hours
# Determine position state at each bar
long_position = long_entries.cumsum() - long_exits.cumsum()
short_position = short_entries.cumsum() - short_exits.cumsum()
# Clamp to 0/1 (either in position or not)
long_position = (long_position > 0).astype(int)
short_position = (short_position > 0).astype(int)
# Find funding timestamps (every 8 hours: 00:00, 08:00, 16:00 UTC)
funding_times = close.index[close.index.hour % interval_hours == 0]
total_funding = 0.0
for ts in funding_times:
if ts not in close.index:
continue
price = close.loc[ts]
# Long pays funding, short receives (when rate > 0)
if isinstance(long_position, pd.DataFrame):
long_open = long_position.loc[ts].any()
short_open = short_position.loc[ts].any()
else:
long_open = long_position.loc[ts] > 0
short_open = short_position.loc[ts] > 0
position_value = price * leverage
if long_open:
total_funding += position_value * funding_rate
if short_open:
total_funding -= position_value * funding_rate
return total_funding
def inject_liquidation_exits(
close: pd.Series,
high: pd.Series,
low: pd.Series,
long_entries: pd.DataFrame | pd.Series,
long_exits: pd.DataFrame | pd.Series,
short_entries: pd.DataFrame | pd.Series,
short_exits: pd.DataFrame | pd.Series,
leverage: int,
maintenance_margin_rate: float
) -> tuple[pd.DataFrame | pd.Series, pd.DataFrame | pd.Series, list[LiquidationEvent]]:
"""
Modify exit signals to force position closure at liquidation points.
This function simulates realistic liquidation behavior by:
1. Finding positions that would be liquidated before their normal exit
2. Injecting forced exit signals at the liquidation bar
3. Recording all liquidation events
Args:
close: Close price series
high: High price series
low: Low price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
leverage: Position leverage
maintenance_margin_rate: Maintenance margin rate for liquidation
Returns:
Tuple of (modified_long_exits, modified_short_exits, liquidation_events)
"""
if leverage <= 1:
return long_exits, short_exits, []
liquidation_events: list[LiquidationEvent] = []
# Convert to DataFrame if Series for consistent handling
is_series = isinstance(long_entries, pd.Series)
if is_series:
long_entries_df = long_entries.to_frame()
long_exits_df = long_exits.to_frame()
short_entries_df = short_entries.to_frame()
short_exits_df = short_exits.to_frame()
else:
long_entries_df = long_entries
long_exits_df = long_exits.copy()
short_entries_df = short_entries
short_exits_df = short_exits.copy()
modified_long_exits = long_exits_df.copy()
modified_short_exits = short_exits_df.copy()
# Process long positions
long_mask = long_entries_df.any(axis=1)
for entry_idx in close.index[long_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=True,
maintenance_margin_rate=maintenance_margin_rate
)
# Find the normal exit for this entry
subsequent_exits = long_exits_df.loc[entry_idx:].any(axis=1)
exit_indices = subsequent_exits[subsequent_exits].index
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
# Check if liquidation occurs before normal exit
price_range = low.loc[entry_idx:normal_exit_idx]
if (price_range < liq_price).any():
liq_bar = price_range[price_range < liq_price].index[0]
# Inject forced exit at liquidation bar
for col in modified_long_exits.columns:
modified_long_exits.loc[liq_bar, col] = True
# Record the liquidation event
liquidation_events.append(LiquidationEvent(
entry_time=entry_idx,
entry_price=entry_price,
liquidation_time=liq_bar,
liquidation_price=liq_price,
actual_price=low.loc[liq_bar],
direction='long',
margin_lost_pct=1.0
))
logger.warning(
"LIQUIDATION (Long): Entry %s ($%.2f) -> Liquidated %s "
"(liq=$%.2f, low=$%.2f)",
entry_idx.strftime('%Y-%m-%d'), entry_price,
liq_bar.strftime('%Y-%m-%d'), liq_price, low.loc[liq_bar]
)
# Process short positions
short_mask = short_entries_df.any(axis=1)
for entry_idx in close.index[short_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=False,
maintenance_margin_rate=maintenance_margin_rate
)
# Find the normal exit for this entry
subsequent_exits = short_exits_df.loc[entry_idx:].any(axis=1)
exit_indices = subsequent_exits[subsequent_exits].index
normal_exit_idx = exit_indices[0] if len(exit_indices) > 0 else close.index[-1]
# Check if liquidation occurs before normal exit
price_range = high.loc[entry_idx:normal_exit_idx]
if (price_range > liq_price).any():
liq_bar = price_range[price_range > liq_price].index[0]
# Inject forced exit at liquidation bar
for col in modified_short_exits.columns:
modified_short_exits.loc[liq_bar, col] = True
# Record the liquidation event
liquidation_events.append(LiquidationEvent(
entry_time=entry_idx,
entry_price=entry_price,
liquidation_time=liq_bar,
liquidation_price=liq_price,
actual_price=high.loc[liq_bar],
direction='short',
margin_lost_pct=1.0
))
logger.warning(
"LIQUIDATION (Short): Entry %s ($%.2f) -> Liquidated %s "
"(liq=$%.2f, high=$%.2f)",
entry_idx.strftime('%Y-%m-%d'), entry_price,
liq_bar.strftime('%Y-%m-%d'), liq_price, high.loc[liq_bar]
)
# Convert back to Series if input was Series
if is_series:
modified_long_exits = modified_long_exits.iloc[:, 0]
modified_short_exits = modified_short_exits.iloc[:, 0]
return modified_long_exits, modified_short_exits, liquidation_events
def calculate_liquidation_adjustment(
liquidation_events: list[LiquidationEvent],
init_cash: float,
leverage: int
) -> tuple[float, float]:
"""
Calculate the return adjustment for liquidated positions.
VectorBT calculates trade P&L using close price at exit bar.
For liquidations, the actual loss is 100% of the position margin.
This function calculates the difference between what VectorBT
recorded and what actually would have happened.
In our portfolio setup:
- Long/short each get half the capital (init_cash * leverage / 2)
- Each trade uses 100% of that allocation (size=1.0, percent)
- On liquidation, the margin for that trade is lost entirely
The adjustment is the DIFFERENCE between:
- VectorBT's calculated P&L (exit at close price)
- Actual liquidation P&L (100% margin loss)
Args:
liquidation_events: List of liquidation events
init_cash: Initial portfolio cash (before leverage)
leverage: Position leverage used
Returns:
Tuple of (total_margin_lost, adjustment_pct)
- total_margin_lost: Estimated total margin lost from liquidations
- adjustment_pct: Percentage adjustment to apply to returns
"""
if not liquidation_events:
return 0.0, 0.0
# In our setup, each side (long/short) gets half the capital
# Margin per side = init_cash / 2
margin_per_side = init_cash / 2
# For each liquidation, VectorBT recorded some P&L based on close price
# The actual P&L should be -100% of the margin used for that trade
#
# We estimate the adjustment as:
# - Each liquidation should have resulted in ~-20% loss (at 5x leverage)
# - VectorBT may have recorded a different value
# - The margin loss is (1/leverage) per trade that gets liquidated
# Calculate liquidation loss rate based on leverage
# At 5x leverage, liquidation = ~19.6% adverse move = 100% margin loss
liq_loss_rate = 1.0 / leverage # Approximate loss per trade as % of position
# Count liquidations
n_liquidations = len(liquidation_events)
# Estimate total margin lost:
# Each liquidation on average loses the margin for that trade
# Since VectorBT uses half capital per side, and we trade 100% size,
# each liquidation loses approximately margin_per_side
# But we cap at available capital
total_margin_lost = min(n_liquidations * margin_per_side * liq_loss_rate, init_cash)
# Calculate as percentage of initial capital
adjustment_pct = (total_margin_lost / init_cash) * 100
return total_margin_lost, adjustment_pct
def check_liquidations(
close: pd.Series,
high: pd.Series,
low: pd.Series,
long_entries: pd.DataFrame,
long_exits: pd.DataFrame,
short_entries: pd.DataFrame,
short_exits: pd.DataFrame,
leverage: int,
maintenance_margin_rate: float
) -> int:
"""
Check for liquidation events and log warnings.
Args:
close: Close price series
high: High price series
low: Low price series
long_entries: Long entry signals
long_exits: Long exit signals
short_entries: Short entry signals
short_exits: Short exit signals
leverage: Position leverage
maintenance_margin_rate: Maintenance margin rate for liquidation
Returns:
Count of liquidation warnings
"""
warnings = 0
# For long positions
long_mask = (
long_entries.any(axis=1)
if isinstance(long_entries, pd.DataFrame)
else long_entries
)
for entry_idx in close.index[long_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=True,
maintenance_margin_rate=maintenance_margin_rate
)
subsequent = low.loc[entry_idx:]
if (subsequent < liq_price).any():
liq_bar = subsequent[subsequent < liq_price].index[0]
logger.warning(
"LIQUIDATION WARNING (Long): Entry at %s ($%.2f), "
"would liquidate at %s (liq_price=$%.2f, low=$%.2f)",
entry_idx, entry_price, liq_bar, liq_price, low.loc[liq_bar]
)
warnings += 1
# For short positions
short_mask = (
short_entries.any(axis=1)
if isinstance(short_entries, pd.DataFrame)
else short_entries
)
for entry_idx in close.index[short_mask]:
entry_price = close.loc[entry_idx]
liq_price = calculate_liquidation_price(
entry_price, leverage, is_long=False,
maintenance_margin_rate=maintenance_margin_rate
)
subsequent = high.loc[entry_idx:]
if (subsequent > liq_price).any():
liq_bar = subsequent[subsequent > liq_price].index[0]
logger.warning(
"LIQUIDATION WARNING (Short): Entry at %s ($%.2f), "
"would liquidate at %s (liq_price=$%.2f, high=$%.2f)",
entry_idx, entry_price, liq_bar, liq_price, high.loc[liq_bar]
)
warnings += 1
return warnings

View File

@@ -1,3 +0,0 @@
from .supertrend import add_supertrends, compute_meta_trend
__all__ = ["add_supertrends", "compute_meta_trend"]

View File

@@ -1,58 +0,0 @@
from __future__ import annotations
import pandas as pd
import numpy as np
def _atr(high: pd.Series, low: pd.Series, close: pd.Series, period: int) -> pd.Series:
hl = (high - low).abs()
hc = (high - close.shift()).abs()
lc = (low - close.shift()).abs()
tr = pd.concat([hl, hc, lc], axis=1).max(axis=1)
return tr.rolling(period, min_periods=period).mean()
def supertrend_series(df: pd.DataFrame, length: int, multiplier: float) -> pd.Series:
atr = _atr(df["High"], df["Low"], df["Close"], length)
hl2 = (df["High"] + df["Low"]) / 2
upper = hl2 + multiplier * atr
lower = hl2 - multiplier * atr
trend = pd.Series(index=df.index, dtype=float)
dir_up = True
prev_upper = np.nan
prev_lower = np.nan
for i in range(len(df)):
if i == 0 or pd.isna(atr.iat[i]):
trend.iat[i] = np.nan
prev_upper = upper.iat[i]
prev_lower = lower.iat[i]
continue
cu = min(upper.iat[i], prev_upper) if dir_up else upper.iat[i]
cl = max(lower.iat[i], prev_lower) if not dir_up else lower.iat[i]
if df["Close"].iat[i] > cu:
dir_up = True
elif df["Close"].iat[i] < cl:
dir_up = False
prev_upper = cu if dir_up else upper.iat[i]
prev_lower = lower.iat[i] if dir_up else cl
trend.iat[i] = cl if dir_up else cu
return trend
def add_supertrends(df: pd.DataFrame, settings: list[tuple[int, float]]) -> pd.DataFrame:
out = df.copy()
for length, mult in settings:
col = f"supertrend_{length}_{mult}"
out[col] = supertrend_series(out, length, mult)
out[f"bull_{length}_{mult}"] = (out["Close"] >= out[col]).astype(int)
return out
def compute_meta_trend(df: pd.DataFrame, settings: list[tuple[int, float]]) -> pd.Series:
bull_cols = [f"bull_{l}_{m}" for l, m in settings]
return (df[bull_cols].sum(axis=1) == len(bull_cols)).astype(int)

View File

@@ -1,10 +0,0 @@
from __future__ import annotations
import pandas as pd
def precompute_slices(df: pd.DataFrame) -> pd.DataFrame:
return df # hook for future use
def entry_slippage_row(price: float, qty: float, slippage_bps: float) -> float:
return price + price * (slippage_bps / 1e4)

View File

@@ -1,11 +0,0 @@
from __future__ import annotations
from pathlib import Path
import pandas as pd
def write_trade_log(trades: list[dict], path: Path) -> None:
if not trades:
return
df = pd.DataFrame(trades)
path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(path, index=False)

10
main.py Normal file
View File

@@ -0,0 +1,10 @@
"""
Lowkey Backtest CLI - VectorBT Edition
A backtesting framework supporting multiple market types (spot, perpetual)
with realistic trading simulation including leverage, funding, and shorts.
"""
from engine.cli import main
if __name__ == "__main__":
main()

View File

@@ -1,11 +0,0 @@
from __future__ import annotations
TAKER_FEE_BPS_DEFAULT = 10.0 # 0.10%
def okx_fee(fee_bps: float, notional_usd: float) -> float:
return notional_usd * (fee_bps / 1e4)
def estimate_slippage_rate(slippage_bps: float, notional_usd: float) -> float:
return notional_usd * (slippage_bps / 1e4)

View File

@@ -1,54 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import pandas as pd
@dataclass
class Perf:
total_return: float
max_drawdown: float
sharpe_ratio: float
win_rate: float
num_trades: int
final_equity: float
initial_equity: float
num_stop_losses: int
total_fees: float
total_slippage_usd: float
avg_slippage_bps: float
def compute_metrics(equity_curve: pd.Series, trades: list[dict]) -> Perf:
ret = equity_curve.pct_change().fillna(0.0)
total_return = equity_curve.iat[-1] / equity_curve.iat[0] - 1.0
cummax = equity_curve.cummax()
dd = (equity_curve / cummax - 1.0).min()
max_drawdown = dd
if ret.std(ddof=0) > 0:
sharpe = (ret.mean() / ret.std(ddof=0)) * np.sqrt(252 * 24 * 60) # minute bars -> annualized
else:
sharpe = 0.0
closes = [t for t in trades if t.get("side") == "SELL"]
wins = [t for t in closes if t.get("pnl", 0.0) > 0]
win_rate = (len(wins) / len(closes)) if closes else 0.0
fees = sum(t.get("fee", 0.0) for t in trades)
slip = sum(t.get("slippage", 0.0) for t in trades)
slippage_bps = [t.get("slippage_bps", 0.0) for t in trades if "slippage_bps" in t]
return Perf(
total_return=total_return,
max_drawdown=max_drawdown,
sharpe_ratio=sharpe,
win_rate=win_rate,
num_trades=len(closes),
final_equity=float(equity_curve.iat[-1]),
initial_equity=float(equity_curve.iat[0]),
num_stop_losses=sum(1 for t in closes if t.get("reason") == "stop"),
total_fees=fees,
total_slippage_usd=slip,
avg_slippage_bps=float(np.mean(slippage_bps)) if slippage_bps else 0.0,
)

View File

@@ -5,5 +5,20 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"ccxt>=4.5.32",
"numpy>=2.3.2",
"pandas>=2.3.1",
"ta>=0.11.0", "ta>=0.11.0",
"vectorbt>=0.28.2",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
]
[tool.pytest.ini_options]
pythonpath = ["."]
markers = [
"network: marks tests as requiring network access",
] ]

80
strategies/base.py Normal file
View File

@@ -0,0 +1,80 @@
"""
Base strategy class for all trading strategies.
Strategies should inherit from BaseStrategy and implement the run() method.
"""
from abc import ABC, abstractmethod
import pandas as pd
from engine.market import MarketType
class BaseStrategy(ABC):
"""
Abstract base class for trading strategies.
Class Attributes:
default_market_type: Default market type for this strategy
default_leverage: Default leverage (only applies to perpetuals)
default_sl_stop: Default stop-loss percentage
default_tp_stop: Default take-profit percentage
default_sl_trail: Whether stop-loss is trailing by default
"""
# Market configuration defaults
default_market_type: MarketType = MarketType.SPOT
default_leverage: int = 1
# Risk management defaults (can be overridden per strategy)
default_sl_stop: float | None = None
default_tp_stop: float | None = None
default_sl_trail: bool = False
def __init__(self, **kwargs):
self.params = kwargs
@abstractmethod
def run(
self,
close: pd.Series,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Run the strategy logic.
Args:
close: Price series (can be multiple columns for grid search)
**kwargs: Additional data (high, low, open, volume) and parameters
Returns:
Tuple of 4 DataFrames/Series:
- long_entries: Boolean signals to open long positions
- long_exits: Boolean signals to close long positions
- short_entries: Boolean signals to open short positions
- short_exits: Boolean signals to close short positions
Note:
For spot markets, short signals will be ignored.
For backward compatibility, strategies can return 2-tuple (entries, exits)
which will be interpreted as long-only signals.
"""
pass
def get_indicator(self, ind_cls, *args, **kwargs):
"""Helper to run a vectorbt indicator."""
return ind_cls.run(*args, **kwargs)
@staticmethod
def create_empty_signals(reference: pd.Series | pd.DataFrame) -> pd.DataFrame:
"""
Create an empty (all False) signal DataFrame matching the reference shape.
Args:
reference: Series or DataFrame to match shape/index
Returns:
DataFrame of False values with same shape as reference
"""
if isinstance(reference, pd.DataFrame):
return pd.DataFrame(False, index=reference.index, columns=reference.columns)
return pd.Series(False, index=reference.index)

97
strategies/examples.py Normal file
View File

@@ -0,0 +1,97 @@
"""
Example trading strategies for backtesting.
These are simple strategies demonstrating the framework usage.
"""
import pandas as pd
import vectorbt as vbt
from engine.market import MarketType
from strategies.base import BaseStrategy
class RsiStrategy(BaseStrategy):
"""
RSI mean-reversion strategy.
Long entry when RSI crosses below oversold level.
Long exit when RSI crosses above overbought level.
"""
default_market_type = MarketType.SPOT
default_leverage = 1
def run(
self,
close: pd.Series,
period: int = 14,
rsi_lower: int = 30,
rsi_upper: int = 70,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Generate RSI-based trading signals.
Args:
close: Price series
period: RSI calculation period
rsi_lower: Oversold threshold (buy signal)
rsi_upper: Overbought threshold (sell signal)
Returns:
4-tuple of (long_entries, long_exits, short_entries, short_exits)
"""
# Calculate RSI
rsi = vbt.RSI.run(close, window=period)
# Long signals: buy oversold, sell overbought
long_entries = rsi.rsi_crossed_below(rsi_lower)
long_exits = rsi.rsi_crossed_above(rsi_upper)
# No short signals for this strategy (spot-focused)
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits
class MaCrossStrategy(BaseStrategy):
"""
Moving Average crossover strategy.
Long entry when fast MA crosses above slow MA.
Long exit when fast MA crosses below slow MA.
"""
default_market_type = MarketType.SPOT
default_leverage = 1
def run(
self,
close: pd.Series,
fast_window: int = 10,
slow_window: int = 20,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Generate MA crossover trading signals.
Args:
close: Price series
fast_window: Fast MA period
slow_window: Slow MA period
Returns:
4-tuple of (long_entries, long_exits, short_entries, short_exits)
"""
# Calculate Moving Averages
fast_ma = vbt.MA.run(close, window=fast_window)
slow_ma = vbt.MA.run(close, window=slow_window)
# Long signals
long_entries = fast_ma.ma_crossed_above(slow_ma)
long_exits = fast_ma.ma_crossed_below(slow_ma)
# No short signals for this strategy
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits

128
strategies/factory.py Normal file
View File

@@ -0,0 +1,128 @@
"""
Strategy factory for creating strategy instances with their parameters.
Centralizes strategy creation and parameter configuration.
"""
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from strategies.base import BaseStrategy
@dataclass
class StrategyConfig:
"""
Configuration for a strategy including default and grid parameters.
Attributes:
strategy_class: The strategy class to instantiate
default_params: Parameters for single backtest runs
grid_params: Parameters for grid search optimization
"""
strategy_class: type[BaseStrategy]
default_params: dict[str, Any] = field(default_factory=dict)
grid_params: dict[str, Any] = field(default_factory=dict)
def _build_registry() -> dict[str, StrategyConfig]:
"""
Build the strategy registry lazily to avoid circular imports.
Returns:
Dictionary mapping strategy names to their configurations
"""
# Import here to avoid circular imports
from strategies.examples import MaCrossStrategy, RsiStrategy
from strategies.supertrend import MetaSupertrendStrategy
return {
"rsi": StrategyConfig(
strategy_class=RsiStrategy,
default_params={
'period': 14,
'rsi_lower': 30,
'rsi_upper': 70
},
grid_params={
'period': np.arange(10, 25, 2),
'rsi_lower': [20, 30, 40],
'rsi_upper': [60, 70, 80]
}
),
"macross": StrategyConfig(
strategy_class=MaCrossStrategy,
default_params={
'fast_window': 10,
'slow_window': 20
},
grid_params={
'fast_window': np.arange(5, 20, 5),
'slow_window': np.arange(20, 60, 10)
}
),
"meta_st": StrategyConfig(
strategy_class=MetaSupertrendStrategy,
default_params={
'period1': 12, 'multiplier1': 3.0,
'period2': 10, 'multiplier2': 1.0,
'period3': 11, 'multiplier3': 2.0
},
grid_params={
'multiplier1': [2.0, 3.0, 4.0],
'period1': [10, 12, 14],
'period2': 11, 'multiplier2': 2.0,
'period3': 12, 'multiplier3': 1.0
}
),
}
# Module-level cache for the registry
_REGISTRY_CACHE: dict[str, StrategyConfig] | None = None
def get_registry() -> dict[str, StrategyConfig]:
"""Get the strategy registry, building it on first access."""
global _REGISTRY_CACHE
if _REGISTRY_CACHE is None:
_REGISTRY_CACHE = _build_registry()
return _REGISTRY_CACHE
def get_strategy_names() -> list[str]:
"""
Get list of available strategy names.
Returns:
List of strategy name strings
"""
return list(get_registry().keys())
def get_strategy(name: str, is_grid: bool = False) -> tuple[BaseStrategy, dict[str, Any]]:
"""
Create a strategy instance with appropriate parameters.
Args:
name: Strategy identifier (e.g., 'rsi', 'macross', 'meta_st')
is_grid: If True, return grid search parameters
Returns:
Tuple of (strategy instance, parameters dict)
Raises:
KeyError: If strategy name is not found in registry
"""
registry = get_registry()
if name not in registry:
available = ", ".join(registry.keys())
raise KeyError(f"Unknown strategy '{name}'. Available: {available}")
config = registry[name]
strategy = config.strategy_class()
params = config.grid_params if is_grid else config.default_params
return strategy, params.copy()

View File

@@ -0,0 +1,6 @@
"""
Meta Supertrend strategy package.
"""
from .strategy import MetaSupertrendStrategy
__all__ = ['MetaSupertrendStrategy']

View File

@@ -0,0 +1,128 @@
"""
Supertrend indicators and helper functions.
"""
import numpy as np
import vectorbt as vbt
from numba import njit
# --- Numba Compiled Helper Functions ---
@njit(cache=False) # Disable cache to avoid stale compilation issues
def get_tr_nb(high, low, close):
"""Calculate True Range (Numba compiled)."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
tr = np.empty_like(close)
tr[0] = high[0] - low[0]
for i in range(1, len(close)):
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
return tr
@njit(cache=False)
def get_atr_nb(high, low, close, period):
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
# Ensure period is native Python int (critical for Numba array indexing)
n = len(close)
p = int(period)
tr = get_tr_nb(high, low, close)
atr = np.full(n, np.nan, dtype=np.float64)
if n < p:
return atr
# Initial ATR is simple average of TR
sum_tr = 0.0
for i in range(p):
sum_tr += tr[i]
atr[p - 1] = sum_tr / p
# Subsequent ATR is Wilder's smoothed
for i in range(p, n):
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
return atr
@njit(cache=False)
def get_supertrend_nb(high, low, close, period, multiplier):
"""Calculate SuperTrend completely in Numba."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
# Ensure params are native Python types (critical for Numba)
n = len(close)
p = int(period)
m = float(multiplier)
atr = get_atr_nb(high, low, close, p)
final_upper = np.full(n, np.nan, dtype=np.float64)
final_lower = np.full(n, np.nan, dtype=np.float64)
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
# Skip until we have valid ATR
start_idx = p
if start_idx >= n:
return trend
# Init first valid point
hl2 = (high[start_idx] + low[start_idx]) / 2
final_upper[start_idx] = hl2 + m * atr[start_idx]
final_lower[start_idx] = hl2 - m * atr[start_idx]
# Loop
for i in range(start_idx + 1, n):
cur_hl2 = (high[i] + low[i]) / 2
cur_atr = atr[i]
basic_upper = cur_hl2 + m * cur_atr
basic_lower = cur_hl2 - m * cur_atr
# Upper Band Logic
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
final_upper[i] = basic_upper
else:
final_upper[i] = final_upper[i-1]
# Lower Band Logic
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
final_lower[i] = basic_lower
else:
final_lower[i] = final_lower[i-1]
# Trend Logic
if trend[i-1] == 1:
if close[i] < final_lower[i-1]:
trend[i] = -1
else:
trend[i] = 1
else:
if close[i] > final_upper[i-1]:
trend[i] = 1
else:
trend[i] = -1
return trend
# --- VectorBT Indicator Factory ---
SuperTrendIndicator = vbt.IndicatorFactory(
class_name='SuperTrend',
short_name='st',
input_names=['high', 'low', 'close'],
param_names=['period', 'multiplier'],
output_names=['trend']
).from_apply_func(
get_supertrend_nb,
keep_pd=False, # Disable automatic Pandas wrapping of inputs
param_product=True # Enable Cartesian product for list params
)

View File

@@ -0,0 +1,142 @@
"""
Meta Supertrend strategy implementation.
"""
import numpy as np
import pandas as pd
from engine.market import MarketType
from strategies.base import BaseStrategy
from .indicators import SuperTrendIndicator
class MetaSupertrendStrategy(BaseStrategy):
"""
Meta Supertrend Strategy using 3 Supertrend indicators.
Enters long when all 3 Supertrends are bullish.
Enters short when all 3 Supertrends are bearish.
Designed for perpetual futures with leverage and short-selling support.
"""
# Market configuration
default_market_type = MarketType.PERPETUAL
default_leverage = 5
# Risk management parameters
default_sl_stop = 0.02 # 2% stop loss
default_sl_trail = True # Trailing stop enabled
default_exit_on_bearish_flip = False # Rely on SL/TP, not bearish flip
def run(
self,
close: pd.Series,
high: pd.Series = None,
low: pd.Series = None,
period1: int = 10,
multiplier1: float = 3.0,
period2: int = 11,
multiplier2: float = 2.0,
period3: int = 12,
multiplier3: float = 1.0,
exit_on_bearish_flip: bool = None,
enable_short: bool = True,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
# 1. Validation & Setup
if exit_on_bearish_flip is None:
exit_on_bearish_flip = self.default_exit_on_bearish_flip
if high is None or low is None:
raise ValueError("MetaSupertrendStrategy requires High and Low prices.")
# 2. Calculate Supertrends
t1, t2, t3 = self._calculate_supertrends(
high, low, close,
period1, multiplier1,
period2, multiplier2,
period3, multiplier3
)
# 3. Meta Signals
bullish, bearish = self._calculate_meta_signals(t1, t2, t3, close)
# 4. Generate Entry/Exit Signals
return self._generate_signals(bullish, bearish, exit_on_bearish_flip, enable_short)
def _calculate_supertrends(
self, high, low, close, p1, m1, p2, m2, p3, m3
):
"""Run the 3 Supertrend indicators."""
# Pass NumPy arrays explicitly to avoid Numba typing errors
h_vals = high.values
l_vals = low.values
c_vals = close.values
def run_st(p, m):
st = SuperTrendIndicator.run(h_vals, l_vals, c_vals, period=p, multiplier=m)
trend = st.trend
if isinstance(trend, pd.DataFrame):
trend.index = close.index
if trend.shape[1] == 1:
trend = trend.iloc[:, 0]
elif isinstance(trend, pd.Series):
trend.index = close.index
return trend
t1 = run_st(p1, m1)
t2 = run_st(p2, m2)
t3 = run_st(p3, m3)
return t1, t2, t3
def _calculate_meta_signals(self, t1, t2, t3, close_series):
"""Combine 3 Supertrends into boolean Bullish/Bearish signals."""
# Use NumPy broadcasting
t1_vals = t1.values if isinstance(t1, pd.DataFrame) else t1.values.reshape(-1, 1)
# Force column vectors for broadcasting if scalar result
t2_vals = t2.values.reshape(-1, 1)
t3_vals = t3.values.reshape(-1, 1)
# Boolean logic on numpy arrays (1 = Bull, -1 = Bear)
bullish_vals = (t1_vals == 1) & (t2_vals == 1) & (t3_vals == 1)
bearish_vals = (t1_vals == -1) & (t2_vals == -1) & (t3_vals == -1)
# Reconstruct Pandas objects
if isinstance(t1, pd.DataFrame):
bullish = pd.DataFrame(bullish_vals, index=t1.index, columns=t1.columns)
bearish = pd.DataFrame(bearish_vals, index=t1.index, columns=t1.columns)
else:
bullish = pd.Series(bullish_vals.flatten(), index=t1.index)
bearish = pd.Series(bearish_vals.flatten(), index=t1.index)
return bullish, bearish
def _generate_signals(
self, bullish, bearish, exit_on_bearish_flip, enable_short
):
"""Generate long/short entry/exit signals based on meta trend."""
# Long Entries: Change from Not Bullish to Bullish
prev_bullish = bullish.shift(1).fillna(False)
long_entries = bullish & (~prev_bullish)
# Long Exits
if exit_on_bearish_flip:
prev_bearish = bearish.shift(1).fillna(False)
long_exits = bearish & (~prev_bearish)
else:
long_exits = BaseStrategy.create_empty_signals(long_entries)
# Short signals
if enable_short:
prev_bearish = bearish.shift(1).fillna(False)
short_entries = bearish & (~prev_bearish)
if exit_on_bearish_flip:
short_exits = bullish & (~prev_bullish)
else:
short_exits = BaseStrategy.create_empty_signals(long_entries)
else:
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits

View File

@@ -0,0 +1,6 @@
"""
Meta Supertrend strategy package.
"""
from .strategy import MetaSupertrendStrategy
__all__ = ['MetaSupertrendStrategy']

View File

@@ -0,0 +1,128 @@
"""
Supertrend indicators and helper functions.
"""
import numpy as np
import vectorbt as vbt
from numba import njit
# --- Numba Compiled Helper Functions ---
@njit(cache=False) # Disable cache to avoid stale compilation issues
def get_tr_nb(high, low, close):
"""Calculate True Range (Numba compiled)."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
tr = np.empty_like(close)
tr[0] = high[0] - low[0]
for i in range(1, len(close)):
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
return tr
@njit(cache=False)
def get_atr_nb(high, low, close, period):
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
# Ensure period is native Python int (critical for Numba array indexing)
n = len(close)
p = int(period)
tr = get_tr_nb(high, low, close)
atr = np.full(n, np.nan, dtype=np.float64)
if n < p:
return atr
# Initial ATR is simple average of TR
sum_tr = 0.0
for i in range(p):
sum_tr += tr[i]
atr[p - 1] = sum_tr / p
# Subsequent ATR is Wilder's smoothed
for i in range(p, n):
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
return atr
@njit(cache=False)
def get_supertrend_nb(high, low, close, period, multiplier):
"""Calculate SuperTrend completely in Numba."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
# Ensure params are native Python types (critical for Numba)
n = len(close)
p = int(period)
m = float(multiplier)
atr = get_atr_nb(high, low, close, p)
final_upper = np.full(n, np.nan, dtype=np.float64)
final_lower = np.full(n, np.nan, dtype=np.float64)
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
# Skip until we have valid ATR
start_idx = p
if start_idx >= n:
return trend
# Init first valid point
hl2 = (high[start_idx] + low[start_idx]) / 2
final_upper[start_idx] = hl2 + m * atr[start_idx]
final_lower[start_idx] = hl2 - m * atr[start_idx]
# Loop
for i in range(start_idx + 1, n):
cur_hl2 = (high[i] + low[i]) / 2
cur_atr = atr[i]
basic_upper = cur_hl2 + m * cur_atr
basic_lower = cur_hl2 - m * cur_atr
# Upper Band Logic
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
final_upper[i] = basic_upper
else:
final_upper[i] = final_upper[i-1]
# Lower Band Logic
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
final_lower[i] = basic_lower
else:
final_lower[i] = final_lower[i-1]
# Trend Logic
if trend[i-1] == 1:
if close[i] < final_lower[i-1]:
trend[i] = -1
else:
trend[i] = 1
else:
if close[i] > final_upper[i-1]:
trend[i] = 1
else:
trend[i] = -1
return trend
# --- VectorBT Indicator Factory ---
SuperTrendIndicator = vbt.IndicatorFactory(
class_name='SuperTrend',
short_name='st',
input_names=['high', 'low', 'close'],
param_names=['period', 'multiplier'],
output_names=['trend']
).from_apply_func(
get_supertrend_nb,
keep_pd=False, # Disable automatic Pandas wrapping of inputs
param_product=True # Enable Cartesian product for list params
)

View File

@@ -0,0 +1,142 @@
"""
Meta Supertrend strategy implementation.
"""
import numpy as np
import pandas as pd
from engine.market import MarketType
from strategies.base import BaseStrategy
from .indicators import SuperTrendIndicator
class MetaSupertrendStrategy(BaseStrategy):
"""
Meta Supertrend Strategy using 3 Supertrend indicators.
Enters long when all 3 Supertrends are bullish.
Enters short when all 3 Supertrends are bearish.
Designed for perpetual futures with leverage and short-selling support.
"""
# Market configuration
default_market_type = MarketType.PERPETUAL
default_leverage = 5
# Risk management parameters
default_sl_stop = 0.02 # 2% stop loss
default_sl_trail = True # Trailing stop enabled
default_exit_on_bearish_flip = False # Rely on SL/TP, not bearish flip
def run(
self,
close: pd.Series,
high: pd.Series = None,
low: pd.Series = None,
period1: int = 10,
multiplier1: float = 3.0,
period2: int = 11,
multiplier2: float = 2.0,
period3: int = 12,
multiplier3: float = 1.0,
exit_on_bearish_flip: bool = None,
enable_short: bool = True,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
# 1. Validation & Setup
if exit_on_bearish_flip is None:
exit_on_bearish_flip = self.default_exit_on_bearish_flip
if high is None or low is None:
raise ValueError("MetaSupertrendStrategy requires High and Low prices.")
# 2. Calculate Supertrends
t1, t2, t3 = self._calculate_supertrends(
high, low, close,
period1, multiplier1,
period2, multiplier2,
period3, multiplier3
)
# 3. Meta Signals
bullish, bearish = self._calculate_meta_signals(t1, t2, t3, close)
# 4. Generate Entry/Exit Signals
return self._generate_signals(bullish, bearish, exit_on_bearish_flip, enable_short)
def _calculate_supertrends(
self, high, low, close, p1, m1, p2, m2, p3, m3
):
"""Run the 3 Supertrend indicators."""
# Pass NumPy arrays explicitly to avoid Numba typing errors
h_vals = high.values
l_vals = low.values
c_vals = close.values
def run_st(p, m):
st = SuperTrendIndicator.run(h_vals, l_vals, c_vals, period=p, multiplier=m)
trend = st.trend
if isinstance(trend, pd.DataFrame):
trend.index = close.index
if trend.shape[1] == 1:
trend = trend.iloc[:, 0]
elif isinstance(trend, pd.Series):
trend.index = close.index
return trend
t1 = run_st(p1, m1)
t2 = run_st(p2, m2)
t3 = run_st(p3, m3)
return t1, t2, t3
def _calculate_meta_signals(self, t1, t2, t3, close_series):
"""Combine 3 Supertrends into boolean Bullish/Bearish signals."""
# Use NumPy broadcasting
t1_vals = t1.values if isinstance(t1, pd.DataFrame) else t1.values.reshape(-1, 1)
# Force column vectors for broadcasting if scalar result
t2_vals = t2.values.reshape(-1, 1)
t3_vals = t3.values.reshape(-1, 1)
# Boolean logic on numpy arrays (1 = Bull, -1 = Bear)
bullish_vals = (t1_vals == 1) & (t2_vals == 1) & (t3_vals == 1)
bearish_vals = (t1_vals == -1) & (t2_vals == -1) & (t3_vals == -1)
# Reconstruct Pandas objects
if isinstance(t1, pd.DataFrame):
bullish = pd.DataFrame(bullish_vals, index=t1.index, columns=t1.columns)
bearish = pd.DataFrame(bearish_vals, index=t1.index, columns=t1.columns)
else:
bullish = pd.Series(bullish_vals.flatten(), index=t1.index)
bearish = pd.Series(bearish_vals.flatten(), index=t1.index)
return bullish, bearish
def _generate_signals(
self, bullish, bearish, exit_on_bearish_flip, enable_short
):
"""Generate long/short entry/exit signals based on meta trend."""
# Long Entries: Change from Not Bullish to Bullish
prev_bullish = bullish.shift(1).fillna(False)
long_entries = bullish & (~prev_bullish)
# Long Exits
if exit_on_bearish_flip:
prev_bearish = bearish.shift(1).fillna(False)
long_exits = bearish & (~prev_bearish)
else:
long_exits = BaseStrategy.create_empty_signals(long_entries)
# Short signals
if enable_short:
prev_bearish = bearish.shift(1).fillna(False)
short_entries = bearish & (~prev_bearish)
if exit_on_bearish_flip:
short_exits = bullish & (~prev_bullish)
else:
short_exits = BaseStrategy.create_empty_signals(long_entries)
else:
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits

View File

@@ -0,0 +1,295 @@
# PRD: Market Type Selection for Backtesting
## Introduction/Overview
Currently, the backtesting system operates with a single, implicit market type assumption. This PRD defines the implementation of **market type selection** (Spot vs. USDT-M Perpetual Futures) to enable realistic simulation of different trading conditions.
**Problem Statement:**
- Strategies cannot be backtested against different market mechanics (leverage, funding, short-selling)
- Fee structures are uniform regardless of market type
- No support for short-selling strategies
- Data fetching doesn't distinguish between spot and futures markets
**Goal:**
Enable users to backtest strategies against specific market types (Spot or USDT-M Perpetual) with realistic trading conditions matching OKX's live environment.
---
## Goals
1. **Support two market types:** Spot and USDT-M Perpetual Futures
2. **Realistic fee simulation:** Match OKX's fee structure per market type
3. **Leverage support:** Per-strategy configurable leverage (perpetuals only)
4. **Funding rate simulation:** Simplified funding rate model for perpetuals
5. **Short-selling support:** Enable strategies to generate short signals
6. **Liquidation awareness:** Warn when positions would be liquidated (no full simulation)
7. **Separate data storage:** Download and store data per market type
8. **Grid search integration:** Allow leverage optimization in parameter grids
---
## User Stories
1. **As a trader**, I want to backtest my strategy on perpetual futures so that I can simulate leveraged trading with funding costs.
2. **As a trader**, I want to backtest on spot markets so that I can compare performance without leverage or funding overhead.
3. **As a strategy developer**, I want to define a default market type for my strategy so that it runs with appropriate settings by default.
4. **As a trader**, I want to test different leverage levels so that I can find the optimal risk/reward balance.
5. **As a trader**, I want to see warnings when my position would have been liquidated so that I can adjust my risk parameters.
6. **As a strategy developer**, I want to create strategies that can go short so that I can profit from downward price movements.
---
## Functional Requirements
### FR1: Market Type Enum and Configuration
1.1. Create a `MarketType` enum with values: `SPOT`, `PERPETUAL`
1.2. Each strategy class must have a `default_market_type` class attribute
1.3. Market type can be overridden via CLI (optional, for testing)
### FR2: Data Management
2.1. Modify `DataManager` to support market type in data paths:
- Spot: `data/ccxt/{exchange}/spot/{symbol}/{timeframe}.csv`
- Perpetual: `data/ccxt/{exchange}/perpetual/{symbol}/{timeframe}.csv`
2.2. Update `download` command to accept `--market` flag:
```bash
uv run python main.py download --pair BTC/USDT --market perpetual
```
2.3. Use CCXT's market type parameter when fetching data:
- Spot: `exchange.fetch_ohlcv(symbol, timeframe, ...)`
- Perpetual: `exchange.fetch_ohlcv(symbol + ':USDT', timeframe, ...)`
### FR3: Fee Structure
3.1. Define default fees per market type (matching OKX):
| Market Type | Maker Fee | Taker Fee | Notes |
|-------------|-----------|-----------|-------|
| Spot | 0.08% | 0.10% | No funding |
| Perpetual | 0.02% | 0.05% | + funding |
3.2. Allow fee override via CLI (existing `--fees` flag)
### FR4: Leverage Support (Perpetual Only)
4.1. Add `default_leverage` class attribute to strategies (default: 1 for spot, configurable for perpetual)
4.2. Add `--leverage` CLI flag for backtest command
4.3. Leverage affects:
- Position sizing (notional = cash * leverage)
- PnL calculation (multiplied by leverage)
- Liquidation threshold calculation
4.4. Support leverage in grid search parameter grids
### FR5: Funding Rate Simulation (Perpetual Only)
5.1. Implement simplified funding rate model:
- Default rate: 0.01% per 8 hours (configurable)
- Applied every 8 hours to open positions
- Positive rate: Longs pay shorts
- Negative rate: Shorts pay longs
5.2. Add `--funding-rate` CLI flag to override default
5.3. Track cumulative funding paid/received in backtest stats
### FR6: Short-Selling Support
6.1. Modify `BaseStrategy.run()` signature to return 4 signal arrays:
```python
def run(self, close, **kwargs) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Returns:
long_entries: Boolean signals to open long positions
long_exits: Boolean signals to close long positions
short_entries: Boolean signals to open short positions
short_exits: Boolean signals to close short positions
"""
```
6.2. Update `Backtester` to use VectorBT's `direction` parameter or dual portfolio simulation
6.3. For spot market: Ignore short signals (log warning if present)
6.4. For perpetual market: Process both long and short signals
### FR7: Liquidation Warning
7.1. Calculate liquidation price based on:
- Entry price
- Leverage
- Maintenance margin rate (OKX: ~0.4% for BTC)
7.2. During backtest, check if price crosses liquidation threshold
7.3. Log warning with details:
```
WARNING: Position would be liquidated at bar 1234 (price: $45,000, liq_price: $44,820)
```
7.4. Include liquidation event count in backtest summary stats
### FR8: Backtester Integration
8.1. Modify `Backtester.run_strategy()` to accept market type from strategy
8.2. Apply market-specific simulation parameters:
- Fees (if not overridden)
- Leverage
- Funding rate calculation
- Short-selling capability
8.3. Update portfolio simulation to handle leveraged positions
### FR9: Reporting Updates
9.1. Add market type to backtest summary output
9.2. Add new stats for perpetual backtests:
- Total funding paid/received
- Number of liquidation warnings
- Effective leverage used
9.3. Update CSV exports to include market-specific columns
---
## Non-Goals (Out of Scope)
- **Coin-M (Inverse) Perpetuals:** Not included in v1
- **Spot Margin Trading:** Not included in v1
- **Expiry Futures:** Not included in v1
- **Full Liquidation Simulation:** Only warnings, no automatic position closure
- **Real Funding Rate Data:** Use simplified model; historical funding API integration is future work
- **Cross-Margin Mode:** Assume isolated margin for simplicity
- **Partial Liquidation:** Assume full liquidation threshold only
---
## Design Considerations
### Data Directory Structure (New)
```
data/ccxt/
okx/
spot/
BTC-USDT/
1m.csv
1d.csv
perpetual/
BTC-USDT/
1m.csv
1d.csv
```
### Strategy Class Example
```python
class MetaSupertrendStrategy(BaseStrategy):
default_market_type = MarketType.PERPETUAL
default_leverage = 5
default_sl_stop = 0.02
def run(self, close, **kwargs):
# ... indicator logic ...
return long_entries, long_exits, short_entries, short_exits
```
### CLI Usage Examples
```bash
# Download perpetual data
uv run python main.py download --pair BTC/USDT --market perpetual
# Backtest with strategy defaults (uses strategy's default_market_type)
uv run python main.py backtest --strategy meta_st --pair BTC/USDT
# Override leverage
uv run python main.py backtest --strategy meta_st --pair BTC/USDT --leverage 10
# Grid search including leverage
uv run python main.py backtest --strategy meta_st --pair BTC/USDT --grid
# (leverage can be part of param grid in strategy factory)
```
---
## Technical Considerations
1. **VectorBT Compatibility:**
- VectorBT's `Portfolio.from_signals()` supports `direction` parameter for long/short
- Alternatively, run two portfolios (long-only, short-only) and combine
- Leverage can be simulated via `size` parameter or post-processing returns
2. **CCXT Market Type Handling:**
- OKX perpetual symbols use format: `BTC/USDT:USDT`
- Need to handle symbol conversion in DataManager
3. **Funding Rate Timing:**
- OKX funding at 00:00, 08:00, 16:00 UTC
- Need to identify these timestamps in the data and apply funding
4. **Backward Compatibility:**
- Existing strategies should work with minimal changes
- Default to `MarketType.SPOT` if not specified
- Existing 2-tuple return from `run()` should be interpreted as long-only
---
## Success Metrics
1. **Functional:** All existing backtests produce same results when run with `MarketType.SPOT`
2. **Functional:** Perpetual backtests correctly apply funding every 8 hours
3. **Functional:** Leverage multiplies both gains and losses correctly
4. **Functional:** Short signals are processed for perpetual, ignored for spot
5. **Usability:** Users can switch market types with minimal configuration
6. **Accuracy:** Fee structures match OKX's published rates
---
## Open Questions
1. **Position Sizing with Leverage:**
- Should leverage affect `init_cash` interpretation (notional value) or position size directly?
- Recommendation: Affect position size; `init_cash` remains the actual margin deposited.
2. **Multiple Positions:**
- Can strategies hold both long and short simultaneously (hedging)?
- Recommendation: No for v1; only one direction at a time.
3. **Funding Rate Sign:**
- When funding is positive, longs pay shorts. Should we assume the user is always the "taker" of funding?
- Recommendation: Yes, apply funding based on position direction.
4. **Migration Path:**
- Should we migrate existing data to new directory structure?
- Recommendation: No auto-migration; users re-download with `--market` flag.
---
## Implementation Priority
| Priority | Component | Complexity |
|----------|-----------|------------|
| 1 | MarketType enum + strategy defaults | Low |
| 2 | DataManager market type support | Medium |
| 3 | Fee structure per market type | Low |
| 4 | Short-selling signal support | Medium |
| 5 | Leverage simulation | Medium |
| 6 | Funding rate simulation | Medium |
| 7 | Liquidation warnings | Low |
| 8 | Reporting updates | Low |
| 9 | Grid search leverage support | Low |

View File

@@ -0,0 +1,76 @@
# PRD: VectorBT Migration & CCXT Integration
## 1. Introduction
The goal of this project is to refactor the current backtesting infrastructure to a professional-grade stack using **VectorBT** for high-performance backtesting and **CCXT** for robust historical data acquisition. The system will support rapid prototyping of "many simple strategies," parameter optimization (Grid Search), and stability testing (Walk-Forward Analysis).
## 2. Goals
- **Replace Custom Backtester:** Retire the existing loop-based backtesting logic in favor of vectorized operations using `vectorbt`.
- **Automate Data Collection:** Implement a `ccxt` based downloader to fetch and cache OHLCV data from OKX (and other exchanges) automatically.
- **Enable Optimization:** Built-in support for Grid Search to find optimal strategy parameters.
- **Validation:** Implement Walk-Forward Analysis (WFA) to validate strategy robustness and prevent overfitting.
- **Standardized Reporting:** Generate consistent outputs: Console summaries, CSV logs, and VectorBT interactive plots.
## 3. User Stories
- **Data Acquisition:** "As a user, I want to run a command `download_data --pair BTC/USDT --exchange okx` and have the system fetch historical 1-minute candles and save them to `data/ccxt/okx/BTC-USDT/1m.csv`."
- **Strategy Dev:** "As a researcher, I want to define a new strategy by simply writing a class/function that defines entry/exit signals, without worrying about the backtesting loop."
- **Optimization:** "As a researcher, I want to say 'Optimize RSI period between 10 and 20' and get a heatmap of results."
- **Validation:** "As a researcher, I want to verify if my 'best' parameters work on unseen data using Walk-Forward Analysis."
- **Analysis:** "As a user, I want to see an equity curve and key metrics (Sharpe, Drawdown) immediately after a test run."
## 4. Functional Requirements
### 4.1 Data Module (`data_manager`)
- **Exchange Interface:** Use `ccxt` to connect to exchanges (initially OKX).
- **Fetching Logic:** Fetch OHLCV data in chunks to handle rate limits and long histories.
- **Storage:** Save data to standardized paths: `data/ccxt/{exchange}/{pair}_{timeframe}.csv`.
- **Loading:** Utility to load saved CSVs into a Pandas DataFrame compatible with `vectorbt`.
### 4.2 Strategy Interface (`strategies/`)
- **Base Protocol:** Define a standard structure for strategies. A strategy should return/define:
- Indicator calculations (Vectorized).
- Entry signals (Boolean Series).
- Exit signals (Boolean Series).
- **Parameterization:** Strategies must accept dynamic parameters to support Grid Search.
### 4.3 Backtest Engine (`engine.py`)
- **Simulation:** Use `vectorbt.Portfolio.from_signals` (or similar) for fast simulation.
- **Cost Model:** Support configurable fees (maker/taker) and slippage estimates.
- **Grid Search:** Utilize `vectorbt`'s parameter broadcasting to run many variations simultaneously.
- **Walk-Forward Analysis:**
- Implement a splitting mechanism (e.g., `vectorbt.Splitter`) to divide data into In-Sample (Train) and Out-of-Sample (Test) sets.
- Execute optimization on Train, validate on Test.
### 4.4 Reporting (`reporting.py`)
- **Console:** Print key metrics: Total Return, Sharpe Ratio, Max Drawdown, Win Rate, Count of Trades.
- **Files:** Save detailed trade logs and metrics summaries to `backtest_logs/`.
- **Visuals:** Generate and save/show `vectorbt` plots (Equity curve, Drawdowns).
## 5. Non-Goals
- Real-time live trading execution (this is strictly for research/backtesting).
- Complex Machine Learning models (initially focusing on indicator-based logic).
- High-frequency tick-level backtesting (1-minute granularity is the target).
## 6. Technical Architecture Proposal
```text
project_root/
├── data/
│ └── ccxt/ # New data storage structure
├── strategies/ # Strategy definitions
│ ├── __init__.py
│ ├── base.py # Abstract Base Class
│ └── ma_cross.py # Example strategy
├── engine/
│ ├── data_loader.py # CCXT wrapper
│ ├── backtester.py # VBT runner
│ └── optimizer.py # Grid Search & WFA logic
├── main.py # CLI entry point
└── pyproject.toml
```
## 7. Success Metrics
- Can download 1 year of 1m BTC/USDT data from OKX in < 2 minutes.
- Can run a 100-parameter grid search on 1 year of 1m data in < 10 seconds.
- Walk-forward analysis produces a clear "Robustness Score" or visual comparison of Train vs Test performance.
## 8. Open Questions
- Do we need to handle funding rates for perp futures in the PnL calculation immediately? (Assumed NO for V1, stick to spot/simple futures price action).

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Test suite for lowkey_backtest."""

View File

@@ -0,0 +1,69 @@
"""Tests for DataManager functionality."""
import pytest
from engine.data_manager import DataManager
from engine.market import MarketType
class TestDataManager:
"""Test suite for DataManager class."""
def test_init_creates_data_dir(self, tmp_path):
"""Test that DataManager creates data directory on init."""
data_dir = tmp_path / "test_data"
dm = DataManager(data_dir=str(data_dir))
assert data_dir.exists()
def test_get_data_path_spot(self, tmp_path):
"""Test data path generation for spot market."""
dm = DataManager(data_dir=str(tmp_path))
path = dm._get_data_path("okx", "BTC/USDT", "1m", MarketType.SPOT)
expected = tmp_path / "okx" / "spot" / "BTC-USDT" / "1m.csv"
assert path == expected
def test_get_data_path_perpetual(self, tmp_path):
"""Test data path generation for perpetual market."""
dm = DataManager(data_dir=str(tmp_path))
path = dm._get_data_path("okx", "BTC/USDT", "1h", MarketType.PERPETUAL)
expected = tmp_path / "okx" / "perpetual" / "BTC-USDT" / "1h.csv"
assert path == expected
def test_load_data_file_not_found(self, tmp_path):
"""Test that load_data raises FileNotFoundError for missing data."""
dm = DataManager(data_dir=str(tmp_path))
with pytest.raises(FileNotFoundError):
dm.load_data("okx", "BTC/USDT", "1m", MarketType.SPOT)
@pytest.mark.network
class TestDataManagerDownload:
"""Tests requiring network access (marked for selective running)."""
def test_download_spot_data(self):
"""Test downloading spot market data."""
dm = DataManager()
df = dm.download_data(
"okx", "BTC/USDT", "1m",
start_date="2025-01-01",
market_type=MarketType.SPOT
)
assert df is not None
assert 'open' in df.columns
assert 'close' in df.columns
def test_download_perpetual_data(self):
"""Test downloading perpetual market data."""
dm = DataManager()
df = dm.download_data(
"okx", "BTC/USDT", "1m",
start_date="2025-01-01",
market_type=MarketType.PERPETUAL
)
assert df is not None
assert 'open' in df.columns
assert 'close' in df.columns

View File

@@ -1,52 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
import pandas as pd
from market_costs import okx_fee, estimate_slippage_rate
from intrabar import entry_slippage_row
@dataclass
class TradeState:
cash: float = 1000.0
qty: float = 0.0
entry_px: float | None = None
max_px: float | None = None
stop_loss_frac: float = 0.02
fee_bps: float = 10.0
slippage_bps: float = 2.0
def enter_long(state: TradeState, price: float) -> dict:
if state.qty > 0:
return {}
px = entry_slippage_row(price, 0.0, state.slippage_bps)
qty = state.cash / px
fee = okx_fee(state.fee_bps, state.cash)
state.qty = max(qty - fee / px, 0.0)
state.cash = 0.0
state.entry_px = px
state.max_px = px
return {"side": "BUY", "price": px, "qty": state.qty, "fee": fee}
def maybe_trailing_stop(state: TradeState, price: float) -> float:
if state.qty <= 0:
return float("inf")
state.max_px = max(state.max_px or price, price)
trail_px = state.max_px * (1.0 - state.stop_loss_frac)
return trail_px
def exit_long(state: TradeState, price: float) -> dict:
if state.qty <= 0:
return {}
notional = state.qty * price
slip = estimate_slippage_rate(state.slippage_bps, notional)
fee = okx_fee(state.fee_bps, notional)
cash_back = notional - slip - fee
event = {"side": "SELL", "price": price, "qty": state.qty, "fee": fee, "slippage": slip}
state.cash = cash_back
state.qty = 0.0
state.entry_px = None
state.max_px = None
return event

1831
uv.lock generated

File diff suppressed because it is too large Load Diff