Cycles/cycles/Analysis/supertrend.py
2025-05-22 17:09:29 +08:00

337 lines
12 KiB
Python

import pandas as pd
import numpy as np
import logging
from scipy.signal import find_peaks
from matplotlib.patches import Rectangle
from scipy import stats
import concurrent.futures
from functools import partial
from functools import lru_cache
import matplotlib.pyplot as plt
# Color configuration
# Plot colors
DARK_BG_COLOR = '#181C27'
LEGEND_BG_COLOR = '#333333'
TITLE_COLOR = 'white'
AXIS_LABEL_COLOR = 'white'
# Candlestick colors
CANDLE_UP_COLOR = '#089981' # Green
CANDLE_DOWN_COLOR = '#F23645' # Red
# Marker colors
MIN_COLOR = 'red'
MAX_COLOR = 'green'
# Line style colors
MIN_LINE_STYLE = 'g--' # Green dashed
MAX_LINE_STYLE = 'r--' # Red dashed
SMA7_LINE_STYLE = 'y-' # Yellow solid
SMA15_LINE_STYLE = 'm-' # Magenta solid
# SuperTrend colors
ST_COLOR_UP = 'g-'
ST_COLOR_DOWN = 'r-'
# Cache the calculation results by function parameters
@lru_cache(maxsize=32)
def cached_supertrend_calculation(period, multiplier, data_tuple):
# Convert tuple back to numpy arrays
high = np.array(data_tuple[0])
low = np.array(data_tuple[1])
close = np.array(data_tuple[2])
# Calculate TR and ATR using vectorized operations
tr = np.zeros_like(close)
tr[0] = high[0] - low[0]
hc_range = np.abs(high[1:] - close[:-1])
lc_range = np.abs(low[1:] - close[:-1])
hl_range = high[1:] - low[1:]
tr[1:] = np.maximum.reduce([hl_range, hc_range, lc_range])
# Use numpy's exponential moving average
atr = np.zeros_like(tr)
atr[0] = tr[0]
multiplier_ema = 2.0 / (period + 1)
for i in range(1, len(tr)):
atr[i] = (tr[i] * multiplier_ema) + (atr[i-1] * (1 - multiplier_ema))
# Calculate bands
upper_band = np.zeros_like(close)
lower_band = np.zeros_like(close)
for i in range(len(close)):
hl_avg = (high[i] + low[i]) / 2
upper_band[i] = hl_avg + (multiplier * atr[i])
lower_band[i] = hl_avg - (multiplier * atr[i])
final_upper = np.zeros_like(close)
final_lower = np.zeros_like(close)
supertrend = np.zeros_like(close)
trend = np.zeros_like(close)
final_upper[0] = upper_band[0]
final_lower[0] = lower_band[0]
if close[0] <= upper_band[0]:
supertrend[0] = upper_band[0]
trend[0] = -1
else:
supertrend[0] = lower_band[0]
trend[0] = 1
for i in range(1, len(close)):
if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]):
final_upper[i] = upper_band[i]
else:
final_upper[i] = final_upper[i-1]
if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]):
final_lower[i] = lower_band[i]
else:
final_lower[i] = final_lower[i-1]
if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]:
supertrend[i] = final_upper[i]
trend[i] = -1
elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]:
supertrend[i] = final_lower[i]
trend[i] = 1
elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]:
supertrend[i] = final_lower[i]
trend[i] = 1
elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]:
supertrend[i] = final_upper[i]
trend[i] = -1
return {
'supertrend': supertrend,
'trend': trend,
'upper_band': final_upper,
'lower_band': final_lower
}
def calculate_supertrend_external(data, period, multiplier):
# Convert DataFrame columns to hashable tuples
high_tuple = tuple(data['high'])
low_tuple = tuple(data['low'])
close_tuple = tuple(data['close'])
# Call the cached function
return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple))
class Supertrends:
def __init__(self, data, verbose=False, display=False):
"""
Initialize the TrendDetectorSimple class.
Parameters:
- data: pandas DataFrame containing price data
- verbose: boolean, whether to display detailed logging information
- display: boolean, whether to enable display/plotting features
"""
self.data = data
self.verbose = verbose
self.display = display
# Only define display-related variables if display is True
if self.display:
# Plot style configuration
self.plot_style = 'dark_background'
self.bg_color = DARK_BG_COLOR
self.plot_size = (12, 8)
# Candlestick configuration
self.candle_width = 0.6
self.candle_up_color = CANDLE_UP_COLOR
self.candle_down_color = CANDLE_DOWN_COLOR
self.candle_alpha = 0.8
self.wick_width = 1
# Marker configuration
self.min_marker = '^'
self.min_color = MIN_COLOR
self.min_size = 100
self.max_marker = 'v'
self.max_color = MAX_COLOR
self.max_size = 100
self.marker_zorder = 100
# Line configuration
self.line_width = 1
self.min_line_style = MIN_LINE_STYLE
self.max_line_style = MAX_LINE_STYLE
self.sma7_line_style = SMA7_LINE_STYLE
self.sma15_line_style = SMA15_LINE_STYLE
# Text configuration
self.title_size = 14
self.title_color = TITLE_COLOR
self.axis_label_size = 12
self.axis_label_color = AXIS_LABEL_COLOR
# Legend configuration
self.legend_loc = 'best'
self.legend_bg_color = LEGEND_BG_COLOR
# Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger('TrendDetectorSimple')
# Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame):
if isinstance(self.data, list):
self.data = pd.DataFrame({'close': self.data})
else:
raise ValueError("Data must be a pandas DataFrame or a list")
def calculate_tr(self):
"""
Calculate True Range (TR) for the price data.
True Range is the greatest of:
1. Current high - current low
2. |Current high - previous close|
3. |Current low - previous close|
Returns:
- Numpy array of TR values
"""
df = self.data.copy()
high = df['high'].values
low = df['low'].values
close = df['close'].values
tr = np.zeros_like(close)
tr[0] = high[0] - low[0] # First TR is just the first day's range
for i in range(1, len(close)):
# Current high - current low
hl_range = high[i] - low[i]
# |Current high - previous close|
hc_range = abs(high[i] - close[i-1])
# |Current low - previous close|
lc_range = abs(low[i] - close[i-1])
# TR is the maximum of these three values
tr[i] = max(hl_range, hc_range, lc_range)
return tr
def calculate_atr(self, period=14):
"""
Calculate Average True Range (ATR) for the price data.
ATR is the exponential moving average of the True Range over a specified period.
Parameters:
- period: int, the period for the ATR calculation (default: 14)
Returns:
- Numpy array of ATR values
"""
tr = self.calculate_tr()
atr = np.zeros_like(tr)
# First ATR value is just the first TR
atr[0] = tr[0]
# Calculate exponential moving average (EMA) of TR
multiplier = 2.0 / (period + 1)
for i in range(1, len(tr)):
atr[i] = (tr[i] * multiplier) + (atr[i-1] * (1 - multiplier))
return atr
def detect_trends(self):
"""
Detect trends by identifying local minima and maxima in the price data
using scipy.signal.find_peaks.
Parameters:
- prominence: float, required prominence of peaks (relative to the price range)
- width: int, required width of peaks in data points
Returns:
- DataFrame with columns for timestamps, prices, and trend indicators
- Dictionary containing analysis results including linear regression, SMAs, and SuperTrend indicators
"""
df = self.data
# close_prices = df['close'].values
# max_peaks, _ = find_peaks(close_prices)
# min_peaks, _ = find_peaks(-close_prices)
# df['is_min'] = False
# df['is_max'] = False
# for peak in max_peaks:
# df.at[peak, 'is_max'] = True
# for peak in min_peaks:
# df.at[peak, 'is_min'] = True
# result = df[['timestamp', 'close', 'is_min', 'is_max']].copy()
# Perform linear regression on min_peaks and max_peaks
# min_prices = df['close'].iloc[min_peaks].values
# max_prices = df['close'].iloc[max_peaks].values
# Linear regression for min peaks if we have at least 2 points
# min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
# Linear regression for max peaks if we have at least 2 points
# max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
# sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values
# sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values
analysis_results = {}
# analysis_results['linear_regression'] = {
# 'min': {
# 'slope': min_slope,
# 'intercept': min_intercept,
# 'r_squared': min_r_value ** 2
# },
# 'max': {
# 'slope': max_slope,
# 'intercept': max_intercept,
# 'r_squared': max_r_value ** 2
# }
# }
# analysis_results['sma'] = {
# '7': sma_7,
# '15': sma_15
# }
# Calculate SuperTrend indicators
supertrend_results_list = self._calculate_supertrend_indicators()
analysis_results['supertrend'] = supertrend_results_list
return analysis_results
def calculate_supertrend_indicators(self):
"""
Calculate SuperTrend indicators with different parameter sets in parallel.
Returns:
- list, the SuperTrend results
"""
supertrend_params = [
{"period": 12, "multiplier": 3.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
{"period": 10, "multiplier": 1.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
{"period": 11, "multiplier": 2.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN}
]
data = self.data.copy()
# For just 3 calculations, direct calculation might be faster than process pool
results = []
for p in supertrend_params:
result = calculate_supertrend_external(data, p["period"], p["multiplier"])
results.append(result)
supertrend_results_list = []
for params, result in zip(supertrend_params, results):
supertrend_results_list.append({
"results": result,
"params": params
})
return supertrend_results_list