129 lines
3.6 KiB
Python
129 lines
3.6 KiB
Python
"""
|
|
Supertrend indicators and helper functions.
|
|
"""
|
|
import numpy as np
|
|
import vectorbt as vbt
|
|
from numba import njit
|
|
|
|
# --- Numba Compiled Helper Functions ---
|
|
|
|
@njit(cache=False) # Disable cache to avoid stale compilation issues
|
|
def get_tr_nb(high, low, close):
|
|
"""Calculate True Range (Numba compiled)."""
|
|
# Ensure 1D arrays
|
|
high = high.ravel()
|
|
low = low.ravel()
|
|
close = close.ravel()
|
|
|
|
tr = np.empty_like(close)
|
|
tr[0] = high[0] - low[0]
|
|
for i in range(1, len(close)):
|
|
tr[i] = max(high[i] - low[i], abs(high[i] - close[i-1]), abs(low[i] - close[i-1]))
|
|
return tr
|
|
|
|
@njit(cache=False)
|
|
def get_atr_nb(high, low, close, period):
|
|
"""Calculate ATR using Wilder's Smoothing (Numba compiled)."""
|
|
# Ensure 1D arrays
|
|
high = high.ravel()
|
|
low = low.ravel()
|
|
close = close.ravel()
|
|
|
|
# Ensure period is native Python int (critical for Numba array indexing)
|
|
n = len(close)
|
|
p = int(period)
|
|
|
|
tr = get_tr_nb(high, low, close)
|
|
atr = np.full(n, np.nan, dtype=np.float64)
|
|
|
|
if n < p:
|
|
return atr
|
|
|
|
# Initial ATR is simple average of TR
|
|
sum_tr = 0.0
|
|
for i in range(p):
|
|
sum_tr += tr[i]
|
|
atr[p - 1] = sum_tr / p
|
|
|
|
# Subsequent ATR is Wilder's smoothed
|
|
for i in range(p, n):
|
|
atr[i] = (atr[i - 1] * (p - 1) + tr[i]) / p
|
|
|
|
return atr
|
|
|
|
@njit(cache=False)
|
|
def get_supertrend_nb(high, low, close, period, multiplier):
|
|
"""Calculate SuperTrend completely in Numba."""
|
|
# Ensure 1D arrays
|
|
high = high.ravel()
|
|
low = low.ravel()
|
|
close = close.ravel()
|
|
|
|
# Ensure params are native Python types (critical for Numba)
|
|
n = len(close)
|
|
p = int(period)
|
|
m = float(multiplier)
|
|
|
|
atr = get_atr_nb(high, low, close, p)
|
|
|
|
final_upper = np.full(n, np.nan, dtype=np.float64)
|
|
final_lower = np.full(n, np.nan, dtype=np.float64)
|
|
trend = np.ones(n, dtype=np.int8) # 1 Bull, -1 Bear
|
|
|
|
# Skip until we have valid ATR
|
|
start_idx = p
|
|
if start_idx >= n:
|
|
return trend
|
|
|
|
# Init first valid point
|
|
hl2 = (high[start_idx] + low[start_idx]) / 2
|
|
final_upper[start_idx] = hl2 + m * atr[start_idx]
|
|
final_lower[start_idx] = hl2 - m * atr[start_idx]
|
|
|
|
# Loop
|
|
for i in range(start_idx + 1, n):
|
|
cur_hl2 = (high[i] + low[i]) / 2
|
|
cur_atr = atr[i]
|
|
basic_upper = cur_hl2 + m * cur_atr
|
|
basic_lower = cur_hl2 - m * cur_atr
|
|
|
|
# Upper Band Logic
|
|
if basic_upper < final_upper[i-1] or close[i-1] > final_upper[i-1]:
|
|
final_upper[i] = basic_upper
|
|
else:
|
|
final_upper[i] = final_upper[i-1]
|
|
|
|
# Lower Band Logic
|
|
if basic_lower > final_lower[i-1] or close[i-1] < final_lower[i-1]:
|
|
final_lower[i] = basic_lower
|
|
else:
|
|
final_lower[i] = final_lower[i-1]
|
|
|
|
# Trend Logic
|
|
if trend[i-1] == 1:
|
|
if close[i] < final_lower[i-1]:
|
|
trend[i] = -1
|
|
else:
|
|
trend[i] = 1
|
|
else:
|
|
if close[i] > final_upper[i-1]:
|
|
trend[i] = 1
|
|
else:
|
|
trend[i] = -1
|
|
|
|
return trend
|
|
|
|
# --- VectorBT Indicator Factory ---
|
|
|
|
SuperTrendIndicator = vbt.IndicatorFactory(
|
|
class_name='SuperTrend',
|
|
short_name='st',
|
|
input_names=['high', 'low', 'close'],
|
|
param_names=['period', 'multiplier'],
|
|
output_names=['trend']
|
|
).from_apply_func(
|
|
get_supertrend_nb,
|
|
keep_pd=False, # Disable automatic Pandas wrapping of inputs
|
|
param_product=True # Enable Cartesian product for list params
|
|
)
|