205 lines
7.9 KiB
Python
205 lines
7.9 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
import logging
|
|
from scipy.signal import find_peaks
|
|
import matplotlib.dates as mdates
|
|
from scipy import stats
|
|
|
|
class TrendDetectorSimple:
|
|
def __init__(self, data, verbose=False):
|
|
"""
|
|
Initialize the TrendDetectorSimple class.
|
|
|
|
Parameters:
|
|
- data: pandas DataFrame containing price data
|
|
- verbose: boolean, whether to display detailed logging information
|
|
"""
|
|
|
|
self.data = data
|
|
self.verbose = verbose
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
|
format='%(asctime)s - %(levelname)s - %(message)s')
|
|
self.logger = logging.getLogger('TrendDetectorSimple')
|
|
|
|
# Convert data to pandas DataFrame if it's not already
|
|
if not isinstance(self.data, pd.DataFrame):
|
|
if isinstance(self.data, list):
|
|
self.logger.info("Converting list to DataFrame")
|
|
self.data = pd.DataFrame({'close': self.data})
|
|
else:
|
|
self.logger.error("Invalid data format provided")
|
|
raise ValueError("Data must be a pandas DataFrame or a list")
|
|
|
|
self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points")
|
|
|
|
def detect_trends(self):
|
|
"""
|
|
Detect trends by identifying local minima and maxima in the price data
|
|
using scipy.signal.find_peaks.
|
|
|
|
Parameters:
|
|
- prominence: float, required prominence of peaks (relative to the price range)
|
|
- width: int, required width of peaks in data points
|
|
|
|
Returns:
|
|
- DataFrame with columns for timestamps, prices, and trend indicators
|
|
"""
|
|
self.logger.info(f"Detecting trends")
|
|
|
|
df = self.data.copy()
|
|
close_prices = df['close'].values
|
|
|
|
max_peaks, _ = find_peaks(close_prices)
|
|
min_peaks, _ = find_peaks(-close_prices)
|
|
|
|
self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima")
|
|
|
|
df['is_min'] = False
|
|
df['is_max'] = False
|
|
|
|
for peak in max_peaks:
|
|
df.at[peak, 'is_max'] = True
|
|
for peak in min_peaks:
|
|
df.at[peak, 'is_min'] = True
|
|
|
|
result = df[['datetime', 'close', 'is_min', 'is_max']].copy()
|
|
|
|
# Perform linear regression on min_peaks and max_peaks
|
|
self.logger.info("Performing linear regression on min and max peaks")
|
|
min_prices = df['close'].iloc[min_peaks].values
|
|
max_prices = df['close'].iloc[max_peaks].values
|
|
|
|
# Linear regression for min peaks if we have at least 2 points
|
|
min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
|
|
# Linear regression for max peaks if we have at least 2 points
|
|
max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
|
|
|
|
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
|
|
self.logger.info("Calculating SMA-7 and SMA-15")
|
|
|
|
# Calculate SMA values and exclude NaN values
|
|
sma_7 = df['close'].rolling(window=7).mean().dropna().values
|
|
sma_15 = df['close'].rolling(window=15).mean().dropna().values
|
|
|
|
# Add SMA values to regression_results
|
|
analysis_results = {}
|
|
analysis_results['linear_regression'] = {
|
|
'min': {
|
|
'slope': min_slope,
|
|
'intercept': min_intercept,
|
|
'r_squared': min_r_value ** 2
|
|
},
|
|
'max': {
|
|
'slope': max_slope,
|
|
'intercept': max_intercept,
|
|
'r_squared': max_r_value ** 2
|
|
}
|
|
}
|
|
analysis_results['sma'] = {
|
|
'7': sma_7,
|
|
'15': sma_15
|
|
}
|
|
|
|
self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}")
|
|
self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}")
|
|
|
|
return result, analysis_results
|
|
|
|
def plot_trends(self, trend_data, analysis_results):
|
|
"""
|
|
Plot the price data with detected trends using a candlestick chart.
|
|
|
|
Parameters:
|
|
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
|
|
|
|
Returns:
|
|
- None (displays the plot)
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.patches import Rectangle
|
|
|
|
# Create the figure and axis
|
|
fig, ax = plt.subplots(figsize=(12, 8))
|
|
|
|
# Create a copy of the data
|
|
df = self.data.copy()
|
|
|
|
# Plot candlestick chart
|
|
up_color = 'green'
|
|
down_color = 'red'
|
|
|
|
# Draw candlesticks manually
|
|
width = 0.6
|
|
x_values = range(len(df))
|
|
|
|
for i in range(len(df)):
|
|
# Get OHLC values for this candle
|
|
open_val = df['open'].iloc[i]
|
|
close_val = df['close'].iloc[i]
|
|
high_val = df['high'].iloc[i]
|
|
low_val = df['low'].iloc[i]
|
|
|
|
# Determine candle color
|
|
color = up_color if close_val >= open_val else down_color
|
|
|
|
# Plot candle body
|
|
body_height = abs(close_val - open_val)
|
|
bottom = min(open_val, close_val)
|
|
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
|
|
ax.add_patch(rect)
|
|
|
|
# Plot candle wicks
|
|
ax.plot([i, i], [low_val, high_val], color='black', linewidth=1)
|
|
|
|
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
|
|
if min_indices:
|
|
min_y = [df['close'].iloc[i] for i in min_indices]
|
|
ax.scatter(min_indices, min_y, color='darkred', s=200, marker='^', label='Local Minima', zorder=100)
|
|
|
|
max_indices = trend_data.index[trend_data['is_max'] == True].tolist()
|
|
if max_indices:
|
|
max_y = [df['close'].iloc[i] for i in max_indices]
|
|
ax.scatter(max_indices, max_y, color='darkgreen', s=200, marker='v', label='Local Maxima', zorder=100)
|
|
|
|
if analysis_results:
|
|
x_vals = np.arange(len(df))
|
|
# Minima regression line (support)
|
|
min_slope = analysis_results['linear_regression']['min']['slope']
|
|
min_intercept = analysis_results['linear_regression']['min']['intercept']
|
|
min_line = min_slope * x_vals + min_intercept
|
|
ax.plot(x_vals, min_line, 'g--', linewidth=2, label='Minima Regression')
|
|
|
|
# Maxima regression line (resistance)
|
|
max_slope = analysis_results['linear_regression']['max']['slope']
|
|
max_intercept = analysis_results['linear_regression']['max']['intercept']
|
|
max_line = max_slope * x_vals + max_intercept
|
|
ax.plot(x_vals, max_line, 'r--', linewidth=2, label='Maxima Regression')
|
|
|
|
# SMA-7 line
|
|
sma_7 = analysis_results['sma']['7']
|
|
ax.plot(x_vals, sma_7, 'y-', linewidth=2, label='SMA-7')
|
|
|
|
# SMA-15 line
|
|
# sma_15 = analysis_results['sma']['15']
|
|
# valid_idx_15 = ~np.isnan(sma_15)
|
|
# ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], 'm-', linewidth=2, label='SMA-15')
|
|
|
|
# Set title and labels
|
|
ax.set_title('Price Candlestick Chart with Local Minima and Maxima', fontsize=14)
|
|
ax.set_xlabel('Date', fontsize=12)
|
|
ax.set_ylabel('Price', fontsize=12)
|
|
|
|
# Set appropriate x-axis limits
|
|
ax.set_xlim(-0.5, len(df) - 0.5)
|
|
|
|
# Add a legend
|
|
ax.legend(loc='best')
|
|
|
|
# Adjust layout
|
|
plt.tight_layout()
|
|
|
|
# Show the plot
|
|
plt.show()
|
|
|