Compare commits

..

3 Commits

6 changed files with 767 additions and 712 deletions

6
.gitignore vendored
View File

@ -168,3 +168,9 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
An introduction to trading cycles.pdf
An introduction to trading cycles.txt
README.md
.vscode/launch.json
data/btcusd_1-day_data.csv
data/btcusd_1-min_data.csv

View File

@ -1 +1 @@
# Cycles # Cycles

View File

@ -1,248 +1,248 @@
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.signal import argrelextrema from scipy.signal import argrelextrema
class CycleDetector: class CycleDetector:
def __init__(self, data, timeframe='daily'): def __init__(self, data, timeframe='daily'):
""" """
Initialize the CycleDetector with price data. Initialize the CycleDetector with price data.
Parameters: Parameters:
- data: DataFrame with at least 'date' or 'datetime' and 'close' columns - data: DataFrame with at least 'date' or 'datetime' and 'close' columns
- timeframe: 'daily', 'weekly', or 'monthly' - timeframe: 'daily', 'weekly', or 'monthly'
""" """
self.data = data.copy() self.data = data.copy()
self.timeframe = timeframe self.timeframe = timeframe
# Ensure we have a consistent date column name # Ensure we have a consistent date column name
if 'datetime' in self.data.columns and 'date' not in self.data.columns: if 'datetime' in self.data.columns and 'date' not in self.data.columns:
self.data.rename(columns={'datetime': 'date'}, inplace=True) self.data.rename(columns={'datetime': 'date'}, inplace=True)
# Convert data to specified timeframe if needed # Convert data to specified timeframe if needed
if timeframe == 'weekly' and 'date' in self.data.columns: if timeframe == 'weekly' and 'date' in self.data.columns:
self.data = self._convert_data(self.data, 'W') self.data = self._convert_data(self.data, 'W')
elif timeframe == 'monthly' and 'date' in self.data.columns: elif timeframe == 'monthly' and 'date' in self.data.columns:
self.data = self._convert_data(self.data, 'M') self.data = self._convert_data(self.data, 'M')
# Add columns for local minima and maxima detection # Add columns for local minima and maxima detection
self._add_swing_points() self._add_swing_points()
def _convert_data(self, data, timeframe): def _convert_data(self, data, timeframe):
"""Convert daily data to 'timeframe' timeframe.""" """Convert daily data to 'timeframe' timeframe."""
data['date'] = pd.to_datetime(data['date']) data['date'] = pd.to_datetime(data['date'])
data.set_index('date', inplace=True) data.set_index('date', inplace=True)
weekly = data.resample(timeframe).agg({ weekly = data.resample(timeframe).agg({
'open': 'first', 'open': 'first',
'high': 'max', 'high': 'max',
'low': 'min', 'low': 'min',
'close': 'last', 'close': 'last',
'volume': 'sum' 'volume': 'sum'
}) })
return weekly.reset_index() return weekly.reset_index()
def _add_swing_points(self, window=5): def _add_swing_points(self, window=5):
""" """
Identify swing points (local minima and maxima). Identify swing points (local minima and maxima).
Parameters: Parameters:
- window: The window size for local minima/maxima detection - window: The window size for local minima/maxima detection
""" """
# Set the index to make calculations easier # Set the index to make calculations easier
if 'date' in self.data.columns: if 'date' in self.data.columns:
self.data.set_index('date', inplace=True) self.data.set_index('date', inplace=True)
# Detect local minima (swing lows) # Detect local minima (swing lows)
min_idx = argrelextrema(self.data['low'].values, np.less, order=window)[0] min_idx = argrelextrema(self.data['low'].values, np.less, order=window)[0]
self.data['swing_low'] = False self.data['swing_low'] = False
self.data.iloc[min_idx, self.data.columns.get_loc('swing_low')] = True self.data.iloc[min_idx, self.data.columns.get_loc('swing_low')] = True
# Detect local maxima (swing highs) # Detect local maxima (swing highs)
max_idx = argrelextrema(self.data['high'].values, np.greater, order=window)[0] max_idx = argrelextrema(self.data['high'].values, np.greater, order=window)[0]
self.data['swing_high'] = False self.data['swing_high'] = False
self.data.iloc[max_idx, self.data.columns.get_loc('swing_high')] = True self.data.iloc[max_idx, self.data.columns.get_loc('swing_high')] = True
# Reset index # Reset index
self.data.reset_index(inplace=True) self.data.reset_index(inplace=True)
def find_cycle_lows(self): def find_cycle_lows(self):
"""Find all swing lows which represent cycle lows.""" """Find all swing lows which represent cycle lows."""
swing_low_dates = self.data[self.data['swing_low']]['date'].values swing_low_dates = self.data[self.data['swing_low']]['date'].values
return swing_low_dates return swing_low_dates
def calculate_cycle_lengths(self): def calculate_cycle_lengths(self):
"""Calculate the lengths of each cycle between consecutive lows.""" """Calculate the lengths of each cycle between consecutive lows."""
swing_low_indices = np.where(self.data['swing_low'])[0] swing_low_indices = np.where(self.data['swing_low'])[0]
cycle_lengths = np.diff(swing_low_indices) cycle_lengths = np.diff(swing_low_indices)
return cycle_lengths return cycle_lengths
def get_average_cycle_length(self): def get_average_cycle_length(self):
"""Calculate the average cycle length.""" """Calculate the average cycle length."""
cycle_lengths = self.calculate_cycle_lengths() cycle_lengths = self.calculate_cycle_lengths()
if len(cycle_lengths) > 0: if len(cycle_lengths) > 0:
return np.mean(cycle_lengths) return np.mean(cycle_lengths)
return None return None
def get_cycle_window(self, tolerance=0.10): def get_cycle_window(self, tolerance=0.10):
""" """
Get the cycle window with the specified tolerance. Get the cycle window with the specified tolerance.
Parameters: Parameters:
- tolerance: The tolerance as a percentage (default: 10%) - tolerance: The tolerance as a percentage (default: 10%)
Returns: Returns:
- tuple: (min_cycle_length, avg_cycle_length, max_cycle_length) - tuple: (min_cycle_length, avg_cycle_length, max_cycle_length)
""" """
avg_length = self.get_average_cycle_length() avg_length = self.get_average_cycle_length()
if avg_length is not None: if avg_length is not None:
min_length = avg_length * (1 - tolerance) min_length = avg_length * (1 - tolerance)
max_length = avg_length * (1 + tolerance) max_length = avg_length * (1 + tolerance)
return (min_length, avg_length, max_length) return (min_length, avg_length, max_length)
return None return None
def detect_two_drives_pattern(self, lookback=10): def detect_two_drives_pattern(self, lookback=10):
""" """
Detect 2-drives pattern: a swing low, counter trend bounce, and a lower low. Detect 2-drives pattern: a swing low, counter trend bounce, and a lower low.
Parameters: Parameters:
- lookback: Number of periods to look back - lookback: Number of periods to look back
Returns: Returns:
- list: Indices where 2-drives patterns are detected - list: Indices where 2-drives patterns are detected
""" """
patterns = [] patterns = []
for i in range(lookback, len(self.data) - 1): for i in range(lookback, len(self.data) - 1):
if not self.data.iloc[i]['swing_low']: if not self.data.iloc[i]['swing_low']:
continue continue
# Get the segment of data to check for pattern # Get the segment of data to check for pattern
segment = self.data.iloc[i-lookback:i+1] segment = self.data.iloc[i-lookback:i+1]
swing_lows = segment[segment['swing_low']]['low'].values swing_lows = segment[segment['swing_low']]['low'].values
if len(swing_lows) >= 2 and swing_lows[-1] < swing_lows[-2]: if len(swing_lows) >= 2 and swing_lows[-1] < swing_lows[-2]:
# Check if there was a bounce between the two lows # Check if there was a bounce between the two lows
between_lows = segment.iloc[-len(swing_lows):-1] between_lows = segment.iloc[-len(swing_lows):-1]
if len(between_lows) > 0 and max(between_lows['high']) > swing_lows[-2]: if len(between_lows) > 0 and max(between_lows['high']) > swing_lows[-2]:
patterns.append(i) patterns.append(i)
return patterns return patterns
def detect_v_shaped_lows(self, window=5, threshold=0.02): def detect_v_shaped_lows(self, window=5, threshold=0.02):
""" """
Detect V-shaped cycle lows (sharp decline followed by sharp rise). Detect V-shaped cycle lows (sharp decline followed by sharp rise).
Parameters: Parameters:
- window: Window to look for sharp price changes - window: Window to look for sharp price changes
- threshold: Percentage change threshold to consider 'sharp' - threshold: Percentage change threshold to consider 'sharp'
Returns: Returns:
- list: Indices where V-shaped patterns are detected - list: Indices where V-shaped patterns are detected
""" """
patterns = [] patterns = []
# Find all swing lows # Find all swing lows
swing_low_indices = np.where(self.data['swing_low'])[0] swing_low_indices = np.where(self.data['swing_low'])[0]
for idx in swing_low_indices: for idx in swing_low_indices:
# Need enough data points before and after # Need enough data points before and after
if idx < window or idx + window >= len(self.data): if idx < window or idx + window >= len(self.data):
continue continue
# Get the low price at this swing low # Get the low price at this swing low
low_price = self.data.iloc[idx]['low'] low_price = self.data.iloc[idx]['low']
# Check for sharp decline before low (at least window bars before) # Check for sharp decline before low (at least window bars before)
before_segment = self.data.iloc[max(0, idx-window):idx] before_segment = self.data.iloc[max(0, idx-window):idx]
if len(before_segment) > 0: if len(before_segment) > 0:
max_before = before_segment['high'].max() max_before = before_segment['high'].max()
decline = (max_before - low_price) / max_before decline = (max_before - low_price) / max_before
# Check for sharp rise after low (at least window bars after) # Check for sharp rise after low (at least window bars after)
after_segment = self.data.iloc[idx+1:min(len(self.data), idx+window+1)] after_segment = self.data.iloc[idx+1:min(len(self.data), idx+window+1)]
if len(after_segment) > 0: if len(after_segment) > 0:
max_after = after_segment['high'].max() max_after = after_segment['high'].max()
rise = (max_after - low_price) / low_price rise = (max_after - low_price) / low_price
# Both decline and rise must exceed threshold to be considered V-shaped # Both decline and rise must exceed threshold to be considered V-shaped
if decline > threshold and rise > threshold: if decline > threshold and rise > threshold:
patterns.append(idx) patterns.append(idx)
return patterns return patterns
def plot_cycles(self, pattern_detection=None, title_suffix=''): def plot_cycles(self, pattern_detection=None, title_suffix=''):
""" """
Plot the price data with cycle lows and detected patterns. Plot the price data with cycle lows and detected patterns.
Parameters: Parameters:
- pattern_detection: 'two_drives', 'v_shape', or None - pattern_detection: 'two_drives', 'v_shape', or None
- title_suffix: Optional suffix for the plot title - title_suffix: Optional suffix for the plot title
""" """
plt.figure(figsize=(14, 7)) plt.figure(figsize=(14, 7))
# Determine the date column name (could be 'date' or 'datetime') # Determine the date column name (could be 'date' or 'datetime')
date_col = 'date' if 'date' in self.data.columns else 'datetime' date_col = 'date' if 'date' in self.data.columns else 'datetime'
# Plot price data # Plot price data
plt.plot(self.data[date_col], self.data['close'], label='Close Price') plt.plot(self.data[date_col], self.data['close'], label='Close Price')
# Calculate a consistent vertical position for indicators based on price range # Calculate a consistent vertical position for indicators based on price range
price_range = self.data['close'].max() - self.data['close'].min() price_range = self.data['close'].max() - self.data['close'].min()
indicator_offset = price_range * 0.01 # 1% of price range indicator_offset = price_range * 0.01 # 1% of price range
# Plot cycle lows (now at a fixed offset below the low price) # Plot cycle lows (now at a fixed offset below the low price)
swing_lows = self.data[self.data['swing_low']] swing_lows = self.data[self.data['swing_low']]
plt.scatter(swing_lows[date_col], swing_lows['low'] - indicator_offset, plt.scatter(swing_lows[date_col], swing_lows['low'] - indicator_offset,
color='green', marker='^', s=100, label='Cycle Lows') color='green', marker='^', s=100, label='Cycle Lows')
# Plot specific patterns if requested # Plot specific patterns if requested
if 'two_drives' in pattern_detection: if 'two_drives' in pattern_detection:
pattern_indices = self.detect_two_drives_pattern() pattern_indices = self.detect_two_drives_pattern()
if pattern_indices: if pattern_indices:
patterns = self.data.iloc[pattern_indices] patterns = self.data.iloc[pattern_indices]
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2, plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
color='red', marker='o', s=150, label='Two Drives Pattern') color='red', marker='o', s=150, label='Two Drives Pattern')
elif 'v_shape' in pattern_detection: elif 'v_shape' in pattern_detection:
pattern_indices = self.detect_v_shaped_lows() pattern_indices = self.detect_v_shaped_lows()
if pattern_indices: if pattern_indices:
patterns = self.data.iloc[pattern_indices] patterns = self.data.iloc[pattern_indices]
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2, plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
color='purple', marker='o', s=150, label='V-Shape Pattern') color='purple', marker='o', s=150, label='V-Shape Pattern')
# Add cycle lengths and averages # Add cycle lengths and averages
cycle_lengths = self.calculate_cycle_lengths() cycle_lengths = self.calculate_cycle_lengths()
avg_cycle = self.get_average_cycle_length() avg_cycle = self.get_average_cycle_length()
cycle_window = self.get_cycle_window() cycle_window = self.get_cycle_window()
window_text = "" window_text = ""
if cycle_window: if cycle_window:
window_text = f"Tolerance Window: [{cycle_window[0]:.2f} - {cycle_window[2]:.2f}]" window_text = f"Tolerance Window: [{cycle_window[0]:.2f} - {cycle_window[2]:.2f}]"
plt.title(f"Detected Cycles - {self.timeframe.capitalize()} Timeframe {title_suffix}\n" plt.title(f"Detected Cycles - {self.timeframe.capitalize()} Timeframe {title_suffix}\n"
f"Average Cycle Length: {avg_cycle:.2f} periods, {window_text}") f"Average Cycle Length: {avg_cycle:.2f} periods, {window_text}")
plt.legend() plt.legend()
plt.grid(True) plt.grid(True)
plt.show() plt.show()
# Usage example: # Usage example:
# 1. Load your data # 1. Load your data
# data = pd.read_csv('your_price_data.csv') # data = pd.read_csv('your_price_data.csv')
# 2. Create cycle detector instances for different timeframes # 2. Create cycle detector instances for different timeframes
# weekly_detector = CycleDetector(data, timeframe='weekly') # weekly_detector = CycleDetector(data, timeframe='weekly')
# daily_detector = CycleDetector(data, timeframe='daily') # daily_detector = CycleDetector(data, timeframe='daily')
# 3. Analyze cycles # 3. Analyze cycles
# weekly_cycle_length = weekly_detector.get_average_cycle_length() # weekly_cycle_length = weekly_detector.get_average_cycle_length()
# daily_cycle_length = daily_detector.get_average_cycle_length() # daily_cycle_length = daily_detector.get_average_cycle_length()
# 4. Detect patterns # 4. Detect patterns
# two_drives = weekly_detector.detect_two_drives_pattern() # two_drives = weekly_detector.detect_two_drives_pattern()
# v_shapes = daily_detector.detect_v_shaped_lows() # v_shapes = daily_detector.detect_v_shaped_lows()
# 5. Visualize # 5. Visualize
# weekly_detector.plot_cycles(pattern_detection='two_drives') # weekly_detector.plot_cycles(pattern_detection='two_drives')
# daily_detector.plot_cycles(pattern_detection='v_shape') # daily_detector.plot_cycles(pattern_detection='v_shape')

