Files
lowkey_backtest/engine/cli.py

244 lines
7.5 KiB
Python

"""
CLI handler for Lowkey Backtest.
"""
import argparse
import sys
from engine.backtester import Backtester
from engine.data_manager import DataManager
from engine.logging_config import get_logger, setup_logging
from engine.market import MarketType
from engine.reporting import Reporter
from strategies.factory import get_strategy, get_strategy_names
logger = get_logger(__name__)
def create_parser() -> argparse.ArgumentParser:
"""Create and configure the argument parser."""
parser = argparse.ArgumentParser(
description="Lowkey Backtest CLI (VectorBT Edition)"
)
subparsers = parser.add_subparsers(dest="command", help="Command to run")
_add_download_parser(subparsers)
_add_backtest_parser(subparsers)
_add_wfa_parser(subparsers)
return parser
def _add_download_parser(subparsers) -> None:
"""Add download command parser."""
dl_parser = subparsers.add_parser("download", help="Download historical data")
dl_parser.add_argument("--exchange", "-e", type=str, default="okx")
dl_parser.add_argument("--pair", "-p", type=str, required=True)
dl_parser.add_argument("--timeframe", "-t", type=str, default="1m")
dl_parser.add_argument("--start", type=str, help="Start Date (YYYY-MM-DD)")
dl_parser.add_argument(
"--market", "-m",
type=str,
choices=["spot", "perpetual"],
default="spot"
)
def _add_backtest_parser(subparsers) -> None:
"""Add backtest command parser."""
strategy_choices = get_strategy_names()
bt_parser = subparsers.add_parser("backtest", help="Run a backtest")
bt_parser.add_argument(
"--strategy", "-s",
type=str,
choices=strategy_choices,
required=True
)
bt_parser.add_argument("--exchange", "-e", type=str, default="okx")
bt_parser.add_argument("--pair", "-p", type=str, required=True)
bt_parser.add_argument("--timeframe", "-t", type=str, default="1m")
bt_parser.add_argument("--start", type=str)
bt_parser.add_argument("--end", type=str)
bt_parser.add_argument("--grid", "-g", action="store_true")
bt_parser.add_argument("--plot", action="store_true")
# Risk parameters
bt_parser.add_argument("--sl", type=float, help="Stop Loss %%")
bt_parser.add_argument("--tp", type=float, help="Take Profit %%")
bt_parser.add_argument("--trail", action="store_true")
bt_parser.add_argument("--no-bear-exit", action="store_true")
# Cost parameters
bt_parser.add_argument("--fees", type=float, default=None)
bt_parser.add_argument("--slippage", type=float, default=0.001)
bt_parser.add_argument("--leverage", "-l", type=int, default=None)
def _add_wfa_parser(subparsers) -> None:
"""Add walk-forward analysis command parser."""
strategy_choices = get_strategy_names()
wfa_parser = subparsers.add_parser("wfa", help="Run Walk-Forward Analysis")
wfa_parser.add_argument(
"--strategy", "-s",
type=str,
choices=strategy_choices,
required=True
)
wfa_parser.add_argument("--pair", "-p", type=str, required=True)
wfa_parser.add_argument("--timeframe", "-t", type=str, default="1d")
wfa_parser.add_argument("--windows", "-w", type=int, default=10)
wfa_parser.add_argument("--plot", action="store_true")
def run_download(args) -> None:
"""Execute download command."""
dm = DataManager()
market_type = MarketType(args.market)
dm.download_data(
args.exchange,
args.pair,
args.timeframe,
start_date=args.start,
market_type=market_type
)
def run_backtest(args) -> None:
"""Execute backtest command."""
dm = DataManager()
bt = Backtester(dm)
reporter = Reporter()
strategy, params = get_strategy(args.strategy, args.grid)
# Apply CLI overrides for meta_st strategy
params = _apply_strategy_overrides(args, strategy, params)
if args.grid and args.strategy == "meta_st":
logger.info("Running Grid Search for Meta Supertrend...")
try:
result = bt.run_strategy(
strategy,
args.exchange,
args.pair,
timeframe=args.timeframe,
start_date=args.start,
end_date=args.end,
fees=args.fees,
slippage=args.slippage,
sl_stop=args.sl,
tp_stop=args.tp,
sl_trail=args.trail,
leverage=args.leverage,
**params
)
reporter.print_summary(result)
reporter.save_reports(result, f"{args.strategy}_{args.pair.replace('/','-')}")
if args.plot and not args.grid:
reporter.plot(result.portfolio)
elif args.plot and args.grid:
logger.info("Plotting skipped for Grid Search. Check CSV results.")
except Exception as e:
logger.error("Backtest failed: %s", e, exc_info=True)
def run_wfa(args) -> None:
"""Execute walk-forward analysis command."""
dm = DataManager()
bt = Backtester(dm)
reporter = Reporter()
strategy, params = get_strategy(args.strategy, is_grid=True)
logger.info(
"Running WFA on %s for %s (%s) with %d windows...",
args.strategy, args.pair, args.timeframe, args.windows
)
try:
results, stitched_curve = bt.run_wfa(
strategy,
"okx",
args.pair,
params,
n_windows=args.windows,
timeframe=args.timeframe
)
_log_wfa_results(results)
_save_wfa_results(args, results, stitched_curve, reporter)
except Exception as e:
logger.error("WFA failed: %s", e, exc_info=True)
def _apply_strategy_overrides(args, strategy, params: dict) -> dict:
"""Apply CLI argument overrides to strategy parameters."""
if args.strategy != "meta_st":
return params
if args.no_bear_exit:
params['exit_on_bearish_flip'] = False
if args.sl is None:
args.sl = strategy.default_sl_stop
if not args.trail:
args.trail = strategy.default_sl_trail
return params
def _log_wfa_results(results) -> None:
"""Log WFA results summary."""
logger.info("Walk-Forward Analysis Results:")
if results.empty or 'window' not in results.columns:
logger.warning("No valid WFA results. All windows may have failed.")
return
columns = ['window', 'train_score', 'test_score', 'test_return']
logger.info("\n%s", results[columns].to_string(index=False))
avg_test_sharpe = results['test_score'].mean()
avg_test_return = results['test_return'].mean()
logger.info("Average Test Sharpe: %.2f", avg_test_sharpe)
logger.info("Average Test Return: %.2f%%", avg_test_return * 100)
def _save_wfa_results(args, results, stitched_curve, reporter) -> None:
"""Save WFA results to file and optionally plot."""
if results.empty:
return
output_path = f"backtest_logs/wfa_{args.strategy}_{args.pair.replace('/','-')}.csv"
results.to_csv(output_path)
logger.info("Saved full results to %s", output_path)
if args.plot:
reporter.plot_wfa(results, stitched_curve)
def main():
"""Main entry point."""
setup_logging()
parser = create_parser()
args = parser.parse_args()
commands = {
"download": run_download,
"backtest": run_backtest,
"wfa": run_wfa,
}
if args.command in commands:
commands[args.command](args)
else:
parser.print_help()