Refactor cycle detection and trend analysis; enhance trend detection with linear regression and moving averages. Update main script for improved data handling and visualization.

This commit is contained in:
Simon Moisy 2025-05-09 12:23:45 +08:00
parent cbc6a7493d
commit e9bfcd03eb
4 changed files with 762 additions and 711 deletions

View File

@ -1,248 +1,248 @@
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.signal import argrelextrema from scipy.signal import argrelextrema
class CycleDetector: class CycleDetector:
def __init__(self, data, timeframe='daily'): def __init__(self, data, timeframe='daily'):
""" """
Initialize the CycleDetector with price data. Initialize the CycleDetector with price data.
Parameters: Parameters:
- data: DataFrame with at least 'date' or 'datetime' and 'close' columns - data: DataFrame with at least 'date' or 'datetime' and 'close' columns
- timeframe: 'daily', 'weekly', or 'monthly' - timeframe: 'daily', 'weekly', or 'monthly'
""" """
self.data = data.copy() self.data = data.copy()
self.timeframe = timeframe self.timeframe = timeframe
# Ensure we have a consistent date column name # Ensure we have a consistent date column name
if 'datetime' in self.data.columns and 'date' not in self.data.columns: if 'datetime' in self.data.columns and 'date' not in self.data.columns:
self.data.rename(columns={'datetime': 'date'}, inplace=True) self.data.rename(columns={'datetime': 'date'}, inplace=True)
# Convert data to specified timeframe if needed # Convert data to specified timeframe if needed
if timeframe == 'weekly' and 'date' in self.data.columns: if timeframe == 'weekly' and 'date' in self.data.columns:
self.data = self._convert_data(self.data, 'W') self.data = self._convert_data(self.data, 'W')
elif timeframe == 'monthly' and 'date' in self.data.columns: elif timeframe == 'monthly' and 'date' in self.data.columns:
self.data = self._convert_data(self.data, 'M') self.data = self._convert_data(self.data, 'M')
# Add columns for local minima and maxima detection # Add columns for local minima and maxima detection
self._add_swing_points() self._add_swing_points()
def _convert_data(self, data, timeframe): def _convert_data(self, data, timeframe):
"""Convert daily data to 'timeframe' timeframe.""" """Convert daily data to 'timeframe' timeframe."""
data['date'] = pd.to_datetime(data['date']) data['date'] = pd.to_datetime(data['date'])
data.set_index('date', inplace=True) data.set_index('date', inplace=True)
weekly = data.resample(timeframe).agg({ weekly = data.resample(timeframe).agg({
'open': 'first', 'open': 'first',
'high': 'max', 'high': 'max',
'low': 'min', 'low': 'min',
'close': 'last', 'close': 'last',
'volume': 'sum' 'volume': 'sum'
}) })
return weekly.reset_index() return weekly.reset_index()
def _add_swing_points(self, window=5): def _add_swing_points(self, window=5):
""" """
Identify swing points (local minima and maxima). Identify swing points (local minima and maxima).
Parameters: Parameters:
- window: The window size for local minima/maxima detection - window: The window size for local minima/maxima detection
""" """
# Set the index to make calculations easier # Set the index to make calculations easier
if 'date' in self.data.columns: if 'date' in self.data.columns:
self.data.set_index('date', inplace=True) self.data.set_index('date', inplace=True)
# Detect local minima (swing lows) # Detect local minima (swing lows)
min_idx = argrelextrema(self.data['low'].values, np.less, order=window)[0] min_idx = argrelextrema(self.data['low'].values, np.less, order=window)[0]
self.data['swing_low'] = False self.data['swing_low'] = False
self.data.iloc[min_idx, self.data.columns.get_loc('swing_low')] = True self.data.iloc[min_idx, self.data.columns.get_loc('swing_low')] = True
# Detect local maxima (swing highs) # Detect local maxima (swing highs)
max_idx = argrelextrema(self.data['high'].values, np.greater, order=window)[0] max_idx = argrelextrema(self.data['high'].values, np.greater, order=window)[0]
self.data['swing_high'] = False self.data['swing_high'] = False
self.data.iloc[max_idx, self.data.columns.get_loc('swing_high')] = True self.data.iloc[max_idx, self.data.columns.get_loc('swing_high')] = True
# Reset index # Reset index
self.data.reset_index(inplace=True) self.data.reset_index(inplace=True)
def find_cycle_lows(self): def find_cycle_lows(self):
"""Find all swing lows which represent cycle lows.""" """Find all swing lows which represent cycle lows."""
swing_low_dates = self.data[self.data['swing_low']]['date'].values swing_low_dates = self.data[self.data['swing_low']]['date'].values
return swing_low_dates return swing_low_dates
def calculate_cycle_lengths(self): def calculate_cycle_lengths(self):
"""Calculate the lengths of each cycle between consecutive lows.""" """Calculate the lengths of each cycle between consecutive lows."""
swing_low_indices = np.where(self.data['swing_low'])[0] swing_low_indices = np.where(self.data['swing_low'])[0]
cycle_lengths = np.diff(swing_low_indices) cycle_lengths = np.diff(swing_low_indices)
return cycle_lengths return cycle_lengths
def get_average_cycle_length(self): def get_average_cycle_length(self):
"""Calculate the average cycle length.""" """Calculate the average cycle length."""
cycle_lengths = self.calculate_cycle_lengths() cycle_lengths = self.calculate_cycle_lengths()
if len(cycle_lengths) > 0: if len(cycle_lengths) > 0:
return np.mean(cycle_lengths) return np.mean(cycle_lengths)
return None return None
def get_cycle_window(self, tolerance=0.10): def get_cycle_window(self, tolerance=0.10):
""" """
Get the cycle window with the specified tolerance. Get the cycle window with the specified tolerance.
Parameters: Parameters:
- tolerance: The tolerance as a percentage (default: 10%) - tolerance: The tolerance as a percentage (default: 10%)
Returns: Returns:
- tuple: (min_cycle_length, avg_cycle_length, max_cycle_length) - tuple: (min_cycle_length, avg_cycle_length, max_cycle_length)
""" """
avg_length = self.get_average_cycle_length() avg_length = self.get_average_cycle_length()
if avg_length is not None: if avg_length is not None:
min_length = avg_length * (1 - tolerance) min_length = avg_length * (1 - tolerance)
max_length = avg_length * (1 + tolerance) max_length = avg_length * (1 + tolerance)
return (min_length, avg_length, max_length) return (min_length, avg_length, max_length)
return None return None
def detect_two_drives_pattern(self, lookback=10): def detect_two_drives_pattern(self, lookback=10):
""" """
Detect 2-drives pattern: a swing low, counter trend bounce, and a lower low. Detect 2-drives pattern: a swing low, counter trend bounce, and a lower low.
Parameters: Parameters:
- lookback: Number of periods to look back - lookback: Number of periods to look back
Returns: Returns:
- list: Indices where 2-drives patterns are detected - list: Indices where 2-drives patterns are detected
""" """
patterns = [] patterns = []
for i in range(lookback, len(self.data) - 1): for i in range(lookback, len(self.data) - 1):
if not self.data.iloc[i]['swing_low']: if not self.data.iloc[i]['swing_low']:
continue continue
# Get the segment of data to check for pattern # Get the segment of data to check for pattern
segment = self.data.iloc[i-lookback:i+1] segment = self.data.iloc[i-lookback:i+1]
swing_lows = segment[segment['swing_low']]['low'].values swing_lows = segment[segment['swing_low']]['low'].values
if len(swing_lows) >= 2 and swing_lows[-1] < swing_lows[-2]: if len(swing_lows) >= 2 and swing_lows[-1] < swing_lows[-2]:
# Check if there was a bounce between the two lows # Check if there was a bounce between the two lows
between_lows = segment.iloc[-len(swing_lows):-1] between_lows = segment.iloc[-len(swing_lows):-1]
if len(between_lows) > 0 and max(between_lows['high']) > swing_lows[-2]: if len(between_lows) > 0 and max(between_lows['high']) > swing_lows[-2]:
patterns.append(i) patterns.append(i)
return patterns return patterns
def detect_v_shaped_lows(self, window=5, threshold=0.02): def detect_v_shaped_lows(self, window=5, threshold=0.02):
""" """
Detect V-shaped cycle lows (sharp decline followed by sharp rise). Detect V-shaped cycle lows (sharp decline followed by sharp rise).
Parameters: Parameters:
- window: Window to look for sharp price changes - window: Window to look for sharp price changes
- threshold: Percentage change threshold to consider 'sharp' - threshold: Percentage change threshold to consider 'sharp'
Returns: Returns:
- list: Indices where V-shaped patterns are detected - list: Indices where V-shaped patterns are detected
""" """
patterns = [] patterns = []
# Find all swing lows # Find all swing lows
swing_low_indices = np.where(self.data['swing_low'])[0] swing_low_indices = np.where(self.data['swing_low'])[0]
for idx in swing_low_indices: for idx in swing_low_indices:
# Need enough data points before and after # Need enough data points before and after
if idx < window or idx + window >= len(self.data): if idx < window or idx + window >= len(self.data):
continue continue
# Get the low price at this swing low # Get the low price at this swing low
low_price = self.data.iloc[idx]['low'] low_price = self.data.iloc[idx]['low']
# Check for sharp decline before low (at least window bars before) # Check for sharp decline before low (at least window bars before)
before_segment = self.data.iloc[max(0, idx-window):idx] before_segment = self.data.iloc[max(0, idx-window):idx]
if len(before_segment) > 0: if len(before_segment) > 0:
max_before = before_segment['high'].max() max_before = before_segment['high'].max()
decline = (max_before - low_price) / max_before decline = (max_before - low_price) / max_before
# Check for sharp rise after low (at least window bars after) # Check for sharp rise after low (at least window bars after)
after_segment = self.data.iloc[idx+1:min(len(self.data), idx+window+1)] after_segment = self.data.iloc[idx+1:min(len(self.data), idx+window+1)]
if len(after_segment) > 0: if len(after_segment) > 0:
max_after = after_segment['high'].max() max_after = after_segment['high'].max()
rise = (max_after - low_price) / low_price rise = (max_after - low_price) / low_price
# Both decline and rise must exceed threshold to be considered V-shaped # Both decline and rise must exceed threshold to be considered V-shaped
if decline > threshold and rise > threshold: if decline > threshold and rise > threshold:
patterns.append(idx) patterns.append(idx)
return patterns return patterns
def plot_cycles(self, pattern_detection=None, title_suffix=''): def plot_cycles(self, pattern_detection=None, title_suffix=''):
""" """
Plot the price data with cycle lows and detected patterns. Plot the price data with cycle lows and detected patterns.
Parameters: Parameters:
- pattern_detection: 'two_drives', 'v_shape', or None - pattern_detection: 'two_drives', 'v_shape', or None
- title_suffix: Optional suffix for the plot title - title_suffix: Optional suffix for the plot title
""" """
plt.figure(figsize=(14, 7)) plt.figure(figsize=(14, 7))
# Determine the date column name (could be 'date' or 'datetime') # Determine the date column name (could be 'date' or 'datetime')
date_col = 'date' if 'date' in self.data.columns else 'datetime' date_col = 'date' if 'date' in self.data.columns else 'datetime'
# Plot price data # Plot price data
plt.plot(self.data[date_col], self.data['close'], label='Close Price') plt.plot(self.data[date_col], self.data['close'], label='Close Price')
# Calculate a consistent vertical position for indicators based on price range # Calculate a consistent vertical position for indicators based on price range
price_range = self.data['close'].max() - self.data['close'].min() price_range = self.data['close'].max() - self.data['close'].min()
indicator_offset = price_range * 0.01 # 1% of price range indicator_offset = price_range * 0.01 # 1% of price range
# Plot cycle lows (now at a fixed offset below the low price) # Plot cycle lows (now at a fixed offset below the low price)
swing_lows = self.data[self.data['swing_low']] swing_lows = self.data[self.data['swing_low']]
plt.scatter(swing_lows[date_col], swing_lows['low'] - indicator_offset, plt.scatter(swing_lows[date_col], swing_lows['low'] - indicator_offset,
color='green', marker='^', s=100, label='Cycle Lows') color='green', marker='^', s=100, label='Cycle Lows')
# Plot specific patterns if requested # Plot specific patterns if requested
if 'two_drives' in pattern_detection: if 'two_drives' in pattern_detection:
pattern_indices = self.detect_two_drives_pattern() pattern_indices = self.detect_two_drives_pattern()
if pattern_indices: if pattern_indices:
patterns = self.data.iloc[pattern_indices] patterns = self.data.iloc[pattern_indices]
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2, plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
color='red', marker='o', s=150, label='Two Drives Pattern') color='red', marker='o', s=150, label='Two Drives Pattern')
elif 'v_shape' in pattern_detection: elif 'v_shape' in pattern_detection:
pattern_indices = self.detect_v_shaped_lows() pattern_indices = self.detect_v_shaped_lows()
if pattern_indices: if pattern_indices:
patterns = self.data.iloc[pattern_indices] patterns = self.data.iloc[pattern_indices]
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2, plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
color='purple', marker='o', s=150, label='V-Shape Pattern') color='purple', marker='o', s=150, label='V-Shape Pattern')
# Add cycle lengths and averages # Add cycle lengths and averages
cycle_lengths = self.calculate_cycle_lengths() cycle_lengths = self.calculate_cycle_lengths()
avg_cycle = self.get_average_cycle_length() avg_cycle = self.get_average_cycle_length()
cycle_window = self.get_cycle_window() cycle_window = self.get_cycle_window()
window_text = "" window_text = ""
if cycle_window: if cycle_window:
window_text = f"Tolerance Window: [{cycle_window[0]:.2f} - {cycle_window[2]:.2f}]" window_text = f"Tolerance Window: [{cycle_window[0]:.2f} - {cycle_window[2]:.2f}]"
plt.title(f"Detected Cycles - {self.timeframe.capitalize()} Timeframe {title_suffix}\n" plt.title(f"Detected Cycles - {self.timeframe.capitalize()} Timeframe {title_suffix}\n"
f"Average Cycle Length: {avg_cycle:.2f} periods, {window_text}") f"Average Cycle Length: {avg_cycle:.2f} periods, {window_text}")
plt.legend() plt.legend()
plt.grid(True) plt.grid(True)
plt.show() plt.show()
# Usage example: # Usage example:
# 1. Load your data # 1. Load your data
# data = pd.read_csv('your_price_data.csv') # data = pd.read_csv('your_price_data.csv')
# 2. Create cycle detector instances for different timeframes # 2. Create cycle detector instances for different timeframes
# weekly_detector = CycleDetector(data, timeframe='weekly') # weekly_detector = CycleDetector(data, timeframe='weekly')
# daily_detector = CycleDetector(data, timeframe='daily') # daily_detector = CycleDetector(data, timeframe='daily')
# 3. Analyze cycles # 3. Analyze cycles
# weekly_cycle_length = weekly_detector.get_average_cycle_length() # weekly_cycle_length = weekly_detector.get_average_cycle_length()
# daily_cycle_length = daily_detector.get_average_cycle_length() # daily_cycle_length = daily_detector.get_average_cycle_length()
# 4. Detect patterns # 4. Detect patterns
# two_drives = weekly_detector.detect_two_drives_pattern() # two_drives = weekly_detector.detect_two_drives_pattern()
# v_shapes = daily_detector.detect_v_shaped_lows() # v_shapes = daily_detector.detect_v_shaped_lows()
# 5. Visualize # 5. Visualize
# weekly_detector.plot_cycles(pattern_detection='two_drives') # weekly_detector.plot_cycles(pattern_detection='two_drives')
# daily_detector.plot_cycles(pattern_detection='v_shape') # daily_detector.plot_cycles(pattern_detection='v_shape')

