Cycles/trend_detector_simple.py

255 lines
10 KiB
Python
Raw Normal View History

import pandas as pd
import numpy as np
import logging
from scipy.signal import find_peaks
import matplotlib.dates as mdates
from scipy import stats
from scipy import stats
class TrendDetectorSimple:
def __init__(self, data, verbose=False):
"""
Initialize the TrendDetectorSimple class.
Parameters:
- data: pandas DataFrame containing price data
- verbose: boolean, whether to display detailed logging information
"""
self.data = data
self.verbose = verbose
# Plot style configuration
self.plot_style = 'dark_background'
self.bg_color = '#181C27'
self.plot_size = (12, 8)
# Candlestick configuration
self.candle_width = 0.6
self.candle_up_color = '#089981'
self.candle_down_color = '#F23645'
self.candle_alpha = 0.8
self.wick_width = 1
# Marker configuration
self.min_marker = '^'
self.min_color = 'red'
self.min_size = 100
self.max_marker = 'v'
self.max_color = 'green'
self.max_size = 100
self.marker_zorder = 100
# Line configuration
self.line_width = 2
self.min_line_style = 'g--' # green dashed
self.max_line_style = 'r--' # red dashed
self.sma7_line_style = 'y-' # yellow solid
self.sma15_line_style = 'm-' # magenta solid
# Text configuration
self.title_size = 14
self.title_color = 'white'
self.axis_label_size = 12
self.axis_label_color = 'white'
# Legend configuration
self.legend_loc = 'best'
self.legend_bg_color = '#333333'
# Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger('TrendDetectorSimple')
# Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame):
if isinstance(self.data, list):
self.logger.info("Converting list to DataFrame")
self.data = pd.DataFrame({'close': self.data})
else:
self.logger.error("Invalid data format provided")
raise ValueError("Data must be a pandas DataFrame or a list")
self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points")
def detect_trends(self):
def detect_trends(self):
"""
Detect trends by identifying local minima and maxima in the price data
using scipy.signal.find_peaks.
Parameters:
- prominence: float, required prominence of peaks (relative to the price range)
- width: int, required width of peaks in data points
Returns:
- DataFrame with columns for timestamps, prices, and trend indicators
"""
self.logger.info(f"Detecting trends")
self.logger.info(f"Detecting trends")
df = self.data.copy()
close_prices = df['close'].values
max_peaks, _ = find_peaks(close_prices)
min_peaks, _ = find_peaks(-close_prices)
max_peaks, _ = find_peaks(close_prices)
min_peaks, _ = find_peaks(-close_prices)
self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima")
df['is_min'] = False
df['is_max'] = False
for peak in max_peaks:
df.at[peak, 'is_max'] = True
for peak in min_peaks:
df.at[peak, 'is_min'] = True
result = df[['datetime', 'close', 'is_min', 'is_max']].copy()
# Perform linear regression on min_peaks and max_peaks
self.logger.info("Performing linear regression on min and max peaks")
min_prices = df['close'].iloc[min_peaks].values
max_prices = df['close'].iloc[max_peaks].values
# Linear regression for min peaks if we have at least 2 points
min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
# Linear regression for max peaks if we have at least 2 points
max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
self.logger.info("Calculating SMA-7 and SMA-15")
sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values
sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values
analysis_results = {}
analysis_results['linear_regression'] = {
'min': {
'slope': min_slope,
'intercept': min_intercept,
'r_squared': min_r_value ** 2
},
'max': {
'slope': max_slope,
'intercept': max_intercept,
'r_squared': max_r_value ** 2
}
}
analysis_results['sma'] = {
'7': sma_7,
'15': sma_15
}
self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}")
self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}")
return result, analysis_results
def plot_trends(self, trend_data, analysis_results):
def plot_trends(self, trend_data, analysis_results):
"""
Plot the price data with detected trends using a candlestick chart.
Parameters:
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
Returns:
- None (displays the plot)
"""
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# Create the figure and axis with specified background
plt.style.use(self.plot_style)
fig, ax = plt.subplots(figsize=self.plot_size)
# Set the custom background color
fig.patch.set_facecolor(self.bg_color)
ax.set_facecolor(self.bg_color)
# Create a copy of the data
df = self.data.copy()
# Draw candlesticks manually
x_values = range(len(df))
for i in range(len(df)):
# Get OHLC values for this candle
open_val = df['open'].iloc[i]
close_val = df['close'].iloc[i]
high_val = df['high'].iloc[i]
low_val = df['low'].iloc[i]
# Determine candle color
color = self.candle_up_color if close_val >= open_val else self.candle_down_color
# Plot candle body
body_height = abs(close_val - open_val)
bottom = min(open_val, close_val)
rect = Rectangle((i - self.candle_width/2, bottom), self.candle_width, body_height,
color=color, alpha=self.candle_alpha)
ax.add_patch(rect)
# Plot candle wicks
ax.plot([i, i], [low_val, high_val], color=color, linewidth=self.wick_width)
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
if min_indices:
min_y = [df['close'].iloc[i] for i in min_indices]
ax.scatter(min_indices, min_y, color=self.min_color, s=self.min_size,
marker=self.min_marker, label='Local Minima', zorder=self.marker_zorder)
max_indices = trend_data.index[trend_data['is_max'] == True].tolist()
if max_indices:
max_y = [df['close'].iloc[i] for i in max_indices]
ax.scatter(max_indices, max_y, color=self.max_color, s=self.max_size,
marker=self.max_marker, label='Local Maxima', zorder=self.marker_zorder)
if analysis_results:
x_vals = np.arange(len(df))
# Minima regression line (support)
min_slope = analysis_results['linear_regression']['min']['slope']
min_intercept = analysis_results['linear_regression']['min']['intercept']
min_line = min_slope * x_vals + min_intercept
ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width,
label='Minima Regression')
# Maxima regression line (resistance)
max_slope = analysis_results['linear_regression']['max']['slope']
max_intercept = analysis_results['linear_regression']['max']['intercept']
max_line = max_slope * x_vals + max_intercept
ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width,
label='Maxima Regression')
# SMA-7 line
sma_7 = analysis_results['sma']['7']
ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width,
label='SMA-7')
# SMA-15 line
sma_15 = analysis_results['sma']['15']
valid_idx_15 = ~np.isnan(sma_15)
ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style,
linewidth=self.line_width, label='SMA-15')
# Set title and labels
ax.set_title('Price Candlestick Chart with Local Minima and Maxima',
fontsize=self.title_size, color=self.title_color)
ax.set_xlabel('Date', fontsize=self.axis_label_size, color=self.axis_label_color)
ax.set_ylabel('Price', fontsize=self.axis_label_size, color=self.axis_label_color)
# Set appropriate x-axis limits
ax.set_xlim(-0.5, len(df) - 0.5)
# Add a legend
ax.legend(loc=self.legend_loc, facecolor=self.legend_bg_color)
# Adjust layout
plt.tight_layout()
# Show the plot
plt.show()
2025-05-06 15:24:36 +08:00