first trend detection WIP

This commit is contained in:
Simon Moisy 2025-05-06 15:24:36 +08:00
parent e2cd746dc6
commit 0bbe308321
6 changed files with 720 additions and 1 deletions

View File

@ -1,2 +1 @@
# Cycles

248
cycle_detector.py Normal file
View File

@ -0,0 +1,248 @@
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import argrelextrema
class CycleDetector:
def __init__(self, data, timeframe='daily'):
"""
Initialize the CycleDetector with price data.
Parameters:
- data: DataFrame with at least 'date' or 'datetime' and 'close' columns
- timeframe: 'daily', 'weekly', or 'monthly'
"""
self.data = data.copy()
self.timeframe = timeframe
# Ensure we have a consistent date column name
if 'datetime' in self.data.columns and 'date' not in self.data.columns:
self.data.rename(columns={'datetime': 'date'}, inplace=True)
# Convert data to specified timeframe if needed
if timeframe == 'weekly' and 'date' in self.data.columns:
self.data = self._convert_data(self.data, 'W')
elif timeframe == 'monthly' and 'date' in self.data.columns:
self.data = self._convert_data(self.data, 'M')
# Add columns for local minima and maxima detection
self._add_swing_points()
def _convert_data(self, data, timeframe):
"""Convert daily data to 'timeframe' timeframe."""
data['date'] = pd.to_datetime(data['date'])
data.set_index('date', inplace=True)
weekly = data.resample(timeframe).agg({
'open': 'first',
'high': 'max',
'low': 'min',
'close': 'last',
'volume': 'sum'
})
return weekly.reset_index()
def _add_swing_points(self, window=5):
"""
Identify swing points (local minima and maxima).
Parameters:
- window: The window size for local minima/maxima detection
"""
# Set the index to make calculations easier
if 'date' in self.data.columns:
self.data.set_index('date', inplace=True)
# Detect local minima (swing lows)
min_idx = argrelextrema(self.data['low'].values, np.less, order=window)[0]
self.data['swing_low'] = False
self.data.iloc[min_idx, self.data.columns.get_loc('swing_low')] = True
# Detect local maxima (swing highs)
max_idx = argrelextrema(self.data['high'].values, np.greater, order=window)[0]
self.data['swing_high'] = False
self.data.iloc[max_idx, self.data.columns.get_loc('swing_high')] = True
# Reset index
self.data.reset_index(inplace=True)
def find_cycle_lows(self):
"""Find all swing lows which represent cycle lows."""
swing_low_dates = self.data[self.data['swing_low']]['date'].values
return swing_low_dates
def calculate_cycle_lengths(self):
"""Calculate the lengths of each cycle between consecutive lows."""
swing_low_indices = np.where(self.data['swing_low'])[0]
cycle_lengths = np.diff(swing_low_indices)
return cycle_lengths
def get_average_cycle_length(self):
"""Calculate the average cycle length."""
cycle_lengths = self.calculate_cycle_lengths()
if len(cycle_lengths) > 0:
return np.mean(cycle_lengths)
return None
def get_cycle_window(self, tolerance=0.10):
"""
Get the cycle window with the specified tolerance.
Parameters:
- tolerance: The tolerance as a percentage (default: 10%)
Returns:
- tuple: (min_cycle_length, avg_cycle_length, max_cycle_length)
"""
avg_length = self.get_average_cycle_length()
if avg_length is not None:
min_length = avg_length * (1 - tolerance)
max_length = avg_length * (1 + tolerance)
return (min_length, avg_length, max_length)
return None
def detect_two_drives_pattern(self, lookback=10):
"""
Detect 2-drives pattern: a swing low, counter trend bounce, and a lower low.
Parameters:
- lookback: Number of periods to look back
Returns:
- list: Indices where 2-drives patterns are detected
"""
patterns = []
for i in range(lookback, len(self.data) - 1):
if not self.data.iloc[i]['swing_low']:
continue
# Get the segment of data to check for pattern
segment = self.data.iloc[i-lookback:i+1]
swing_lows = segment[segment['swing_low']]['low'].values
if len(swing_lows) >= 2 and swing_lows[-1] < swing_lows[-2]:
# Check if there was a bounce between the two lows
between_lows = segment.iloc[-len(swing_lows):-1]
if len(between_lows) > 0 and max(between_lows['high']) > swing_lows[-2]:
patterns.append(i)
return patterns
def detect_v_shaped_lows(self, window=5, threshold=0.02):
"""
Detect V-shaped cycle lows (sharp decline followed by sharp rise).
Parameters:
- window: Window to look for sharp price changes
- threshold: Percentage change threshold to consider 'sharp'
Returns:
- list: Indices where V-shaped patterns are detected
"""
patterns = []
# Find all swing lows
swing_low_indices = np.where(self.data['swing_low'])[0]
for idx in swing_low_indices:
# Need enough data points before and after
if idx < window or idx + window >= len(self.data):
continue
# Get the low price at this swing low
low_price = self.data.iloc[idx]['low']
# Check for sharp decline before low (at least window bars before)
before_segment = self.data.iloc[max(0, idx-window):idx]
if len(before_segment) > 0:
max_before = before_segment['high'].max()
decline = (max_before - low_price) / max_before
# Check for sharp rise after low (at least window bars after)
after_segment = self.data.iloc[idx+1:min(len(self.data), idx+window+1)]
if len(after_segment) > 0:
max_after = after_segment['high'].max()
rise = (max_after - low_price) / low_price
# Both decline and rise must exceed threshold to be considered V-shaped
if decline > threshold and rise > threshold:
patterns.append(idx)
return patterns
def plot_cycles(self, pattern_detection=None, title_suffix=''):
"""
Plot the price data with cycle lows and detected patterns.
Parameters:
- pattern_detection: 'two_drives', 'v_shape', or None
- title_suffix: Optional suffix for the plot title
"""
plt.figure(figsize=(14, 7))
# Determine the date column name (could be 'date' or 'datetime')
date_col = 'date' if 'date' in self.data.columns else 'datetime'
# Plot price data
plt.plot(self.data[date_col], self.data['close'], label='Close Price')
# Calculate a consistent vertical position for indicators based on price range
price_range = self.data['close'].max() - self.data['close'].min()
indicator_offset = price_range * 0.01 # 1% of price range
# Plot cycle lows (now at a fixed offset below the low price)
swing_lows = self.data[self.data['swing_low']]
plt.scatter(swing_lows[date_col], swing_lows['low'] - indicator_offset,
color='green', marker='^', s=100, label='Cycle Lows')
# Plot specific patterns if requested
if 'two_drives' in pattern_detection:
pattern_indices = self.detect_two_drives_pattern()
if pattern_indices:
patterns = self.data.iloc[pattern_indices]
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
color='red', marker='o', s=150, label='Two Drives Pattern')
elif 'v_shape' in pattern_detection:
pattern_indices = self.detect_v_shaped_lows()
if pattern_indices:
patterns = self.data.iloc[pattern_indices]
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
color='purple', marker='o', s=150, label='V-Shape Pattern')
# Add cycle lengths and averages
cycle_lengths = self.calculate_cycle_lengths()
avg_cycle = self.get_average_cycle_length()
cycle_window = self.get_cycle_window()
window_text = ""
if cycle_window:
window_text = f"Tolerance Window: [{cycle_window[0]:.2f} - {cycle_window[2]:.2f}]"
plt.title(f"Detected Cycles - {self.timeframe.capitalize()} Timeframe {title_suffix}\n"
f"Average Cycle Length: {avg_cycle:.2f} periods, {window_text}")
plt.legend()
plt.grid(True)
plt.show()
# Usage example:
# 1. Load your data
# data = pd.read_csv('your_price_data.csv')
# 2. Create cycle detector instances for different timeframes
# weekly_detector = CycleDetector(data, timeframe='weekly')
# daily_detector = CycleDetector(data, timeframe='daily')
# 3. Analyze cycles
# weekly_cycle_length = weekly_detector.get_average_cycle_length()
# daily_cycle_length = daily_detector.get_average_cycle_length()
# 4. Detect patterns
# two_drives = weekly_detector.detect_two_drives_pattern()
# v_shapes = daily_detector.detect_v_shaped_lows()
# 5. Visualize
# weekly_detector.plot_cycles(pattern_detection='two_drives')
# daily_detector.plot_cycles(pattern_detection='v_shape')

