From f316571a3c08a444172d17cf835fda6e7029bf07 Mon Sep 17 00:00:00 2001 From: Simon Moisy Date: Fri, 9 May 2025 15:17:30 +0800 Subject: [PATCH] Update date filtering in main.py and enhance TrendDetectorSimple with SuperTrend calculations and improved plotting functionality. Introduce color configurations for better visualization and streamline trend analysis methods. --- main.py | 5 +- trend_detector_simple.py | 539 +++++++++++++++++++++++++++++++++------ 2 files changed, 470 insertions(+), 74 deletions(-) diff --git a/main.py b/main.py index 06a343b..0e54834 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,7 @@ data = pd.read_csv('data/btcusd_1-day_data.csv') # Convert datetime column to datetime type -start_date = pd.to_datetime('2025-04-01') +start_date = pd.to_datetime('2024-04-06') stop_date = pd.to_datetime('2025-05-06') daily_data = data[(pd.to_datetime(data['datetime']) >= start_date) & @@ -17,13 +17,12 @@ print(f"Number of data points: {len(daily_data)}") trend_detector = TrendDetectorSimple(daily_data, verbose=True) trends, analysis_results = trend_detector.detect_trends() -trend_detector.plot_trends(trends, analysis_results) +trend_detector.plot_trends(trends, analysis_results, "supertrend") #trend_detector = TrendDetectorMACD(daily_data, True) #trends = trend_detector.detect_trends_MACD_signal() #trend_detector.plot_trends(trends) - # # Cycle detection (new code) # print("\n===== CYCLE DETECTION =====") diff --git a/trend_detector_simple.py b/trend_detector_simple.py index efe77b1..bf1ff60 100644 --- a/trend_detector_simple.py +++ b/trend_detector_simple.py @@ -2,10 +2,35 @@ import pandas as pd import numpy as np import logging from scipy.signal import find_peaks -import matplotlib.dates as mdates +from matplotlib.patches import Rectangle from scipy import stats from scipy import stats +# Color configuration +# Plot colors +DARK_BG_COLOR = '#181C27' +LEGEND_BG_COLOR = '#333333' +TITLE_COLOR = 'white' +AXIS_LABEL_COLOR = 'white' + +# Candlestick colors +CANDLE_UP_COLOR = '#089981' # Green +CANDLE_DOWN_COLOR = '#F23645' # Red + +# Marker colors +MIN_COLOR = 'red' +MAX_COLOR = 'green' + +# Line style colors +MIN_LINE_STYLE = 'g--' # Green dashed +MAX_LINE_STYLE = 'r--' # Red dashed +SMA7_LINE_STYLE = 'y-' # Yellow solid +SMA15_LINE_STYLE = 'm-' # Magenta solid + +# SuperTrend colors +ST_COLOR_UP = 'g-' +ST_COLOR_DOWN = 'r-' + class TrendDetectorSimple: def __init__(self, data, verbose=False): """ @@ -21,41 +46,41 @@ class TrendDetectorSimple: # Plot style configuration self.plot_style = 'dark_background' - self.bg_color = '#181C27' + self.bg_color = DARK_BG_COLOR self.plot_size = (12, 8) # Candlestick configuration self.candle_width = 0.6 - self.candle_up_color = '#089981' - self.candle_down_color = '#F23645' + self.candle_up_color = CANDLE_UP_COLOR + self.candle_down_color = CANDLE_DOWN_COLOR self.candle_alpha = 0.8 self.wick_width = 1 # Marker configuration self.min_marker = '^' - self.min_color = 'red' + self.min_color = MIN_COLOR self.min_size = 100 self.max_marker = 'v' - self.max_color = 'green' + self.max_color = MAX_COLOR self.max_size = 100 self.marker_zorder = 100 # Line configuration - self.line_width = 2 - self.min_line_style = 'g--' # green dashed - self.max_line_style = 'r--' # red dashed - self.sma7_line_style = 'y-' # yellow solid - self.sma15_line_style = 'm-' # magenta solid + self.line_width = 1 + self.min_line_style = MIN_LINE_STYLE + self.max_line_style = MAX_LINE_STYLE + self.sma7_line_style = SMA7_LINE_STYLE + self.sma15_line_style = SMA15_LINE_STYLE # Text configuration self.title_size = 14 - self.title_color = 'white' + self.title_color = TITLE_COLOR self.axis_label_size = 12 - self.axis_label_color = 'white' + self.axis_label_color = AXIS_LABEL_COLOR # Legend configuration self.legend_loc = 'best' - self.legend_bg_color = '#333333' + self.legend_bg_color = LEGEND_BG_COLOR # Configure logging logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, @@ -73,6 +98,66 @@ class TrendDetectorSimple: self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points") + def calculate_tr(self): + """ + Calculate True Range (TR) for the price data. + + True Range is the greatest of: + 1. Current high - current low + 2. |Current high - previous close| + 3. |Current low - previous close| + + Returns: + - Numpy array of TR values + """ + df = self.data.copy() + high = df['high'].values + low = df['low'].values + close = df['close'].values + + tr = np.zeros_like(close) + tr[0] = high[0] - low[0] # First TR is just the first day's range + + for i in range(1, len(close)): + # Current high - current low + hl_range = high[i] - low[i] + # |Current high - previous close| + hc_range = abs(high[i] - close[i-1]) + # |Current low - previous close| + lc_range = abs(low[i] - close[i-1]) + + # TR is the maximum of these three values + tr[i] = max(hl_range, hc_range, lc_range) + + return tr + + def calculate_atr(self, period=14): + """ + Calculate Average True Range (ATR) for the price data. + + ATR is the exponential moving average of the True Range over a specified period. + + Parameters: + - period: int, the period for the ATR calculation (default: 14) + + Returns: + - Numpy array of ATR values + """ + + tr = self.calculate_tr() + atr = np.zeros_like(tr) + + # First ATR value is just the first TR + atr[0] = tr[0] + + # Calculate exponential moving average (EMA) of TR + multiplier = 2.0 / (period + 1) + + for i in range(1, len(tr)): + atr[i] = (tr[i] * multiplier) + (atr[i-1] * (1 - multiplier)) + + return atr + def detect_trends(self): """ Detect trends by identifying local minima and maxima in the price data @@ -84,15 +169,13 @@ class TrendDetectorSimple: Returns: - DataFrame with columns for timestamps, prices, and trend indicators + - Dictionary containing analysis results including linear regression, SMAs, and SuperTrend indicators """ self.logger.info(f"Detecting trends") - self.logger.info(f"Detecting trends") df = self.data.copy() close_prices = df['close'].values - max_peaks, _ = find_peaks(close_prices) - min_peaks, _ = find_peaks(-close_prices) max_peaks, _ = find_peaks(close_prices) min_peaks, _ = find_peaks(-close_prices) @@ -118,9 +201,7 @@ class TrendDetectorSimple: # Linear regression for max peaks if we have at least 2 points max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices) - # Calculate Simple Moving Averages (SMA) for 7 and 15 periods - self.logger.info("Calculating SMA-7 and SMA-15") - + # Calculate Simple Moving Averages (SMA) for 7 and 15 periods sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values @@ -142,37 +223,182 @@ class TrendDetectorSimple: '15': sma_15 } - self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}") - self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}") - + # Calculate SuperTrend indicators + supertrend_results_list = self._calculate_supertrend_indicators() + analysis_results['supertrend'] = supertrend_results_list + return result, analysis_results - - def plot_trends(self, trend_data, analysis_results): + + def _calculate_supertrend_indicators(self): """ - Plot the price data with detected trends using a candlestick chart. + Calculate SuperTrend indicators with different parameter sets. + + Returns: + - list, the SuperTrend results + """ + supertrend_params = [ + {"period": 12, "multiplier": 3.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN}, + {"period": 10, "multiplier": 1.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN}, + {"period": 11, "multiplier": 2.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN} + ] + + supertrend_results_list = [] + for params in supertrend_params: + supertrend_results = self.calculate_supertrend( + period=params["period"], + multiplier=params["multiplier"] + ) + supertrend_results_list.append({ + "results": supertrend_results, + "params": params + }) + + return supertrend_results_list + + def calculate_supertrend(self, period, multiplier): + """ + Calculate SuperTrend indicator for the price data. + + SuperTrend is a trend-following indicator that uses ATR to determine the trend direction. Parameters: - - trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called. + - period: int, the period for the ATR calculation (default: 10) + - multiplier: float, the multiplier for the ATR (default: 3.0) + Returns: + - Dictionary containing SuperTrend values, trend direction, and upper/lower bands + """ + df = self.data.copy() + high = df['high'].values + low = df['low'].values + close = df['close'].values + + # Calculate ATR + atr = self.calculate_atr(period) + + # Calculate basic upper and lower bands + upper_band = np.zeros_like(close) + lower_band = np.zeros_like(close) + + for i in range(len(close)): + # Calculate the basic bands + hl_avg = (high[i] + low[i]) / 2 + upper_band[i] = hl_avg + (multiplier * atr[i]) + lower_band[i] = hl_avg - (multiplier * atr[i]) + + # Calculate final upper and lower bands with trend logic + final_upper = np.zeros_like(close) + final_lower = np.zeros_like(close) + supertrend = np.zeros_like(close) + trend = np.zeros_like(close) # 1 for uptrend, -1 for downtrend + + # Initialize first values + final_upper[0] = upper_band[0] + final_lower[0] = lower_band[0] + + # If close price is above upper band, we're in a downtrend (ST = upper band) + # If close price is below lower band, we're in an uptrend (ST = lower band) + if close[0] <= upper_band[0]: + supertrend[0] = upper_band[0] + trend[0] = -1 # Downtrend + else: + supertrend[0] = lower_band[0] + trend[0] = 1 # Uptrend + + # Calculate SuperTrend for the rest of the data + for i in range(1, len(close)): + # Calculate final upper band + if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]): + final_upper[i] = upper_band[i] + else: + final_upper[i] = final_upper[i-1] + + # Calculate final lower band + if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]): + final_lower[i] = lower_band[i] + else: + final_lower[i] = final_lower[i-1] + + # Determine trend and SuperTrend value + if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]: + # Continuing downtrend + supertrend[i] = final_upper[i] + trend[i] = -1 + elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]: + # Switching to uptrend + supertrend[i] = final_lower[i] + trend[i] = 1 + elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]: + # Continuing uptrend + supertrend[i] = final_lower[i] + trend[i] = 1 + elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]: + # Switching to downtrend + supertrend[i] = final_upper[i] + trend[i] = -1 + + # Prepare result + supertrend_results = { + 'supertrend': supertrend, + 'trend': trend, + 'upper_band': final_upper, + 'lower_band': final_lower + } + + return supertrend_results + + def plot_trends(self, trend_data, analysis_results, view="both"): + """ + Plot the price data with detected trends using a candlestick chart. + Also plots SuperTrend indicators with three different parameter sets. + + Parameters: + - trend_data: DataFrame, the output from detect_trends() + - analysis_results: Dictionary containing analysis results from detect_trends() + - view: str, one of 'both', 'trend', 'supertrend'; determines which plot(s) to display + Returns: - None (displays the plot) """ import matplotlib.pyplot as plt from matplotlib.patches import Rectangle - - # Create the figure and axis with specified background + plt.style.use(self.plot_style) - fig, ax = plt.subplots(figsize=self.plot_size) - - # Set the custom background color + + if view == "both": + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(self.plot_size[0]*2, self.plot_size[1])) + else: + fig, ax = plt.subplots(figsize=self.plot_size) + ax1 = ax2 = None + if view == "trend": + ax1 = ax + elif view == "supertrend": + ax2 = ax + fig.patch.set_facecolor(self.bg_color) - ax.set_facecolor(self.bg_color) - - # Create a copy of the data + if ax1: ax1.set_facecolor(self.bg_color) + if ax2: ax2.set_facecolor(self.bg_color) + df = self.data.copy() + + if ax1: + self._plot_trend_analysis(ax1, df, trend_data, analysis_results) + + if ax2: + self._plot_supertrend_analysis(ax2, df, analysis_results['supertrend']) + + plt.tight_layout() + plt.show() + + def _plot_candlesticks(self, ax, df): + """ + Plot candlesticks on the given axis. - # Draw candlesticks manually - x_values = range(len(df)) + Parameters: + - ax: matplotlib.axes.Axes, the axis to plot on + - df: pandas.DataFrame, the data to plot + """ + from matplotlib.patches import Rectangle for i in range(len(df)): # Get OHLC values for this candle @@ -193,7 +419,39 @@ class TrendDetectorSimple: # Plot candle wicks ax.plot([i, i], [low_val, high_val], color=color, linewidth=self.wick_width) + + def _plot_trend_analysis(self, ax, df, trend_data, analysis_results): + """ + Plot trend analysis on the given axis. + Parameters: + - ax: matplotlib.axes.Axes, the axis to plot on + - df: pandas.DataFrame, the data to plot + - trend_data: pandas.DataFrame, the trend data + - analysis_results: dict, the analysis results + """ + # Draw candlesticks + self._plot_candlesticks(ax, df) + + # Plot minima and maxima points + self._plot_min_max_points(ax, df, trend_data) + + # Plot trend lines and moving averages + if analysis_results: + self._plot_trend_lines(ax, df, analysis_results) + + # Configure the subplot + self._configure_subplot(ax, 'Price Chart with Trend Analysis', len(df)) + + def _plot_min_max_points(self, ax, df, trend_data): + """ + Plot minimum and maximum points on the given axis. + + Parameters: + - ax: matplotlib.axes.Axes, the axis to plot on + - df: pandas.DataFrame, the data to plot + - trend_data: pandas.DataFrame, the trend data + """ min_indices = trend_data.index[trend_data['is_min'] == True].tolist() if min_indices: min_y = [df['close'].iloc[i] for i in min_indices] @@ -205,49 +463,188 @@ class TrendDetectorSimple: max_y = [df['close'].iloc[i] for i in max_indices] ax.scatter(max_indices, max_y, color=self.max_color, s=self.max_size, marker=self.max_marker, label='Local Maxima', zorder=self.marker_zorder) + + def _plot_trend_lines(self, ax, df, analysis_results): + """ + Plot trend lines on the given axis. - if analysis_results: - x_vals = np.arange(len(df)) - # Minima regression line (support) - min_slope = analysis_results['linear_regression']['min']['slope'] - min_intercept = analysis_results['linear_regression']['min']['intercept'] - min_line = min_slope * x_vals + min_intercept - ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width, - label='Minima Regression') - - # Maxima regression line (resistance) - max_slope = analysis_results['linear_regression']['max']['slope'] - max_intercept = analysis_results['linear_regression']['max']['intercept'] - max_line = max_slope * x_vals + max_intercept - ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width, - label='Maxima Regression') - - # SMA-7 line - sma_7 = analysis_results['sma']['7'] - ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width, - label='SMA-7') - - # SMA-15 line - sma_15 = analysis_results['sma']['15'] - valid_idx_15 = ~np.isnan(sma_15) - ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style, - linewidth=self.line_width, label='SMA-15') + Parameters: + - ax: matplotlib.axes.Axes, the axis to plot on + - df: pandas.DataFrame, the data to plot + - analysis_results: dict, the analysis results + """ + x_vals = np.arange(len(df)) + # Minima regression line (support) + min_slope = analysis_results['linear_regression']['min']['slope'] + min_intercept = analysis_results['linear_regression']['min']['intercept'] + min_line = min_slope * x_vals + min_intercept + ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width, + label='Minima Regression') + + # Maxima regression line (resistance) + max_slope = analysis_results['linear_regression']['max']['slope'] + max_intercept = analysis_results['linear_regression']['max']['intercept'] + max_line = max_slope * x_vals + max_intercept + ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width, + label='Maxima Regression') + + # SMA-7 line + sma_7 = analysis_results['sma']['7'] + ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width, + label='SMA-7') + + # SMA-15 line + sma_15 = analysis_results['sma']['15'] + valid_idx_15 = ~np.isnan(sma_15) + ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style, + linewidth=self.line_width, label='SMA-15') + + def _configure_subplot(self, ax, title, data_length): + """ + Configure the subplot with title, labels, limits, and legend. + + Parameters: + - ax: matplotlib.axes.Axes, the axis to configure + - title: str, the title of the subplot + - data_length: int, the length of the data + """ # Set title and labels - ax.set_title('Price Candlestick Chart with Local Minima and Maxima', - fontsize=self.title_size, color=self.title_color) + ax.set_title(title, fontsize=self.title_size, color=self.title_color) ax.set_xlabel('Date', fontsize=self.axis_label_size, color=self.axis_label_color) ax.set_ylabel('Price', fontsize=self.axis_label_size, color=self.axis_label_color) # Set appropriate x-axis limits - ax.set_xlim(-0.5, len(df) - 0.5) + ax.set_xlim(-0.5, data_length - 0.5) # Add a legend ax.legend(loc=self.legend_loc, facecolor=self.legend_bg_color) + + def _plot_supertrend_analysis(self, ax, df, supertrend_results_list=None): + """ + Plot SuperTrend analysis on the given axis. - # Adjust layout - plt.tight_layout() + Parameters: + - ax: matplotlib.axes.Axes, the axis to plot on + - df: pandas.DataFrame, the data to plot + - supertrend_results_list: list, the SuperTrend results (optional) + """ + self._plot_candlesticks(ax, df) + self._plot_supertrend_lines(ax, df, supertrend_results_list, style='Both') + self._configure_subplot(ax, 'Multiple SuperTrend Indicators', len(df)) - # Show the plot - plt.show() - \ No newline at end of file + def _plot_supertrend_lines(self, ax, df, supertrend_results_list, style="Horizontal"): + """ + Plot SuperTrend lines on the given axis. + + Parameters: + - ax: matplotlib.axes.Axes, the axis to plot on + - df: pandas.DataFrame, the data to plot + - supertrend_results_list: list, the SuperTrend results + """ + x_vals = np.arange(len(df)) + + if style == 'Horizontal' or style == 'Both': + if len(supertrend_results_list) != 3: + raise ValueError("Expected exactly 3 SuperTrend results for meta calculation") + + trends = [st["results"]["trend"] for st in supertrend_results_list] + + band_height = 0.02 * (df["high"].max() - df["low"].min()) + y_base = df["low"].min() - band_height * 1.5 + + prev_color = None + for i in range(1, len(x_vals)): + t_vals = [t[i] for t in trends] + up_count = t_vals.count(1) + down_count = t_vals.count(-1) + + if down_count == 3: + color = "red" + elif down_count == 2 and up_count == 1: + color = "orange" + elif down_count == 1 and up_count == 2: + color = "yellow" + elif up_count == 3: + color = "green" + else: + continue # skip if unknown or inconsistent values + + ax.add_patch(Rectangle( + (x_vals[i-1], y_base), + 1, + band_height, + color=color, + linewidth=0, + alpha=0.6 + )) + # Draw a vertical line at the change of color + if prev_color and prev_color != color: + ax.axvline(x_vals[i-1], color="grey", alpha=0.3, linewidth=1) + prev_color = color + + ax.set_ylim(bottom=y_base - band_height * 0.5) + if style == 'Curves' or style == 'Both': + for st in supertrend_results_list: + params = st["params"] + results = st["results"] + supertrend = results["supertrend"] + trend = results["trend"] + + # Plot SuperTrend line with color based on trend + for i in range(1, len(x_vals)): + if trend[i] == 1: # Uptrend + ax.plot(x_vals[i-1:i+1], supertrend[i-1:i+1], params["color_up"], linewidth=self.line_width) + else: # Downtrend + ax.plot(x_vals[i-1:i+1], supertrend[i-1:i+1], params["color_down"], linewidth=self.line_width) + self._plot_metasupertrend_lines(ax, df, supertrend_results_list) + self._add_supertrend_legend(ax, supertrend_results_list) + + def _plot_metasupertrend_lines(self, ax, df, supertrend_results_list): + """ + Plot a Meta SuperTrend line where all individual SuperTrends agree on trend. + + Parameters: + - ax: matplotlib.axes.Axes, the axis to plot on + - df: pandas.DataFrame, the data to plot + - supertrend_results_list: list, each item contains SuperTrend 'results' and 'params' + """ + x_vals = np.arange(len(df)) + + if len(supertrend_results_list) != 3: + raise ValueError("Expected exactly 3 SuperTrend results for meta calculation") + + trends = [st["results"]["trend"] for st in supertrend_results_list] + supertrends = [st["results"]["supertrend"] for st in supertrend_results_list] + params = supertrend_results_list[0]["params"] # Use first config for styling + + for i in range(1, len(x_vals)): + t1, t2, t3 = trends[0][i], trends[1][i], trends[2][i] + if t1 == t2 == t3: + meta_trend = t1 + # Average the 3 supertrend values + st_avg_prev = np.mean([s[i-1] for s in supertrends]) + st_avg_curr = np.mean([s[i] for s in supertrends]) + color = params["color_up"] if meta_trend == 1 else params["color_down"] + ax.plot(x_vals[i-1:i+1], [st_avg_prev, st_avg_curr], color, linewidth=self.line_width) + + def _add_supertrend_legend(self, ax, supertrend_results_list): + """ + Add SuperTrend legend entries to the given axis. + + Parameters: + - ax: matplotlib.axes.Axes, the axis to add legend entries to + - supertrend_results_list: list, the SuperTrend results + """ + for st in supertrend_results_list: + params = st["params"] + period = params["period"] + multiplier = params["multiplier"] + color_up = params["color_up"] + color_down = params["color_down"] + + ax.plot([], [], color_up, linewidth=self.line_width, + label=f'ST (P:{period}, M:{multiplier}) Up') + ax.plot([], [], color_down, linewidth=self.line_width, + label=f'ST (P:{period}, M:{multiplier}) Down') + \ No newline at end of file