Update date filtering in main.py and enhance TrendDetectorSimple with SuperTrend calculations and improved plotting functionality. Introduce color configurations for better visualization and streamline trend analysis methods.
This commit is contained in:
parent
c7732881c5
commit
f316571a3c
5
main.py
5
main.py
@ -8,7 +8,7 @@ data = pd.read_csv('data/btcusd_1-day_data.csv')
|
||||
|
||||
|
||||
# Convert datetime column to datetime type
|
||||
start_date = pd.to_datetime('2025-04-01')
|
||||
start_date = pd.to_datetime('2024-04-06')
|
||||
stop_date = pd.to_datetime('2025-05-06')
|
||||
|
||||
daily_data = data[(pd.to_datetime(data['datetime']) >= start_date) &
|
||||
@ -17,13 +17,12 @@ print(f"Number of data points: {len(daily_data)}")
|
||||
|
||||
trend_detector = TrendDetectorSimple(daily_data, verbose=True)
|
||||
trends, analysis_results = trend_detector.detect_trends()
|
||||
trend_detector.plot_trends(trends, analysis_results)
|
||||
trend_detector.plot_trends(trends, analysis_results, "supertrend")
|
||||
|
||||
#trend_detector = TrendDetectorMACD(daily_data, True)
|
||||
#trends = trend_detector.detect_trends_MACD_signal()
|
||||
#trend_detector.plot_trends(trends)
|
||||
|
||||
|
||||
# # Cycle detection (new code)
|
||||
# print("\n===== CYCLE DETECTION =====")
|
||||
|
||||
|
||||
@ -2,10 +2,35 @@ import pandas as pd
|
||||
import numpy as np
|
||||
import logging
|
||||
from scipy.signal import find_peaks
|
||||
import matplotlib.dates as mdates
|
||||
from matplotlib.patches import Rectangle
|
||||
from scipy import stats
|
||||
from scipy import stats
|
||||
|
||||
# Color configuration
|
||||
# Plot colors
|
||||
DARK_BG_COLOR = '#181C27'
|
||||
LEGEND_BG_COLOR = '#333333'
|
||||
TITLE_COLOR = 'white'
|
||||
AXIS_LABEL_COLOR = 'white'
|
||||
|
||||
# Candlestick colors
|
||||
CANDLE_UP_COLOR = '#089981' # Green
|
||||
CANDLE_DOWN_COLOR = '#F23645' # Red
|
||||
|
||||
# Marker colors
|
||||
MIN_COLOR = 'red'
|
||||
MAX_COLOR = 'green'
|
||||
|
||||
# Line style colors
|
||||
MIN_LINE_STYLE = 'g--' # Green dashed
|
||||
MAX_LINE_STYLE = 'r--' # Red dashed
|
||||
SMA7_LINE_STYLE = 'y-' # Yellow solid
|
||||
SMA15_LINE_STYLE = 'm-' # Magenta solid
|
||||
|
||||
# SuperTrend colors
|
||||
ST_COLOR_UP = 'g-'
|
||||
ST_COLOR_DOWN = 'r-'
|
||||
|
||||
class TrendDetectorSimple:
|
||||
def __init__(self, data, verbose=False):
|
||||
"""
|
||||
@ -21,41 +46,41 @@ class TrendDetectorSimple:
|
||||
|
||||
# Plot style configuration
|
||||
self.plot_style = 'dark_background'
|
||||
self.bg_color = '#181C27'
|
||||
self.bg_color = DARK_BG_COLOR
|
||||
self.plot_size = (12, 8)
|
||||
|
||||
# Candlestick configuration
|
||||
self.candle_width = 0.6
|
||||
self.candle_up_color = '#089981'
|
||||
self.candle_down_color = '#F23645'
|
||||
self.candle_up_color = CANDLE_UP_COLOR
|
||||
self.candle_down_color = CANDLE_DOWN_COLOR
|
||||
self.candle_alpha = 0.8
|
||||
self.wick_width = 1
|
||||
|
||||
# Marker configuration
|
||||
self.min_marker = '^'
|
||||
self.min_color = 'red'
|
||||
self.min_color = MIN_COLOR
|
||||
self.min_size = 100
|
||||
self.max_marker = 'v'
|
||||
self.max_color = 'green'
|
||||
self.max_color = MAX_COLOR
|
||||
self.max_size = 100
|
||||
self.marker_zorder = 100
|
||||
|
||||
# Line configuration
|
||||
self.line_width = 2
|
||||
self.min_line_style = 'g--' # green dashed
|
||||
self.max_line_style = 'r--' # red dashed
|
||||
self.sma7_line_style = 'y-' # yellow solid
|
||||
self.sma15_line_style = 'm-' # magenta solid
|
||||
self.line_width = 1
|
||||
self.min_line_style = MIN_LINE_STYLE
|
||||
self.max_line_style = MAX_LINE_STYLE
|
||||
self.sma7_line_style = SMA7_LINE_STYLE
|
||||
self.sma15_line_style = SMA15_LINE_STYLE
|
||||
|
||||
# Text configuration
|
||||
self.title_size = 14
|
||||
self.title_color = 'white'
|
||||
self.title_color = TITLE_COLOR
|
||||
self.axis_label_size = 12
|
||||
self.axis_label_color = 'white'
|
||||
self.axis_label_color = AXIS_LABEL_COLOR
|
||||
|
||||
# Legend configuration
|
||||
self.legend_loc = 'best'
|
||||
self.legend_bg_color = '#333333'
|
||||
self.legend_bg_color = LEGEND_BG_COLOR
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
||||
@ -73,6 +98,66 @@ class TrendDetectorSimple:
|
||||
|
||||
self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points")
|
||||
|
||||
def calculate_tr(self):
|
||||
"""
|
||||
Calculate True Range (TR) for the price data.
|
||||
|
||||
True Range is the greatest of:
|
||||
1. Current high - current low
|
||||
2. |Current high - previous close|
|
||||
3. |Current low - previous close|
|
||||
|
||||
Returns:
|
||||
- Numpy array of TR values
|
||||
"""
|
||||
df = self.data.copy()
|
||||
high = df['high'].values
|
||||
low = df['low'].values
|
||||
close = df['close'].values
|
||||
|
||||
tr = np.zeros_like(close)
|
||||
tr[0] = high[0] - low[0] # First TR is just the first day's range
|
||||
|
||||
for i in range(1, len(close)):
|
||||
# Current high - current low
|
||||
hl_range = high[i] - low[i]
|
||||
# |Current high - previous close|
|
||||
hc_range = abs(high[i] - close[i-1])
|
||||
# |Current low - previous close|
|
||||
lc_range = abs(low[i] - close[i-1])
|
||||
|
||||
# TR is the maximum of these three values
|
||||
tr[i] = max(hl_range, hc_range, lc_range)
|
||||
|
||||
return tr
|
||||
|
||||
def calculate_atr(self, period=14):
|
||||
"""
|
||||
Calculate Average True Range (ATR) for the price data.
|
||||
|
||||
ATR is the exponential moving average of the True Range over a specified period.
|
||||
|
||||
Parameters:
|
||||
- period: int, the period for the ATR calculation (default: 14)
|
||||
|
||||
Returns:
|
||||
- Numpy array of ATR values
|
||||
"""
|
||||
|
||||
tr = self.calculate_tr()
|
||||
atr = np.zeros_like(tr)
|
||||
|
||||
# First ATR value is just the first TR
|
||||
atr[0] = tr[0]
|
||||
|
||||
# Calculate exponential moving average (EMA) of TR
|
||||
multiplier = 2.0 / (period + 1)
|
||||
|
||||
for i in range(1, len(tr)):
|
||||
atr[i] = (tr[i] * multiplier) + (atr[i-1] * (1 - multiplier))
|
||||
|
||||
return atr
|
||||
|
||||
def detect_trends(self):
|
||||
"""
|
||||
Detect trends by identifying local minima and maxima in the price data
|
||||
@ -84,15 +169,13 @@ class TrendDetectorSimple:
|
||||
|
||||
Returns:
|
||||
- DataFrame with columns for timestamps, prices, and trend indicators
|
||||
- Dictionary containing analysis results including linear regression, SMAs, and SuperTrend indicators
|
||||
"""
|
||||
self.logger.info(f"Detecting trends")
|
||||
self.logger.info(f"Detecting trends")
|
||||
|
||||
df = self.data.copy()
|
||||
close_prices = df['close'].values
|
||||
|
||||
max_peaks, _ = find_peaks(close_prices)
|
||||
min_peaks, _ = find_peaks(-close_prices)
|
||||
max_peaks, _ = find_peaks(close_prices)
|
||||
min_peaks, _ = find_peaks(-close_prices)
|
||||
|
||||
@ -118,9 +201,7 @@ class TrendDetectorSimple:
|
||||
# Linear regression for max peaks if we have at least 2 points
|
||||
max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
|
||||
|
||||
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
|
||||
self.logger.info("Calculating SMA-7 and SMA-15")
|
||||
|
||||
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
|
||||
sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values
|
||||
sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values
|
||||
|
||||
@ -142,37 +223,182 @@ class TrendDetectorSimple:
|
||||
'15': sma_15
|
||||
}
|
||||
|
||||
self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}")
|
||||
self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}")
|
||||
|
||||
# Calculate SuperTrend indicators
|
||||
supertrend_results_list = self._calculate_supertrend_indicators()
|
||||
analysis_results['supertrend'] = supertrend_results_list
|
||||
|
||||
return result, analysis_results
|
||||
|
||||
def plot_trends(self, trend_data, analysis_results):
|
||||
|
||||
def _calculate_supertrend_indicators(self):
|
||||
"""
|
||||
Plot the price data with detected trends using a candlestick chart.
|
||||
Calculate SuperTrend indicators with different parameter sets.
|
||||
|
||||
Returns:
|
||||
- list, the SuperTrend results
|
||||
"""
|
||||
supertrend_params = [
|
||||
{"period": 12, "multiplier": 3.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
|
||||
{"period": 10, "multiplier": 1.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
|
||||
{"period": 11, "multiplier": 2.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN}
|
||||
]
|
||||
|
||||
supertrend_results_list = []
|
||||
for params in supertrend_params:
|
||||
supertrend_results = self.calculate_supertrend(
|
||||
period=params["period"],
|
||||
multiplier=params["multiplier"]
|
||||
)
|
||||
supertrend_results_list.append({
|
||||
"results": supertrend_results,
|
||||
"params": params
|
||||
})
|
||||
|
||||
return supertrend_results_list
|
||||
|
||||
def calculate_supertrend(self, period, multiplier):
|
||||
"""
|
||||
Calculate SuperTrend indicator for the price data.
|
||||
|
||||
SuperTrend is a trend-following indicator that uses ATR to determine the trend direction.
|
||||
|
||||
Parameters:
|
||||
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
|
||||
- period: int, the period for the ATR calculation (default: 10)
|
||||
- multiplier: float, the multiplier for the ATR (default: 3.0)
|
||||
|
||||
Returns:
|
||||
- Dictionary containing SuperTrend values, trend direction, and upper/lower bands
|
||||
"""
|
||||
df = self.data.copy()
|
||||
high = df['high'].values
|
||||
low = df['low'].values
|
||||
close = df['close'].values
|
||||
|
||||
# Calculate ATR
|
||||
atr = self.calculate_atr(period)
|
||||
|
||||
# Calculate basic upper and lower bands
|
||||
upper_band = np.zeros_like(close)
|
||||
lower_band = np.zeros_like(close)
|
||||
|
||||
for i in range(len(close)):
|
||||
# Calculate the basic bands
|
||||
hl_avg = (high[i] + low[i]) / 2
|
||||
upper_band[i] = hl_avg + (multiplier * atr[i])
|
||||
lower_band[i] = hl_avg - (multiplier * atr[i])
|
||||
|
||||
# Calculate final upper and lower bands with trend logic
|
||||
final_upper = np.zeros_like(close)
|
||||
final_lower = np.zeros_like(close)
|
||||
supertrend = np.zeros_like(close)
|
||||
trend = np.zeros_like(close) # 1 for uptrend, -1 for downtrend
|
||||
|
||||
# Initialize first values
|
||||
final_upper[0] = upper_band[0]
|
||||
final_lower[0] = lower_band[0]
|
||||
|
||||
# If close price is above upper band, we're in a downtrend (ST = upper band)
|
||||
# If close price is below lower band, we're in an uptrend (ST = lower band)
|
||||
if close[0] <= upper_band[0]:
|
||||
supertrend[0] = upper_band[0]
|
||||
trend[0] = -1 # Downtrend
|
||||
else:
|
||||
supertrend[0] = lower_band[0]
|
||||
trend[0] = 1 # Uptrend
|
||||
|
||||
# Calculate SuperTrend for the rest of the data
|
||||
for i in range(1, len(close)):
|
||||
# Calculate final upper band
|
||||
if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]):
|
||||
final_upper[i] = upper_band[i]
|
||||
else:
|
||||
final_upper[i] = final_upper[i-1]
|
||||
|
||||
# Calculate final lower band
|
||||
if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]):
|
||||
final_lower[i] = lower_band[i]
|
||||
else:
|
||||
final_lower[i] = final_lower[i-1]
|
||||
|
||||
# Determine trend and SuperTrend value
|
||||
if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]:
|
||||
# Continuing downtrend
|
||||
supertrend[i] = final_upper[i]
|
||||
trend[i] = -1
|
||||
elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]:
|
||||
# Switching to uptrend
|
||||
supertrend[i] = final_lower[i]
|
||||
trend[i] = 1
|
||||
elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]:
|
||||
# Continuing uptrend
|
||||
supertrend[i] = final_lower[i]
|
||||
trend[i] = 1
|
||||
elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]:
|
||||
# Switching to downtrend
|
||||
supertrend[i] = final_upper[i]
|
||||
trend[i] = -1
|
||||
|
||||
# Prepare result
|
||||
supertrend_results = {
|
||||
'supertrend': supertrend,
|
||||
'trend': trend,
|
||||
'upper_band': final_upper,
|
||||
'lower_band': final_lower
|
||||
}
|
||||
|
||||
return supertrend_results
|
||||
|
||||
def plot_trends(self, trend_data, analysis_results, view="both"):
|
||||
"""
|
||||
Plot the price data with detected trends using a candlestick chart.
|
||||
Also plots SuperTrend indicators with three different parameter sets.
|
||||
|
||||
Parameters:
|
||||
- trend_data: DataFrame, the output from detect_trends()
|
||||
- analysis_results: Dictionary containing analysis results from detect_trends()
|
||||
- view: str, one of 'both', 'trend', 'supertrend'; determines which plot(s) to display
|
||||
|
||||
Returns:
|
||||
- None (displays the plot)
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
# Create the figure and axis with specified background
|
||||
|
||||
plt.style.use(self.plot_style)
|
||||
fig, ax = plt.subplots(figsize=self.plot_size)
|
||||
|
||||
# Set the custom background color
|
||||
|
||||
if view == "both":
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(self.plot_size[0]*2, self.plot_size[1]))
|
||||
else:
|
||||
fig, ax = plt.subplots(figsize=self.plot_size)
|
||||
ax1 = ax2 = None
|
||||
if view == "trend":
|
||||
ax1 = ax
|
||||
elif view == "supertrend":
|
||||
ax2 = ax
|
||||
|
||||
fig.patch.set_facecolor(self.bg_color)
|
||||
ax.set_facecolor(self.bg_color)
|
||||
|
||||
# Create a copy of the data
|
||||
if ax1: ax1.set_facecolor(self.bg_color)
|
||||
if ax2: ax2.set_facecolor(self.bg_color)
|
||||
|
||||
df = self.data.copy()
|
||||
|
||||
if ax1:
|
||||
self._plot_trend_analysis(ax1, df, trend_data, analysis_results)
|
||||
|
||||
if ax2:
|
||||
self._plot_supertrend_analysis(ax2, df, analysis_results['supertrend'])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
def _plot_candlesticks(self, ax, df):
|
||||
"""
|
||||
Plot candlesticks on the given axis.
|
||||
|
||||
# Draw candlesticks manually
|
||||
x_values = range(len(df))
|
||||
Parameters:
|
||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
||||
- df: pandas.DataFrame, the data to plot
|
||||
"""
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
for i in range(len(df)):
|
||||
# Get OHLC values for this candle
|
||||
@ -193,7 +419,39 @@ class TrendDetectorSimple:
|
||||
|
||||
# Plot candle wicks
|
||||
ax.plot([i, i], [low_val, high_val], color=color, linewidth=self.wick_width)
|
||||
|
||||
def _plot_trend_analysis(self, ax, df, trend_data, analysis_results):
|
||||
"""
|
||||
Plot trend analysis on the given axis.
|
||||
|
||||
Parameters:
|
||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
||||
- df: pandas.DataFrame, the data to plot
|
||||
- trend_data: pandas.DataFrame, the trend data
|
||||
- analysis_results: dict, the analysis results
|
||||
"""
|
||||
# Draw candlesticks
|
||||
self._plot_candlesticks(ax, df)
|
||||
|
||||
# Plot minima and maxima points
|
||||
self._plot_min_max_points(ax, df, trend_data)
|
||||
|
||||
# Plot trend lines and moving averages
|
||||
if analysis_results:
|
||||
self._plot_trend_lines(ax, df, analysis_results)
|
||||
|
||||
# Configure the subplot
|
||||
self._configure_subplot(ax, 'Price Chart with Trend Analysis', len(df))
|
||||
|
||||
def _plot_min_max_points(self, ax, df, trend_data):
|
||||
"""
|
||||
Plot minimum and maximum points on the given axis.
|
||||
|
||||
Parameters:
|
||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
||||
- df: pandas.DataFrame, the data to plot
|
||||
- trend_data: pandas.DataFrame, the trend data
|
||||
"""
|
||||
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
|
||||
if min_indices:
|
||||
min_y = [df['close'].iloc[i] for i in min_indices]
|
||||
@ -205,49 +463,188 @@ class TrendDetectorSimple:
|
||||
max_y = [df['close'].iloc[i] for i in max_indices]
|
||||
ax.scatter(max_indices, max_y, color=self.max_color, s=self.max_size,
|
||||
marker=self.max_marker, label='Local Maxima', zorder=self.marker_zorder)
|
||||
|
||||
def _plot_trend_lines(self, ax, df, analysis_results):
|
||||
"""
|
||||
Plot trend lines on the given axis.
|
||||
|
||||
if analysis_results:
|
||||
x_vals = np.arange(len(df))
|
||||
# Minima regression line (support)
|
||||
min_slope = analysis_results['linear_regression']['min']['slope']
|
||||
min_intercept = analysis_results['linear_regression']['min']['intercept']
|
||||
min_line = min_slope * x_vals + min_intercept
|
||||
ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width,
|
||||
label='Minima Regression')
|
||||
|
||||
# Maxima regression line (resistance)
|
||||
max_slope = analysis_results['linear_regression']['max']['slope']
|
||||
max_intercept = analysis_results['linear_regression']['max']['intercept']
|
||||
max_line = max_slope * x_vals + max_intercept
|
||||
ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width,
|
||||
label='Maxima Regression')
|
||||
|
||||
# SMA-7 line
|
||||
sma_7 = analysis_results['sma']['7']
|
||||
ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width,
|
||||
label='SMA-7')
|
||||
|
||||
# SMA-15 line
|
||||
sma_15 = analysis_results['sma']['15']
|
||||
valid_idx_15 = ~np.isnan(sma_15)
|
||||
ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style,
|
||||
linewidth=self.line_width, label='SMA-15')
|
||||
Parameters:
|
||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
||||
- df: pandas.DataFrame, the data to plot
|
||||
- analysis_results: dict, the analysis results
|
||||
"""
|
||||
x_vals = np.arange(len(df))
|
||||
|
||||
# Minima regression line (support)
|
||||
min_slope = analysis_results['linear_regression']['min']['slope']
|
||||
min_intercept = analysis_results['linear_regression']['min']['intercept']
|
||||
min_line = min_slope * x_vals + min_intercept
|
||||
ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width,
|
||||
label='Minima Regression')
|
||||
|
||||
# Maxima regression line (resistance)
|
||||
max_slope = analysis_results['linear_regression']['max']['slope']
|
||||
max_intercept = analysis_results['linear_regression']['max']['intercept']
|
||||
max_line = max_slope * x_vals + max_intercept
|
||||
ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width,
|
||||
label='Maxima Regression')
|
||||
|
||||
# SMA-7 line
|
||||
sma_7 = analysis_results['sma']['7']
|
||||
ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width,
|
||||
label='SMA-7')
|
||||
|
||||
# SMA-15 line
|
||||
sma_15 = analysis_results['sma']['15']
|
||||
valid_idx_15 = ~np.isnan(sma_15)
|
||||
ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style,
|
||||
linewidth=self.line_width, label='SMA-15')
|
||||
|
||||
def _configure_subplot(self, ax, title, data_length):
|
||||
"""
|
||||
Configure the subplot with title, labels, limits, and legend.
|
||||
|
||||
Parameters:
|
||||
- ax: matplotlib.axes.Axes, the axis to configure
|
||||
- title: str, the title of the subplot
|
||||
- data_length: int, the length of the data
|
||||
"""
|
||||
# Set title and labels
|
||||
ax.set_title('Price Candlestick Chart with Local Minima and Maxima',
|
||||
fontsize=self.title_size, color=self.title_color)
|
||||
ax.set_title(title, fontsize=self.title_size, color=self.title_color)
|
||||
ax.set_xlabel('Date', fontsize=self.axis_label_size, color=self.axis_label_color)
|
||||
ax.set_ylabel('Price', fontsize=self.axis_label_size, color=self.axis_label_color)
|
||||
|
||||
# Set appropriate x-axis limits
|
||||
ax.set_xlim(-0.5, len(df) - 0.5)
|
||||
ax.set_xlim(-0.5, data_length - 0.5)
|
||||
|
||||
# Add a legend
|
||||
ax.legend(loc=self.legend_loc, facecolor=self.legend_bg_color)
|
||||
|
||||
def _plot_supertrend_analysis(self, ax, df, supertrend_results_list=None):
|
||||
"""
|
||||
Plot SuperTrend analysis on the given axis.
|
||||
|
||||
# Adjust layout
|
||||
plt.tight_layout()
|
||||
Parameters:
|
||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
||||
- df: pandas.DataFrame, the data to plot
|
||||
- supertrend_results_list: list, the SuperTrend results (optional)
|
||||
"""
|
||||
self._plot_candlesticks(ax, df)
|
||||
self._plot_supertrend_lines(ax, df, supertrend_results_list, style='Both')
|
||||
self._configure_subplot(ax, 'Multiple SuperTrend Indicators', len(df))
|
||||
|
||||
# Show the plot
|
||||
plt.show()
|
||||
|
||||
def _plot_supertrend_lines(self, ax, df, supertrend_results_list, style="Horizontal"):
|
||||
"""
|
||||
Plot SuperTrend lines on the given axis.
|
||||
|
||||
Parameters:
|
||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
||||
- df: pandas.DataFrame, the data to plot
|
||||
- supertrend_results_list: list, the SuperTrend results
|
||||
"""
|
||||
x_vals = np.arange(len(df))
|
||||
|
||||
if style == 'Horizontal' or style == 'Both':
|
||||
if len(supertrend_results_list) != 3:
|
||||
raise ValueError("Expected exactly 3 SuperTrend results for meta calculation")
|
||||
|
||||
trends = [st["results"]["trend"] for st in supertrend_results_list]
|
||||
|
||||
band_height = 0.02 * (df["high"].max() - df["low"].min())
|
||||
y_base = df["low"].min() - band_height * 1.5
|
||||
|
||||
prev_color = None
|
||||
for i in range(1, len(x_vals)):
|
||||
t_vals = [t[i] for t in trends]
|
||||
up_count = t_vals.count(1)
|
||||
down_count = t_vals.count(-1)
|
||||
|
||||
if down_count == 3:
|
||||
color = "red"
|
||||
elif down_count == 2 and up_count == 1:
|
||||
color = "orange"
|
||||
elif down_count == 1 and up_count == 2:
|
||||
color = "yellow"
|
||||
elif up_count == 3:
|
||||
color = "green"
|
||||
else:
|
||||
continue # skip if unknown or inconsistent values
|
||||
|
||||
ax.add_patch(Rectangle(
|
||||
(x_vals[i-1], y_base),
|
||||
1,
|
||||
band_height,
|
||||
color=color,
|
||||
linewidth=0,
|
||||
alpha=0.6
|
||||
))
|
||||
# Draw a vertical line at the change of color
|
||||
if prev_color and prev_color != color:
|
||||
ax.axvline(x_vals[i-1], color="grey", alpha=0.3, linewidth=1)
|
||||
prev_color = color
|
||||
|
||||
ax.set_ylim(bottom=y_base - band_height * 0.5)
|
||||
if style == 'Curves' or style == 'Both':
|
||||
for st in supertrend_results_list:
|
||||
params = st["params"]
|
||||
results = st["results"]
|
||||
supertrend = results["supertrend"]
|
||||
trend = results["trend"]
|
||||
|
||||
# Plot SuperTrend line with color based on trend
|
||||
for i in range(1, len(x_vals)):
|
||||
if trend[i] == 1: # Uptrend
|
||||
ax.plot(x_vals[i-1:i+1], supertrend[i-1:i+1], params["color_up"], linewidth=self.line_width)
|
||||
else: # Downtrend
|
||||
ax.plot(x_vals[i-1:i+1], supertrend[i-1:i+1], params["color_down"], linewidth=self.line_width)
|
||||
self._plot_metasupertrend_lines(ax, df, supertrend_results_list)
|
||||
self._add_supertrend_legend(ax, supertrend_results_list)
|
||||
|
||||
def _plot_metasupertrend_lines(self, ax, df, supertrend_results_list):
|
||||
"""
|
||||
Plot a Meta SuperTrend line where all individual SuperTrends agree on trend.
|
||||
|
||||
Parameters:
|
||||
- ax: matplotlib.axes.Axes, the axis to plot on
|
||||
- df: pandas.DataFrame, the data to plot
|
||||
- supertrend_results_list: list, each item contains SuperTrend 'results' and 'params'
|
||||
"""
|
||||
x_vals = np.arange(len(df))
|
||||
|
||||
if len(supertrend_results_list) != 3:
|
||||
raise ValueError("Expected exactly 3 SuperTrend results for meta calculation")
|
||||
|
||||
trends = [st["results"]["trend"] for st in supertrend_results_list]
|
||||
supertrends = [st["results"]["supertrend"] for st in supertrend_results_list]
|
||||
params = supertrend_results_list[0]["params"] # Use first config for styling
|
||||
|
||||
for i in range(1, len(x_vals)):
|
||||
t1, t2, t3 = trends[0][i], trends[1][i], trends[2][i]
|
||||
if t1 == t2 == t3:
|
||||
meta_trend = t1
|
||||
# Average the 3 supertrend values
|
||||
st_avg_prev = np.mean([s[i-1] for s in supertrends])
|
||||
st_avg_curr = np.mean([s[i] for s in supertrends])
|
||||
color = params["color_up"] if meta_trend == 1 else params["color_down"]
|
||||
ax.plot(x_vals[i-1:i+1], [st_avg_prev, st_avg_curr], color, linewidth=self.line_width)
|
||||
|
||||
def _add_supertrend_legend(self, ax, supertrend_results_list):
|
||||
"""
|
||||
Add SuperTrend legend entries to the given axis.
|
||||
|
||||
Parameters:
|
||||
- ax: matplotlib.axes.Axes, the axis to add legend entries to
|
||||
- supertrend_results_list: list, the SuperTrend results
|
||||
"""
|
||||
for st in supertrend_results_list:
|
||||
params = st["params"]
|
||||
period = params["period"]
|
||||
multiplier = params["multiplier"]
|
||||
color_up = params["color_up"]
|
||||
color_down = params["color_down"]
|
||||
|
||||
ax.plot([], [], color_up, linewidth=self.line_width,
|
||||
label=f'ST (P:{period}, M:{multiplier}) Up')
|
||||
ax.plot([], [], color_down, linewidth=self.line_width,
|
||||
label=f'ST (P:{period}, M:{multiplier}) Down')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user