52
main.py Normal file
View File

@ -0,0 +1,52 @@
import pandas as pd
from trend_detector_macd import TrendDetectorMACD
from trend_detector_simple import TrendDetectorSimple
from cycle_detector import CycleDetector
# Load data from CSV file instead of database
data = pd.read_csv('data/btcusd_1-day_data.csv')
# Convert datetime column to datetime type
one_month_ago = pd.to_datetime('2025-04-05')
daily_data = data[pd.to_datetime(data['datetime']) >= one_month_ago]
print(f"Number of data points: {len(daily_data)}")
trend_detector = TrendDetectorSimple(daily_data, verbose=True)
trends = trend_detector.detect_trends(width=1)
trend_detector.plot_trends(trends)
#trend_detector = TrendDetectorMACD(daily_data, True)
#trends = trend_detector.detect_trends_MACD_signal()
#trend_detector.plot_trends(trends)
# # Cycle detection (new code)
# print("\n===== CYCLE DETECTION =====")
# # Daily cycles
# daily_detector = CycleDetector(daily_data, timeframe='daily')
# daily_avg_cycle = daily_detector.get_average_cycle_length()
# daily_window = daily_detector.get_cycle_window()
# print(f"Daily Timeframe: Average Cycle Length = {daily_avg_cycle:.2f} days")
# if daily_window:
# print(f"Daily Cycle Window: {daily_window[0]:.2f} to {daily_window[2]:.2f} days")
# weekly_detector = CycleDetector(daily_data, timeframe='weekly')
# weekly_avg_cycle = weekly_detector.get_average_cycle_length()
# weekly_window = weekly_detector.get_cycle_window()
# print(f"\nWeekly Timeframe: Average Cycle Length = {weekly_avg_cycle:.2f} weeks")
# if weekly_window:
# print(f"Weekly Cycle Window: {weekly_window[0]:.2f} to {weekly_window[2]:.2f} weeks")
# # Detect patterns
# two_drives = daily_detector.detect_two_drives_pattern()
# v_shapes = daily_detector.detect_v_shaped_lows()
# print(f"\nDetected {len(two_drives)} 'Two Drives' patterns in daily data")
# print(f"Detected {len(v_shapes)} 'V-Shaped' lows in daily data")
# # Plot cycles with detected patterns
# print("\nPlotting cycles and patterns...")
# daily_detector.plot_cycles(pattern_detection=['two_drives', 'v_shape'], title_suffix='(with Two Drives Pattern)')
# weekly_detector.plot_cycles(title_suffix='(Weekly View)')