View File

@ -6,6 +6,7 @@ from cycle_detector import CycleDetector
# Load data from CSV file instead of database # Load data from CSV file instead of database
data = pd.read_csv('data/btcusd_1-day_data.csv') data = pd.read_csv('data/btcusd_1-day_data.csv')
# Convert datetime column to datetime type # Convert datetime column to datetime type
start_date = pd.to_datetime('2025-04-01') start_date = pd.to_datetime('2025-04-01')
stop_date = pd.to_datetime('2025-05-06') stop_date = pd.to_datetime('2025-05-06')

View File

@ -1,259 +1,259 @@
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import ta import ta
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.dates as mdates import matplotlib.dates as mdates
import logging import logging
import mplfinance as mpf import mplfinance as mpf
from matplotlib.patches import Rectangle from matplotlib.patches import Rectangle
class TrendDetectorMACD: class TrendDetectorMACD:
def __init__(self, data, verbose=False): def __init__(self, data, verbose=False):
self.data = data self.data = data
self.verbose = verbose self.verbose = verbose
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s') format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger('TrendDetector') self.logger = logging.getLogger('TrendDetector')
# Convert data to pandas DataFrame if it's not already # Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame): if not isinstance(self.data, pd.DataFrame):
if isinstance(self.data, list): if isinstance(self.data, list):
self.logger.info("Converting list to DataFrame") self.logger.info("Converting list to DataFrame")
self.data = pd.DataFrame({'close': self.data}) self.data = pd.DataFrame({'close': self.data})
else: else:
self.logger.error("Invalid data format provided") self.logger.error("Invalid data format provided")
raise ValueError("Data must be a pandas DataFrame or a list") raise ValueError("Data must be a pandas DataFrame or a list")
self.logger.info(f"Initialized TrendDetector with {len(self.data)} data points") self.logger.info(f"Initialized TrendDetector with {len(self.data)} data points")
def detect_trends_MACD_signal(self): def detect_trends_MACD_signal(self):
self.logger.info("Starting trend detection") self.logger.info("Starting trend detection")
if len(self.data) < 3: if len(self.data) < 3:
self.logger.warning("Not enough data points for trend detection") self.logger.warning("Not enough data points for trend detection")
return {"error": "Not enough data points for trend detection"} return {"error": "Not enough data points for trend detection"}
# Create a copy of the DataFrame to avoid modifying the original # Create a copy of the DataFrame to avoid modifying the original
df = self.data.copy() df = self.data.copy()
self.logger.info("Created copy of input data") self.logger.info("Created copy of input data")
# If 'close' column doesn't exist, try to use a relevant column # If 'close' column doesn't exist, try to use a relevant column
if 'close' not in df.columns and len(df.columns) > 0: if 'close' not in df.columns and len(df.columns) > 0:
self.logger.info(f"'close' column not found, using {df.columns[0]} instead") self.logger.info(f"'close' column not found, using {df.columns[0]} instead")
df['close'] = df[df.columns[0]] # Use the first column as 'close' df['close'] = df[df.columns[0]] # Use the first column as 'close'
# Add trend indicators # Add trend indicators
self.logger.info("Calculating MACD indicators") self.logger.info("Calculating MACD indicators")
# Moving Average Convergence Divergence (MACD) # Moving Average Convergence Divergence (MACD)
df['macd'] = ta.trend.macd(df['close']) df['macd'] = ta.trend.macd(df['close'])
df['macd_signal'] = ta.trend.macd_signal(df['close']) df['macd_signal'] = ta.trend.macd_signal(df['close'])
df['macd_diff'] = ta.trend.macd_diff(df['close']) df['macd_diff'] = ta.trend.macd_diff(df['close'])
# Directional Movement Index (DMI) # Directional Movement Index (DMI)
if all(col in df.columns for col in ['high', 'low', 'close']): if all(col in df.columns for col in ['high', 'low', 'close']):
self.logger.info("Calculating ADX indicators") self.logger.info("Calculating ADX indicators")
df['adx'] = ta.trend.adx(df['high'], df['low'], df['close']) df['adx'] = ta.trend.adx(df['high'], df['low'], df['close'])
df['adx_pos'] = ta.trend.adx_pos(df['high'], df['low'], df['close']) df['adx_pos'] = ta.trend.adx_pos(df['high'], df['low'], df['close'])
df['adx_neg'] = ta.trend.adx_neg(df['high'], df['low'], df['close']) df['adx_neg'] = ta.trend.adx_neg(df['high'], df['low'], df['close'])
# Identify trend changes # Identify trend changes
self.logger.info("Identifying trend changes") self.logger.info("Identifying trend changes")
df['trend'] = np.where(df['macd'] > df['macd_signal'], 'up', 'down') df['trend'] = np.where(df['macd'] > df['macd_signal'], 'up', 'down')
df['trend_change'] = df['trend'] != df['trend'].shift(1) df['trend_change'] = df['trend'] != df['trend'].shift(1)
# Generate trend segments # Generate trend segments
self.logger.info("Generating trend segments") self.logger.info("Generating trend segments")
trends = [] trends = []
trend_start = 0 trend_start = 0
for i in range(1, len(df)): for i in range(1, len(df)):
if df['trend_change'].iloc[i]: if df['trend_change'].iloc[i]:
if i > trend_start: if i > trend_start:
trends.append({ trends.append({
"type": df['trend'].iloc[i-1], "type": df['trend'].iloc[i-1],
"start_index": trend_start, "start_index": trend_start,
"end_index": i-1, "end_index": i-1,
"start_value": df['close'].iloc[trend_start], "start_value": df['close'].iloc[trend_start],
"end_value": df['close'].iloc[i-1] "end_value": df['close'].iloc[i-1]
}) })
trend_start = i trend_start = i
# Add the last trend # Add the last trend
if trend_start < len(df): if trend_start < len(df):
trends.append({ trends.append({
"type": df['trend'].iloc[-1], "type": df['trend'].iloc[-1],
"start_index": trend_start, "start_index": trend_start,
"end_index": len(df)-1, "end_index": len(df)-1,
"start_value": df['close'].iloc[trend_start], "start_value": df['close'].iloc[trend_start],
"end_value": df['close'].iloc[-1] "end_value": df['close'].iloc[-1]
}) })
self.logger.info(f"Detected {len(trends)} trend segments") self.logger.info(f"Detected {len(trends)} trend segments")
return trends return trends
def get_strongest_trend(self): def get_strongest_trend(self):
self.logger.info("Finding strongest trend") self.logger.info("Finding strongest trend")
trends = self.detect_trends_MACD_signal() trends = self.detect_trends_MACD_signal()
if isinstance(trends, dict) and "error" in trends: if isinstance(trends, dict) and "error" in trends:
self.logger.warning(f"Error in trend detection: {trends['error']}") self.logger.warning(f"Error in trend detection: {trends['error']}")
return trends return trends
if not trends: if not trends:
self.logger.info("No significant trends detected") self.logger.info("No significant trends detected")
return {"message": "No significant trends detected"} return {"message": "No significant trends detected"}
strongest = max(trends, key=lambda x: abs(x["end_value"] - x["start_value"])) strongest = max(trends, key=lambda x: abs(x["end_value"] - x["start_value"]))
self.logger.info(f"Strongest trend: {strongest['type']} from index {strongest['start_index']} to {strongest['end_index']}") self.logger.info(f"Strongest trend: {strongest['type']} from index {strongest['start_index']} to {strongest['end_index']}")
return strongest return strongest
def plot_trends(self, trends): def plot_trends(self, trends):
""" """
Plot price data with identified trends highlighted using candlestick charts. Plot price data with identified trends highlighted using candlestick charts.
""" """
self.logger.info("Plotting trends with candlesticks") self.logger.info("Plotting trends with candlesticks")
if isinstance(trends, dict) and "error" in trends: if isinstance(trends, dict) and "error" in trends:
self.logger.error(trends["error"]) self.logger.error(trends["error"])
print(trends["error"]) print(trends["error"])
return return
if not trends: if not trends:
self.logger.warning("No significant trends detected for plotting") self.logger.warning("No significant trends detected for plotting")
print("No significant trends detected") print("No significant trends detected")
return return
# Create a figure with 2 subplots that share the x-axis # Create a figure with 2 subplots that share the x-axis
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [2, 1]}, sharex=True) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
self.logger.info("Creating plot figure with shared x-axis") self.logger.info("Creating plot figure with shared x-axis")
# Prepare data for candlestick chart # Prepare data for candlestick chart
df = self.data.copy() df = self.data.copy()
# Ensure required columns exist for candlestick # Ensure required columns exist for candlestick
required_cols = ['open', 'high', 'low', 'close'] required_cols = ['open', 'high', 'low', 'close']
if not all(col in df.columns for col in required_cols): if not all(col in df.columns for col in required_cols):
self.logger.warning("Missing required columns for candlestick. Defaulting to line chart.") self.logger.warning("Missing required columns for candlestick. Defaulting to line chart.")
if 'close' in df.columns: if 'close' in df.columns:
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'], ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
df['close'], color='black', alpha=0.7, linewidth=1, label='Price') df['close'], color='black', alpha=0.7, linewidth=1, label='Price')
else: else:
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'], ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
df[df.columns[0]], color='black', alpha=0.7, linewidth=1, label='Price') df[df.columns[0]], color='black', alpha=0.7, linewidth=1, label='Price')
else: else:
# Get x values (dates if available, otherwise indices) # Get x values (dates if available, otherwise indices)
if 'datetime' in df.columns: if 'datetime' in df.columns:
x_label = 'Date' x_label = 'Date'
# Format date axis # Format date axis
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
fig.autofmt_xdate() fig.autofmt_xdate()
self.logger.info("Using datetime for x-axis") self.logger.info("Using datetime for x-axis")
# For candlestick, ensure datetime is the index # For candlestick, ensure datetime is the index
if df.index.name != 'datetime': if df.index.name != 'datetime':
df = df.set_index('datetime') df = df.set_index('datetime')
else: else:
x_label = 'Index' x_label = 'Index'
self.logger.info("Using index for x-axis") self.logger.info("Using index for x-axis")
# Plot candlestick chart # Plot candlestick chart
up_color = 'green' up_color = 'green'
down_color = 'red' down_color = 'red'
# Draw candlesticks manually # Draw candlesticks manually
width = 0.6 width = 0.6
for i in range(len(df)): for i in range(len(df)):
# Get OHLC values for this candle # Get OHLC values for this candle
open_val = df['open'].iloc[i] open_val = df['open'].iloc[i]
close_val = df['close'].iloc[i] close_val = df['close'].iloc[i]
high_val = df['high'].iloc[i] high_val = df['high'].iloc[i]
low_val = df['low'].iloc[i] low_val = df['low'].iloc[i]
idx = df.index[i] idx = df.index[i]
# Determine candle color # Determine candle color
color = up_color if close_val >= open_val else down_color color = up_color if close_val >= open_val else down_color
# Plot candle body # Plot candle body
body_height = abs(close_val - open_val) body_height = abs(close_val - open_val)
bottom = min(open_val, close_val) bottom = min(open_val, close_val)
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8) rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
ax1.add_patch(rect) ax1.add_patch(rect)
# Plot candle wicks # Plot candle wicks
ax1.plot([i, i], [low_val, high_val], color='black', linewidth=1) ax1.plot([i, i], [low_val, high_val], color='black', linewidth=1)
# Set appropriate x-axis limits # Set appropriate x-axis limits
ax1.set_xlim(-0.5, len(df) - 0.5) ax1.set_xlim(-0.5, len(df) - 0.5)
# Highlight each trend with a different color # Highlight each trend with a different color
self.logger.info("Highlighting trends on plot") self.logger.info("Highlighting trends on plot")
for trend in trends: for trend in trends:
start_idx = trend['start_index'] start_idx = trend['start_index']
end_idx = trend['end_index'] end_idx = trend['end_index']
trend_type = trend['type'] trend_type = trend['type']
# Get x-coordinates for trend plotting # Get x-coordinates for trend plotting
x_start = start_idx x_start = start_idx
x_end = end_idx x_end = end_idx
# Get y-coordinates for trend line # Get y-coordinates for trend line
if 'close' in df.columns: if 'close' in df.columns:
y_start = df['close'].iloc[start_idx] y_start = df['close'].iloc[start_idx]
y_end = df['close'].iloc[end_idx] y_end = df['close'].iloc[end_idx]
else: else:
y_start = df[df.columns[0]].iloc[start_idx] y_start = df[df.columns[0]].iloc[start_idx]
y_end = df[df.columns[0]].iloc[end_idx] y_end = df[df.columns[0]].iloc[end_idx]
# Choose color based on trend type # Choose color based on trend type
color = 'green' if trend_type == 'up' else 'red' color = 'green' if trend_type == 'up' else 'red'
# Plot trend line # Plot trend line
ax1.plot([x_start, x_end], [y_start, y_end], color=color, linewidth=2, ax1.plot([x_start, x_end], [y_start, y_end], color=color, linewidth=2,
label=f"{trend_type.capitalize()} Trend" if f"{trend_type.capitalize()} Trend" not in ax1.get_legend_handles_labels()[1] else "") label=f"{trend_type.capitalize()} Trend" if f"{trend_type.capitalize()} Trend" not in ax1.get_legend_handles_labels()[1] else "")
# Add markers at start and end points # Add markers at start and end points
ax1.scatter(x_start, y_start, color=color, marker='o', s=50) ax1.scatter(x_start, y_start, color=color, marker='o', s=50)
ax1.scatter(x_end, y_end, color=color, marker='s', s=50) ax1.scatter(x_end, y_end, color=color, marker='s', s=50)
# Configure first subplot # Configure first subplot
ax1.set_title('Price with Trends (Candlestick)', fontsize=16) ax1.set_title('Price with Trends (Candlestick)', fontsize=16)
ax1.set_ylabel('Price', fontsize=14) ax1.set_ylabel('Price', fontsize=14)
ax1.grid(alpha=0.3) ax1.grid(alpha=0.3)
ax1.legend() ax1.legend()
# Create MACD in second subplot # Create MACD in second subplot
self.logger.info("Creating MACD subplot") self.logger.info("Creating MACD subplot")
# Calculate MACD indicators if not already present # Calculate MACD indicators if not already present
if 'macd' not in df.columns: if 'macd' not in df.columns:
if 'close' not in df.columns and len(df.columns) > 0: if 'close' not in df.columns and len(df.columns) > 0:
df['close'] = df[df.columns[0]] df['close'] = df[df.columns[0]]
df['macd'] = ta.trend.macd(df['close']) df['macd'] = ta.trend.macd(df['close'])
df['macd_signal'] = ta.trend.macd_signal(df['close']) df['macd_signal'] = ta.trend.macd_signal(df['close'])
df['macd_diff'] = ta.trend.macd_diff(df['close']) df['macd_diff'] = ta.trend.macd_diff(df['close'])
# Plot MACD components on second subplot # Plot MACD components on second subplot
x_indices = np.arange(len(df)) x_indices = np.arange(len(df))
ax2.plot(x_indices, df['macd'], label='MACD', color='blue') ax2.plot(x_indices, df['macd'], label='MACD', color='blue')
ax2.plot(x_indices, df['macd_signal'], label='Signal', color='orange') ax2.plot(x_indices, df['macd_signal'], label='Signal', color='orange')
# Plot MACD histogram # Plot MACD histogram
for i in range(len(df)): for i in range(len(df)):
if df['macd_diff'].iloc[i] >= 0: if df['macd_diff'].iloc[i] >= 0:
ax2.bar(i, df['macd_diff'].iloc[i], color='green', alpha=0.5, width=0.8) ax2.bar(i, df['macd_diff'].iloc[i], color='green', alpha=0.5, width=0.8)
else: else:
ax2.bar(i, df['macd_diff'].iloc[i], color='red', alpha=0.5, width=0.8) ax2.bar(i, df['macd_diff'].iloc[i], color='red', alpha=0.5, width=0.8)
ax2.set_title('MACD Indicator', fontsize=16) ax2.set_title('MACD Indicator', fontsize=16)
ax2.set_xlabel(x_label, fontsize=14) ax2.set_xlabel(x_label, fontsize=14)
ax2.set_ylabel('MACD', fontsize=14) ax2.set_ylabel('MACD', fontsize=14)
ax2.grid(alpha=0.3) ax2.grid(alpha=0.3)
ax2.legend() ax2.legend()
# Enable synchronized zooming # Enable synchronized zooming
plt.tight_layout() plt.tight_layout()
plt.subplots_adjust(hspace=0.1) plt.subplots_adjust(hspace=0.1)
plt.show() plt.show()
return plt return plt

