288 lines
12 KiB
Python
288 lines
12 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
import ta
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.dates as mdates
|
|
import logging
|
|
import mplfinance as mpf
|
|
from matplotlib.patches import Rectangle
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
import concurrent.futures
|
|
|
|
class TrendDetectorMACD:
|
|
def __init__(self, data, verbose=False):
|
|
self.data = data
|
|
self.verbose = verbose
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
|
format='%(asctime)s - %(levelname)s - %(message)s')
|
|
self.logger = logging.getLogger('TrendDetector')
|
|
|
|
# Convert data to pandas DataFrame if it's not already
|
|
if not isinstance(self.data, pd.DataFrame):
|
|
if isinstance(self.data, list):
|
|
self.logger.info("Converting list to DataFrame")
|
|
self.data = pd.DataFrame({'close': self.data})
|
|
else:
|
|
self.logger.error("Invalid data format provided")
|
|
raise ValueError("Data must be a pandas DataFrame or a list")
|
|
|
|
def detect_trends_MACD_signal(self):
|
|
self.logger.info("Starting trend detection")
|
|
if len(self.data) < 3:
|
|
self.logger.warning("Not enough data points for trend detection")
|
|
return {"error": "Not enough data points for trend detection"}
|
|
|
|
# Create a copy of the DataFrame to avoid modifying the original
|
|
df = self.data.copy()
|
|
self.logger.info("Created copy of input data")
|
|
|
|
# If 'close' column doesn't exist, try to use a relevant column
|
|
if 'close' not in df.columns and len(df.columns) > 0:
|
|
self.logger.info(f"'close' column not found, using {df.columns[0]} instead")
|
|
df['close'] = df[df.columns[0]] # Use the first column as 'close'
|
|
|
|
# Add trend indicators
|
|
self.logger.info("Calculating MACD indicators")
|
|
# Moving Average Convergence Divergence (MACD)
|
|
df['macd'] = ta.trend.macd(df['close'])
|
|
df['macd_signal'] = ta.trend.macd_signal(df['close'])
|
|
df['macd_diff'] = ta.trend.macd_diff(df['close'])
|
|
|
|
# Directional Movement Index (DMI)
|
|
if all(col in df.columns for col in ['high', 'low', 'close']):
|
|
self.logger.info("Calculating ADX indicators")
|
|
df['adx'] = ta.trend.adx(df['high'], df['low'], df['close'])
|
|
df['adx_pos'] = ta.trend.adx_pos(df['high'], df['low'], df['close'])
|
|
df['adx_neg'] = ta.trend.adx_neg(df['high'], df['low'], df['close'])
|
|
|
|
# Identify trend changes
|
|
self.logger.info("Identifying trend changes")
|
|
df['trend'] = np.where(df['macd'] > df['macd_signal'], 'up', 'down')
|
|
df['trend_change'] = df['trend'] != df['trend'].shift(1)
|
|
|
|
# Generate trend segments
|
|
self.logger.info("Generating trend segments")
|
|
trends = []
|
|
trend_start = 0
|
|
|
|
for i in range(1, len(df)):
|
|
|
|
if df['trend_change'].iloc[i]:
|
|
if i > trend_start:
|
|
trends.append({
|
|
"type": df['trend'].iloc[i-1],
|
|
"start_index": trend_start,
|
|
"end_index": i-1,
|
|
"start_value": df['close'].iloc[trend_start],
|
|
"end_value": df['close'].iloc[i-1]
|
|
})
|
|
trend_start = i
|
|
|
|
# Add the last trend
|
|
if trend_start < len(df):
|
|
trends.append({
|
|
"type": df['trend'].iloc[-1],
|
|
"start_index": trend_start,
|
|
"end_index": len(df)-1,
|
|
"start_value": df['close'].iloc[trend_start],
|
|
"end_value": df['close'].iloc[-1]
|
|
})
|
|
|
|
self.logger.info(f"Detected {len(trends)} trend segments")
|
|
return trends
|
|
|
|
def get_strongest_trend(self):
|
|
self.logger.info("Finding strongest trend")
|
|
trends = self.detect_trends_MACD_signal()
|
|
if isinstance(trends, dict) and "error" in trends:
|
|
self.logger.warning(f"Error in trend detection: {trends['error']}")
|
|
return trends
|
|
|
|
if not trends:
|
|
self.logger.info("No significant trends detected")
|
|
return {"message": "No significant trends detected"}
|
|
|
|
strongest = max(trends, key=lambda x: abs(x["end_value"] - x["start_value"]))
|
|
self.logger.info(f"Strongest trend: {strongest['type']} from index {strongest['start_index']} to {strongest['end_index']}")
|
|
return strongest
|
|
|
|
def plot_trends(self, trends):
|
|
"""
|
|
Plot price data with identified trends highlighted using candlestick charts.
|
|
"""
|
|
self.logger.info("Plotting trends with candlesticks")
|
|
if isinstance(trends, dict) and "error" in trends:
|
|
self.logger.error(trends["error"])
|
|
print(trends["error"])
|
|
return
|
|
|
|
if not trends:
|
|
self.logger.warning("No significant trends detected for plotting")
|
|
print("No significant trends detected")
|
|
return
|
|
|
|
# Create a figure with 2 subplots that share the x-axis
|
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
|
|
self.logger.info("Creating plot figure with shared x-axis")
|
|
|
|
# Prepare data for candlestick chart
|
|
df = self.data.copy()
|
|
|
|
# Ensure required columns exist for candlestick
|
|
required_cols = ['open', 'high', 'low', 'close']
|
|
if not all(col in df.columns for col in required_cols):
|
|
self.logger.warning("Missing required columns for candlestick. Defaulting to line chart.")
|
|
if 'close' in df.columns:
|
|
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
|
|
df['close'], color='black', alpha=0.7, linewidth=1, label='Price')
|
|
else:
|
|
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
|
|
df[df.columns[0]], color='black', alpha=0.7, linewidth=1, label='Price')
|
|
else:
|
|
# Get x values (dates if available, otherwise indices)
|
|
if 'datetime' in df.columns:
|
|
x_label = 'Date'
|
|
# Format date axis
|
|
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
|
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
|
fig.autofmt_xdate()
|
|
self.logger.info("Using datetime for x-axis")
|
|
|
|
# For candlestick, ensure datetime is the index
|
|
if df.index.name != 'datetime':
|
|
df = df.set_index('datetime')
|
|
else:
|
|
x_label = 'Index'
|
|
self.logger.info("Using index for x-axis")
|
|
|
|
# Plot candlestick chart
|
|
up_color = 'green'
|
|
down_color = 'red'
|
|
|
|
# Draw candlesticks manually
|
|
width = 0.6
|
|
for i in range(len(df)):
|
|
# Get OHLC values for this candle
|
|
open_val = df['open'].iloc[i]
|
|
close_val = df['close'].iloc[i]
|
|
high_val = df['high'].iloc[i]
|
|
low_val = df['low'].iloc[i]
|
|
idx = df.index[i]
|
|
|
|
# Determine candle color
|
|
color = up_color if close_val >= open_val else down_color
|
|
|
|
# Plot candle body
|
|
body_height = abs(close_val - open_val)
|
|
bottom = min(open_val, close_val)
|
|
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
|
|
ax1.add_patch(rect)
|
|
|
|
# Plot candle wicks
|
|
ax1.plot([i, i], [low_val, high_val], color='black', linewidth=1)
|
|
|
|
# Set appropriate x-axis limits
|
|
ax1.set_xlim(-0.5, len(df) - 0.5)
|
|
|
|
# Highlight each trend with a different color
|
|
self.logger.info("Highlighting trends on plot")
|
|
for trend in trends:
|
|
start_idx = trend['start_index']
|
|
end_idx = trend['end_index']
|
|
trend_type = trend['type']
|
|
|
|
# Get x-coordinates for trend plotting
|
|
x_start = start_idx
|
|
x_end = end_idx
|
|
|
|
# Get y-coordinates for trend line
|
|
if 'close' in df.columns:
|
|
y_start = df['close'].iloc[start_idx]
|
|
y_end = df['close'].iloc[end_idx]
|
|
else:
|
|
y_start = df[df.columns[0]].iloc[start_idx]
|
|
y_end = df[df.columns[0]].iloc[end_idx]
|
|
|
|
# Choose color based on trend type
|
|
color = 'green' if trend_type == 'up' else 'red'
|
|
|
|
# Plot trend line
|
|
ax1.plot([x_start, x_end], [y_start, y_end], color=color, linewidth=2,
|
|
label=f"{trend_type.capitalize()} Trend" if f"{trend_type.capitalize()} Trend" not in ax1.get_legend_handles_labels()[1] else "")
|
|
|
|
# Add markers at start and end points
|
|
ax1.scatter(x_start, y_start, color=color, marker='o', s=50)
|
|
ax1.scatter(x_end, y_end, color=color, marker='s', s=50)
|
|
|
|
# Configure first subplot
|
|
ax1.set_title('Price with Trends (Candlestick)', fontsize=16)
|
|
ax1.set_ylabel('Price', fontsize=14)
|
|
ax1.grid(alpha=0.3)
|
|
ax1.legend()
|
|
|
|
# Create MACD in second subplot
|
|
self.logger.info("Creating MACD subplot")
|
|
|
|
# Calculate MACD indicators if not already present
|
|
if 'macd' not in df.columns:
|
|
if 'close' not in df.columns and len(df.columns) > 0:
|
|
df['close'] = df[df.columns[0]]
|
|
|
|
df['macd'] = ta.trend.macd(df['close'])
|
|
df['macd_signal'] = ta.trend.macd_signal(df['close'])
|
|
df['macd_diff'] = ta.trend.macd_diff(df['close'])
|
|
|
|
# Plot MACD components on second subplot
|
|
x_indices = np.arange(len(df))
|
|
ax2.plot(x_indices, df['macd'], label='MACD', color='blue')
|
|
ax2.plot(x_indices, df['macd_signal'], label='Signal', color='orange')
|
|
|
|
# Plot MACD histogram
|
|
for i in range(len(df)):
|
|
if df['macd_diff'].iloc[i] >= 0:
|
|
ax2.bar(i, df['macd_diff'].iloc[i], color='green', alpha=0.5, width=0.8)
|
|
else:
|
|
ax2.bar(i, df['macd_diff'].iloc[i], color='red', alpha=0.5, width=0.8)
|
|
|
|
ax2.set_title('MACD Indicator', fontsize=16)
|
|
ax2.set_xlabel(x_label, fontsize=14)
|
|
ax2.set_ylabel('MACD', fontsize=14)
|
|
ax2.grid(alpha=0.3)
|
|
ax2.legend()
|
|
|
|
# Enable synchronized zooming
|
|
plt.tight_layout()
|
|
plt.subplots_adjust(hspace=0.1)
|
|
plt.show()
|
|
|
|
return plt
|
|
|
|
def _calculate_supertrend_indicators(self):
|
|
"""
|
|
Calculate SuperTrend indicators with different parameter sets in parallel.
|
|
|
|
Returns:
|
|
- list, the SuperTrend results
|
|
"""
|
|
supertrend_params = [
|
|
{"period": 12, "multiplier": 3.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
|
|
{"period": 10, "multiplier": 1.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
|
|
{"period": 11, "multiplier": 2.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN}
|
|
]
|
|
|
|
def run_supertrend(params):
|
|
# Each thread gets its own copy of the data to avoid race conditions
|
|
return {
|
|
"results": self.calculate_supertrend(
|
|
period=params["period"],
|
|
multiplier=params["multiplier"]
|
|
),
|
|
"params": params
|
|
}
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
results = list(executor.map(run_supertrend, supertrend_params))
|
|
|
|
return results
|