17
requirements.txt Normal file
View File

@ -0,0 +1,17 @@
contourpy==1.3.1
cycler==0.12.1
fonttools==4.57.0
greenlet==3.2.1
kiwisolver==1.4.8
matplotlib==3.10.1
numpy==2.2.4
packaging==24.2
pandas==2.2.3
pillow==11.2.1
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytz==2025.2
six==1.17.0
SQLAlchemy==2.0.40
typing_extensions==4.13.2
tzdata==2025.2

259
trend_detector_macd.py Normal file
View File

@ -0,0 +1,259 @@
import pandas as pd
import numpy as np
import ta
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import logging
import mplfinance as mpf
from matplotlib.patches import Rectangle
class TrendDetectorMACD:
def __init__(self, data, verbose=False):
self.data = data
self.verbose = verbose
# Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger('TrendDetector')
# Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame):
if isinstance(self.data, list):
self.logger.info("Converting list to DataFrame")
self.data = pd.DataFrame({'close': self.data})
else:
self.logger.error("Invalid data format provided")
raise ValueError("Data must be a pandas DataFrame or a list")
self.logger.info(f"Initialized TrendDetector with {len(self.data)} data points")
def detect_trends_MACD_signal(self):
self.logger.info("Starting trend detection")
if len(self.data) < 3:
self.logger.warning("Not enough data points for trend detection")
return {"error": "Not enough data points for trend detection"}
# Create a copy of the DataFrame to avoid modifying the original
df = self.data.copy()
self.logger.info("Created copy of input data")
# If 'close' column doesn't exist, try to use a relevant column
if 'close' not in df.columns and len(df.columns) > 0:
self.logger.info(f"'close' column not found, using {df.columns[0]} instead")
df['close'] = df[df.columns[0]] # Use the first column as 'close'
# Add trend indicators
self.logger.info("Calculating MACD indicators")
# Moving Average Convergence Divergence (MACD)
df['macd'] = ta.trend.macd(df['close'])
df['macd_signal'] = ta.trend.macd_signal(df['close'])
df['macd_diff'] = ta.trend.macd_diff(df['close'])
# Directional Movement Index (DMI)
if all(col in df.columns for col in ['high', 'low', 'close']):
self.logger.info("Calculating ADX indicators")
df['adx'] = ta.trend.adx(df['high'], df['low'], df['close'])
df['adx_pos'] = ta.trend.adx_pos(df['high'], df['low'], df['close'])
df['adx_neg'] = ta.trend.adx_neg(df['high'], df['low'], df['close'])
# Identify trend changes
self.logger.info("Identifying trend changes")
df['trend'] = np.where(df['macd'] > df['macd_signal'], 'up', 'down')
df['trend_change'] = df['trend'] != df['trend'].shift(1)
# Generate trend segments
self.logger.info("Generating trend segments")
trends = []
trend_start = 0
for i in range(1, len(df)):
if df['trend_change'].iloc[i]:
if i > trend_start:
trends.append({
"type": df['trend'].iloc[i-1],
"start_index": trend_start,
"end_index": i-1,
"start_value": df['close'].iloc[trend_start],
"end_value": df['close'].iloc[i-1]
})
trend_start = i
# Add the last trend
if trend_start < len(df):
trends.append({
"type": df['trend'].iloc[-1],
"start_index": trend_start,
"end_index": len(df)-1,
"start_value": df['close'].iloc[trend_start],
"end_value": df['close'].iloc[-1]
})
self.logger.info(f"Detected {len(trends)} trend segments")
return trends
def get_strongest_trend(self):
self.logger.info("Finding strongest trend")
trends = self.detect_trends_MACD_signal()
if isinstance(trends, dict) and "error" in trends:
self.logger.warning(f"Error in trend detection: {trends['error']}")
return trends
if not trends:
self.logger.info("No significant trends detected")
return {"message": "No significant trends detected"}
strongest = max(trends, key=lambda x: abs(x["end_value"] - x["start_value"]))
self.logger.info(f"Strongest trend: {strongest['type']} from index {strongest['start_index']} to {strongest['end_index']}")
return strongest
def plot_trends(self, trends):
"""
Plot price data with identified trends highlighted using candlestick charts.
"""
self.logger.info("Plotting trends with candlesticks")
if isinstance(trends, dict) and "error" in trends:
self.logger.error(trends["error"])
print(trends["error"])
return
if not trends:
self.logger.warning("No significant trends detected for plotting")
print("No significant trends detected")
return
# Create a figure with 2 subplots that share the x-axis
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
self.logger.info("Creating plot figure with shared x-axis")
# Prepare data for candlestick chart
df = self.data.copy()
# Ensure required columns exist for candlestick
required_cols = ['open', 'high', 'low', 'close']
if not all(col in df.columns for col in required_cols):
self.logger.warning("Missing required columns for candlestick. Defaulting to line chart.")
if 'close' in df.columns:
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
df['close'], color='black', alpha=0.7, linewidth=1, label='Price')
else:
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
df[df.columns[0]], color='black', alpha=0.7, linewidth=1, label='Price')
else:
# Get x values (dates if available, otherwise indices)
if 'datetime' in df.columns:
x_label = 'Date'
# Format date axis
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
fig.autofmt_xdate()
self.logger.info("Using datetime for x-axis")
# For candlestick, ensure datetime is the index
if df.index.name != 'datetime':
df = df.set_index('datetime')
else:
x_label = 'Index'
self.logger.info("Using index for x-axis")
# Plot candlestick chart
up_color = 'green'
down_color = 'red'
# Draw candlesticks manually
width = 0.6
for i in range(len(df)):
# Get OHLC values for this candle
open_val = df['open'].iloc[i]
close_val = df['close'].iloc[i]
high_val = df['high'].iloc[i]
low_val = df['low'].iloc[i]
idx = df.index[i]
# Determine candle color
color = up_color if close_val >= open_val else down_color
# Plot candle body
body_height = abs(close_val - open_val)
bottom = min(open_val, close_val)
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
ax1.add_patch(rect)
# Plot candle wicks
ax1.plot([i, i], [low_val, high_val], color='black', linewidth=1)
# Set appropriate x-axis limits
ax1.set_xlim(-0.5, len(df) - 0.5)
# Highlight each trend with a different color
self.logger.info("Highlighting trends on plot")
for trend in trends:
start_idx = trend['start_index']
end_idx = trend['end_index']
trend_type = trend['type']
# Get x-coordinates for trend plotting
x_start = start_idx
x_end = end_idx
# Get y-coordinates for trend line
if 'close' in df.columns:
y_start = df['close'].iloc[start_idx]
y_end = df['close'].iloc[end_idx]
else:
y_start = df[df.columns[0]].iloc[start_idx]
y_end = df[df.columns[0]].iloc[end_idx]
# Choose color based on trend type
color = 'green' if trend_type == 'up' else 'red'
# Plot trend line
ax1.plot([x_start, x_end], [y_start, y_end], color=color, linewidth=2,
label=f"{trend_type.capitalize()} Trend" if f"{trend_type.capitalize()} Trend" not in ax1.get_legend_handles_labels()[1] else "")
# Add markers at start and end points
ax1.scatter(x_start, y_start, color=color, marker='o', s=50)
ax1.scatter(x_end, y_end, color=color, marker='s', s=50)
# Configure first subplot
ax1.set_title('Price with Trends (Candlestick)', fontsize=16)
ax1.set_ylabel('Price', fontsize=14)
ax1.grid(alpha=0.3)
ax1.legend()
# Create MACD in second subplot
self.logger.info("Creating MACD subplot")
# Calculate MACD indicators if not already present
if 'macd' not in df.columns:
if 'close' not in df.columns and len(df.columns) > 0:
df['close'] = df[df.columns[0]]
df['macd'] = ta.trend.macd(df['close'])
df['macd_signal'] = ta.trend.macd_signal(df['close'])
df['macd_diff'] = ta.trend.macd_diff(df['close'])
# Plot MACD components on second subplot
x_indices = np.arange(len(df))
ax2.plot(x_indices, df['macd'], label='MACD', color='blue')
ax2.plot(x_indices, df['macd_signal'], label='Signal', color='orange')
# Plot MACD histogram
for i in range(len(df)):
if df['macd_diff'].iloc[i] >= 0:
ax2.bar(i, df['macd_diff'].iloc[i], color='green', alpha=0.5, width=0.8)
else:
ax2.bar(i, df['macd_diff'].iloc[i], color='red', alpha=0.5, width=0.8)
ax2.set_title('MACD Indicator', fontsize=16)
ax2.set_xlabel(x_label, fontsize=14)
ax2.set_ylabel('MACD', fontsize=14)
ax2.grid(alpha=0.3)
ax2.legend()
# Enable synchronized zooming
plt.tight_layout()
plt.subplots_adjust(hspace=0.1)
plt.show()
return plt

