import pandas as pd import numpy as np import logging from scipy.signal import find_peaks import matplotlib.dates as mdates from scipy import stats class TrendDetectorSimple: def __init__(self, data, verbose=False): """ Initialize the TrendDetectorSimple class. Parameters: - data: pandas DataFrame containing price data - verbose: boolean, whether to display detailed logging information """ self.data = data self.verbose = verbose # Configure logging logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s') self.logger = logging.getLogger('TrendDetectorSimple') # Convert data to pandas DataFrame if it's not already if not isinstance(self.data, pd.DataFrame): if isinstance(self.data, list): self.logger.info("Converting list to DataFrame") self.data = pd.DataFrame({'close': self.data}) else: self.logger.error("Invalid data format provided") raise ValueError("Data must be a pandas DataFrame or a list") self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points") def detect_trends(self): """ Detect trends by identifying local minima and maxima in the price data using scipy.signal.find_peaks. Parameters: - prominence: float, required prominence of peaks (relative to the price range) - width: int, required width of peaks in data points Returns: - DataFrame with columns for timestamps, prices, and trend indicators """ self.logger.info(f"Detecting trends") df = self.data.copy() close_prices = df['close'].values max_peaks, _ = find_peaks(close_prices) min_peaks, _ = find_peaks(-close_prices) self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima") df['is_min'] = False df['is_max'] = False for peak in max_peaks: df.at[peak, 'is_max'] = True for peak in min_peaks: df.at[peak, 'is_min'] = True result = df[['datetime', 'close', 'is_min', 'is_max']].copy() # Perform linear regression on min_peaks and max_peaks self.logger.info("Performing linear regression on min and max peaks") min_prices = df['close'].iloc[min_peaks].values max_prices = df['close'].iloc[max_peaks].values # Linear regression for min peaks if we have at least 2 points min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices) # Linear regression for max peaks if we have at least 2 points max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices) # Calculate Simple Moving Averages (SMA) for 7 and 15 periods self.logger.info("Calculating SMA-7 and SMA-15") # Calculate SMA values and exclude NaN values sma_7 = df['close'].rolling(window=7).mean().dropna().values sma_15 = df['close'].rolling(window=15).mean().dropna().values # Add SMA values to regression_results analysis_results = {} analysis_results['linear_regression'] = { 'min': { 'slope': min_slope, 'intercept': min_intercept, 'r_squared': min_r_value ** 2 }, 'max': { 'slope': max_slope, 'intercept': max_intercept, 'r_squared': max_r_value ** 2 } } analysis_results['sma'] = { '7': sma_7, '15': sma_15 } self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}") self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}") return result, analysis_results def plot_trends(self, trend_data, analysis_results): """ Plot the price data with detected trends using a candlestick chart. Parameters: - trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called. Returns: - None (displays the plot) """ import matplotlib.pyplot as plt from matplotlib.patches import Rectangle # Create the figure and axis fig, ax = plt.subplots(figsize=(12, 8)) # Create a copy of the data df = self.data.copy() # Plot candlestick chart up_color = 'green' down_color = 'red' # Draw candlesticks manually width = 0.6 x_values = range(len(df)) for i in range(len(df)): # Get OHLC values for this candle open_val = df['open'].iloc[i] close_val = df['close'].iloc[i] high_val = df['high'].iloc[i] low_val = df['low'].iloc[i] # Determine candle color color = up_color if close_val >= open_val else down_color # Plot candle body body_height = abs(close_val - open_val) bottom = min(open_val, close_val) rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8) ax.add_patch(rect) # Plot candle wicks ax.plot([i, i], [low_val, high_val], color='black', linewidth=1) min_indices = trend_data.index[trend_data['is_min'] == True].tolist() if min_indices: min_y = [df['close'].iloc[i] for i in min_indices] ax.scatter(min_indices, min_y, color='darkred', s=200, marker='^', label='Local Minima', zorder=100) max_indices = trend_data.index[trend_data['is_max'] == True].tolist() if max_indices: max_y = [df['close'].iloc[i] for i in max_indices] ax.scatter(max_indices, max_y, color='darkgreen', s=200, marker='v', label='Local Maxima', zorder=100) if analysis_results: x_vals = np.arange(len(df)) # Minima regression line (support) min_slope = analysis_results['linear_regression']['min']['slope'] min_intercept = analysis_results['linear_regression']['min']['intercept'] min_line = min_slope * x_vals + min_intercept ax.plot(x_vals, min_line, 'g--', linewidth=2, label='Minima Regression') # Maxima regression line (resistance) max_slope = analysis_results['linear_regression']['max']['slope'] max_intercept = analysis_results['linear_regression']['max']['intercept'] max_line = max_slope * x_vals + max_intercept ax.plot(x_vals, max_line, 'r--', linewidth=2, label='Maxima Regression') # SMA-7 line sma_7 = analysis_results['sma']['7'] ax.plot(x_vals, sma_7, 'y-', linewidth=2, label='SMA-7') # SMA-15 line # sma_15 = analysis_results['sma']['15'] # valid_idx_15 = ~np.isnan(sma_15) # ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], 'm-', linewidth=2, label='SMA-15') # Set title and labels ax.set_title('Price Candlestick Chart with Local Minima and Maxima', fontsize=14) ax.set_xlabel('Date', fontsize=12) ax.set_ylabel('Price', fontsize=12) # Set appropriate x-axis limits ax.set_xlim(-0.5, len(df) - 0.5) # Add a legend ax.legend(loc='best') # Adjust layout plt.tight_layout() # Show the plot plt.show()