Remove deprecated modules and files related to the backtesting framework, including backtest.py, cli.py, config.py, data.py, intrabar.py, logging_utils.py, market_costs.py, metrics.py, trade.py, and supertrend indicators. Introduce a new structure for the backtesting engine with improved organization and functionality, including a CLI handler, data manager, and reporting capabilities. Update dependencies in pyproject.toml to support the new architecture.
This commit is contained in:
80
strategies/base.py
Normal file
80
strategies/base.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Base strategy class for all trading strategies.
|
||||
|
||||
Strategies should inherit from BaseStrategy and implement the run() method.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from engine.market import MarketType
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""
|
||||
Abstract base class for trading strategies.
|
||||
|
||||
Class Attributes:
|
||||
default_market_type: Default market type for this strategy
|
||||
default_leverage: Default leverage (only applies to perpetuals)
|
||||
default_sl_stop: Default stop-loss percentage
|
||||
default_tp_stop: Default take-profit percentage
|
||||
default_sl_trail: Whether stop-loss is trailing by default
|
||||
"""
|
||||
# Market configuration defaults
|
||||
default_market_type: MarketType = MarketType.SPOT
|
||||
default_leverage: int = 1
|
||||
|
||||
# Risk management defaults (can be overridden per strategy)
|
||||
default_sl_stop: float | None = None
|
||||
default_tp_stop: float | None = None
|
||||
default_sl_trail: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.params = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Run the strategy logic.
|
||||
|
||||
Args:
|
||||
close: Price series (can be multiple columns for grid search)
|
||||
**kwargs: Additional data (high, low, open, volume) and parameters
|
||||
|
||||
Returns:
|
||||
Tuple of 4 DataFrames/Series:
|
||||
- long_entries: Boolean signals to open long positions
|
||||
- long_exits: Boolean signals to close long positions
|
||||
- short_entries: Boolean signals to open short positions
|
||||
- short_exits: Boolean signals to close short positions
|
||||
|
||||
Note:
|
||||
For spot markets, short signals will be ignored.
|
||||
For backward compatibility, strategies can return 2-tuple (entries, exits)
|
||||
which will be interpreted as long-only signals.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_indicator(self, ind_cls, *args, **kwargs):
|
||||
"""Helper to run a vectorbt indicator."""
|
||||
return ind_cls.run(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_empty_signals(reference: pd.Series | pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Create an empty (all False) signal DataFrame matching the reference shape.
|
||||
|
||||
Args:
|
||||
reference: Series or DataFrame to match shape/index
|
||||
|
||||
Returns:
|
||||
DataFrame of False values with same shape as reference
|
||||
"""
|
||||
if isinstance(reference, pd.DataFrame):
|
||||
return pd.DataFrame(False, index=reference.index, columns=reference.columns)
|
||||
return pd.Series(False, index=reference.index)
|
||||
97
strategies/examples.py
Normal file
97
strategies/examples.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Example trading strategies for backtesting.
|
||||
|
||||
These are simple strategies demonstrating the framework usage.
|
||||
"""
|
||||
import pandas as pd
|
||||
import vectorbt as vbt
|
||||
|
||||
from engine.market import MarketType
|
||||
from strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class RsiStrategy(BaseStrategy):
|
||||
"""
|
||||
RSI mean-reversion strategy.
|
||||
|
||||
Long entry when RSI crosses below oversold level.
|
||||
Long exit when RSI crosses above overbought level.
|
||||
"""
|
||||
default_market_type = MarketType.SPOT
|
||||
default_leverage = 1
|
||||
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
period: int = 14,
|
||||
rsi_lower: int = 30,
|
||||
rsi_upper: int = 70,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Generate RSI-based trading signals.
|
||||
|
||||
Args:
|
||||
close: Price series
|
||||
period: RSI calculation period
|
||||
rsi_lower: Oversold threshold (buy signal)
|
||||
rsi_upper: Overbought threshold (sell signal)
|
||||
|
||||
Returns:
|
||||
4-tuple of (long_entries, long_exits, short_entries, short_exits)
|
||||
"""
|
||||
# Calculate RSI
|
||||
rsi = vbt.RSI.run(close, window=period)
|
||||
|
||||
# Long signals: buy oversold, sell overbought
|
||||
long_entries = rsi.rsi_crossed_below(rsi_lower)
|
||||
long_exits = rsi.rsi_crossed_above(rsi_upper)
|
||||
|
||||
# No short signals for this strategy (spot-focused)
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
|
||||
|
||||
class MaCrossStrategy(BaseStrategy):
|
||||
"""
|
||||
Moving Average crossover strategy.
|
||||
|
||||
Long entry when fast MA crosses above slow MA.
|
||||
Long exit when fast MA crosses below slow MA.
|
||||
"""
|
||||
default_market_type = MarketType.SPOT
|
||||
default_leverage = 1
|
||||
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
fast_window: int = 10,
|
||||
slow_window: int = 20,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Generate MA crossover trading signals.
|
||||
|
||||
Args:
|
||||
close: Price series
|
||||
fast_window: Fast MA period
|
||||
slow_window: Slow MA period
|
||||
|
||||
Returns:
|
||||
4-tuple of (long_entries, long_exits, short_entries, short_exits)
|
||||
"""
|
||||
# Calculate Moving Averages
|
||||
fast_ma = vbt.MA.run(close, window=fast_window)
|
||||
slow_ma = vbt.MA.run(close, window=slow_window)
|
||||
|
||||
# Long signals
|
||||
long_entries = fast_ma.ma_crossed_above(slow_ma)
|
||||
long_exits = fast_ma.ma_crossed_below(slow_ma)
|
||||
|
||||
# No short signals for this strategy
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
128
strategies/factory.py
Normal file
128
strategies/factory.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Strategy factory for creating strategy instances with their parameters.
|
||||
|
||||
Centralizes strategy creation and parameter configuration.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from strategies.base import BaseStrategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyConfig:
|
||||
"""
|
||||
Configuration for a strategy including default and grid parameters.
|
||||
|
||||
Attributes:
|
||||
strategy_class: The strategy class to instantiate
|
||||
default_params: Parameters for single backtest runs
|
||||
grid_params: Parameters for grid search optimization
|
||||
"""
|
||||
strategy_class: type[BaseStrategy]
|
||||
default_params: dict[str, Any] = field(default_factory=dict)
|
||||
grid_params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _build_registry() -> dict[str, StrategyConfig]:
|
||||
"""
|
||||
Build the strategy registry lazily to avoid circular imports.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping strategy names to their configurations
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from strategies.examples import MaCrossStrategy, RsiStrategy
|
||||
from strategies.supertrend import MetaSupertrendStrategy
|
||||
|
||||
return {
|
||||
"rsi": StrategyConfig(
|
||||
strategy_class=RsiStrategy,
|
||||
default_params={
|
||||
'period': 14,
|
||||
'rsi_lower': 30,
|
||||
'rsi_upper': 70
|
||||
},
|
||||
grid_params={
|
||||
'period': np.arange(10, 25, 2),
|
||||
'rsi_lower': [20, 30, 40],
|
||||
'rsi_upper': [60, 70, 80]
|
||||
}
|
||||
),
|
||||
"macross": StrategyConfig(
|
||||
strategy_class=MaCrossStrategy,
|
||||
default_params={
|
||||
'fast_window': 10,
|
||||
'slow_window': 20
|
||||
},
|
||||
grid_params={
|
||||
'fast_window': np.arange(5, 20, 5),
|
||||
'slow_window': np.arange(20, 60, 10)
|
||||
}
|
||||
),
|
||||
"meta_st": StrategyConfig(
|
||||
strategy_class=MetaSupertrendStrategy,
|
||||
default_params={
|
||||
'period1': 12, 'multiplier1': 3.0,
|
||||
'period2': 10, 'multiplier2': 1.0,
|
||||
'period3': 11, 'multiplier3': 2.0
|
||||
},
|
||||
grid_params={
|
||||
'multiplier1': [2.0, 3.0, 4.0],
|
||||
'period1': [10, 12, 14],
|
||||
'period2': 11, 'multiplier2': 2.0,
|
||||
'period3': 12, 'multiplier3': 1.0
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Module-level cache for the registry
|
||||
_REGISTRY_CACHE: dict[str, StrategyConfig] | None = None
|
||||
|
||||
|
||||
def get_registry() -> dict[str, StrategyConfig]:
|
||||
"""Get the strategy registry, building it on first access."""
|
||||
global _REGISTRY_CACHE
|
||||
if _REGISTRY_CACHE is None:
|
||||
_REGISTRY_CACHE = _build_registry()
|
||||
return _REGISTRY_CACHE
|
||||
|
||||
|
||||
def get_strategy_names() -> list[str]:
|
||||
"""
|
||||
Get list of available strategy names.
|
||||
|
||||
Returns:
|
||||
List of strategy name strings
|
||||
"""
|
||||
return list(get_registry().keys())
|
||||
|
||||
|
||||
def get_strategy(name: str, is_grid: bool = False) -> tuple[BaseStrategy, dict[str, Any]]:
|
||||
"""
|
||||
Create a strategy instance with appropriate parameters.
|
||||
|
||||
Args:
|
||||
name: Strategy identifier (e.g., 'rsi', 'macross', 'meta_st')
|
||||
is_grid: If True, return grid search parameters
|
||||
|
||||
Returns:
|
||||
Tuple of (strategy instance, parameters dict)
|
||||
|
||||
Raises:
|
||||
KeyError: If strategy name is not found in registry
|
||||
"""
|
||||
registry = get_registry()
|
||||
|
||||
if name not in registry:
|
||||
available = ", ".join(registry.keys())
|
||||
raise KeyError(f"Unknown strategy '{name}'. Available: {available}")
|
||||
|
||||
config = registry[name]
|
||||
strategy = config.strategy_class()
|
||||
params = config.grid_params if is_grid else config.default_params
|
||||
|
||||
return strategy, params.copy()
|
||||
6
strategies/supertrend/__init__.py
Normal file
6
strategies/supertrend/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Meta Supertrend strategy package.
|
||||
"""
|
||||
from .strategy import MetaSupertrendStrategy
|
||||
|
||||
__all__ = ['MetaSupertrendStrategy']
|
||||
128
strategies/supertrend/indicators.py
Normal file
128
strategies/supertrend/indicators.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Supertrend indicators and helper functions.
|
||||
"""
|
||||
import numpy as np
|
||||
import vectorbt as vbt
|
||||
from numba import njit
|
||||
|
||||
# --- Numba Compiled Helper Functions ---
|
||||
|
||||
@njit(cache=False) # Disable cache to avoid stale compilation issues
|
||||
def get_tr_nb(high, low, close):
|
||||
"""Calculate True Range (Numba compiled)."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
tr = np.empty_like(close)
|
||||
tr[0] = high[0] - low[0]
|
||||
for i in range(1, len(close)):
|
||||
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
|
||||
return tr
|
||||
|
||||
@njit(cache=False)
|
||||
def get_atr_nb(high, low, close, period):
|
||||
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
# Ensure period is native Python int (critical for Numba array indexing)
|
||||
n = len(close)
|
||||
p = int(period)
|
||||
|
||||
tr = get_tr_nb(high, low, close)
|
||||
atr = np.full(n, np.nan, dtype=np.float64)
|
||||
|
||||
if n < p:
|
||||
return atr
|
||||
|
||||
# Initial ATR is simple average of TR
|
||||
sum_tr = 0.0
|
||||
for i in range(p):
|
||||
sum_tr += tr[i]
|
||||
atr[p - 1] = sum_tr / p
|
||||
|
||||
# Subsequent ATR is Wilder's smoothed
|
||||
for i in range(p, n):
|
||||
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
|
||||
|
||||
return atr
|
||||
|
||||
@njit(cache=False)
|
||||
def get_supertrend_nb(high, low, close, period, multiplier):
|
||||
"""Calculate SuperTrend completely in Numba."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
# Ensure params are native Python types (critical for Numba)
|
||||
n = len(close)
|
||||
p = int(period)
|
||||
m = float(multiplier)
|
||||
|
||||
atr = get_atr_nb(high, low, close, p)
|
||||
|
||||
final_upper = np.full(n, np.nan, dtype=np.float64)
|
||||
final_lower = np.full(n, np.nan, dtype=np.float64)
|
||||
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
|
||||
|
||||
# Skip until we have valid ATR
|
||||
start_idx = p
|
||||
if start_idx >= n:
|
||||
return trend
|
||||
|
||||
# Init first valid point
|
||||
hl2 = (high[start_idx] + low[start_idx]) / 2
|
||||
final_upper[start_idx] = hl2 + m * atr[start_idx]
|
||||
final_lower[start_idx] = hl2 - m * atr[start_idx]
|
||||
|
||||
# Loop
|
||||
for i in range(start_idx + 1, n):
|
||||
cur_hl2 = (high[i] + low[i]) / 2
|
||||
cur_atr = atr[i]
|
||||
basic_upper = cur_hl2 + m * cur_atr
|
||||
basic_lower = cur_hl2 - m * cur_atr
|
||||
|
||||
# Upper Band Logic
|
||||
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
|
||||
final_upper[i] = basic_upper
|
||||
else:
|
||||
final_upper[i] = final_upper[i-1]
|
||||
|
||||
# Lower Band Logic
|
||||
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
|
||||
final_lower[i] = basic_lower
|
||||
else:
|
||||
final_lower[i] = final_lower[i-1]
|
||||
|
||||
# Trend Logic
|
||||
if trend[i-1] == 1:
|
||||
if close[i] < final_lower[i-1]:
|
||||
trend[i] = -1
|
||||
else:
|
||||
trend[i] = 1
|
||||
else:
|
||||
if close[i] > final_upper[i-1]:
|
||||
trend[i] = 1
|
||||
else:
|
||||
trend[i] = -1
|
||||
|
||||
return trend
|
||||
|
||||
# --- VectorBT Indicator Factory ---
|
||||
|
||||
SuperTrendIndicator = vbt.IndicatorFactory(
|
||||
class_name='SuperTrend',
|
||||
short_name='st',
|
||||
input_names=['high', 'low', 'close'],
|
||||
param_names=['period', 'multiplier'],
|
||||
output_names=['trend']
|
||||
).from_apply_func(
|
||||
get_supertrend_nb,
|
||||
keep_pd=False, # Disable automatic Pandas wrapping of inputs
|
||||
param_product=True # Enable Cartesian product for list params
|
||||
)
|
||||
142
strategies/supertrend/strategy.py
Normal file
142
strategies/supertrend/strategy.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Meta Supertrend strategy implementation.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from engine.market import MarketType
|
||||
from strategies.base import BaseStrategy
|
||||
from .indicators import SuperTrendIndicator
|
||||
|
||||
|
||||
class MetaSupertrendStrategy(BaseStrategy):
|
||||
"""
|
||||
Meta Supertrend Strategy using 3 Supertrend indicators.
|
||||
|
||||
Enters long when all 3 Supertrends are bullish.
|
||||
Enters short when all 3 Supertrends are bearish.
|
||||
|
||||
Designed for perpetual futures with leverage and short-selling support.
|
||||
"""
|
||||
# Market configuration
|
||||
default_market_type = MarketType.PERPETUAL
|
||||
default_leverage = 5
|
||||
|
||||
# Risk management parameters
|
||||
default_sl_stop = 0.02 # 2% stop loss
|
||||
default_sl_trail = True # Trailing stop enabled
|
||||
default_exit_on_bearish_flip = False # Rely on SL/TP, not bearish flip
|
||||
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
high: pd.Series = None,
|
||||
low: pd.Series = None,
|
||||
period1: int = 10,
|
||||
multiplier1: float = 3.0,
|
||||
period2: int = 11,
|
||||
multiplier2: float = 2.0,
|
||||
period3: int = 12,
|
||||
multiplier3: float = 1.0,
|
||||
exit_on_bearish_flip: bool = None,
|
||||
enable_short: bool = True,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
|
||||
# 1. Validation & Setup
|
||||
if exit_on_bearish_flip is None:
|
||||
exit_on_bearish_flip = self.default_exit_on_bearish_flip
|
||||
|
||||
if high is None or low is None:
|
||||
raise ValueError("MetaSupertrendStrategy requires High and Low prices.")
|
||||
|
||||
# 2. Calculate Supertrends
|
||||
t1, t2, t3 = self._calculate_supertrends(
|
||||
high, low, close,
|
||||
period1, multiplier1,
|
||||
period2, multiplier2,
|
||||
period3, multiplier3
|
||||
)
|
||||
|
||||
# 3. Meta Signals
|
||||
bullish, bearish = self._calculate_meta_signals(t1, t2, t3, close)
|
||||
|
||||
# 4. Generate Entry/Exit Signals
|
||||
return self._generate_signals(bullish, bearish, exit_on_bearish_flip, enable_short)
|
||||
|
||||
def _calculate_supertrends(
|
||||
self, high, low, close, p1, m1, p2, m2, p3, m3
|
||||
):
|
||||
"""Run the 3 Supertrend indicators."""
|
||||
# Pass NumPy arrays explicitly to avoid Numba typing errors
|
||||
h_vals = high.values
|
||||
l_vals = low.values
|
||||
c_vals = close.values
|
||||
|
||||
def run_st(p, m):
|
||||
st = SuperTrendIndicator.run(h_vals, l_vals, c_vals, period=p, multiplier=m)
|
||||
trend = st.trend
|
||||
if isinstance(trend, pd.DataFrame):
|
||||
trend.index = close.index
|
||||
if trend.shape[1] == 1:
|
||||
trend = trend.iloc[:, 0]
|
||||
elif isinstance(trend, pd.Series):
|
||||
trend.index = close.index
|
||||
return trend
|
||||
|
||||
t1 = run_st(p1, m1)
|
||||
t2 = run_st(p2, m2)
|
||||
t3 = run_st(p3, m3)
|
||||
return t1, t2, t3
|
||||
|
||||
def _calculate_meta_signals(self, t1, t2, t3, close_series):
|
||||
"""Combine 3 Supertrends into boolean Bullish/Bearish signals."""
|
||||
# Use NumPy broadcasting
|
||||
t1_vals = t1.values if isinstance(t1, pd.DataFrame) else t1.values.reshape(-1, 1)
|
||||
# Force column vectors for broadcasting if scalar result
|
||||
t2_vals = t2.values.reshape(-1, 1)
|
||||
t3_vals = t3.values.reshape(-1, 1)
|
||||
|
||||
# Boolean logic on numpy arrays (1 = Bull, -1 = Bear)
|
||||
bullish_vals = (t1_vals == 1) & (t2_vals == 1) & (t3_vals == 1)
|
||||
bearish_vals = (t1_vals == -1) & (t2_vals == -1) & (t3_vals == -1)
|
||||
|
||||
# Reconstruct Pandas objects
|
||||
if isinstance(t1, pd.DataFrame):
|
||||
bullish = pd.DataFrame(bullish_vals, index=t1.index, columns=t1.columns)
|
||||
bearish = pd.DataFrame(bearish_vals, index=t1.index, columns=t1.columns)
|
||||
else:
|
||||
bullish = pd.Series(bullish_vals.flatten(), index=t1.index)
|
||||
bearish = pd.Series(bearish_vals.flatten(), index=t1.index)
|
||||
|
||||
return bullish, bearish
|
||||
|
||||
def _generate_signals(
|
||||
self, bullish, bearish, exit_on_bearish_flip, enable_short
|
||||
):
|
||||
"""Generate long/short entry/exit signals based on meta trend."""
|
||||
# Long Entries: Change from Not Bullish to Bullish
|
||||
prev_bullish = bullish.shift(1).fillna(False)
|
||||
long_entries = bullish & (~prev_bullish)
|
||||
|
||||
# Long Exits
|
||||
if exit_on_bearish_flip:
|
||||
prev_bearish = bearish.shift(1).fillna(False)
|
||||
long_exits = bearish & (~prev_bearish)
|
||||
else:
|
||||
long_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
# Short signals
|
||||
if enable_short:
|
||||
prev_bearish = bearish.shift(1).fillna(False)
|
||||
short_entries = bearish & (~prev_bearish)
|
||||
|
||||
if exit_on_bearish_flip:
|
||||
short_exits = bullish & (~prev_bullish)
|
||||
else:
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
else:
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
6
strategies/supertrend_pkg/__init__.py
Normal file
6
strategies/supertrend_pkg/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Meta Supertrend strategy package.
|
||||
"""
|
||||
from .strategy import MetaSupertrendStrategy
|
||||
|
||||
__all__ = ['MetaSupertrendStrategy']
|
||||
128
strategies/supertrend_pkg/indicators.py
Normal file
128
strategies/supertrend_pkg/indicators.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Supertrend indicators and helper functions.
|
||||
"""
|
||||
import numpy as np
|
||||
import vectorbt as vbt
|
||||
from numba import njit
|
||||
|
||||
# --- Numba Compiled Helper Functions ---
|
||||
|
||||
@njit(cache=False) # Disable cache to avoid stale compilation issues
|
||||
def get_tr_nb(high, low, close):
|
||||
"""Calculate True Range (Numba compiled)."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
tr = np.empty_like(close)
|
||||
tr[0] = high[0] - low[0]
|
||||
for i in range(1, len(close)):
|
||||
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
|
||||
return tr
|
||||
|
||||
@njit(cache=False)
|
||||
def get_atr_nb(high, low, close, period):
|
||||
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
# Ensure period is native Python int (critical for Numba array indexing)
|
||||
n = len(close)
|
||||
p = int(period)
|
||||
|
||||
tr = get_tr_nb(high, low, close)
|
||||
atr = np.full(n, np.nan, dtype=np.float64)
|
||||
|
||||
if n < p:
|
||||
return atr
|
||||
|
||||
# Initial ATR is simple average of TR
|
||||
sum_tr = 0.0
|
||||
for i in range(p):
|
||||
sum_tr += tr[i]
|
||||
atr[p - 1] = sum_tr / p
|
||||
|
||||
# Subsequent ATR is Wilder's smoothed
|
||||
for i in range(p, n):
|
||||
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
|
||||
|
||||
return atr
|
||||
|
||||
@njit(cache=False)
|
||||
def get_supertrend_nb(high, low, close, period, multiplier):
|
||||
"""Calculate SuperTrend completely in Numba."""
|
||||
# Ensure 1D arrays
|
||||
high = high.ravel()
|
||||
low = low.ravel()
|
||||
close = close.ravel()
|
||||
|
||||
# Ensure params are native Python types (critical for Numba)
|
||||
n = len(close)
|
||||
p = int(period)
|
||||
m = float(multiplier)
|
||||
|
||||
atr = get_atr_nb(high, low, close, p)
|
||||
|
||||
final_upper = np.full(n, np.nan, dtype=np.float64)
|
||||
final_lower = np.full(n, np.nan, dtype=np.float64)
|
||||
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
|
||||
|
||||
# Skip until we have valid ATR
|
||||
start_idx = p
|
||||
if start_idx >= n:
|
||||
return trend
|
||||
|
||||
# Init first valid point
|
||||
hl2 = (high[start_idx] + low[start_idx]) / 2
|
||||
final_upper[start_idx] = hl2 + m * atr[start_idx]
|
||||
final_lower[start_idx] = hl2 - m * atr[start_idx]
|
||||
|
||||
# Loop
|
||||
for i in range(start_idx + 1, n):
|
||||
cur_hl2 = (high[i] + low[i]) / 2
|
||||
cur_atr = atr[i]
|
||||
basic_upper = cur_hl2 + m * cur_atr
|
||||
basic_lower = cur_hl2 - m * cur_atr
|
||||
|
||||
# Upper Band Logic
|
||||
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
|
||||
final_upper[i] = basic_upper
|
||||
else:
|
||||
final_upper[i] = final_upper[i-1]
|
||||
|
||||
# Lower Band Logic
|
||||
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
|
||||
final_lower[i] = basic_lower
|
||||
else:
|
||||
final_lower[i] = final_lower[i-1]
|
||||
|
||||
# Trend Logic
|
||||
if trend[i-1] == 1:
|
||||
if close[i] < final_lower[i-1]:
|
||||
trend[i] = -1
|
||||
else:
|
||||
trend[i] = 1
|
||||
else:
|
||||
if close[i] > final_upper[i-1]:
|
||||
trend[i] = 1
|
||||
else:
|
||||
trend[i] = -1
|
||||
|
||||
return trend
|
||||
|
||||
# --- VectorBT Indicator Factory ---
|
||||
|
||||
SuperTrendIndicator = vbt.IndicatorFactory(
|
||||
class_name='SuperTrend',
|
||||
short_name='st',
|
||||
input_names=['high', 'low', 'close'],
|
||||
param_names=['period', 'multiplier'],
|
||||
output_names=['trend']
|
||||
).from_apply_func(
|
||||
get_supertrend_nb,
|
||||
keep_pd=False, # Disable automatic Pandas wrapping of inputs
|
||||
param_product=True # Enable Cartesian product for list params
|
||||
)
|
||||
142
strategies/supertrend_pkg/strategy.py
Normal file
142
strategies/supertrend_pkg/strategy.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Meta Supertrend strategy implementation.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from engine.market import MarketType
|
||||
from strategies.base import BaseStrategy
|
||||
from .indicators import SuperTrendIndicator
|
||||
|
||||
|
||||
class MetaSupertrendStrategy(BaseStrategy):
|
||||
"""
|
||||
Meta Supertrend Strategy using 3 Supertrend indicators.
|
||||
|
||||
Enters long when all 3 Supertrends are bullish.
|
||||
Enters short when all 3 Supertrends are bearish.
|
||||
|
||||
Designed for perpetual futures with leverage and short-selling support.
|
||||
"""
|
||||
# Market configuration
|
||||
default_market_type = MarketType.PERPETUAL
|
||||
default_leverage = 5
|
||||
|
||||
# Risk management parameters
|
||||
default_sl_stop = 0.02 # 2% stop loss
|
||||
default_sl_trail = True # Trailing stop enabled
|
||||
default_exit_on_bearish_flip = False # Rely on SL/TP, not bearish flip
|
||||
|
||||
def run(
|
||||
self,
|
||||
close: pd.Series,
|
||||
high: pd.Series = None,
|
||||
low: pd.Series = None,
|
||||
period1: int = 10,
|
||||
multiplier1: float = 3.0,
|
||||
period2: int = 11,
|
||||
multiplier2: float = 2.0,
|
||||
period3: int = 12,
|
||||
multiplier3: float = 1.0,
|
||||
exit_on_bearish_flip: bool = None,
|
||||
enable_short: bool = True,
|
||||
**kwargs
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
|
||||
# 1. Validation & Setup
|
||||
if exit_on_bearish_flip is None:
|
||||
exit_on_bearish_flip = self.default_exit_on_bearish_flip
|
||||
|
||||
if high is None or low is None:
|
||||
raise ValueError("MetaSupertrendStrategy requires High and Low prices.")
|
||||
|
||||
# 2. Calculate Supertrends
|
||||
t1, t2, t3 = self._calculate_supertrends(
|
||||
high, low, close,
|
||||
period1, multiplier1,
|
||||
period2, multiplier2,
|
||||
period3, multiplier3
|
||||
)
|
||||
|
||||
# 3. Meta Signals
|
||||
bullish, bearish = self._calculate_meta_signals(t1, t2, t3, close)
|
||||
|
||||
# 4. Generate Entry/Exit Signals
|
||||
return self._generate_signals(bullish, bearish, exit_on_bearish_flip, enable_short)
|
||||
|
||||
def _calculate_supertrends(
|
||||
self, high, low, close, p1, m1, p2, m2, p3, m3
|
||||
):
|
||||
"""Run the 3 Supertrend indicators."""
|
||||
# Pass NumPy arrays explicitly to avoid Numba typing errors
|
||||
h_vals = high.values
|
||||
l_vals = low.values
|
||||
c_vals = close.values
|
||||
|
||||
def run_st(p, m):
|
||||
st = SuperTrendIndicator.run(h_vals, l_vals, c_vals, period=p, multiplier=m)
|
||||
trend = st.trend
|
||||
if isinstance(trend, pd.DataFrame):
|
||||
trend.index = close.index
|
||||
if trend.shape[1] == 1:
|
||||
trend = trend.iloc[:, 0]
|
||||
elif isinstance(trend, pd.Series):
|
||||
trend.index = close.index
|
||||
return trend
|
||||
|
||||
t1 = run_st(p1, m1)
|
||||
t2 = run_st(p2, m2)
|
||||
t3 = run_st(p3, m3)
|
||||
return t1, t2, t3
|
||||
|
||||
def _calculate_meta_signals(self, t1, t2, t3, close_series):
|
||||
"""Combine 3 Supertrends into boolean Bullish/Bearish signals."""
|
||||
# Use NumPy broadcasting
|
||||
t1_vals = t1.values if isinstance(t1, pd.DataFrame) else t1.values.reshape(-1, 1)
|
||||
# Force column vectors for broadcasting if scalar result
|
||||
t2_vals = t2.values.reshape(-1, 1)
|
||||
t3_vals = t3.values.reshape(-1, 1)
|
||||
|
||||
# Boolean logic on numpy arrays (1 = Bull, -1 = Bear)
|
||||
bullish_vals = (t1_vals == 1) & (t2_vals == 1) & (t3_vals == 1)
|
||||
bearish_vals = (t1_vals == -1) & (t2_vals == -1) & (t3_vals == -1)
|
||||
|
||||
# Reconstruct Pandas objects
|
||||
if isinstance(t1, pd.DataFrame):
|
||||
bullish = pd.DataFrame(bullish_vals, index=t1.index, columns=t1.columns)
|
||||
bearish = pd.DataFrame(bearish_vals, index=t1.index, columns=t1.columns)
|
||||
else:
|
||||
bullish = pd.Series(bullish_vals.flatten(), index=t1.index)
|
||||
bearish = pd.Series(bearish_vals.flatten(), index=t1.index)
|
||||
|
||||
return bullish, bearish
|
||||
|
||||
def _generate_signals(
|
||||
self, bullish, bearish, exit_on_bearish_flip, enable_short
|
||||
):
|
||||
"""Generate long/short entry/exit signals based on meta trend."""
|
||||
# Long Entries: Change from Not Bullish to Bullish
|
||||
prev_bullish = bullish.shift(1).fillna(False)
|
||||
long_entries = bullish & (~prev_bullish)
|
||||
|
||||
# Long Exits
|
||||
if exit_on_bearish_flip:
|
||||
prev_bearish = bearish.shift(1).fillna(False)
|
||||
long_exits = bearish & (~prev_bearish)
|
||||
else:
|
||||
long_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
# Short signals
|
||||
if enable_short:
|
||||
prev_bearish = bearish.shift(1).fillna(False)
|
||||
short_entries = bearish & (~prev_bearish)
|
||||
|
||||
if exit_on_bearish_flip:
|
||||
short_exits = bullish & (~prev_bullish)
|
||||
else:
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
else:
|
||||
short_entries = BaseStrategy.create_empty_signals(long_entries)
|
||||
short_exits = BaseStrategy.create_empty_signals(long_entries)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits
|
||||
Reference in New Issue
Block a user