144
trend_detector_simple.py Normal file
View File

@ -0,0 +1,144 @@
import pandas as pd
import numpy as np
import logging
from scipy.signal import find_peaks
import matplotlib.dates as mdates
class TrendDetectorSimple:
def __init__(self, data, verbose=False):
"""
Initialize the TrendDetectorSimple class.
Parameters:
- data: pandas DataFrame containing price data
- verbose: boolean, whether to display detailed logging information
"""
self.data = data
self.verbose = verbose
# Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger('TrendDetectorSimple')
# Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame):
if isinstance(self.data, list):
self.logger.info("Converting list to DataFrame")
self.data = pd.DataFrame({'close': self.data})
else:
self.logger.error("Invalid data format provided")
raise ValueError("Data must be a pandas DataFrame or a list")
self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points")
def detect_trends(self, width):
"""
Detect trends by identifying local minima and maxima in the price data
using scipy.signal.find_peaks.
Parameters:
- prominence: float, required prominence of peaks (relative to the price range)
- width: int, required width of peaks in data points
Returns:
- DataFrame with columns for timestamps, prices, and trend indicators
"""
self.logger.info(f"Detecting trends with width {width}")
df = self.data.copy()
close_prices = df['close'].values
max_peaks, _ = find_peaks(close_prices, width=width)
min_peaks, _ = find_peaks(-close_prices, width=width)
self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima")
df['is_min'] = False
df['is_max'] = False
for peak in max_peaks:
df.at[peak, 'is_max'] = True
for peak in min_peaks:
df.at[peak, 'is_min'] = True
result = df[['datetime', 'close', 'is_min', 'is_max']]
return result
def plot_trends(self, trend_data):
"""
Plot the price data with detected trends using a candlestick chart.
Parameters:
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
Returns:
- None (displays the plot)
"""
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# Create the figure and axis
fig, ax = plt.subplots(figsize=(12, 8))
# Create a copy of the data
df = self.data.copy()
# Plot candlestick chart
up_color = 'green'
down_color = 'red'
# Draw candlesticks manually
width = 0.6
x_values = range(len(df))
for i in range(len(df)):
# Get OHLC values for this candle
open_val = df['open'].iloc[i]
close_val = df['close'].iloc[i]
high_val = df['high'].iloc[i]
low_val = df['low'].iloc[i]
# Determine candle color
color = up_color if close_val >= open_val else down_color
# Plot candle body
body_height = abs(close_val - open_val)
bottom = min(open_val, close_val)
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
ax.add_patch(rect)
# Plot candle wicks
ax.plot([i, i], [low_val, high_val], color='black', linewidth=1)
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
if min_indices:
min_y = [df['close'].iloc[i] for i in min_indices]
ax.scatter(min_indices, min_y, color='black', s=200, marker='^', label='Local Minima', zorder=100)
max_indices = trend_data.index[trend_data['is_max'] == True].tolist()
if max_indices:
max_y = [df['close'].iloc[i] for i in max_indices]
ax.scatter(max_indices, max_y, color='black', s=200, marker='v', label='Local Maxima', zorder=100)
# Set title and labels
ax.set_title('Price Candlestick Chart with Local Minima and Maxima', fontsize=14)
ax.set_xlabel('Date', fontsize=12)
ax.set_ylabel('Price', fontsize=12)
# Set appropriate x-axis limits
ax.set_xlim(-0.5, len(df) - 0.5)
# Add a legend
ax.legend(loc='best')
# Adjust layout
plt.tight_layout()
# Show the plot
plt.show()