""" Supertrend indicators and helper functions. """ import numpy as np import vectorbt as vbt from numba import njit # --- Numba Compiled Helper Functions --- @njit(cache=False) # Disable cache to avoid stale compilation issues def get_tr_nb(high, low, close): """Calculate True Range (Numba compiled).""" # Ensure 1D arrays high = high.ravel() low = low.ravel() close = close.ravel() tr = np.empty_like(close) tr[0] = high[0] - low[0] for i in range(1, len(close)): tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1])) return tr @njit(cache=False) def get_atr_nb(high, low, close, period): """Calculate ATR using Wilder's Smoothing (Numba compiled).""" # Ensure 1D arrays high = high.ravel() low = low.ravel() close = close.ravel() # Ensure period is native Python int (critical for Numba array indexing) n = len(close) p = int(period) tr = get_tr_nb(high, low, close) atr = np.full(n, np.nan, dtype=np.float64) if n < p: return atr # Initial ATR is simple average of TR sum_tr = 0.0 for i in range(p): sum_tr += tr[i] atr[p - 1] = sum_tr / p # Subsequent ATR is Wilder's smoothed for i in range(p, n): atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p return atr @njit(cache=False) def get_supertrend_nb(high, low, close, period, multiplier): """Calculate SuperTrend completely in Numba.""" # Ensure 1D arrays high = high.ravel() low = low.ravel() close = close.ravel() # Ensure params are native Python types (critical for Numba) n = len(close) p = int(period) m = float(multiplier) atr = get_atr_nb(high, low, close, p) final_upper = np.full(n, np.nan, dtype=np.float64) final_lower = np.full(n, np.nan, dtype=np.float64) trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear # Skip until we have valid ATR start_idx = p if start_idx >= n: return trend # Init first valid point hl2 = (high[start_idx] + low[start_idx]) / 2 final_upper[start_idx] = hl2 + m * atr[start_idx] final_lower[start_idx] = hl2 - m * atr[start_idx] # Loop for i in range(start_idx + 1, n): cur_hl2 = (high[i] + low[i]) / 2 cur_atr = atr[i] basic_upper = cur_hl2 + m * cur_atr basic_lower = cur_hl2 - m * cur_atr # Upper Band Logic if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]: final_upper[i] = basic_upper else: final_upper[i] = final_upper[i-1] # Lower Band Logic if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]: final_lower[i] = basic_lower else: final_lower[i] = final_lower[i-1] # Trend Logic if trend[i-1] == 1: if close[i] < final_lower[i-1]: trend[i] = -1 else: trend[i] = 1 else: if close[i] > final_upper[i-1]: trend[i] = 1 else: trend[i] = -1 return trend # --- VectorBT Indicator Factory --- SuperTrendIndicator = vbt.IndicatorFactory( class_name='SuperTrend', short_name='st', input_names=['high', 'low', 'close'], param_names=['period', 'multiplier'], output_names=['trend'] ).from_apply_func( get_supertrend_nb, keep_pd=False, # Disable automatic Pandas wrapping of inputs param_product=True # Enable Cartesian product for list params )