Update date filtering in main.py and enhance TrendDetectorSimple with SuperTrend calculations and improved plotting functionality. Introduce color configurations for better visualization and streamline trend analysis methods.

This commit is contained in:
Simon Moisy 2025-05-09 15:17:30 +08:00
parent c7732881c5
commit f316571a3c
2 changed files with 470 additions and 74 deletions

View File

@ -8,7 +8,7 @@ data = pd.read_csv('data/btcusd_1-day_data.csv')
# Convert datetime column to datetime type
start_date = pd.to_datetime('2025-04-01')
start_date = pd.to_datetime('2024-04-06')
stop_date = pd.to_datetime('2025-05-06')
daily_data = data[(pd.to_datetime(data['datetime']) >= start_date) &
@ -17,13 +17,12 @@ print(f"Number of data points: {len(daily_data)}")
trend_detector = TrendDetectorSimple(daily_data, verbose=True)
trends, analysis_results = trend_detector.detect_trends()
trend_detector.plot_trends(trends, analysis_results)
trend_detector.plot_trends(trends, analysis_results, "supertrend")
#trend_detector = TrendDetectorMACD(daily_data, True)
#trends = trend_detector.detect_trends_MACD_signal()
#trend_detector.plot_trends(trends)
# # Cycle detection (new code)
# print("\n===== CYCLE DETECTION =====")

View File

@ -2,10 +2,35 @@ import pandas as pd
import numpy as np
import logging
from scipy.signal import find_peaks
import matplotlib.dates as mdates
from matplotlib.patches import Rectangle
from scipy import stats
from scipy import stats
# Color configuration
# Plot colors
DARK_BG_COLOR = '#181C27'
LEGEND_BG_COLOR = '#333333'
TITLE_COLOR = 'white'
AXIS_LABEL_COLOR = 'white'
# Candlestick colors
CANDLE_UP_COLOR = '#089981' # Green
CANDLE_DOWN_COLOR = '#F23645' # Red
# Marker colors
MIN_COLOR = 'red'
MAX_COLOR = 'green'
# Line style colors
MIN_LINE_STYLE = 'g--' # Green dashed
MAX_LINE_STYLE = 'r--' # Red dashed
SMA7_LINE_STYLE = 'y-' # Yellow solid
SMA15_LINE_STYLE = 'm-' # Magenta solid
# SuperTrend colors
ST_COLOR_UP = 'g-'
ST_COLOR_DOWN = 'r-'
class TrendDetectorSimple:
def __init__(self, data, verbose=False):
"""
@ -21,41 +46,41 @@ class TrendDetectorSimple:
# Plot style configuration
self.plot_style = 'dark_background'
self.bg_color = '#181C27'
self.bg_color = DARK_BG_COLOR
self.plot_size = (12, 8)
# Candlestick configuration
self.candle_width = 0.6
self.candle_up_color = '#089981'
self.candle_down_color = '#F23645'
self.candle_up_color = CANDLE_UP_COLOR
self.candle_down_color = CANDLE_DOWN_COLOR
self.candle_alpha = 0.8
self.wick_width = 1
# Marker configuration
self.min_marker = '^'
self.min_color = 'red'
self.min_color = MIN_COLOR
self.min_size = 100
self.max_marker = 'v'
self.max_color = 'green'
self.max_color = MAX_COLOR
self.max_size = 100
self.marker_zorder = 100
# Line configuration
self.line_width = 2
self.min_line_style = 'g--' # green dashed
self.max_line_style = 'r--' # red dashed
self.sma7_line_style = 'y-' # yellow solid
self.sma15_line_style = 'm-' # magenta solid
self.line_width = 1
self.min_line_style = MIN_LINE_STYLE
self.max_line_style = MAX_LINE_STYLE
self.sma7_line_style = SMA7_LINE_STYLE
self.sma15_line_style = SMA15_LINE_STYLE
# Text configuration
self.title_size = 14
self.title_color = 'white'
self.title_color = TITLE_COLOR
self.axis_label_size = 12
self.axis_label_color = 'white'
self.axis_label_color = AXIS_LABEL_COLOR
# Legend configuration
self.legend_loc = 'best'
self.legend_bg_color = '#333333'
self.legend_bg_color = LEGEND_BG_COLOR
# Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
@ -73,6 +98,66 @@ class TrendDetectorSimple:
self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points")
def calculate_tr(self):
"""
Calculate True Range (TR) for the price data.
True Range is the greatest of:
1. Current high - current low
2. |Current high - previous close|
3. |Current low - previous close|
Returns:
- Numpy array of TR values
"""
df = self.data.copy()
high = df['high'].values
low = df['low'].values
close = df['close'].values
tr = np.zeros_like(close)
tr[0] = high[0] - low[0] # First TR is just the first day's range
for i in range(1, len(close)):
# Current high - current low
hl_range = high[i] - low[i]
# |Current high - previous close|
hc_range = abs(high[i] - close[i-1])
# |Current low - previous close|
lc_range = abs(low[i] - close[i-1])
# TR is the maximum of these three values
tr[i] = max(hl_range, hc_range, lc_range)
return tr
def calculate_atr(self, period=14):
"""
Calculate Average True Range (ATR) for the price data.
ATR is the exponential moving average of the True Range over a specified period.
Parameters:
- period: int, the period for the ATR calculation (default: 14)
Returns:
- Numpy array of ATR values
"""
tr = self.calculate_tr()
atr = np.zeros_like(tr)
# First ATR value is just the first TR
atr[0] = tr[0]
# Calculate exponential moving average (EMA) of TR
multiplier = 2.0 / (period + 1)
for i in range(1, len(tr)):
atr[i] = (tr[i] * multiplier) + (atr[i-1] * (1 - multiplier))
return atr
def detect_trends(self):
"""
Detect trends by identifying local minima and maxima in the price data
@ -84,15 +169,13 @@ class TrendDetectorSimple:
Returns:
- DataFrame with columns for timestamps, prices, and trend indicators
- Dictionary containing analysis results including linear regression, SMAs, and SuperTrend indicators
"""
self.logger.info(f"Detecting trends")
self.logger.info(f"Detecting trends")
df = self.data.copy()
close_prices = df['close'].values
max_peaks, _ = find_peaks(close_prices)
min_peaks, _ = find_peaks(-close_prices)
max_peaks, _ = find_peaks(close_prices)
min_peaks, _ = find_peaks(-close_prices)
@ -118,9 +201,7 @@ class TrendDetectorSimple:
# Linear regression for max peaks if we have at least 2 points
max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
self.logger.info("Calculating SMA-7 and SMA-15")
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values
sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values
@ -142,37 +223,182 @@ class TrendDetectorSimple:
'15': sma_15
}
self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}")
self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}")
# Calculate SuperTrend indicators
supertrend_results_list = self._calculate_supertrend_indicators()
analysis_results['supertrend'] = supertrend_results_list
return result, analysis_results
def plot_trends(self, trend_data, analysis_results):
def _calculate_supertrend_indicators(self):
"""
Plot the price data with detected trends using a candlestick chart.
Calculate SuperTrend indicators with different parameter sets.
Returns:
- list, the SuperTrend results
"""
supertrend_params = [
{"period": 12, "multiplier": 3.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
{"period": 10, "multiplier": 1.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
{"period": 11, "multiplier": 2.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN}
]
supertrend_results_list = []
for params in supertrend_params:
supertrend_results = self.calculate_supertrend(
period=params["period"],
multiplier=params["multiplier"]
)
supertrend_results_list.append({
"results": supertrend_results,
"params": params
})
return supertrend_results_list
def calculate_supertrend(self, period, multiplier):
"""
Calculate SuperTrend indicator for the price data.
SuperTrend is a trend-following indicator that uses ATR to determine the trend direction.
Parameters:
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
- period: int, the period for the ATR calculation (default: 10)
- multiplier: float, the multiplier for the ATR (default: 3.0)
Returns:
- Dictionary containing SuperTrend values, trend direction, and upper/lower bands
"""
df = self.data.copy()
high = df['high'].values
low = df['low'].values
close = df['close'].values
# Calculate ATR
atr = self.calculate_atr(period)
# Calculate basic upper and lower bands
upper_band = np.zeros_like(close)
lower_band = np.zeros_like(close)
for i in range(len(close)):
# Calculate the basic bands
hl_avg = (high[i] + low[i]) / 2
upper_band[i] = hl_avg + (multiplier * atr[i])
lower_band[i] = hl_avg - (multiplier * atr[i])
# Calculate final upper and lower bands with trend logic
final_upper = np.zeros_like(close)
final_lower = np.zeros_like(close)
supertrend = np.zeros_like(close)
trend = np.zeros_like(close) # 1 for uptrend, -1 for downtrend
# Initialize first values
final_upper[0] = upper_band[0]
final_lower[0] = lower_band[0]
# If close price is above upper band, we're in a downtrend (ST = upper band)
# If close price is below lower band, we're in an uptrend (ST = lower band)
if close[0] <= upper_band[0]:
supertrend[0] = upper_band[0]
trend[0] = -1 # Downtrend
else:
supertrend[0] = lower_band[0]
trend[0] = 1 # Uptrend
# Calculate SuperTrend for the rest of the data
for i in range(1, len(close)):
# Calculate final upper band
if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]):
final_upper[i] = upper_band[i]
else:
final_upper[i] = final_upper[i-1]
# Calculate final lower band
if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]):
final_lower[i] = lower_band[i]
else:
final_lower[i] = final_lower[i-1]
# Determine trend and SuperTrend value
if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]:
# Continuing downtrend
supertrend[i] = final_upper[i]
trend[i] = -1
elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]:
# Switching to uptrend
supertrend[i] = final_lower[i]
trend[i] = 1
elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]:
# Continuing uptrend
supertrend[i] = final_lower[i]
trend[i] = 1
elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]:
# Switching to downtrend
supertrend[i] = final_upper[i]
trend[i] = -1
# Prepare result
supertrend_results = {
'supertrend': supertrend,
'trend': trend,
'upper_band': final_upper,
'lower_band': final_lower
}
return supertrend_results
def plot_trends(self, trend_data, analysis_results, view="both"):
"""
Plot the price data with detected trends using a candlestick chart.
Also plots SuperTrend indicators with three different parameter sets.
Parameters:
- trend_data: DataFrame, the output from detect_trends()
- analysis_results: Dictionary containing analysis results from detect_trends()
- view: str, one of 'both', 'trend', 'supertrend'; determines which plot(s) to display
Returns:
- None (displays the plot)
"""
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# Create the figure and axis with specified background
plt.style.use(self.plot_style)
fig, ax = plt.subplots(figsize=self.plot_size)
# Set the custom background color
if view == "both":
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(self.plot_size[0]*2, self.plot_size[1]))
else:
fig, ax = plt.subplots(figsize=self.plot_size)
ax1 = ax2 = None
if view == "trend":
ax1 = ax
elif view == "supertrend":
ax2 = ax
fig.patch.set_facecolor(self.bg_color)
ax.set_facecolor(self.bg_color)
# Create a copy of the data
if ax1: ax1.set_facecolor(self.bg_color)
if ax2: ax2.set_facecolor(self.bg_color)
df = self.data.copy()
if ax1:
self._plot_trend_analysis(ax1, df, trend_data, analysis_results)
if ax2:
self._plot_supertrend_analysis(ax2, df, analysis_results['supertrend'])
plt.tight_layout()
plt.show()
def _plot_candlesticks(self, ax, df):
"""
Plot candlesticks on the given axis.
# Draw candlesticks manually
x_values = range(len(df))
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
"""
from matplotlib.patches import Rectangle
for i in range(len(df)):
# Get OHLC values for this candle
@ -193,7 +419,39 @@ class TrendDetectorSimple:
# Plot candle wicks
ax.plot([i, i], [low_val, high_val], color=color, linewidth=self.wick_width)
def _plot_trend_analysis(self, ax, df, trend_data, analysis_results):
"""
Plot trend analysis on the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- trend_data: pandas.DataFrame, the trend data
- analysis_results: dict, the analysis results
"""
# Draw candlesticks
self._plot_candlesticks(ax, df)
# Plot minima and maxima points
self._plot_min_max_points(ax, df, trend_data)
# Plot trend lines and moving averages
if analysis_results:
self._plot_trend_lines(ax, df, analysis_results)
# Configure the subplot
self._configure_subplot(ax, 'Price Chart with Trend Analysis', len(df))
def _plot_min_max_points(self, ax, df, trend_data):
"""
Plot minimum and maximum points on the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- trend_data: pandas.DataFrame, the trend data
"""
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
if min_indices:
min_y = [df['close'].iloc[i] for i in min_indices]
@ -205,49 +463,188 @@ class TrendDetectorSimple:
max_y = [df['close'].iloc[i] for i in max_indices]
ax.scatter(max_indices, max_y, color=self.max_color, s=self.max_size,
marker=self.max_marker, label='Local Maxima', zorder=self.marker_zorder)
def _plot_trend_lines(self, ax, df, analysis_results):
"""
Plot trend lines on the given axis.
if analysis_results:
x_vals = np.arange(len(df))
# Minima regression line (support)
min_slope = analysis_results['linear_regression']['min']['slope']
min_intercept = analysis_results['linear_regression']['min']['intercept']
min_line = min_slope * x_vals + min_intercept
ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width,
label='Minima Regression')
# Maxima regression line (resistance)
max_slope = analysis_results['linear_regression']['max']['slope']
max_intercept = analysis_results['linear_regression']['max']['intercept']
max_line = max_slope * x_vals + max_intercept
ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width,
label='Maxima Regression')
# SMA-7 line
sma_7 = analysis_results['sma']['7']
ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width,
label='SMA-7')
# SMA-15 line
sma_15 = analysis_results['sma']['15']
valid_idx_15 = ~np.isnan(sma_15)
ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style,
linewidth=self.line_width, label='SMA-15')
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- analysis_results: dict, the analysis results
"""
x_vals = np.arange(len(df))
# Minima regression line (support)
min_slope = analysis_results['linear_regression']['min']['slope']
min_intercept = analysis_results['linear_regression']['min']['intercept']
min_line = min_slope * x_vals + min_intercept
ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width,
label='Minima Regression')
# Maxima regression line (resistance)
max_slope = analysis_results['linear_regression']['max']['slope']
max_intercept = analysis_results['linear_regression']['max']['intercept']
max_line = max_slope * x_vals + max_intercept
ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width,
label='Maxima Regression')
# SMA-7 line
sma_7 = analysis_results['sma']['7']
ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width,
label='SMA-7')
# SMA-15 line
sma_15 = analysis_results['sma']['15']
valid_idx_15 = ~np.isnan(sma_15)
ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style,
linewidth=self.line_width, label='SMA-15')
def _configure_subplot(self, ax, title, data_length):
"""
Configure the subplot with title, labels, limits, and legend.
Parameters:
- ax: matplotlib.axes.Axes, the axis to configure
- title: str, the title of the subplot
- data_length: int, the length of the data
"""
# Set title and labels
ax.set_title('Price Candlestick Chart with Local Minima and Maxima',
fontsize=self.title_size, color=self.title_color)
ax.set_title(title, fontsize=self.title_size, color=self.title_color)
ax.set_xlabel('Date', fontsize=self.axis_label_size, color=self.axis_label_color)
ax.set_ylabel('Price', fontsize=self.axis_label_size, color=self.axis_label_color)
# Set appropriate x-axis limits
ax.set_xlim(-0.5, len(df) - 0.5)
ax.set_xlim(-0.5, data_length - 0.5)
# Add a legend
ax.legend(loc=self.legend_loc, facecolor=self.legend_bg_color)
def _plot_supertrend_analysis(self, ax, df, supertrend_results_list=None):
"""
Plot SuperTrend analysis on the given axis.
# Adjust layout
plt.tight_layout()
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- supertrend_results_list: list, the SuperTrend results (optional)
"""
self._plot_candlesticks(ax, df)
self._plot_supertrend_lines(ax, df, supertrend_results_list, style='Both')
self._configure_subplot(ax, 'Multiple SuperTrend Indicators', len(df))
# Show the plot
plt.show()
def _plot_supertrend_lines(self, ax, df, supertrend_results_list, style="Horizontal"):
"""
Plot SuperTrend lines on the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- supertrend_results_list: list, the SuperTrend results
"""
x_vals = np.arange(len(df))
if style == 'Horizontal' or style == 'Both':
if len(supertrend_results_list) != 3:
raise ValueError("Expected exactly 3 SuperTrend results for meta calculation")
trends = [st["results"]["trend"] for st in supertrend_results_list]
band_height = 0.02 * (df["high"].max() - df["low"].min())
y_base = df["low"].min() - band_height * 1.5
prev_color = None
for i in range(1, len(x_vals)):
t_vals = [t[i] for t in trends]
up_count = t_vals.count(1)
down_count = t_vals.count(-1)
if down_count == 3:
color = "red"
elif down_count == 2 and up_count == 1:
color = "orange"
elif down_count == 1 and up_count == 2:
color = "yellow"
elif up_count == 3:
color = "green"
else:
continue # skip if unknown or inconsistent values
ax.add_patch(Rectangle(
(x_vals[i-1], y_base),
1,
band_height,
color=color,
linewidth=0,
alpha=0.6
))
# Draw a vertical line at the change of color
if prev_color and prev_color != color:
ax.axvline(x_vals[i-1], color="grey", alpha=0.3, linewidth=1)
prev_color = color
ax.set_ylim(bottom=y_base - band_height * 0.5)
if style == 'Curves' or style == 'Both':
for st in supertrend_results_list:
params = st["params"]
results = st["results"]
supertrend = results["supertrend"]
trend = results["trend"]
# Plot SuperTrend line with color based on trend
for i in range(1, len(x_vals)):
if trend[i] == 1: # Uptrend
ax.plot(x_vals[i-1:i+1], supertrend[i-1:i+1], params["color_up"], linewidth=self.line_width)
else: # Downtrend
ax.plot(x_vals[i-1:i+1], supertrend[i-1:i+1], params["color_down"], linewidth=self.line_width)
self._plot_metasupertrend_lines(ax, df, supertrend_results_list)
self._add_supertrend_legend(ax, supertrend_results_list)
def _plot_metasupertrend_lines(self, ax, df, supertrend_results_list):
"""
Plot a Meta SuperTrend line where all individual SuperTrends agree on trend.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- supertrend_results_list: list, each item contains SuperTrend 'results' and 'params'
"""
x_vals = np.arange(len(df))
if len(supertrend_results_list) != 3:
raise ValueError("Expected exactly 3 SuperTrend results for meta calculation")
trends = [st["results"]["trend"] for st in supertrend_results_list]
supertrends = [st["results"]["supertrend"] for st in supertrend_results_list]
params = supertrend_results_list[0]["params"] # Use first config for styling
for i in range(1, len(x_vals)):
t1, t2, t3 = trends[0][i], trends[1][i], trends[2][i]
if t1 == t2 == t3:
meta_trend = t1
# Average the 3 supertrend values
st_avg_prev = np.mean([s[i-1] for s in supertrends])
st_avg_curr = np.mean([s[i] for s in supertrends])
color = params["color_up"] if meta_trend == 1 else params["color_down"]
ax.plot(x_vals[i-1:i+1], [st_avg_prev, st_avg_curr], color, linewidth=self.line_width)
def _add_supertrend_legend(self, ax, supertrend_results_list):
"""
Add SuperTrend legend entries to the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to add legend entries to
- supertrend_results_list: list, the SuperTrend results
"""
for st in supertrend_results_list:
params = st["params"]
period = params["period"]
multiplier = params["multiplier"]
color_up = params["color_up"]
color_down = params["color_down"]
ax.plot([], [], color_up, linewidth=self.line_width,
label=f'ST (P:{period}, M:{multiplier}) Up')
ax.plot([], [], color_down, linewidth=self.line_width,
label=f'ST (P:{period}, M:{multiplier}) Down')