View File

@ -1,205 +1,253 @@
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import logging import logging
from scipy.signal import find_peaks from scipy.signal import find_peaks
import matplotlib.dates as mdates import matplotlib.dates as mdates
from scipy import stats from scipy import stats
from scipy import stats
class TrendDetectorSimple:
def __init__(self, data, verbose=False): class TrendDetectorSimple:
""" def __init__(self, data, verbose=False):
Initialize the TrendDetectorSimple class. """
Initialize the TrendDetectorSimple class.
Parameters:
- data: pandas DataFrame containing price data Parameters:
- verbose: boolean, whether to display detailed logging information - data: pandas DataFrame containing price data
""" - verbose: boolean, whether to display detailed logging information
"""
self.data = data
self.verbose = verbose self.data = data
self.verbose = verbose
# Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, # Plot style configuration
format='%(asctime)s - %(levelname)s - %(message)s') self.plot_style = 'dark_background'
self.logger = logging.getLogger('TrendDetectorSimple') self.bg_color = '#181C27'
self.plot_size = (12, 8)
# Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame): # Candlestick configuration
if isinstance(self.data, list): self.candle_width = 0.6
self.logger.info("Converting list to DataFrame") self.candle_up_color = '#089981'
self.data = pd.DataFrame({'close': self.data}) self.candle_down_color = '#F23645'
else: self.candle_alpha = 0.8
self.logger.error("Invalid data format provided") self.wick_width = 1
raise ValueError("Data must be a pandas DataFrame or a list")
# Marker configuration
self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points") self.min_marker = '^'
self.min_color = 'red'
def detect_trends(self): self.min_size = 100
""" self.max_marker = 'v'
Detect trends by identifying local minima and maxima in the price data self.max_color = 'green'
using scipy.signal.find_peaks. self.max_size = 100
self.marker_zorder = 100
Parameters:
- prominence: float, required prominence of peaks (relative to the price range) # Line configuration
- width: int, required width of peaks in data points self.line_width = 2
self.min_line_style = 'g--' # green dashed
Returns: self.max_line_style = 'r--' # red dashed
- DataFrame with columns for timestamps, prices, and trend indicators self.sma7_line_style = 'y-' # yellow solid
""" self.sma15_line_style = 'm-' # magenta solid
self.logger.info(f"Detecting trends")
# Text configuration
df = self.data.copy() self.title_size = 14
close_prices = df['close'].values self.title_color = 'white'
self.axis_label_size = 12
max_peaks, _ = find_peaks(close_prices) self.axis_label_color = 'white'
min_peaks, _ = find_peaks(-close_prices)
# Legend configuration
self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima") self.legend_loc = 'best'
self.legend_bg_color = '#333333'
df['is_min'] = False
df['is_max'] = False # Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
for peak in max_peaks: format='%(asctime)s - %(levelname)s - %(message)s')
df.at[peak, 'is_max'] = True self.logger = logging.getLogger('TrendDetectorSimple')
for peak in min_peaks:
df.at[peak, 'is_min'] = True # Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame):
result = df[['datetime', 'close', 'is_min', 'is_max']].copy() if isinstance(self.data, list):
self.logger.info("Converting list to DataFrame")
# Perform linear regression on min_peaks and max_peaks self.data = pd.DataFrame({'close': self.data})
self.logger.info("Performing linear regression on min and max peaks") else:
min_prices = df['close'].iloc[min_peaks].values self.logger.error("Invalid data format provided")
max_prices = df['close'].iloc[max_peaks].values raise ValueError("Data must be a pandas DataFrame or a list")
# Linear regression for min peaks if we have at least 2 points self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points")
min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
# Linear regression for max peaks if we have at least 2 points def detect_trends(self):
max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices) """
Detect trends by identifying local minima and maxima in the price data
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods using scipy.signal.find_peaks.
self.logger.info("Calculating SMA-7 and SMA-15")
Parameters:
# Calculate SMA values and exclude NaN values - prominence: float, required prominence of peaks (relative to the price range)
sma_7 = df['close'].rolling(window=7).mean().dropna().values - width: int, required width of peaks in data points
sma_15 = df['close'].rolling(window=15).mean().dropna().values
Returns:
# Add SMA values to regression_results - DataFrame with columns for timestamps, prices, and trend indicators
analysis_results = {} """
analysis_results['linear_regression'] = { self.logger.info(f"Detecting trends")
'min': { self.logger.info(f"Detecting trends")
'slope': min_slope,
'intercept': min_intercept, df = self.data.copy()
'r_squared': min_r_value ** 2 close_prices = df['close'].values
},
'max': { max_peaks, _ = find_peaks(close_prices)
'slope': max_slope, min_peaks, _ = find_peaks(-close_prices)
'intercept': max_intercept, max_peaks, _ = find_peaks(close_prices)
'r_squared': max_r_value ** 2 min_peaks, _ = find_peaks(-close_prices)
}
} self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima")
analysis_results['sma'] = {
'7': sma_7, df['is_min'] = False
'15': sma_15 df['is_max'] = False
}
for peak in max_peaks:
self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}") df.at[peak, 'is_max'] = True
self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}") for peak in min_peaks:
df.at[peak, 'is_min'] = True
return result, analysis_results
result = df[['datetime', 'close', 'is_min', 'is_max']].copy()
def plot_trends(self, trend_data, analysis_results):
""" # Perform linear regression on min_peaks and max_peaks
Plot the price data with detected trends using a candlestick chart. self.logger.info("Performing linear regression on min and max peaks")
min_prices = df['close'].iloc[min_peaks].values
Parameters: max_prices = df['close'].iloc[max_peaks].values
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
# Linear regression for min peaks if we have at least 2 points
Returns: min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
- None (displays the plot) # Linear regression for max peaks if we have at least 2 points
""" max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle # Calculate Simple Moving Averages (SMA) for 7 and 15 periods
self.logger.info("Calculating SMA-7 and SMA-15")
# Create the figure and axis
fig, ax = plt.subplots(figsize=(12, 8)) sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values
sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values
# Create a copy of the data
df = self.data.copy() analysis_results = {}
analysis_results['linear_regression'] = {
# Plot candlestick chart 'min': {
up_color = 'green' 'slope': min_slope,
down_color = 'red' 'intercept': min_intercept,
'r_squared': min_r_value ** 2
# Draw candlesticks manually },
width = 0.6 'max': {
x_values = range(len(df)) 'slope': max_slope,
'intercept': max_intercept,
for i in range(len(df)): 'r_squared': max_r_value ** 2
# Get OHLC values for this candle }
open_val = df['open'].iloc[i] }
close_val = df['close'].iloc[i] analysis_results['sma'] = {
high_val = df['high'].iloc[i] '7': sma_7,
low_val = df['low'].iloc[i] '15': sma_15
}
# Determine candle color
color = up_color if close_val >= open_val else down_color self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}")
self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}")
# Plot candle body
body_height = abs(close_val - open_val) return result, analysis_results
bottom = min(open_val, close_val)
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8) def plot_trends(self, trend_data, analysis_results):
ax.add_patch(rect) """
Plot the price data with detected trends using a candlestick chart.
# Plot candle wicks
ax.plot([i, i], [low_val, high_val], color='black', linewidth=1) Parameters:
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
if min_indices: Returns:
min_y = [df['close'].iloc[i] for i in min_indices] - None (displays the plot)
ax.scatter(min_indices, min_y, color='darkred', s=200, marker='^', label='Local Minima', zorder=100) """
import matplotlib.pyplot as plt
max_indices = trend_data.index[trend_data['is_max'] == True].tolist() from matplotlib.patches import Rectangle
if max_indices:
max_y = [df['close'].iloc[i] for i in max_indices] # Create the figure and axis with specified background
ax.scatter(max_indices, max_y, color='darkgreen', s=200, marker='v', label='Local Maxima', zorder=100) plt.style.use(self.plot_style)
fig, ax = plt.subplots(figsize=self.plot_size)
if analysis_results:
x_vals = np.arange(len(df)) # Set the custom background color
# Minima regression line (support) fig.patch.set_facecolor(self.bg_color)
min_slope = analysis_results['linear_regression']['min']['slope'] ax.set_facecolor(self.bg_color)
min_intercept = analysis_results['linear_regression']['min']['intercept']
min_line = min_slope * x_vals + min_intercept # Create a copy of the data
ax.plot(x_vals, min_line, 'g--', linewidth=2, label='Minima Regression') df = self.data.copy()
# Maxima regression line (resistance) # Draw candlesticks manually
max_slope = analysis_results['linear_regression']['max']['slope'] x_values = range(len(df))
max_intercept = analysis_results['linear_regression']['max']['intercept']
max_line = max_slope * x_vals + max_intercept for i in range(len(df)):
ax.plot(x_vals, max_line, 'r--', linewidth=2, label='Maxima Regression') # Get OHLC values for this candle
open_val = df['open'].iloc[i]
# SMA-7 line close_val = df['close'].iloc[i]
sma_7 = analysis_results['sma']['7'] high_val = df['high'].iloc[i]
ax.plot(x_vals, sma_7, 'y-', linewidth=2, label='SMA-7') low_val = df['low'].iloc[i]
# SMA-15 line # Determine candle color
# sma_15 = analysis_results['sma']['15'] color = self.candle_up_color if close_val >= open_val else self.candle_down_color
# valid_idx_15 = ~np.isnan(sma_15)
# ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], 'm-', linewidth=2, label='SMA-15') # Plot candle body
body_height = abs(close_val - open_val)
# Set title and labels bottom = min(open_val, close_val)
ax.set_title('Price Candlestick Chart with Local Minima and Maxima', fontsize=14) rect = Rectangle((i - self.candle_width/2, bottom), self.candle_width, body_height,
ax.set_xlabel('Date', fontsize=12) color=color, alpha=self.candle_alpha)
ax.set_ylabel('Price', fontsize=12) ax.add_patch(rect)
# Set appropriate x-axis limits # Plot candle wicks
ax.set_xlim(-0.5, len(df) - 0.5) ax.plot([i, i], [low_val, high_val], color=color, linewidth=self.wick_width)
# Add a legend min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
ax.legend(loc='best') if min_indices:
min_y = [df['close'].iloc[i] for i in min_indices]
# Adjust layout ax.scatter(min_indices, min_y, color=self.min_color, s=self.min_size,
plt.tight_layout() marker=self.min_marker, label='Local Minima', zorder=self.marker_zorder)
# Show the plot max_indices = trend_data.index[trend_data['is_max'] == True].tolist()
plt.show() if max_indices:
max_y = [df['close'].iloc[i] for i in max_indices]
ax.scatter(max_indices, max_y, color=self.max_color, s=self.max_size,
marker=self.max_marker, label='Local Maxima', zorder=self.marker_zorder)
if analysis_results:
x_vals = np.arange(len(df))
# Minima regression line (support)
min_slope = analysis_results['linear_regression']['min']['slope']
min_intercept = analysis_results['linear_regression']['min']['intercept']
min_line = min_slope * x_vals + min_intercept
ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width,
label='Minima Regression')
# Maxima regression line (resistance)
max_slope = analysis_results['linear_regression']['max']['slope']
max_intercept = analysis_results['linear_regression']['max']['intercept']
max_line = max_slope * x_vals + max_intercept
ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width,
label='Maxima Regression')
# SMA-7 line
sma_7 = analysis_results['sma']['7']
ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width,
label='SMA-7')
# SMA-15 line
sma_15 = analysis_results['sma']['15']
valid_idx_15 = ~np.isnan(sma_15)
ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style,
linewidth=self.line_width, label='SMA-15')
# Set title and labels
ax.set_title('Price Candlestick Chart with Local Minima and Maxima',
fontsize=self.title_size, color=self.title_color)
ax.set_xlabel('Date', fontsize=self.axis_label_size, color=self.axis_label_color)
ax.set_ylabel('Price', fontsize=self.axis_label_size, color=self.axis_label_color)
# Set appropriate x-axis limits
ax.set_xlim(-0.5, len(df) - 0.5)
# Add a legend
ax.legend(loc=self.legend_loc, facecolor=self.legend_bg_color)
# Adjust layout
plt.tight_layout()
# Show the plot
plt.show()