Remove deprecated modules and files related to the backtesting framework, including backtest.py, cli.py, config.py, data.py, intrabar.py, logging_utils.py, market_costs.py, metrics.py, trade.py, and supertrend indicators. Introduce a new structure for the backtesting engine with improved organization and functionality, including a CLI handler, data manager, and reporting capabilities. Update dependencies in pyproject.toml to support the new architecture.

This commit is contained in:
2026-01-12 21:11:39 +08:00
parent c4aa965a98
commit 44fac1ed25
37 changed files with 5253 additions and 393 deletions

80
strategies/base.py Normal file
View File

@@ -0,0 +1,80 @@
"""
Base strategy class for all trading strategies.
Strategies should inherit from BaseStrategy and implement the run() method.
"""
from abc import ABC, abstractmethod
import pandas as pd
from engine.market import MarketType
class BaseStrategy(ABC):
"""
Abstract base class for trading strategies.
Class Attributes:
default_market_type: Default market type for this strategy
default_leverage: Default leverage (only applies to perpetuals)
default_sl_stop: Default stop-loss percentage
default_tp_stop: Default take-profit percentage
default_sl_trail: Whether stop-loss is trailing by default
"""
# Market configuration defaults
default_market_type: MarketType = MarketType.SPOT
default_leverage: int = 1
# Risk management defaults (can be overridden per strategy)
default_sl_stop: float | None = None
default_tp_stop: float | None = None
default_sl_trail: bool = False
def __init__(self, **kwargs):
self.params = kwargs
@abstractmethod
def run(
self,
close: pd.Series,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Run the strategy logic.
Args:
close: Price series (can be multiple columns for grid search)
**kwargs: Additional data (high, low, open, volume) and parameters
Returns:
Tuple of 4 DataFrames/Series:
- long_entries: Boolean signals to open long positions
- long_exits: Boolean signals to close long positions
- short_entries: Boolean signals to open short positions
- short_exits: Boolean signals to close short positions
Note:
For spot markets, short signals will be ignored.
For backward compatibility, strategies can return 2-tuple (entries, exits)
which will be interpreted as long-only signals.
"""
pass
def get_indicator(self, ind_cls, *args, **kwargs):
"""Helper to run a vectorbt indicator."""
return ind_cls.run(*args, **kwargs)
@staticmethod
def create_empty_signals(reference: pd.Series | pd.DataFrame) -> pd.DataFrame:
"""
Create an empty (all False) signal DataFrame matching the reference shape.
Args:
reference: Series or DataFrame to match shape/index
Returns:
DataFrame of False values with same shape as reference
"""
if isinstance(reference, pd.DataFrame):
return pd.DataFrame(False, index=reference.index, columns=reference.columns)
return pd.Series(False, index=reference.index)

97
strategies/examples.py Normal file
View File

@@ -0,0 +1,97 @@
"""
Example trading strategies for backtesting.
These are simple strategies demonstrating the framework usage.
"""
import pandas as pd
import vectorbt as vbt
from engine.market import MarketType
from strategies.base import BaseStrategy
class RsiStrategy(BaseStrategy):
"""
RSI mean-reversion strategy.
Long entry when RSI crosses below oversold level.
Long exit when RSI crosses above overbought level.
"""
default_market_type = MarketType.SPOT
default_leverage = 1
def run(
self,
close: pd.Series,
period: int = 14,
rsi_lower: int = 30,
rsi_upper: int = 70,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Generate RSI-based trading signals.
Args:
close: Price series
period: RSI calculation period
rsi_lower: Oversold threshold (buy signal)
rsi_upper: Overbought threshold (sell signal)
Returns:
4-tuple of (long_entries, long_exits, short_entries, short_exits)
"""
# Calculate RSI
rsi = vbt.RSI.run(close, window=period)
# Long signals: buy oversold, sell overbought
long_entries = rsi.rsi_crossed_below(rsi_lower)
long_exits = rsi.rsi_crossed_above(rsi_upper)
# No short signals for this strategy (spot-focused)
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits
class MaCrossStrategy(BaseStrategy):
"""
Moving Average crossover strategy.
Long entry when fast MA crosses above slow MA.
Long exit when fast MA crosses below slow MA.
"""
default_market_type = MarketType.SPOT
default_leverage = 1
def run(
self,
close: pd.Series,
fast_window: int = 10,
slow_window: int = 20,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Generate MA crossover trading signals.
Args:
close: Price series
fast_window: Fast MA period
slow_window: Slow MA period
Returns:
4-tuple of (long_entries, long_exits, short_entries, short_exits)
"""
# Calculate Moving Averages
fast_ma = vbt.MA.run(close, window=fast_window)
slow_ma = vbt.MA.run(close, window=slow_window)
# Long signals
long_entries = fast_ma.ma_crossed_above(slow_ma)
long_exits = fast_ma.ma_crossed_below(slow_ma)
# No short signals for this strategy
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits

128
strategies/factory.py Normal file
View File

@@ -0,0 +1,128 @@
"""
Strategy factory for creating strategy instances with their parameters.
Centralizes strategy creation and parameter configuration.
"""
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from strategies.base import BaseStrategy
@dataclass
class StrategyConfig:
"""
Configuration for a strategy including default and grid parameters.
Attributes:
strategy_class: The strategy class to instantiate
default_params: Parameters for single backtest runs
grid_params: Parameters for grid search optimization
"""
strategy_class: type[BaseStrategy]
default_params: dict[str, Any] = field(default_factory=dict)
grid_params: dict[str, Any] = field(default_factory=dict)
def _build_registry() -> dict[str, StrategyConfig]:
"""
Build the strategy registry lazily to avoid circular imports.
Returns:
Dictionary mapping strategy names to their configurations
"""
# Import here to avoid circular imports
from strategies.examples import MaCrossStrategy, RsiStrategy
from strategies.supertrend import MetaSupertrendStrategy
return {
"rsi": StrategyConfig(
strategy_class=RsiStrategy,
default_params={
'period': 14,
'rsi_lower': 30,
'rsi_upper': 70
},
grid_params={
'period': np.arange(10, 25, 2),
'rsi_lower': [20, 30, 40],
'rsi_upper': [60, 70, 80]
}
),
"macross": StrategyConfig(
strategy_class=MaCrossStrategy,
default_params={
'fast_window': 10,
'slow_window': 20
},
grid_params={
'fast_window': np.arange(5, 20, 5),
'slow_window': np.arange(20, 60, 10)
}
),
"meta_st": StrategyConfig(
strategy_class=MetaSupertrendStrategy,
default_params={
'period1': 12, 'multiplier1': 3.0,
'period2': 10, 'multiplier2': 1.0,
'period3': 11, 'multiplier3': 2.0
},
grid_params={
'multiplier1': [2.0, 3.0, 4.0],
'period1': [10, 12, 14],
'period2': 11, 'multiplier2': 2.0,
'period3': 12, 'multiplier3': 1.0
}
),
}
# Module-level cache for the registry
_REGISTRY_CACHE: dict[str, StrategyConfig] | None = None
def get_registry() -> dict[str, StrategyConfig]:
"""Get the strategy registry, building it on first access."""
global _REGISTRY_CACHE
if _REGISTRY_CACHE is None:
_REGISTRY_CACHE = _build_registry()
return _REGISTRY_CACHE
def get_strategy_names() -> list[str]:
"""
Get list of available strategy names.
Returns:
List of strategy name strings
"""
return list(get_registry().keys())
def get_strategy(name: str, is_grid: bool = False) -> tuple[BaseStrategy, dict[str, Any]]:
"""
Create a strategy instance with appropriate parameters.
Args:
name: Strategy identifier (e.g., 'rsi', 'macross', 'meta_st')
is_grid: If True, return grid search parameters
Returns:
Tuple of (strategy instance, parameters dict)
Raises:
KeyError: If strategy name is not found in registry
"""
registry = get_registry()
if name not in registry:
available = ", ".join(registry.keys())
raise KeyError(f"Unknown strategy '{name}'. Available: {available}")
config = registry[name]
strategy = config.strategy_class()
params = config.grid_params if is_grid else config.default_params
return strategy, params.copy()

View File

@@ -0,0 +1,6 @@
"""
Meta Supertrend strategy package.
"""
from .strategy import MetaSupertrendStrategy
__all__ = ['MetaSupertrendStrategy']

View File

@@ -0,0 +1,128 @@
"""
Supertrend indicators and helper functions.
"""
import numpy as np
import vectorbt as vbt
from numba import njit
# --- Numba Compiled Helper Functions ---
@njit(cache=False) # Disable cache to avoid stale compilation issues
def get_tr_nb(high, low, close):
"""Calculate True Range (Numba compiled)."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
tr = np.empty_like(close)
tr[0] = high[0] - low[0]
for i in range(1, len(close)):
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
return tr
@njit(cache=False)
def get_atr_nb(high, low, close, period):
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
# Ensure period is native Python int (critical for Numba array indexing)
n = len(close)
p = int(period)
tr = get_tr_nb(high, low, close)
atr = np.full(n, np.nan, dtype=np.float64)
if n < p:
return atr
# Initial ATR is simple average of TR
sum_tr = 0.0
for i in range(p):
sum_tr += tr[i]
atr[p - 1] = sum_tr / p
# Subsequent ATR is Wilder's smoothed
for i in range(p, n):
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
return atr
@njit(cache=False)
def get_supertrend_nb(high, low, close, period, multiplier):
"""Calculate SuperTrend completely in Numba."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
# Ensure params are native Python types (critical for Numba)
n = len(close)
p = int(period)
m = float(multiplier)
atr = get_atr_nb(high, low, close, p)
final_upper = np.full(n, np.nan, dtype=np.float64)
final_lower = np.full(n, np.nan, dtype=np.float64)
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
# Skip until we have valid ATR
start_idx = p
if start_idx >= n:
return trend
# Init first valid point
hl2 = (high[start_idx] + low[start_idx]) / 2
final_upper[start_idx] = hl2 + m * atr[start_idx]
final_lower[start_idx] = hl2 - m * atr[start_idx]
# Loop
for i in range(start_idx + 1, n):
cur_hl2 = (high[i] + low[i]) / 2
cur_atr = atr[i]
basic_upper = cur_hl2 + m * cur_atr
basic_lower = cur_hl2 - m * cur_atr
# Upper Band Logic
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
final_upper[i] = basic_upper
else:
final_upper[i] = final_upper[i-1]
# Lower Band Logic
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
final_lower[i] = basic_lower
else:
final_lower[i] = final_lower[i-1]
# Trend Logic
if trend[i-1] == 1:
if close[i] < final_lower[i-1]:
trend[i] = -1
else:
trend[i] = 1
else:
if close[i] > final_upper[i-1]:
trend[i] = 1
else:
trend[i] = -1
return trend
# --- VectorBT Indicator Factory ---
SuperTrendIndicator = vbt.IndicatorFactory(
class_name='SuperTrend',
short_name='st',
input_names=['high', 'low', 'close'],
param_names=['period', 'multiplier'],
output_names=['trend']
).from_apply_func(
get_supertrend_nb,
keep_pd=False, # Disable automatic Pandas wrapping of inputs
param_product=True # Enable Cartesian product for list params
)

View File

@@ -0,0 +1,142 @@
"""
Meta Supertrend strategy implementation.
"""
import numpy as np
import pandas as pd
from engine.market import MarketType
from strategies.base import BaseStrategy
from .indicators import SuperTrendIndicator
class MetaSupertrendStrategy(BaseStrategy):
"""
Meta Supertrend Strategy using 3 Supertrend indicators.
Enters long when all 3 Supertrends are bullish.
Enters short when all 3 Supertrends are bearish.
Designed for perpetual futures with leverage and short-selling support.
"""
# Market configuration
default_market_type = MarketType.PERPETUAL
default_leverage = 5
# Risk management parameters
default_sl_stop = 0.02 # 2% stop loss
default_sl_trail = True # Trailing stop enabled
default_exit_on_bearish_flip = False # Rely on SL/TP, not bearish flip
def run(
self,
close: pd.Series,
high: pd.Series = None,
low: pd.Series = None,
period1: int = 10,
multiplier1: float = 3.0,
period2: int = 11,
multiplier2: float = 2.0,
period3: int = 12,
multiplier3: float = 1.0,
exit_on_bearish_flip: bool = None,
enable_short: bool = True,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
# 1. Validation & Setup
if exit_on_bearish_flip is None:
exit_on_bearish_flip = self.default_exit_on_bearish_flip
if high is None or low is None:
raise ValueError("MetaSupertrendStrategy requires High and Low prices.")
# 2. Calculate Supertrends
t1, t2, t3 = self._calculate_supertrends(
high, low, close,
period1, multiplier1,
period2, multiplier2,
period3, multiplier3
)
# 3. Meta Signals
bullish, bearish = self._calculate_meta_signals(t1, t2, t3, close)
# 4. Generate Entry/Exit Signals
return self._generate_signals(bullish, bearish, exit_on_bearish_flip, enable_short)
def _calculate_supertrends(
self, high, low, close, p1, m1, p2, m2, p3, m3
):
"""Run the 3 Supertrend indicators."""
# Pass NumPy arrays explicitly to avoid Numba typing errors
h_vals = high.values
l_vals = low.values
c_vals = close.values
def run_st(p, m):
st = SuperTrendIndicator.run(h_vals, l_vals, c_vals, period=p, multiplier=m)
trend = st.trend
if isinstance(trend, pd.DataFrame):
trend.index = close.index
if trend.shape[1] == 1:
trend = trend.iloc[:, 0]
elif isinstance(trend, pd.Series):
trend.index = close.index
return trend
t1 = run_st(p1, m1)
t2 = run_st(p2, m2)
t3 = run_st(p3, m3)
return t1, t2, t3
def _calculate_meta_signals(self, t1, t2, t3, close_series):
"""Combine 3 Supertrends into boolean Bullish/Bearish signals."""
# Use NumPy broadcasting
t1_vals = t1.values if isinstance(t1, pd.DataFrame) else t1.values.reshape(-1, 1)
# Force column vectors for broadcasting if scalar result
t2_vals = t2.values.reshape(-1, 1)
t3_vals = t3.values.reshape(-1, 1)
# Boolean logic on numpy arrays (1 = Bull, -1 = Bear)
bullish_vals = (t1_vals == 1) & (t2_vals == 1) & (t3_vals == 1)
bearish_vals = (t1_vals == -1) & (t2_vals == -1) & (t3_vals == -1)
# Reconstruct Pandas objects
if isinstance(t1, pd.DataFrame):
bullish = pd.DataFrame(bullish_vals, index=t1.index, columns=t1.columns)
bearish = pd.DataFrame(bearish_vals, index=t1.index, columns=t1.columns)
else:
bullish = pd.Series(bullish_vals.flatten(), index=t1.index)
bearish = pd.Series(bearish_vals.flatten(), index=t1.index)
return bullish, bearish
def _generate_signals(
self, bullish, bearish, exit_on_bearish_flip, enable_short
):
"""Generate long/short entry/exit signals based on meta trend."""
# Long Entries: Change from Not Bullish to Bullish
prev_bullish = bullish.shift(1).fillna(False)
long_entries = bullish & (~prev_bullish)
# Long Exits
if exit_on_bearish_flip:
prev_bearish = bearish.shift(1).fillna(False)
long_exits = bearish & (~prev_bearish)
else:
long_exits = BaseStrategy.create_empty_signals(long_entries)
# Short signals
if enable_short:
prev_bearish = bearish.shift(1).fillna(False)
short_entries = bearish & (~prev_bearish)
if exit_on_bearish_flip:
short_exits = bullish & (~prev_bullish)
else:
short_exits = BaseStrategy.create_empty_signals(long_entries)
else:
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits

View File

@@ -0,0 +1,6 @@
"""
Meta Supertrend strategy package.
"""
from .strategy import MetaSupertrendStrategy
__all__ = ['MetaSupertrendStrategy']

View File

@@ -0,0 +1,128 @@
"""
Supertrend indicators and helper functions.
"""
import numpy as np
import vectorbt as vbt
from numba import njit
# --- Numba Compiled Helper Functions ---
@njit(cache=False) # Disable cache to avoid stale compilation issues
def get_tr_nb(high, low, close):
"""Calculate True Range (Numba compiled)."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
tr = np.empty_like(close)
tr[0] = high[0] - low[0]
for i in range(1, len(close)):
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
return tr
@njit(cache=False)
def get_atr_nb(high, low, close, period):
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
# Ensure period is native Python int (critical for Numba array indexing)
n = len(close)
p = int(period)
tr = get_tr_nb(high, low, close)
atr = np.full(n, np.nan, dtype=np.float64)
if n < p:
return atr
# Initial ATR is simple average of TR
sum_tr = 0.0
for i in range(p):
sum_tr += tr[i]
atr[p - 1] = sum_tr / p
# Subsequent ATR is Wilder's smoothed
for i in range(p, n):
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
return atr
@njit(cache=False)
def get_supertrend_nb(high, low, close, period, multiplier):
"""Calculate SuperTrend completely in Numba."""
# Ensure 1D arrays
high = high.ravel()
low = low.ravel()
close = close.ravel()
# Ensure params are native Python types (critical for Numba)
n = len(close)
p = int(period)
m = float(multiplier)
atr = get_atr_nb(high, low, close, p)
final_upper = np.full(n, np.nan, dtype=np.float64)
final_lower = np.full(n, np.nan, dtype=np.float64)
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
# Skip until we have valid ATR
start_idx = p
if start_idx >= n:
return trend
# Init first valid point
hl2 = (high[start_idx] + low[start_idx]) / 2
final_upper[start_idx] = hl2 + m * atr[start_idx]
final_lower[start_idx] = hl2 - m * atr[start_idx]
# Loop
for i in range(start_idx + 1, n):
cur_hl2 = (high[i] + low[i]) / 2
cur_atr = atr[i]
basic_upper = cur_hl2 + m * cur_atr
basic_lower = cur_hl2 - m * cur_atr
# Upper Band Logic
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
final_upper[i] = basic_upper
else:
final_upper[i] = final_upper[i-1]
# Lower Band Logic
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
final_lower[i] = basic_lower
else:
final_lower[i] = final_lower[i-1]
# Trend Logic
if trend[i-1] == 1:
if close[i] < final_lower[i-1]:
trend[i] = -1
else:
trend[i] = 1
else:
if close[i] > final_upper[i-1]:
trend[i] = 1
else:
trend[i] = -1
return trend
# --- VectorBT Indicator Factory ---
SuperTrendIndicator = vbt.IndicatorFactory(
class_name='SuperTrend',
short_name='st',
input_names=['high', 'low', 'close'],
param_names=['period', 'multiplier'],
output_names=['trend']
).from_apply_func(
get_supertrend_nb,
keep_pd=False, # Disable automatic Pandas wrapping of inputs
param_product=True # Enable Cartesian product for list params
)

View File

@@ -0,0 +1,142 @@
"""
Meta Supertrend strategy implementation.
"""
import numpy as np
import pandas as pd
from engine.market import MarketType
from strategies.base import BaseStrategy
from .indicators import SuperTrendIndicator
class MetaSupertrendStrategy(BaseStrategy):
"""
Meta Supertrend Strategy using 3 Supertrend indicators.
Enters long when all 3 Supertrends are bullish.
Enters short when all 3 Supertrends are bearish.
Designed for perpetual futures with leverage and short-selling support.
"""
# Market configuration
default_market_type = MarketType.PERPETUAL
default_leverage = 5
# Risk management parameters
default_sl_stop = 0.02 # 2% stop loss
default_sl_trail = True # Trailing stop enabled
default_exit_on_bearish_flip = False # Rely on SL/TP, not bearish flip
def run(
self,
close: pd.Series,
high: pd.Series = None,
low: pd.Series = None,
period1: int = 10,
multiplier1: float = 3.0,
period2: int = 11,
multiplier2: float = 2.0,
period3: int = 12,
multiplier3: float = 1.0,
exit_on_bearish_flip: bool = None,
enable_short: bool = True,
**kwargs
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
# 1. Validation & Setup
if exit_on_bearish_flip is None:
exit_on_bearish_flip = self.default_exit_on_bearish_flip
if high is None or low is None:
raise ValueError("MetaSupertrendStrategy requires High and Low prices.")
# 2. Calculate Supertrends
t1, t2, t3 = self._calculate_supertrends(
high, low, close,
period1, multiplier1,
period2, multiplier2,
period3, multiplier3
)
# 3. Meta Signals
bullish, bearish = self._calculate_meta_signals(t1, t2, t3, close)
# 4. Generate Entry/Exit Signals
return self._generate_signals(bullish, bearish, exit_on_bearish_flip, enable_short)
def _calculate_supertrends(
self, high, low, close, p1, m1, p2, m2, p3, m3
):
"""Run the 3 Supertrend indicators."""
# Pass NumPy arrays explicitly to avoid Numba typing errors
h_vals = high.values
l_vals = low.values
c_vals = close.values
def run_st(p, m):
st = SuperTrendIndicator.run(h_vals, l_vals, c_vals, period=p, multiplier=m)
trend = st.trend
if isinstance(trend, pd.DataFrame):
trend.index = close.index
if trend.shape[1] == 1:
trend = trend.iloc[:, 0]
elif isinstance(trend, pd.Series):
trend.index = close.index
return trend
t1 = run_st(p1, m1)
t2 = run_st(p2, m2)
t3 = run_st(p3, m3)
return t1, t2, t3
def _calculate_meta_signals(self, t1, t2, t3, close_series):
"""Combine 3 Supertrends into boolean Bullish/Bearish signals."""
# Use NumPy broadcasting
t1_vals = t1.values if isinstance(t1, pd.DataFrame) else t1.values.reshape(-1, 1)
# Force column vectors for broadcasting if scalar result
t2_vals = t2.values.reshape(-1, 1)
t3_vals = t3.values.reshape(-1, 1)
# Boolean logic on numpy arrays (1 = Bull, -1 = Bear)
bullish_vals = (t1_vals == 1) & (t2_vals == 1) & (t3_vals == 1)
bearish_vals = (t1_vals == -1) & (t2_vals == -1) & (t3_vals == -1)
# Reconstruct Pandas objects
if isinstance(t1, pd.DataFrame):
bullish = pd.DataFrame(bullish_vals, index=t1.index, columns=t1.columns)
bearish = pd.DataFrame(bearish_vals, index=t1.index, columns=t1.columns)
else:
bullish = pd.Series(bullish_vals.flatten(), index=t1.index)
bearish = pd.Series(bearish_vals.flatten(), index=t1.index)
return bullish, bearish
def _generate_signals(
self, bullish, bearish, exit_on_bearish_flip, enable_short
):
"""Generate long/short entry/exit signals based on meta trend."""
# Long Entries: Change from Not Bullish to Bullish
prev_bullish = bullish.shift(1).fillna(False)
long_entries = bullish & (~prev_bullish)
# Long Exits
if exit_on_bearish_flip:
prev_bearish = bearish.shift(1).fillna(False)
long_exits = bearish & (~prev_bearish)
else:
long_exits = BaseStrategy.create_empty_signals(long_entries)
# Short signals
if enable_short:
prev_bearish = bearish.shift(1).fillna(False)
short_entries = bearish & (~prev_bearish)
if exit_on_bearish_flip:
short_exits = bullish & (~prev_bullish)
else:
short_exits = BaseStrategy.create_empty_signals(long_entries)
else:
short_entries = BaseStrategy.create_empty_signals(long_entries)
short_exits = BaseStrategy.create_empty_signals(long_entries)
return long_entries, long_exits, short_entries, short_exits