View File

@ -6,6 +6,7 @@ from cycle_detector import CycleDetector
# Load data from CSV file instead of database # Load data from CSV file instead of database
data = pd.read_csv('data/btcusd_1-day_data.csv') data = pd.read_csv('data/btcusd_1-day_data.csv')
# Convert datetime column to datetime type # Convert datetime column to datetime type
start_date = pd.to_datetime('2025-04-01') start_date = pd.to_datetime('2025-04-01')
stop_date = pd.to_datetime('2025-05-06') stop_date = pd.to_datetime('2025-05-06')

View File

@ -1,259 +1,259 @@
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import ta import ta
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.dates as mdates import matplotlib.dates as mdates
import logging import logging
import mplfinance as mpf import mplfinance as mpf
from matplotlib.patches import Rectangle from matplotlib.patches import Rectangle
class TrendDetectorMACD: class TrendDetectorMACD:
def __init__(self, data, verbose=False): def __init__(self, data, verbose=False):
self.data = data self.data = data
self.verbose = verbose self.verbose = verbose
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s') format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger('TrendDetector') self.logger = logging.getLogger('TrendDetector')
# Convert data to pandas DataFrame if it's not already # Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame): if not isinstance(self.data, pd.DataFrame):
if isinstance(self.data, list): if isinstance(self.data, list):
self.logger.info("Converting list to DataFrame") self.logger.info("Converting list to DataFrame")
self.data = pd.DataFrame({'close': self.data}) self.data = pd.DataFrame({'close': self.data})
else: else:
self.logger.error("Invalid data format provided") self.logger.error("Invalid data format provided")
raise ValueError("Data must be a pandas DataFrame or a list") raise ValueError("Data must be a pandas DataFrame or a list")
self.logger.info(f"Initialized TrendDetector with {len(self.data)} data points") self.logger.info(f"Initialized TrendDetector with {len(self.data)} data points")
def detect_trends_MACD_signal(self): def detect_trends_MACD_signal(self):
self.logger.info("Starting trend detection") self.logger.info("Starting trend detection")
if len(self.data) < 3: if len(self.data) < 3:
self.logger.warning("Not enough data points for trend detection") self.logger.warning("Not enough data points for trend detection")
return {"error": "Not enough data points for trend detection"} return {"error": "Not enough data points for trend detection"}
# Create a copy of the DataFrame to avoid modifying the original # Create a copy of the DataFrame to avoid modifying the original
df = self.data.copy() df = self.data.copy()
self.logger.info("Created copy of input data") self.logger.info("Created copy of input data")
# If 'close' column doesn't exist, try to use a relevant column # If 'close' column doesn't exist, try to use a relevant column
if 'close' not in df.columns and len(df.columns) > 0: if 'close' not in df.columns and len(df.columns) > 0:
self.logger.info(f"'close' column not found, using {df.columns[0]} instead") self.logger.info(f"'close' column not found, using {df.columns[0]} instead")
df['close'] = df[df.columns[0]] # Use the first column as 'close' df['close'] = df[df.columns[0]] # Use the first column as 'close'
# Add trend indicators # Add trend indicators
self.logger.info("Calculating MACD indicators") self.logger.info("Calculating MACD indicators")
# Moving Average Convergence Divergence (MACD) # Moving Average Convergence Divergence (MACD)
df['macd'] = ta.trend.macd(df['close']) df['macd'] = ta.trend.macd(df['close'])
df['macd_signal'] = ta.trend.macd_signal(df['close']) df['macd_signal'] = ta.trend.macd_signal(df['close'])
df['macd_diff'] = ta.trend.macd_diff(df['close']) df['macd_diff'] = ta.trend.macd_diff(df['close'])
# Directional Movement Index (DMI) # Directional Movement Index (DMI)
if all(col in df.columns for col in ['high', 'low', 'close']): if all(col in df.columns for col in ['high', 'low', 'close']):
self.logger.info("Calculating ADX indicators") self.logger.info("Calculating ADX indicators")
df['adx'] = ta.trend.adx(df['high'], df['low'], df['close']) df['adx'] = ta.trend.adx(df['high'], df['low'], df['close'])
df['adx_pos'] = ta.trend.adx_pos(df['high'], df['low'], df['close']) df['adx_pos'] = ta.trend.adx_pos(df['high'], df['low'], df['close'])
df['adx_neg'] = ta.trend.adx_neg(df['high'], df['low'], df['close']) df['adx_neg'] = ta.trend.adx_neg(df['high'], df['low'], df['close'])
# Identify trend changes # Identify trend changes
self.logger.info("Identifying trend changes") self.logger.info("Identifying trend changes")
df['trend'] = np.where(df['macd'] > df['macd_signal'], 'up', 'down') df['trend'] = np.where(df['macd'] > df['macd_signal'], 'up', 'down')
df['trend_change'] = df['trend'] != df['trend'].shift(1) df['trend_change'] = df['trend'] != df['trend'].shift(1)
# Generate trend segments # Generate trend segments
self.logger.info("Generating trend segments") self.logger.info("Generating trend segments")
trends = [] trends = []
trend_start = 0 trend_start = 0
for i in range(1, len(df)): for i in range(1, len(df)):
if df['trend_change'].iloc[i]: if df['trend_change'].iloc[i]:
if i > trend_start: if i > trend_start:
trends.append({ trends.append({
"type": df['trend'].iloc[i-1], "type": df['trend'].iloc[i-1],
"start_index": trend_start, "start_index": trend_start,
"end_index": i-1, "end_index": i-1,
"start_value": df['close'].iloc[trend_start], "start_value": df['close'].iloc[trend_start],
"end_value": df['close'].iloc[i-1] "end_value": df['close'].iloc[i-1]
}) })
trend_start = i trend_start = i
# Add the last trend # Add the last trend
if trend_start < len(df): if trend_start < len(df):
trends.append({ trends.append({
"type": df['trend'].iloc[-1], "type": df['trend'].iloc[-1],
"start_index": trend_start, "start_index": trend_start,
"end_index": len(df)-1, "end_index": len(df)-1,
"start_value": df['close'].iloc[trend_start], "start_value": df['close'].iloc[trend_start],
"end_value": df['close'].iloc[-1] "end_value": df['close'].iloc[-1]
}) })
self.logger.info(f"Detected {len(trends)} trend segments") self.logger.info(f"Detected {len(trends)} trend segments")
return trends return trends
def get_strongest_trend(self): def get_strongest_trend(self):
self.logger.info("Finding strongest trend") self.logger.info("Finding strongest trend")
trends = self.detect_trends_MACD_signal() trends = self.detect_trends_MACD_signal()
if isinstance(trends, dict) and "error" in trends: if isinstance(trends, dict) and "error" in trends:
self.logger.warning(f"Error in trend detection: {trends['error']}") self.logger.warning(f"Error in trend detection: {trends['error']}")
return trends return trends
if not trends: if not trends:
self.logger.info("No significant trends detected") self.logger.info("No significant trends detected")
return {"message": "No significant trends detected"} return {"message": "No significant trends detected"}
strongest = max(trends, key=lambda x: abs(x["end_value"] - x["start_value"])) strongest = max(trends, key=lambda x: abs(x["end_value"] - x["start_value"]))
self.logger.info(f"Strongest trend: {strongest['type']} from index {strongest['start_index']} to {strongest['end_index']}") self.logger.info(f"Strongest trend: {strongest['type']} from index {strongest['start_index']} to {strongest['end_index']}")
return strongest return strongest
def plot_trends(self, trends): def plot_trends(self, trends):
""" """
Plot price data with identified trends highlighted using candlestick charts. Plot price data with identified trends highlighted using candlestick charts.
""" """
self.logger.info("Plotting trends with candlesticks") self.logger.info("Plotting trends with candlesticks")
if isinstance(trends, dict) and "error" in trends: if isinstance(trends, dict) and "error" in trends:
self.logger.error(trends["error"]) self.logger.error(trends["error"])
print(trends["error"]) print(trends["error"])
return return
if not trends: if not trends:
self.logger.warning("No significant trends detected for plotting") self.logger.warning("No significant trends detected for plotting")
print("No significant trends detected") print("No significant trends detected")
return return
# Create a figure with 2 subplots that share the x-axis # Create a figure with 2 subplots that share the x-axis
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [2, 1]}, sharex=True) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
self.logger.info("Creating plot figure with shared x-axis") self.logger.info("Creating plot figure with shared x-axis")
# Prepare data for candlestick chart # Prepare data for candlestick chart
df = self.data.copy() df = self.data.copy()
# Ensure required columns exist for candlestick # Ensure required columns exist for candlestick
required_cols = ['open', 'high', 'low', 'close'] required_cols = ['open', 'high', 'low', 'close']
if not all(col in df.columns for col in required_cols): if not all(col in df.columns for col in required_cols):
self.logger.warning("Missing required columns for candlestick. Defaulting to line chart.") self.logger.warning("Missing required columns for candlestick. Defaulting to line chart.")
if 'close' in df.columns: if 'close' in df.columns:
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'], ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
df['close'], color='black', alpha=0.7, linewidth=1, label='Price') df['close'], color='black', alpha=0.7, linewidth=1, label='Price')
else: else:
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'], ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
df[df.columns[0]], color='black', alpha=0.7, linewidth=1, label='Price') df[df.columns[0]], color='black', alpha=0.7, linewidth=1, label='Price')
else: else:
# Get x values (dates if available, otherwise indices) # Get x values (dates if available, otherwise indices)
if 'datetime' in df.columns: if 'datetime' in df.columns:
x_label = 'Date' x_label = 'Date'
# Format date axis # Format date axis
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
fig.autofmt_xdate() fig.autofmt_xdate()
self.logger.info("Using datetime for x-axis") self.logger.info("Using datetime for x-axis")
# For candlestick, ensure datetime is the index # For candlestick, ensure datetime is the index
if df.index.name != 'datetime': if df.index.name != 'datetime':
df = df.set_index('datetime') df = df.set_index('datetime')
else: else:
x_label = 'Index' x_label = 'Index'
self.logger.info("Using index for x-axis") self.logger.info("Using index for x-axis")
# Plot candlestick chart # Plot candlestick chart
up_color = 'green' up_color = 'green'
down_color = 'red' down_color = 'red'
# Draw candlesticks manually # Draw candlesticks manually
width = 0.6 width = 0.6
for i in range(len(df)): for i in range(len(df)):
# Get OHLC values for this candle # Get OHLC values for this candle
open_val = df['open'].iloc[i] open_val = df['open'].iloc[i]
close_val = df['close'].iloc[i] close_val = df['close'].iloc[i]
high_val = df['high'].iloc[i] high_val = df['high'].iloc[i]
low_val = df['low'].iloc[i] low_val = df['low'].iloc[i]
idx = df.index[i] idx = df.index[i]
# Determine candle color # Determine candle color
color = up_color if close_val >= open_val else down_color color = up_color if close_val >= open_val else down_color
# Plot candle body # Plot candle body
body_height = abs(close_val - open_val) body_height = abs(close_val - open_val)
bottom = min(open_val, close_val) bottom = min(open_val, close_val)
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8) rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
ax1.add_patch(rect) ax1.add_patch(rect)
# Plot candle wicks # Plot candle wicks
ax1.plot([i, i], [low_val, high_val], color='black', linewidth=1) ax1.plot([i, i], [low_val, high_val], color='black', linewidth=1)
# Set appropriate x-axis limits # Set appropriate x-axis limits
ax1.set_xlim(-0.5, len(df) - 0.5) ax1.set_xlim(-0.5, len(df) - 0.5)
# Highlight each trend with a different color # Highlight each trend with a different color
self.logger.info("Highlighting trends on plot") self.logger.info("Highlighting trends on plot")
for trend in trends: for trend in trends:
start_idx = trend['start_index'] start_idx = trend['start_index']
end_idx = trend['end_index'] end_idx = trend['end_index']
trend_type = trend['type'] trend_type = trend['type']
# Get x-coordinates for trend plotting # Get x-coordinates for trend plotting
x_start = start_idx x_start = start_idx
x_end = end_idx x_end = end_idx
# Get y-coordinates for trend line # Get y-coordinates for trend line
if 'close' in df.columns: if 'close' in df.columns:
y_start = df['close'].iloc[start_idx] y_start = df['close'].iloc[start_idx]
y_end = df['close'].iloc[end_idx] y_end = df['close'].iloc[end_idx]
else: else:
y_start = df[df.columns[0]].iloc[start_idx] y_start = df[df.columns[0]].iloc[start_idx]
y_end = df[df.columns[0]].iloc[end_idx] y_end = df[df.columns[0]].iloc[end_idx]
# Choose color based on trend type # Choose color based on trend type
color = 'green' if trend_type == 'up' else 'red' color = 'green' if trend_type == 'up' else 'red'
# Plot trend line # Plot trend line
ax1.plot([x_start, x_end], [y_start, y_end], color=color, linewidth=2, ax1.plot([x_start, x_end], [y_start, y_end], color=color, linewidth=2,
label=f"{trend_type.capitalize()} Trend" if f"{trend_type.capitalize()} Trend" not in ax1.get_legend_handles_labels()[1] else "") label=f"{trend_type.capitalize()} Trend" if f"{trend_type.capitalize()} Trend" not in ax1.get_legend_handles_labels()[1] else "")
# Add markers at start and end points # Add markers at start and end points
ax1.scatter(x_start, y_start, color=color, marker='o', s=50) ax1.scatter(x_start, y_start, color=color, marker='o', s=50)
ax1.scatter(x_end, y_end, color=color, marker='s', s=50) ax1.scatter(x_end, y_end, color=color, marker='s', s=50)
# Configure first subplot # Configure first subplot
ax1.set_title('Price with Trends (Candlestick)', fontsize=16) ax1.set_title('Price with Trends (Candlestick)', fontsize=16)
ax1.set_ylabel('Price', fontsize=14) ax1.set_ylabel('Price', fontsize=14)
ax1.grid(alpha=0.3) ax1.grid(alpha=0.3)
ax1.legend() ax1.legend()
# Create MACD in second subplot # Create MACD in second subplot
self.logger.info("Creating MACD subplot") self.logger.info("Creating MACD subplot")
# Calculate MACD indicators if not already present # Calculate MACD indicators if not already present
if 'macd' not in df.columns: if 'macd' not in df.columns:
if 'close' not in df.columns and len(df.columns) > 0: if 'close' not in df.columns and len(df.columns) > 0:
df['close'] = df[df.columns[0]] df['close'] = df[df.columns[0]]
df['macd'] = ta.trend.macd(df['close']) df['macd'] = ta.trend.macd(df['close'])
df['macd_signal'] = ta.trend.macd_signal(df['close']) df['macd_signal'] = ta.trend.macd_signal(df['close'])
df['macd_diff'] = ta.trend.macd_diff(df['close']) df['macd_diff'] = ta.trend.macd_diff(df['close'])
# Plot MACD components on second subplot # Plot MACD components on second subplot
x_indices = np.arange(len(df)) x_indices = np.arange(len(df))
ax2.plot(x_indices, df['macd'], label='MACD', color='blue') ax2.plot(x_indices, df['macd'], label='MACD', color='blue')
ax2.plot(x_indices, df['macd_signal'], label='Signal', color='orange') ax2.plot(x_indices, df['macd_signal'], label='Signal', color='orange')
# Plot MACD histogram # Plot MACD histogram
for i in range(len(df)): for i in range(len(df)):
if df['macd_diff'].iloc[i] >= 0: if df['macd_diff'].iloc[i] >= 0:
ax2.bar(i, df['macd_diff'].iloc[i], color='green', alpha=0.5, width=0.8) ax2.bar(i, df['macd_diff'].iloc[i], color='green', alpha=0.5, width=0.8)
else: else:
ax2.bar(i, df['macd_diff'].iloc[i], color='red', alpha=0.5, width=0.8) ax2.bar(i, df['macd_diff'].iloc[i], color='red', alpha=0.5, width=0.8)
ax2.set_title('MACD Indicator', fontsize=16) ax2.set_title('MACD Indicator', fontsize=16)
ax2.set_xlabel(x_label, fontsize=14) ax2.set_xlabel(x_label, fontsize=14)
ax2.set_ylabel('MACD', fontsize=14) ax2.set_ylabel('MACD', fontsize=14)
ax2.grid(alpha=0.3) ax2.grid(alpha=0.3)
ax2.legend() ax2.legend()
# Enable synchronized zooming # Enable synchronized zooming
plt.tight_layout() plt.tight_layout()
plt.subplots_adjust(hspace=0.1) plt.subplots_adjust(hspace=0.1)
plt.show() plt.show()
return plt return plt

View File

@ -1,205 +1,255 @@
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import logging import logging
from scipy.signal import find_peaks from scipy.signal import find_peaks
import matplotlib.dates as mdates import matplotlib.dates as mdates
from scipy import stats from scipy import stats
from scipy import stats
class TrendDetectorSimple:
def __init__(self, data, verbose=False): class TrendDetectorSimple:
""" def __init__(self, data, verbose=False):
Initialize the TrendDetectorSimple class. """
Initialize the TrendDetectorSimple class.
Parameters:
- data: pandas DataFrame containing price data Parameters:
- verbose: boolean, whether to display detailed logging information - data: pandas DataFrame containing price data
""" - verbose: boolean, whether to display detailed logging information
"""
self.data = data
self.verbose = verbose self.data = data
self.verbose = verbose
# Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, # Plot style configuration
format='%(asctime)s - %(levelname)s - %(message)s') self.plot_style = 'dark_background'
self.logger = logging.getLogger('TrendDetectorSimple') self.bg_color = '#181C27'
self.plot_size = (12, 8)
# Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame): # Candlestick configuration
if isinstance(self.data, list): self.candle_width = 0.6
self.logger.info("Converting list to DataFrame") self.candle_up_color = '#089981'
self.data = pd.DataFrame({'close': self.data}) self.candle_down_color = '#F23645'
else: self.candle_alpha = 0.8
self.logger.error("Invalid data format provided") self.wick_width = 1
raise ValueError("Data must be a pandas DataFrame or a list")
# Marker configuration
self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points") self.min_marker = '^'
self.min_color = 'red'
def detect_trends(self): self.min_size = 100
""" self.max_marker = 'v'
Detect trends by identifying local minima and maxima in the price data self.max_color = 'green'
using scipy.signal.find_peaks. self.max_size = 100
self.marker_zorder = 100
Parameters:
- prominence: float, required prominence of peaks (relative to the price range) # Line configuration
- width: int, required width of peaks in data points self.line_width = 2
self.min_line_style = 'g--' # green dashed
Returns: self.max_line_style = 'r--' # red dashed
- DataFrame with columns for timestamps, prices, and trend indicators self.sma7_line_style = 'y-' # yellow solid
""" self.sma15_line_style = 'm-' # magenta solid
self.logger.info(f"Detecting trends")
# Text configuration
df = self.data.copy() self.title_size = 14
close_prices = df['close'].values self.title_color = 'white'
self.axis_label_size = 12
max_peaks, _ = find_peaks(close_prices) self.axis_label_color = 'white'
min_peaks, _ = find_peaks(-close_prices)
# Legend configuration
self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima") self.legend_loc = 'best'
self.legend_bg_color = '#333333'
df['is_min'] = False
df['is_max'] = False # Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
for peak in max_peaks: format='%(asctime)s - %(levelname)s - %(message)s')
df.at[peak, 'is_max'] = True self.logger = logging.getLogger('TrendDetectorSimple')
for peak in min_peaks:
df.at[peak, 'is_min'] = True # Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame):
result = df[['datetime', 'close', 'is_min', 'is_max']].copy() if isinstance(self.data, list):
self.logger.info("Converting list to DataFrame")
# Perform linear regression on min_peaks and max_peaks self.data = pd.DataFrame({'close': self.data})
self.logger.info("Performing linear regression on min and max peaks") else:
min_prices = df['close'].iloc[min_peaks].values self.logger.error("Invalid data format provided")
max_prices = df['close'].iloc[max_peaks].values raise ValueError("Data must be a pandas DataFrame or a list")
# Linear regression for min peaks if we have at least 2 points self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points")
min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
# Linear regression for max peaks if we have at least 2 points def detect_trends(self):
max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices) def detect_trends(self):
"""
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods Detect trends by identifying local minima and maxima in the price data
self.logger.info("Calculating SMA-7 and SMA-15") using scipy.signal.find_peaks.
# Calculate SMA values and exclude NaN values Parameters:
sma_7 = df['close'].rolling(window=7).mean().dropna().values - prominence: float, required prominence of peaks (relative to the price range)
sma_15 = df['close'].rolling(window=15).mean().dropna().values - width: int, required width of peaks in data points
# Add SMA values to regression_results Returns:
analysis_results = {} - DataFrame with columns for timestamps, prices, and trend indicators
analysis_results['linear_regression'] = { """
'min': { self.logger.info(f"Detecting trends")
'slope': min_slope, self.logger.info(f"Detecting trends")
'intercept': min_intercept,
'r_squared': min_r_value ** 2 df = self.data.copy()
}, close_prices = df['close'].values
'max': {
'slope': max_slope, max_peaks, _ = find_peaks(close_prices)
'intercept': max_intercept, min_peaks, _ = find_peaks(-close_prices)
'r_squared': max_r_value ** 2 max_peaks, _ = find_peaks(close_prices)
} min_peaks, _ = find_peaks(-close_prices)
}
analysis_results['sma'] = { self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima")
'7': sma_7,
'15': sma_15 df['is_min'] = False
} df['is_max'] = False
self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}") for peak in max_peaks:
self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}") df.at[peak, 'is_max'] = True
for peak in min_peaks:
return result, analysis_results df.at[peak, 'is_min'] = True
def plot_trends(self, trend_data, analysis_results): result = df[['datetime', 'close', 'is_min', 'is_max']].copy()
"""
Plot the price data with detected trends using a candlestick chart. # Perform linear regression on min_peaks and max_peaks
self.logger.info("Performing linear regression on min and max peaks")
Parameters: min_prices = df['close'].iloc[min_peaks].values
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called. max_prices = df['close'].iloc[max_peaks].values
Returns: # Linear regression for min peaks if we have at least 2 points
- None (displays the plot) min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
""" # Linear regression for max peaks if we have at least 2 points
import matplotlib.pyplot as plt max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
from matplotlib.patches import Rectangle
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
# Create the figure and axis self.logger.info("Calculating SMA-7 and SMA-15")
fig, ax = plt.subplots(figsize=(12, 8))
sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values
# Create a copy of the data sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values
df = self.data.copy()
analysis_results = {}
# Plot candlestick chart analysis_results['linear_regression'] = {
up_color = 'green' 'min': {
down_color = 'red' 'slope': min_slope,
'intercept': min_intercept,
# Draw candlesticks manually 'r_squared': min_r_value ** 2
width = 0.6 },
x_values = range(len(df)) 'max': {
'slope': max_slope,
for i in range(len(df)): 'intercept': max_intercept,
# Get OHLC values for this candle 'r_squared': max_r_value ** 2
open_val = df['open'].iloc[i] }
close_val = df['close'].iloc[i] }
high_val = df['high'].iloc[i] analysis_results['sma'] = {
low_val = df['low'].iloc[i] '7': sma_7,
'15': sma_15
# Determine candle color }
color = up_color if close_val >= open_val else down_color
self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}")
# Plot candle body self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}")
body_height = abs(close_val - open_val)
bottom = min(open_val, close_val) return result, analysis_results
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
ax.add_patch(rect) def plot_trends(self, trend_data, analysis_results):
def plot_trends(self, trend_data, analysis_results):
# Plot candle wicks """
ax.plot([i, i], [low_val, high_val], color='black', linewidth=1) Plot the price data with detected trends using a candlestick chart.
min_indices = trend_data.index[trend_data['is_min'] == True].tolist() Parameters:
if min_indices: - trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
min_y = [df['close'].iloc[i] for i in min_indices]
ax.scatter(min_indices, min_y, color='darkred', s=200, marker='^', label='Local Minima', zorder=100) Returns:
- None (displays the plot)
max_indices = trend_data.index[trend_data['is_max'] == True].tolist() """
if max_indices: import matplotlib.pyplot as plt
max_y = [df['close'].iloc[i] for i in max_indices] from matplotlib.patches import Rectangle
ax.scatter(max_indices, max_y, color='darkgreen', s=200, marker='v', label='Local Maxima', zorder=100)
# Create the figure and axis with specified background
if analysis_results: plt.style.use(self.plot_style)
x_vals = np.arange(len(df)) fig, ax = plt.subplots(figsize=self.plot_size)
# Minima regression line (support)
min_slope = analysis_results['linear_regression']['min']['slope'] # Set the custom background color
min_intercept = analysis_results['linear_regression']['min']['intercept'] fig.patch.set_facecolor(self.bg_color)
min_line = min_slope * x_vals + min_intercept ax.set_facecolor(self.bg_color)
ax.plot(x_vals, min_line, 'g--', linewidth=2, label='Minima Regression')
# Create a copy of the data
# Maxima regression line (resistance) df = self.data.copy()
max_slope = analysis_results['linear_regression']['max']['slope']
max_intercept = analysis_results['linear_regression']['max']['intercept'] # Draw candlesticks manually
max_line = max_slope * x_vals + max_intercept x_values = range(len(df))
ax.plot(x_vals, max_line, 'r--', linewidth=2, label='Maxima Regression')
for i in range(len(df)):
# SMA-7 line # Get OHLC values for this candle
sma_7 = analysis_results['sma']['7'] open_val = df['open'].iloc[i]
ax.plot(x_vals, sma_7, 'y-', linewidth=2, label='SMA-7') close_val = df['close'].iloc[i]
high_val = df['high'].iloc[i]
# SMA-15 line low_val = df['low'].iloc[i]
# sma_15 = analysis_results['sma']['15']
# valid_idx_15 = ~np.isnan(sma_15) # Determine candle color
# ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], 'm-', linewidth=2, label='SMA-15') color = self.candle_up_color if close_val >= open_val else self.candle_down_color
# Set title and labels # Plot candle body
ax.set_title('Price Candlestick Chart with Local Minima and Maxima', fontsize=14) body_height = abs(close_val - open_val)
ax.set_xlabel('Date', fontsize=12) bottom = min(open_val, close_val)
ax.set_ylabel('Price', fontsize=12) rect = Rectangle((i - self.candle_width/2, bottom), self.candle_width, body_height,
color=color, alpha=self.candle_alpha)
# Set appropriate x-axis limits ax.add_patch(rect)
ax.set_xlim(-0.5, len(df) - 0.5)
# Plot candle wicks
# Add a legend ax.plot([i, i], [low_val, high_val], color=color, linewidth=self.wick_width)
ax.legend(loc='best')
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
# Adjust layout if min_indices:
plt.tight_layout() min_y = [df['close'].iloc[i] for i in min_indices]
ax.scatter(min_indices, min_y, color=self.min_color, s=self.min_size,
# Show the plot marker=self.min_marker, label='Local Minima', zorder=self.marker_zorder)
plt.show()
max_indices = trend_data.index[trend_data['is_max'] == True].tolist()
if max_indices:
max_y = [df['close'].iloc[i] for i in max_indices]
ax.scatter(max_indices, max_y, color=self.max_color, s=self.max_size,
marker=self.max_marker, label='Local Maxima', zorder=self.marker_zorder)
if analysis_results:
x_vals = np.arange(len(df))
# Minima regression line (support)
min_slope = analysis_results['linear_regression']['min']['slope']
min_intercept = analysis_results['linear_regression']['min']['intercept']
min_line = min_slope * x_vals + min_intercept
ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width,
label='Minima Regression')
# Maxima regression line (resistance)
max_slope = analysis_results['linear_regression']['max']['slope']
max_intercept = analysis_results['linear_regression']['max']['intercept']
max_line = max_slope * x_vals + max_intercept
ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width,
label='Maxima Regression')
# SMA-7 line
sma_7 = analysis_results['sma']['7']
ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width,
label='SMA-7')
# SMA-15 line
sma_15 = analysis_results['sma']['15']
valid_idx_15 = ~np.isnan(sma_15)
ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style,
linewidth=self.line_width, label='SMA-15')
# Set title and labels
ax.set_title('Price Candlestick Chart with Local Minima and Maxima',
fontsize=self.title_size, color=self.title_color)
ax.set_xlabel('Date', fontsize=self.axis_label_size, color=self.axis_label_color)
ax.set_ylabel('Price', fontsize=self.axis_label_size, color=self.axis_label_color)
# Set appropriate x-axis limits
ax.set_xlim(-0.5, len(df) - 0.5)
# Add a legend
ax.legend(loc=self.legend_loc, facecolor=self.legend_bg_color)
# Adjust layout
plt.tight_layout()
# Show the plot
plt.show()