Refactor cycle detection and trend analysis; enhance trend detection with linear regression and moving averages. Update main script for improved data handling and visualization.
This commit is contained in:
parent
cbc6a7493d
commit
e9bfcd03eb
@ -1,248 +1,248 @@
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from scipy.signal import argrelextrema
|
from scipy.signal import argrelextrema
|
||||||
|
|
||||||
class CycleDetector:
|
class CycleDetector:
|
||||||
def __init__(self, data, timeframe='daily'):
|
def __init__(self, data, timeframe='daily'):
|
||||||
"""
|
"""
|
||||||
Initialize the CycleDetector with price data.
|
Initialize the CycleDetector with price data.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- data: DataFrame with at least 'date' or 'datetime' and 'close' columns
|
- data: DataFrame with at least 'date' or 'datetime' and 'close' columns
|
||||||
- timeframe: 'daily', 'weekly', or 'monthly'
|
- timeframe: 'daily', 'weekly', or 'monthly'
|
||||||
"""
|
"""
|
||||||
self.data = data.copy()
|
self.data = data.copy()
|
||||||
self.timeframe = timeframe
|
self.timeframe = timeframe
|
||||||
|
|
||||||
# Ensure we have a consistent date column name
|
# Ensure we have a consistent date column name
|
||||||
if 'datetime' in self.data.columns and 'date' not in self.data.columns:
|
if 'datetime' in self.data.columns and 'date' not in self.data.columns:
|
||||||
self.data.rename(columns={'datetime': 'date'}, inplace=True)
|
self.data.rename(columns={'datetime': 'date'}, inplace=True)
|
||||||
|
|
||||||
# Convert data to specified timeframe if needed
|
# Convert data to specified timeframe if needed
|
||||||
if timeframe == 'weekly' and 'date' in self.data.columns:
|
if timeframe == 'weekly' and 'date' in self.data.columns:
|
||||||
self.data = self._convert_data(self.data, 'W')
|
self.data = self._convert_data(self.data, 'W')
|
||||||
elif timeframe == 'monthly' and 'date' in self.data.columns:
|
elif timeframe == 'monthly' and 'date' in self.data.columns:
|
||||||
self.data = self._convert_data(self.data, 'M')
|
self.data = self._convert_data(self.data, 'M')
|
||||||
|
|
||||||
# Add columns for local minima and maxima detection
|
# Add columns for local minima and maxima detection
|
||||||
self._add_swing_points()
|
self._add_swing_points()
|
||||||
|
|
||||||
def _convert_data(self, data, timeframe):
|
def _convert_data(self, data, timeframe):
|
||||||
"""Convert daily data to 'timeframe' timeframe."""
|
"""Convert daily data to 'timeframe' timeframe."""
|
||||||
data['date'] = pd.to_datetime(data['date'])
|
data['date'] = pd.to_datetime(data['date'])
|
||||||
data.set_index('date', inplace=True)
|
data.set_index('date', inplace=True)
|
||||||
weekly = data.resample(timeframe).agg({
|
weekly = data.resample(timeframe).agg({
|
||||||
'open': 'first',
|
'open': 'first',
|
||||||
'high': 'max',
|
'high': 'max',
|
||||||
'low': 'min',
|
'low': 'min',
|
||||||
'close': 'last',
|
'close': 'last',
|
||||||
'volume': 'sum'
|
'volume': 'sum'
|
||||||
})
|
})
|
||||||
return weekly.reset_index()
|
return weekly.reset_index()
|
||||||
|
|
||||||
|
|
||||||
def _add_swing_points(self, window=5):
|
def _add_swing_points(self, window=5):
|
||||||
"""
|
"""
|
||||||
Identify swing points (local minima and maxima).
|
Identify swing points (local minima and maxima).
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- window: The window size for local minima/maxima detection
|
- window: The window size for local minima/maxima detection
|
||||||
"""
|
"""
|
||||||
# Set the index to make calculations easier
|
# Set the index to make calculations easier
|
||||||
if 'date' in self.data.columns:
|
if 'date' in self.data.columns:
|
||||||
self.data.set_index('date', inplace=True)
|
self.data.set_index('date', inplace=True)
|
||||||
|
|
||||||
# Detect local minima (swing lows)
|
# Detect local minima (swing lows)
|
||||||
min_idx = argrelextrema(self.data['low'].values, np.less, order=window)[0]
|
min_idx = argrelextrema(self.data['low'].values, np.less, order=window)[0]
|
||||||
self.data['swing_low'] = False
|
self.data['swing_low'] = False
|
||||||
self.data.iloc[min_idx, self.data.columns.get_loc('swing_low')] = True
|
self.data.iloc[min_idx, self.data.columns.get_loc('swing_low')] = True
|
||||||
|
|
||||||
# Detect local maxima (swing highs)
|
# Detect local maxima (swing highs)
|
||||||
max_idx = argrelextrema(self.data['high'].values, np.greater, order=window)[0]
|
max_idx = argrelextrema(self.data['high'].values, np.greater, order=window)[0]
|
||||||
self.data['swing_high'] = False
|
self.data['swing_high'] = False
|
||||||
self.data.iloc[max_idx, self.data.columns.get_loc('swing_high')] = True
|
self.data.iloc[max_idx, self.data.columns.get_loc('swing_high')] = True
|
||||||
|
|
||||||
# Reset index
|
# Reset index
|
||||||
self.data.reset_index(inplace=True)
|
self.data.reset_index(inplace=True)
|
||||||
|
|
||||||
def find_cycle_lows(self):
|
def find_cycle_lows(self):
|
||||||
"""Find all swing lows which represent cycle lows."""
|
"""Find all swing lows which represent cycle lows."""
|
||||||
swing_low_dates = self.data[self.data['swing_low']]['date'].values
|
swing_low_dates = self.data[self.data['swing_low']]['date'].values
|
||||||
return swing_low_dates
|
return swing_low_dates
|
||||||
|
|
||||||
def calculate_cycle_lengths(self):
|
def calculate_cycle_lengths(self):
|
||||||
"""Calculate the lengths of each cycle between consecutive lows."""
|
"""Calculate the lengths of each cycle between consecutive lows."""
|
||||||
swing_low_indices = np.where(self.data['swing_low'])[0]
|
swing_low_indices = np.where(self.data['swing_low'])[0]
|
||||||
cycle_lengths = np.diff(swing_low_indices)
|
cycle_lengths = np.diff(swing_low_indices)
|
||||||
return cycle_lengths
|
return cycle_lengths
|
||||||
|
|
||||||
def get_average_cycle_length(self):
|
def get_average_cycle_length(self):
|
||||||
"""Calculate the average cycle length."""
|
"""Calculate the average cycle length."""
|
||||||
cycle_lengths = self.calculate_cycle_lengths()
|
cycle_lengths = self.calculate_cycle_lengths()
|
||||||
if len(cycle_lengths) > 0:
|
if len(cycle_lengths) > 0:
|
||||||
return np.mean(cycle_lengths)
|
return np.mean(cycle_lengths)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_cycle_window(self, tolerance=0.10):
|
def get_cycle_window(self, tolerance=0.10):
|
||||||
"""
|
"""
|
||||||
Get the cycle window with the specified tolerance.
|
Get the cycle window with the specified tolerance.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- tolerance: The tolerance as a percentage (default: 10%)
|
- tolerance: The tolerance as a percentage (default: 10%)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- tuple: (min_cycle_length, avg_cycle_length, max_cycle_length)
|
- tuple: (min_cycle_length, avg_cycle_length, max_cycle_length)
|
||||||
"""
|
"""
|
||||||
avg_length = self.get_average_cycle_length()
|
avg_length = self.get_average_cycle_length()
|
||||||
if avg_length is not None:
|
if avg_length is not None:
|
||||||
min_length = avg_length * (1 - tolerance)
|
min_length = avg_length * (1 - tolerance)
|
||||||
max_length = avg_length * (1 + tolerance)
|
max_length = avg_length * (1 + tolerance)
|
||||||
return (min_length, avg_length, max_length)
|
return (min_length, avg_length, max_length)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def detect_two_drives_pattern(self, lookback=10):
|
def detect_two_drives_pattern(self, lookback=10):
|
||||||
"""
|
"""
|
||||||
Detect 2-drives pattern: a swing low, counter trend bounce, and a lower low.
|
Detect 2-drives pattern: a swing low, counter trend bounce, and a lower low.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- lookback: Number of periods to look back
|
- lookback: Number of periods to look back
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- list: Indices where 2-drives patterns are detected
|
- list: Indices where 2-drives patterns are detected
|
||||||
"""
|
"""
|
||||||
patterns = []
|
patterns = []
|
||||||
|
|
||||||
for i in range(lookback, len(self.data) - 1):
|
for i in range(lookback, len(self.data) - 1):
|
||||||
if not self.data.iloc[i]['swing_low']:
|
if not self.data.iloc[i]['swing_low']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get the segment of data to check for pattern
|
# Get the segment of data to check for pattern
|
||||||
segment = self.data.iloc[i-lookback:i+1]
|
segment = self.data.iloc[i-lookback:i+1]
|
||||||
swing_lows = segment[segment['swing_low']]['low'].values
|
swing_lows = segment[segment['swing_low']]['low'].values
|
||||||
|
|
||||||
if len(swing_lows) >= 2 and swing_lows[-1] < swing_lows[-2]:
|
if len(swing_lows) >= 2 and swing_lows[-1] < swing_lows[-2]:
|
||||||
# Check if there was a bounce between the two lows
|
# Check if there was a bounce between the two lows
|
||||||
between_lows = segment.iloc[-len(swing_lows):-1]
|
between_lows = segment.iloc[-len(swing_lows):-1]
|
||||||
if len(between_lows) > 0 and max(between_lows['high']) > swing_lows[-2]:
|
if len(between_lows) > 0 and max(between_lows['high']) > swing_lows[-2]:
|
||||||
patterns.append(i)
|
patterns.append(i)
|
||||||
|
|
||||||
return patterns
|
return patterns
|
||||||
|
|
||||||
def detect_v_shaped_lows(self, window=5, threshold=0.02):
|
def detect_v_shaped_lows(self, window=5, threshold=0.02):
|
||||||
"""
|
"""
|
||||||
Detect V-shaped cycle lows (sharp decline followed by sharp rise).
|
Detect V-shaped cycle lows (sharp decline followed by sharp rise).
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- window: Window to look for sharp price changes
|
- window: Window to look for sharp price changes
|
||||||
- threshold: Percentage change threshold to consider 'sharp'
|
- threshold: Percentage change threshold to consider 'sharp'
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- list: Indices where V-shaped patterns are detected
|
- list: Indices where V-shaped patterns are detected
|
||||||
"""
|
"""
|
||||||
patterns = []
|
patterns = []
|
||||||
|
|
||||||
# Find all swing lows
|
# Find all swing lows
|
||||||
swing_low_indices = np.where(self.data['swing_low'])[0]
|
swing_low_indices = np.where(self.data['swing_low'])[0]
|
||||||
|
|
||||||
for idx in swing_low_indices:
|
for idx in swing_low_indices:
|
||||||
# Need enough data points before and after
|
# Need enough data points before and after
|
||||||
if idx < window or idx + window >= len(self.data):
|
if idx < window or idx + window >= len(self.data):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get the low price at this swing low
|
# Get the low price at this swing low
|
||||||
low_price = self.data.iloc[idx]['low']
|
low_price = self.data.iloc[idx]['low']
|
||||||
|
|
||||||
# Check for sharp decline before low (at least window bars before)
|
# Check for sharp decline before low (at least window bars before)
|
||||||
before_segment = self.data.iloc[max(0, idx-window):idx]
|
before_segment = self.data.iloc[max(0, idx-window):idx]
|
||||||
if len(before_segment) > 0:
|
if len(before_segment) > 0:
|
||||||
max_before = before_segment['high'].max()
|
max_before = before_segment['high'].max()
|
||||||
decline = (max_before - low_price) / max_before
|
decline = (max_before - low_price) / max_before
|
||||||
|
|
||||||
# Check for sharp rise after low (at least window bars after)
|
# Check for sharp rise after low (at least window bars after)
|
||||||
after_segment = self.data.iloc[idx+1:min(len(self.data), idx+window+1)]
|
after_segment = self.data.iloc[idx+1:min(len(self.data), idx+window+1)]
|
||||||
if len(after_segment) > 0:
|
if len(after_segment) > 0:
|
||||||
max_after = after_segment['high'].max()
|
max_after = after_segment['high'].max()
|
||||||
rise = (max_after - low_price) / low_price
|
rise = (max_after - low_price) / low_price
|
||||||
|
|
||||||
# Both decline and rise must exceed threshold to be considered V-shaped
|
# Both decline and rise must exceed threshold to be considered V-shaped
|
||||||
if decline > threshold and rise > threshold:
|
if decline > threshold and rise > threshold:
|
||||||
patterns.append(idx)
|
patterns.append(idx)
|
||||||
|
|
||||||
return patterns
|
return patterns
|
||||||
|
|
||||||
def plot_cycles(self, pattern_detection=None, title_suffix=''):
|
def plot_cycles(self, pattern_detection=None, title_suffix=''):
|
||||||
"""
|
"""
|
||||||
Plot the price data with cycle lows and detected patterns.
|
Plot the price data with cycle lows and detected patterns.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- pattern_detection: 'two_drives', 'v_shape', or None
|
- pattern_detection: 'two_drives', 'v_shape', or None
|
||||||
- title_suffix: Optional suffix for the plot title
|
- title_suffix: Optional suffix for the plot title
|
||||||
"""
|
"""
|
||||||
plt.figure(figsize=(14, 7))
|
plt.figure(figsize=(14, 7))
|
||||||
|
|
||||||
# Determine the date column name (could be 'date' or 'datetime')
|
# Determine the date column name (could be 'date' or 'datetime')
|
||||||
date_col = 'date' if 'date' in self.data.columns else 'datetime'
|
date_col = 'date' if 'date' in self.data.columns else 'datetime'
|
||||||
|
|
||||||
# Plot price data
|
# Plot price data
|
||||||
plt.plot(self.data[date_col], self.data['close'], label='Close Price')
|
plt.plot(self.data[date_col], self.data['close'], label='Close Price')
|
||||||
|
|
||||||
# Calculate a consistent vertical position for indicators based on price range
|
# Calculate a consistent vertical position for indicators based on price range
|
||||||
price_range = self.data['close'].max() - self.data['close'].min()
|
price_range = self.data['close'].max() - self.data['close'].min()
|
||||||
indicator_offset = price_range * 0.01 # 1% of price range
|
indicator_offset = price_range * 0.01 # 1% of price range
|
||||||
|
|
||||||
# Plot cycle lows (now at a fixed offset below the low price)
|
# Plot cycle lows (now at a fixed offset below the low price)
|
||||||
swing_lows = self.data[self.data['swing_low']]
|
swing_lows = self.data[self.data['swing_low']]
|
||||||
plt.scatter(swing_lows[date_col], swing_lows['low'] - indicator_offset,
|
plt.scatter(swing_lows[date_col], swing_lows['low'] - indicator_offset,
|
||||||
color='green', marker='^', s=100, label='Cycle Lows')
|
color='green', marker='^', s=100, label='Cycle Lows')
|
||||||
|
|
||||||
# Plot specific patterns if requested
|
# Plot specific patterns if requested
|
||||||
if 'two_drives' in pattern_detection:
|
if 'two_drives' in pattern_detection:
|
||||||
pattern_indices = self.detect_two_drives_pattern()
|
pattern_indices = self.detect_two_drives_pattern()
|
||||||
if pattern_indices:
|
if pattern_indices:
|
||||||
patterns = self.data.iloc[pattern_indices]
|
patterns = self.data.iloc[pattern_indices]
|
||||||
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
|
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
|
||||||
color='red', marker='o', s=150, label='Two Drives Pattern')
|
color='red', marker='o', s=150, label='Two Drives Pattern')
|
||||||
|
|
||||||
elif 'v_shape' in pattern_detection:
|
elif 'v_shape' in pattern_detection:
|
||||||
pattern_indices = self.detect_v_shaped_lows()
|
pattern_indices = self.detect_v_shaped_lows()
|
||||||
if pattern_indices:
|
if pattern_indices:
|
||||||
patterns = self.data.iloc[pattern_indices]
|
patterns = self.data.iloc[pattern_indices]
|
||||||
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
|
plt.scatter(patterns[date_col], patterns['low'] - indicator_offset * 2,
|
||||||
color='purple', marker='o', s=150, label='V-Shape Pattern')
|
color='purple', marker='o', s=150, label='V-Shape Pattern')
|
||||||
|
|
||||||
# Add cycle lengths and averages
|
# Add cycle lengths and averages
|
||||||
cycle_lengths = self.calculate_cycle_lengths()
|
cycle_lengths = self.calculate_cycle_lengths()
|
||||||
avg_cycle = self.get_average_cycle_length()
|
avg_cycle = self.get_average_cycle_length()
|
||||||
cycle_window = self.get_cycle_window()
|
cycle_window = self.get_cycle_window()
|
||||||
|
|
||||||
window_text = ""
|
window_text = ""
|
||||||
if cycle_window:
|
if cycle_window:
|
||||||
window_text = f"Tolerance Window: [{cycle_window[0]:.2f} - {cycle_window[2]:.2f}]"
|
window_text = f"Tolerance Window: [{cycle_window[0]:.2f} - {cycle_window[2]:.2f}]"
|
||||||
|
|
||||||
plt.title(f"Detected Cycles - {self.timeframe.capitalize()} Timeframe {title_suffix}\n"
|
plt.title(f"Detected Cycles - {self.timeframe.capitalize()} Timeframe {title_suffix}\n"
|
||||||
f"Average Cycle Length: {avg_cycle:.2f} periods, {window_text}")
|
f"Average Cycle Length: {avg_cycle:.2f} periods, {window_text}")
|
||||||
|
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
# Usage example:
|
# Usage example:
|
||||||
# 1. Load your data
|
# 1. Load your data
|
||||||
# data = pd.read_csv('your_price_data.csv')
|
# data = pd.read_csv('your_price_data.csv')
|
||||||
|
|
||||||
# 2. Create cycle detector instances for different timeframes
|
# 2. Create cycle detector instances for different timeframes
|
||||||
# weekly_detector = CycleDetector(data, timeframe='weekly')
|
# weekly_detector = CycleDetector(data, timeframe='weekly')
|
||||||
# daily_detector = CycleDetector(data, timeframe='daily')
|
# daily_detector = CycleDetector(data, timeframe='daily')
|
||||||
|
|
||||||
# 3. Analyze cycles
|
# 3. Analyze cycles
|
||||||
# weekly_cycle_length = weekly_detector.get_average_cycle_length()
|
# weekly_cycle_length = weekly_detector.get_average_cycle_length()
|
||||||
# daily_cycle_length = daily_detector.get_average_cycle_length()
|
# daily_cycle_length = daily_detector.get_average_cycle_length()
|
||||||
|
|
||||||
# 4. Detect patterns
|
# 4. Detect patterns
|
||||||
# two_drives = weekly_detector.detect_two_drives_pattern()
|
# two_drives = weekly_detector.detect_two_drives_pattern()
|
||||||
# v_shapes = daily_detector.detect_v_shaped_lows()
|
# v_shapes = daily_detector.detect_v_shaped_lows()
|
||||||
|
|
||||||
# 5. Visualize
|
# 5. Visualize
|
||||||
# weekly_detector.plot_cycles(pattern_detection='two_drives')
|
# weekly_detector.plot_cycles(pattern_detection='two_drives')
|
||||||
# daily_detector.plot_cycles(pattern_detection='v_shape')
|
# daily_detector.plot_cycles(pattern_detection='v_shape')
|
||||||
|
|||||||
1
main.py
1
main.py
@ -6,6 +6,7 @@ from cycle_detector import CycleDetector
|
|||||||
# Load data from CSV file instead of database
|
# Load data from CSV file instead of database
|
||||||
data = pd.read_csv('data/btcusd_1-day_data.csv')
|
data = pd.read_csv('data/btcusd_1-day_data.csv')
|
||||||
|
|
||||||
|
|
||||||
# Convert datetime column to datetime type
|
# Convert datetime column to datetime type
|
||||||
start_date = pd.to_datetime('2025-04-01')
|
start_date = pd.to_datetime('2025-04-01')
|
||||||
stop_date = pd.to_datetime('2025-05-06')
|
stop_date = pd.to_datetime('2025-05-06')
|
||||||
|
|||||||
@ -1,259 +1,259 @@
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ta
|
import ta
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.dates as mdates
|
import matplotlib.dates as mdates
|
||||||
import logging
|
import logging
|
||||||
import mplfinance as mpf
|
import mplfinance as mpf
|
||||||
from matplotlib.patches import Rectangle
|
from matplotlib.patches import Rectangle
|
||||||
|
|
||||||
class TrendDetectorMACD:
|
class TrendDetectorMACD:
|
||||||
def __init__(self, data, verbose=False):
|
def __init__(self, data, verbose=False):
|
||||||
self.data = data
|
self.data = data
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s')
|
format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
self.logger = logging.getLogger('TrendDetector')
|
self.logger = logging.getLogger('TrendDetector')
|
||||||
|
|
||||||
# Convert data to pandas DataFrame if it's not already
|
# Convert data to pandas DataFrame if it's not already
|
||||||
if not isinstance(self.data, pd.DataFrame):
|
if not isinstance(self.data, pd.DataFrame):
|
||||||
if isinstance(self.data, list):
|
if isinstance(self.data, list):
|
||||||
self.logger.info("Converting list to DataFrame")
|
self.logger.info("Converting list to DataFrame")
|
||||||
self.data = pd.DataFrame({'close': self.data})
|
self.data = pd.DataFrame({'close': self.data})
|
||||||
else:
|
else:
|
||||||
self.logger.error("Invalid data format provided")
|
self.logger.error("Invalid data format provided")
|
||||||
raise ValueError("Data must be a pandas DataFrame or a list")
|
raise ValueError("Data must be a pandas DataFrame or a list")
|
||||||
|
|
||||||
self.logger.info(f"Initialized TrendDetector with {len(self.data)} data points")
|
self.logger.info(f"Initialized TrendDetector with {len(self.data)} data points")
|
||||||
|
|
||||||
def detect_trends_MACD_signal(self):
|
def detect_trends_MACD_signal(self):
|
||||||
self.logger.info("Starting trend detection")
|
self.logger.info("Starting trend detection")
|
||||||
if len(self.data) < 3:
|
if len(self.data) < 3:
|
||||||
self.logger.warning("Not enough data points for trend detection")
|
self.logger.warning("Not enough data points for trend detection")
|
||||||
return {"error": "Not enough data points for trend detection"}
|
return {"error": "Not enough data points for trend detection"}
|
||||||
|
|
||||||
# Create a copy of the DataFrame to avoid modifying the original
|
# Create a copy of the DataFrame to avoid modifying the original
|
||||||
df = self.data.copy()
|
df = self.data.copy()
|
||||||
self.logger.info("Created copy of input data")
|
self.logger.info("Created copy of input data")
|
||||||
|
|
||||||
# If 'close' column doesn't exist, try to use a relevant column
|
# If 'close' column doesn't exist, try to use a relevant column
|
||||||
if 'close' not in df.columns and len(df.columns) > 0:
|
if 'close' not in df.columns and len(df.columns) > 0:
|
||||||
self.logger.info(f"'close' column not found, using {df.columns[0]} instead")
|
self.logger.info(f"'close' column not found, using {df.columns[0]} instead")
|
||||||
df['close'] = df[df.columns[0]] # Use the first column as 'close'
|
df['close'] = df[df.columns[0]] # Use the first column as 'close'
|
||||||
|
|
||||||
# Add trend indicators
|
# Add trend indicators
|
||||||
self.logger.info("Calculating MACD indicators")
|
self.logger.info("Calculating MACD indicators")
|
||||||
# Moving Average Convergence Divergence (MACD)
|
# Moving Average Convergence Divergence (MACD)
|
||||||
df['macd'] = ta.trend.macd(df['close'])
|
df['macd'] = ta.trend.macd(df['close'])
|
||||||
df['macd_signal'] = ta.trend.macd_signal(df['close'])
|
df['macd_signal'] = ta.trend.macd_signal(df['close'])
|
||||||
df['macd_diff'] = ta.trend.macd_diff(df['close'])
|
df['macd_diff'] = ta.trend.macd_diff(df['close'])
|
||||||
|
|
||||||
# Directional Movement Index (DMI)
|
# Directional Movement Index (DMI)
|
||||||
if all(col in df.columns for col in ['high', 'low', 'close']):
|
if all(col in df.columns for col in ['high', 'low', 'close']):
|
||||||
self.logger.info("Calculating ADX indicators")
|
self.logger.info("Calculating ADX indicators")
|
||||||
df['adx'] = ta.trend.adx(df['high'], df['low'], df['close'])
|
df['adx'] = ta.trend.adx(df['high'], df['low'], df['close'])
|
||||||
df['adx_pos'] = ta.trend.adx_pos(df['high'], df['low'], df['close'])
|
df['adx_pos'] = ta.trend.adx_pos(df['high'], df['low'], df['close'])
|
||||||
df['adx_neg'] = ta.trend.adx_neg(df['high'], df['low'], df['close'])
|
df['adx_neg'] = ta.trend.adx_neg(df['high'], df['low'], df['close'])
|
||||||
|
|
||||||
# Identify trend changes
|
# Identify trend changes
|
||||||
self.logger.info("Identifying trend changes")
|
self.logger.info("Identifying trend changes")
|
||||||
df['trend'] = np.where(df['macd'] > df['macd_signal'], 'up', 'down')
|
df['trend'] = np.where(df['macd'] > df['macd_signal'], 'up', 'down')
|
||||||
df['trend_change'] = df['trend'] != df['trend'].shift(1)
|
df['trend_change'] = df['trend'] != df['trend'].shift(1)
|
||||||
|
|
||||||
# Generate trend segments
|
# Generate trend segments
|
||||||
self.logger.info("Generating trend segments")
|
self.logger.info("Generating trend segments")
|
||||||
trends = []
|
trends = []
|
||||||
trend_start = 0
|
trend_start = 0
|
||||||
|
|
||||||
for i in range(1, len(df)):
|
for i in range(1, len(df)):
|
||||||
|
|
||||||
if df['trend_change'].iloc[i]:
|
if df['trend_change'].iloc[i]:
|
||||||
if i > trend_start:
|
if i > trend_start:
|
||||||
trends.append({
|
trends.append({
|
||||||
"type": df['trend'].iloc[i-1],
|
"type": df['trend'].iloc[i-1],
|
||||||
"start_index": trend_start,
|
"start_index": trend_start,
|
||||||
"end_index": i-1,
|
"end_index": i-1,
|
||||||
"start_value": df['close'].iloc[trend_start],
|
"start_value": df['close'].iloc[trend_start],
|
||||||
"end_value": df['close'].iloc[i-1]
|
"end_value": df['close'].iloc[i-1]
|
||||||
})
|
})
|
||||||
trend_start = i
|
trend_start = i
|
||||||
|
|
||||||
# Add the last trend
|
# Add the last trend
|
||||||
if trend_start < len(df):
|
if trend_start < len(df):
|
||||||
trends.append({
|
trends.append({
|
||||||
"type": df['trend'].iloc[-1],
|
"type": df['trend'].iloc[-1],
|
||||||
"start_index": trend_start,
|
"start_index": trend_start,
|
||||||
"end_index": len(df)-1,
|
"end_index": len(df)-1,
|
||||||
"start_value": df['close'].iloc[trend_start],
|
"start_value": df['close'].iloc[trend_start],
|
||||||
"end_value": df['close'].iloc[-1]
|
"end_value": df['close'].iloc[-1]
|
||||||
})
|
})
|
||||||
|
|
||||||
self.logger.info(f"Detected {len(trends)} trend segments")
|
self.logger.info(f"Detected {len(trends)} trend segments")
|
||||||
return trends
|
return trends
|
||||||
|
|
||||||
def get_strongest_trend(self):
|
def get_strongest_trend(self):
|
||||||
self.logger.info("Finding strongest trend")
|
self.logger.info("Finding strongest trend")
|
||||||
trends = self.detect_trends_MACD_signal()
|
trends = self.detect_trends_MACD_signal()
|
||||||
if isinstance(trends, dict) and "error" in trends:
|
if isinstance(trends, dict) and "error" in trends:
|
||||||
self.logger.warning(f"Error in trend detection: {trends['error']}")
|
self.logger.warning(f"Error in trend detection: {trends['error']}")
|
||||||
return trends
|
return trends
|
||||||
|
|
||||||
if not trends:
|
if not trends:
|
||||||
self.logger.info("No significant trends detected")
|
self.logger.info("No significant trends detected")
|
||||||
return {"message": "No significant trends detected"}
|
return {"message": "No significant trends detected"}
|
||||||
|
|
||||||
strongest = max(trends, key=lambda x: abs(x["end_value"] - x["start_value"]))
|
strongest = max(trends, key=lambda x: abs(x["end_value"] - x["start_value"]))
|
||||||
self.logger.info(f"Strongest trend: {strongest['type']} from index {strongest['start_index']} to {strongest['end_index']}")
|
self.logger.info(f"Strongest trend: {strongest['type']} from index {strongest['start_index']} to {strongest['end_index']}")
|
||||||
return strongest
|
return strongest
|
||||||
|
|
||||||
def plot_trends(self, trends):
|
def plot_trends(self, trends):
|
||||||
"""
|
"""
|
||||||
Plot price data with identified trends highlighted using candlestick charts.
|
Plot price data with identified trends highlighted using candlestick charts.
|
||||||
"""
|
"""
|
||||||
self.logger.info("Plotting trends with candlesticks")
|
self.logger.info("Plotting trends with candlesticks")
|
||||||
if isinstance(trends, dict) and "error" in trends:
|
if isinstance(trends, dict) and "error" in trends:
|
||||||
self.logger.error(trends["error"])
|
self.logger.error(trends["error"])
|
||||||
print(trends["error"])
|
print(trends["error"])
|
||||||
return
|
return
|
||||||
|
|
||||||
if not trends:
|
if not trends:
|
||||||
self.logger.warning("No significant trends detected for plotting")
|
self.logger.warning("No significant trends detected for plotting")
|
||||||
print("No significant trends detected")
|
print("No significant trends detected")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create a figure with 2 subplots that share the x-axis
|
# Create a figure with 2 subplots that share the x-axis
|
||||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
|
||||||
self.logger.info("Creating plot figure with shared x-axis")
|
self.logger.info("Creating plot figure with shared x-axis")
|
||||||
|
|
||||||
# Prepare data for candlestick chart
|
# Prepare data for candlestick chart
|
||||||
df = self.data.copy()
|
df = self.data.copy()
|
||||||
|
|
||||||
# Ensure required columns exist for candlestick
|
# Ensure required columns exist for candlestick
|
||||||
required_cols = ['open', 'high', 'low', 'close']
|
required_cols = ['open', 'high', 'low', 'close']
|
||||||
if not all(col in df.columns for col in required_cols):
|
if not all(col in df.columns for col in required_cols):
|
||||||
self.logger.warning("Missing required columns for candlestick. Defaulting to line chart.")
|
self.logger.warning("Missing required columns for candlestick. Defaulting to line chart.")
|
||||||
if 'close' in df.columns:
|
if 'close' in df.columns:
|
||||||
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
|
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
|
||||||
df['close'], color='black', alpha=0.7, linewidth=1, label='Price')
|
df['close'], color='black', alpha=0.7, linewidth=1, label='Price')
|
||||||
else:
|
else:
|
||||||
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
|
ax1.plot(df.index if 'datetime' not in df.columns else df['datetime'],
|
||||||
df[df.columns[0]], color='black', alpha=0.7, linewidth=1, label='Price')
|
df[df.columns[0]], color='black', alpha=0.7, linewidth=1, label='Price')
|
||||||
else:
|
else:
|
||||||
# Get x values (dates if available, otherwise indices)
|
# Get x values (dates if available, otherwise indices)
|
||||||
if 'datetime' in df.columns:
|
if 'datetime' in df.columns:
|
||||||
x_label = 'Date'
|
x_label = 'Date'
|
||||||
# Format date axis
|
# Format date axis
|
||||||
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
||||||
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
||||||
fig.autofmt_xdate()
|
fig.autofmt_xdate()
|
||||||
self.logger.info("Using datetime for x-axis")
|
self.logger.info("Using datetime for x-axis")
|
||||||
|
|
||||||
# For candlestick, ensure datetime is the index
|
# For candlestick, ensure datetime is the index
|
||||||
if df.index.name != 'datetime':
|
if df.index.name != 'datetime':
|
||||||
df = df.set_index('datetime')
|
df = df.set_index('datetime')
|
||||||
else:
|
else:
|
||||||
x_label = 'Index'
|
x_label = 'Index'
|
||||||
self.logger.info("Using index for x-axis")
|
self.logger.info("Using index for x-axis")
|
||||||
|
|
||||||
# Plot candlestick chart
|
# Plot candlestick chart
|
||||||
up_color = 'green'
|
up_color = 'green'
|
||||||
down_color = 'red'
|
down_color = 'red'
|
||||||
|
|
||||||
# Draw candlesticks manually
|
# Draw candlesticks manually
|
||||||
width = 0.6
|
width = 0.6
|
||||||
for i in range(len(df)):
|
for i in range(len(df)):
|
||||||
# Get OHLC values for this candle
|
# Get OHLC values for this candle
|
||||||
open_val = df['open'].iloc[i]
|
open_val = df['open'].iloc[i]
|
||||||
close_val = df['close'].iloc[i]
|
close_val = df['close'].iloc[i]
|
||||||
high_val = df['high'].iloc[i]
|
high_val = df['high'].iloc[i]
|
||||||
low_val = df['low'].iloc[i]
|
low_val = df['low'].iloc[i]
|
||||||
idx = df.index[i]
|
idx = df.index[i]
|
||||||
|
|
||||||
# Determine candle color
|
# Determine candle color
|
||||||
color = up_color if close_val >= open_val else down_color
|
color = up_color if close_val >= open_val else down_color
|
||||||
|
|
||||||
# Plot candle body
|
# Plot candle body
|
||||||
body_height = abs(close_val - open_val)
|
body_height = abs(close_val - open_val)
|
||||||
bottom = min(open_val, close_val)
|
bottom = min(open_val, close_val)
|
||||||
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
|
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
|
||||||
ax1.add_patch(rect)
|
ax1.add_patch(rect)
|
||||||
|
|
||||||
# Plot candle wicks
|
# Plot candle wicks
|
||||||
ax1.plot([i, i], [low_val, high_val], color='black', linewidth=1)
|
ax1.plot([i, i], [low_val, high_val], color='black', linewidth=1)
|
||||||
|
|
||||||
# Set appropriate x-axis limits
|
# Set appropriate x-axis limits
|
||||||
ax1.set_xlim(-0.5, len(df) - 0.5)
|
ax1.set_xlim(-0.5, len(df) - 0.5)
|
||||||
|
|
||||||
# Highlight each trend with a different color
|
# Highlight each trend with a different color
|
||||||
self.logger.info("Highlighting trends on plot")
|
self.logger.info("Highlighting trends on plot")
|
||||||
for trend in trends:
|
for trend in trends:
|
||||||
start_idx = trend['start_index']
|
start_idx = trend['start_index']
|
||||||
end_idx = trend['end_index']
|
end_idx = trend['end_index']
|
||||||
trend_type = trend['type']
|
trend_type = trend['type']
|
||||||
|
|
||||||
# Get x-coordinates for trend plotting
|
# Get x-coordinates for trend plotting
|
||||||
x_start = start_idx
|
x_start = start_idx
|
||||||
x_end = end_idx
|
x_end = end_idx
|
||||||
|
|
||||||
# Get y-coordinates for trend line
|
# Get y-coordinates for trend line
|
||||||
if 'close' in df.columns:
|
if 'close' in df.columns:
|
||||||
y_start = df['close'].iloc[start_idx]
|
y_start = df['close'].iloc[start_idx]
|
||||||
y_end = df['close'].iloc[end_idx]
|
y_end = df['close'].iloc[end_idx]
|
||||||
else:
|
else:
|
||||||
y_start = df[df.columns[0]].iloc[start_idx]
|
y_start = df[df.columns[0]].iloc[start_idx]
|
||||||
y_end = df[df.columns[0]].iloc[end_idx]
|
y_end = df[df.columns[0]].iloc[end_idx]
|
||||||
|
|
||||||
# Choose color based on trend type
|
# Choose color based on trend type
|
||||||
color = 'green' if trend_type == 'up' else 'red'
|
color = 'green' if trend_type == 'up' else 'red'
|
||||||
|
|
||||||
# Plot trend line
|
# Plot trend line
|
||||||
ax1.plot([x_start, x_end], [y_start, y_end], color=color, linewidth=2,
|
ax1.plot([x_start, x_end], [y_start, y_end], color=color, linewidth=2,
|
||||||
label=f"{trend_type.capitalize()} Trend" if f"{trend_type.capitalize()} Trend" not in ax1.get_legend_handles_labels()[1] else "")
|
label=f"{trend_type.capitalize()} Trend" if f"{trend_type.capitalize()} Trend" not in ax1.get_legend_handles_labels()[1] else "")
|
||||||
|
|
||||||
# Add markers at start and end points
|
# Add markers at start and end points
|
||||||
ax1.scatter(x_start, y_start, color=color, marker='o', s=50)
|
ax1.scatter(x_start, y_start, color=color, marker='o', s=50)
|
||||||
ax1.scatter(x_end, y_end, color=color, marker='s', s=50)
|
ax1.scatter(x_end, y_end, color=color, marker='s', s=50)
|
||||||
|
|
||||||
# Configure first subplot
|
# Configure first subplot
|
||||||
ax1.set_title('Price with Trends (Candlestick)', fontsize=16)
|
ax1.set_title('Price with Trends (Candlestick)', fontsize=16)
|
||||||
ax1.set_ylabel('Price', fontsize=14)
|
ax1.set_ylabel('Price', fontsize=14)
|
||||||
ax1.grid(alpha=0.3)
|
ax1.grid(alpha=0.3)
|
||||||
ax1.legend()
|
ax1.legend()
|
||||||
|
|
||||||
# Create MACD in second subplot
|
# Create MACD in second subplot
|
||||||
self.logger.info("Creating MACD subplot")
|
self.logger.info("Creating MACD subplot")
|
||||||
|
|
||||||
# Calculate MACD indicators if not already present
|
# Calculate MACD indicators if not already present
|
||||||
if 'macd' not in df.columns:
|
if 'macd' not in df.columns:
|
||||||
if 'close' not in df.columns and len(df.columns) > 0:
|
if 'close' not in df.columns and len(df.columns) > 0:
|
||||||
df['close'] = df[df.columns[0]]
|
df['close'] = df[df.columns[0]]
|
||||||
|
|
||||||
df['macd'] = ta.trend.macd(df['close'])
|
df['macd'] = ta.trend.macd(df['close'])
|
||||||
df['macd_signal'] = ta.trend.macd_signal(df['close'])
|
df['macd_signal'] = ta.trend.macd_signal(df['close'])
|
||||||
df['macd_diff'] = ta.trend.macd_diff(df['close'])
|
df['macd_diff'] = ta.trend.macd_diff(df['close'])
|
||||||
|
|
||||||
# Plot MACD components on second subplot
|
# Plot MACD components on second subplot
|
||||||
x_indices = np.arange(len(df))
|
x_indices = np.arange(len(df))
|
||||||
ax2.plot(x_indices, df['macd'], label='MACD', color='blue')
|
ax2.plot(x_indices, df['macd'], label='MACD', color='blue')
|
||||||
ax2.plot(x_indices, df['macd_signal'], label='Signal', color='orange')
|
ax2.plot(x_indices, df['macd_signal'], label='Signal', color='orange')
|
||||||
|
|
||||||
# Plot MACD histogram
|
# Plot MACD histogram
|
||||||
for i in range(len(df)):
|
for i in range(len(df)):
|
||||||
if df['macd_diff'].iloc[i] >= 0:
|
if df['macd_diff'].iloc[i] >= 0:
|
||||||
ax2.bar(i, df['macd_diff'].iloc[i], color='green', alpha=0.5, width=0.8)
|
ax2.bar(i, df['macd_diff'].iloc[i], color='green', alpha=0.5, width=0.8)
|
||||||
else:
|
else:
|
||||||
ax2.bar(i, df['macd_diff'].iloc[i], color='red', alpha=0.5, width=0.8)
|
ax2.bar(i, df['macd_diff'].iloc[i], color='red', alpha=0.5, width=0.8)
|
||||||
|
|
||||||
ax2.set_title('MACD Indicator', fontsize=16)
|
ax2.set_title('MACD Indicator', fontsize=16)
|
||||||
ax2.set_xlabel(x_label, fontsize=14)
|
ax2.set_xlabel(x_label, fontsize=14)
|
||||||
ax2.set_ylabel('MACD', fontsize=14)
|
ax2.set_ylabel('MACD', fontsize=14)
|
||||||
ax2.grid(alpha=0.3)
|
ax2.grid(alpha=0.3)
|
||||||
ax2.legend()
|
ax2.legend()
|
||||||
|
|
||||||
# Enable synchronized zooming
|
# Enable synchronized zooming
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.subplots_adjust(hspace=0.1)
|
plt.subplots_adjust(hspace=0.1)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
return plt
|
return plt
|
||||||
|
|||||||
@ -1,205 +1,255 @@
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
from scipy.signal import find_peaks
|
from scipy.signal import find_peaks
|
||||||
import matplotlib.dates as mdates
|
import matplotlib.dates as mdates
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
|
from scipy import stats
|
||||||
class TrendDetectorSimple:
|
|
||||||
def __init__(self, data, verbose=False):
|
class TrendDetectorSimple:
|
||||||
"""
|
def __init__(self, data, verbose=False):
|
||||||
Initialize the TrendDetectorSimple class.
|
"""
|
||||||
|
Initialize the TrendDetectorSimple class.
|
||||||
Parameters:
|
|
||||||
- data: pandas DataFrame containing price data
|
Parameters:
|
||||||
- verbose: boolean, whether to display detailed logging information
|
- data: pandas DataFrame containing price data
|
||||||
"""
|
- verbose: boolean, whether to display detailed logging information
|
||||||
|
"""
|
||||||
self.data = data
|
|
||||||
self.verbose = verbose
|
self.data = data
|
||||||
|
self.verbose = verbose
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
# Plot style configuration
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s')
|
self.plot_style = 'dark_background'
|
||||||
self.logger = logging.getLogger('TrendDetectorSimple')
|
self.bg_color = '#181C27'
|
||||||
|
self.plot_size = (12, 8)
|
||||||
# Convert data to pandas DataFrame if it's not already
|
|
||||||
if not isinstance(self.data, pd.DataFrame):
|
# Candlestick configuration
|
||||||
if isinstance(self.data, list):
|
self.candle_width = 0.6
|
||||||
self.logger.info("Converting list to DataFrame")
|
self.candle_up_color = '#089981'
|
||||||
self.data = pd.DataFrame({'close': self.data})
|
self.candle_down_color = '#F23645'
|
||||||
else:
|
self.candle_alpha = 0.8
|
||||||
self.logger.error("Invalid data format provided")
|
self.wick_width = 1
|
||||||
raise ValueError("Data must be a pandas DataFrame or a list")
|
|
||||||
|
# Marker configuration
|
||||||
self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points")
|
self.min_marker = '^'
|
||||||
|
self.min_color = 'red'
|
||||||
def detect_trends(self):
|
self.min_size = 100
|
||||||
"""
|
self.max_marker = 'v'
|
||||||
Detect trends by identifying local minima and maxima in the price data
|
self.max_color = 'green'
|
||||||
using scipy.signal.find_peaks.
|
self.max_size = 100
|
||||||
|
self.marker_zorder = 100
|
||||||
Parameters:
|
|
||||||
- prominence: float, required prominence of peaks (relative to the price range)
|
# Line configuration
|
||||||
- width: int, required width of peaks in data points
|
self.line_width = 2
|
||||||
|
self.min_line_style = 'g--' # green dashed
|
||||||
Returns:
|
self.max_line_style = 'r--' # red dashed
|
||||||
- DataFrame with columns for timestamps, prices, and trend indicators
|
self.sma7_line_style = 'y-' # yellow solid
|
||||||
"""
|
self.sma15_line_style = 'm-' # magenta solid
|
||||||
self.logger.info(f"Detecting trends")
|
|
||||||
|
# Text configuration
|
||||||
df = self.data.copy()
|
self.title_size = 14
|
||||||
close_prices = df['close'].values
|
self.title_color = 'white'
|
||||||
|
self.axis_label_size = 12
|
||||||
max_peaks, _ = find_peaks(close_prices)
|
self.axis_label_color = 'white'
|
||||||
min_peaks, _ = find_peaks(-close_prices)
|
|
||||||
|
# Legend configuration
|
||||||
self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima")
|
self.legend_loc = 'best'
|
||||||
|
self.legend_bg_color = '#333333'
|
||||||
df['is_min'] = False
|
|
||||||
df['is_max'] = False
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
||||||
for peak in max_peaks:
|
format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
df.at[peak, 'is_max'] = True
|
self.logger = logging.getLogger('TrendDetectorSimple')
|
||||||
for peak in min_peaks:
|
|
||||||
df.at[peak, 'is_min'] = True
|
# Convert data to pandas DataFrame if it's not already
|
||||||
|
if not isinstance(self.data, pd.DataFrame):
|
||||||
result = df[['datetime', 'close', 'is_min', 'is_max']].copy()
|
if isinstance(self.data, list):
|
||||||
|
self.logger.info("Converting list to DataFrame")
|
||||||
# Perform linear regression on min_peaks and max_peaks
|
self.data = pd.DataFrame({'close': self.data})
|
||||||
self.logger.info("Performing linear regression on min and max peaks")
|
else:
|
||||||
min_prices = df['close'].iloc[min_peaks].values
|
self.logger.error("Invalid data format provided")
|
||||||
max_prices = df['close'].iloc[max_peaks].values
|
raise ValueError("Data must be a pandas DataFrame or a list")
|
||||||
|
|
||||||
# Linear regression for min peaks if we have at least 2 points
|
self.logger.info(f"Initialized TrendDetectorSimple with {len(self.data)} data points")
|
||||||
min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
|
|
||||||
# Linear regression for max peaks if we have at least 2 points
|
def detect_trends(self):
|
||||||
max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
|
def detect_trends(self):
|
||||||
|
"""
|
||||||
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
|
Detect trends by identifying local minima and maxima in the price data
|
||||||
self.logger.info("Calculating SMA-7 and SMA-15")
|
using scipy.signal.find_peaks.
|
||||||
|
|
||||||
# Calculate SMA values and exclude NaN values
|
Parameters:
|
||||||
sma_7 = df['close'].rolling(window=7).mean().dropna().values
|
- prominence: float, required prominence of peaks (relative to the price range)
|
||||||
sma_15 = df['close'].rolling(window=15).mean().dropna().values
|
- width: int, required width of peaks in data points
|
||||||
|
|
||||||
# Add SMA values to regression_results
|
Returns:
|
||||||
analysis_results = {}
|
- DataFrame with columns for timestamps, prices, and trend indicators
|
||||||
analysis_results['linear_regression'] = {
|
"""
|
||||||
'min': {
|
self.logger.info(f"Detecting trends")
|
||||||
'slope': min_slope,
|
self.logger.info(f"Detecting trends")
|
||||||
'intercept': min_intercept,
|
|
||||||
'r_squared': min_r_value ** 2
|
df = self.data.copy()
|
||||||
},
|
close_prices = df['close'].values
|
||||||
'max': {
|
|
||||||
'slope': max_slope,
|
max_peaks, _ = find_peaks(close_prices)
|
||||||
'intercept': max_intercept,
|
min_peaks, _ = find_peaks(-close_prices)
|
||||||
'r_squared': max_r_value ** 2
|
max_peaks, _ = find_peaks(close_prices)
|
||||||
}
|
min_peaks, _ = find_peaks(-close_prices)
|
||||||
}
|
|
||||||
analysis_results['sma'] = {
|
self.logger.info(f"Found {len(min_peaks)} local minima and {len(max_peaks)} local maxima")
|
||||||
'7': sma_7,
|
|
||||||
'15': sma_15
|
df['is_min'] = False
|
||||||
}
|
df['is_max'] = False
|
||||||
|
|
||||||
self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}")
|
for peak in max_peaks:
|
||||||
self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}")
|
df.at[peak, 'is_max'] = True
|
||||||
|
for peak in min_peaks:
|
||||||
return result, analysis_results
|
df.at[peak, 'is_min'] = True
|
||||||
|
|
||||||
def plot_trends(self, trend_data, analysis_results):
|
result = df[['datetime', 'close', 'is_min', 'is_max']].copy()
|
||||||
"""
|
|
||||||
Plot the price data with detected trends using a candlestick chart.
|
# Perform linear regression on min_peaks and max_peaks
|
||||||
|
self.logger.info("Performing linear regression on min and max peaks")
|
||||||
Parameters:
|
min_prices = df['close'].iloc[min_peaks].values
|
||||||
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
|
max_prices = df['close'].iloc[max_peaks].values
|
||||||
|
|
||||||
Returns:
|
# Linear regression for min peaks if we have at least 2 points
|
||||||
- None (displays the plot)
|
min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
|
||||||
"""
|
# Linear regression for max peaks if we have at least 2 points
|
||||||
import matplotlib.pyplot as plt
|
max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
|
||||||
from matplotlib.patches import Rectangle
|
|
||||||
|
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
|
||||||
# Create the figure and axis
|
self.logger.info("Calculating SMA-7 and SMA-15")
|
||||||
fig, ax = plt.subplots(figsize=(12, 8))
|
|
||||||
|
sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values
|
||||||
# Create a copy of the data
|
sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values
|
||||||
df = self.data.copy()
|
|
||||||
|
analysis_results = {}
|
||||||
# Plot candlestick chart
|
analysis_results['linear_regression'] = {
|
||||||
up_color = 'green'
|
'min': {
|
||||||
down_color = 'red'
|
'slope': min_slope,
|
||||||
|
'intercept': min_intercept,
|
||||||
# Draw candlesticks manually
|
'r_squared': min_r_value ** 2
|
||||||
width = 0.6
|
},
|
||||||
x_values = range(len(df))
|
'max': {
|
||||||
|
'slope': max_slope,
|
||||||
for i in range(len(df)):
|
'intercept': max_intercept,
|
||||||
# Get OHLC values for this candle
|
'r_squared': max_r_value ** 2
|
||||||
open_val = df['open'].iloc[i]
|
}
|
||||||
close_val = df['close'].iloc[i]
|
}
|
||||||
high_val = df['high'].iloc[i]
|
analysis_results['sma'] = {
|
||||||
low_val = df['low'].iloc[i]
|
'7': sma_7,
|
||||||
|
'15': sma_15
|
||||||
# Determine candle color
|
}
|
||||||
color = up_color if close_val >= open_val else down_color
|
|
||||||
|
self.logger.info(f"Min peaks regression: slope={min_slope:.4f}, intercept={min_intercept:.4f}, r²={min_r_value**2:.4f}")
|
||||||
# Plot candle body
|
self.logger.info(f"Max peaks regression: slope={max_slope:.4f}, intercept={max_intercept:.4f}, r²={max_r_value**2:.4f}")
|
||||||
body_height = abs(close_val - open_val)
|
|
||||||
bottom = min(open_val, close_val)
|
return result, analysis_results
|
||||||
rect = Rectangle((i - width/2, bottom), width, body_height, color=color, alpha=0.8)
|
|
||||||
ax.add_patch(rect)
|
def plot_trends(self, trend_data, analysis_results):
|
||||||
|
def plot_trends(self, trend_data, analysis_results):
|
||||||
# Plot candle wicks
|
"""
|
||||||
ax.plot([i, i], [low_val, high_val], color='black', linewidth=1)
|
Plot the price data with detected trends using a candlestick chart.
|
||||||
|
|
||||||
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
|
Parameters:
|
||||||
if min_indices:
|
- trend_data: DataFrame, the output from detect_trends(). If None, detect_trends() will be called.
|
||||||
min_y = [df['close'].iloc[i] for i in min_indices]
|
|
||||||
ax.scatter(min_indices, min_y, color='darkred', s=200, marker='^', label='Local Minima', zorder=100)
|
Returns:
|
||||||
|
- None (displays the plot)
|
||||||
max_indices = trend_data.index[trend_data['is_max'] == True].tolist()
|
"""
|
||||||
if max_indices:
|
import matplotlib.pyplot as plt
|
||||||
max_y = [df['close'].iloc[i] for i in max_indices]
|
from matplotlib.patches import Rectangle
|
||||||
ax.scatter(max_indices, max_y, color='darkgreen', s=200, marker='v', label='Local Maxima', zorder=100)
|
|
||||||
|
# Create the figure and axis with specified background
|
||||||
if analysis_results:
|
plt.style.use(self.plot_style)
|
||||||
x_vals = np.arange(len(df))
|
fig, ax = plt.subplots(figsize=self.plot_size)
|
||||||
# Minima regression line (support)
|
|
||||||
min_slope = analysis_results['linear_regression']['min']['slope']
|
# Set the custom background color
|
||||||
min_intercept = analysis_results['linear_regression']['min']['intercept']
|
fig.patch.set_facecolor(self.bg_color)
|
||||||
min_line = min_slope * x_vals + min_intercept
|
ax.set_facecolor(self.bg_color)
|
||||||
ax.plot(x_vals, min_line, 'g--', linewidth=2, label='Minima Regression')
|
|
||||||
|
# Create a copy of the data
|
||||||
# Maxima regression line (resistance)
|
df = self.data.copy()
|
||||||
max_slope = analysis_results['linear_regression']['max']['slope']
|
|
||||||
max_intercept = analysis_results['linear_regression']['max']['intercept']
|
# Draw candlesticks manually
|
||||||
max_line = max_slope * x_vals + max_intercept
|
x_values = range(len(df))
|
||||||
ax.plot(x_vals, max_line, 'r--', linewidth=2, label='Maxima Regression')
|
|
||||||
|
for i in range(len(df)):
|
||||||
# SMA-7 line
|
# Get OHLC values for this candle
|
||||||
sma_7 = analysis_results['sma']['7']
|
open_val = df['open'].iloc[i]
|
||||||
ax.plot(x_vals, sma_7, 'y-', linewidth=2, label='SMA-7')
|
close_val = df['close'].iloc[i]
|
||||||
|
high_val = df['high'].iloc[i]
|
||||||
# SMA-15 line
|
low_val = df['low'].iloc[i]
|
||||||
# sma_15 = analysis_results['sma']['15']
|
|
||||||
# valid_idx_15 = ~np.isnan(sma_15)
|
# Determine candle color
|
||||||
# ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], 'm-', linewidth=2, label='SMA-15')
|
color = self.candle_up_color if close_val >= open_val else self.candle_down_color
|
||||||
|
|
||||||
# Set title and labels
|
# Plot candle body
|
||||||
ax.set_title('Price Candlestick Chart with Local Minima and Maxima', fontsize=14)
|
body_height = abs(close_val - open_val)
|
||||||
ax.set_xlabel('Date', fontsize=12)
|
bottom = min(open_val, close_val)
|
||||||
ax.set_ylabel('Price', fontsize=12)
|
rect = Rectangle((i - self.candle_width/2, bottom), self.candle_width, body_height,
|
||||||
|
color=color, alpha=self.candle_alpha)
|
||||||
# Set appropriate x-axis limits
|
ax.add_patch(rect)
|
||||||
ax.set_xlim(-0.5, len(df) - 0.5)
|
|
||||||
|
# Plot candle wicks
|
||||||
# Add a legend
|
ax.plot([i, i], [low_val, high_val], color=color, linewidth=self.wick_width)
|
||||||
ax.legend(loc='best')
|
|
||||||
|
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
|
||||||
# Adjust layout
|
if min_indices:
|
||||||
plt.tight_layout()
|
min_y = [df['close'].iloc[i] for i in min_indices]
|
||||||
|
ax.scatter(min_indices, min_y, color=self.min_color, s=self.min_size,
|
||||||
# Show the plot
|
marker=self.min_marker, label='Local Minima', zorder=self.marker_zorder)
|
||||||
plt.show()
|
|
||||||
|
max_indices = trend_data.index[trend_data['is_max'] == True].tolist()
|
||||||
|
if max_indices:
|
||||||
|
max_y = [df['close'].iloc[i] for i in max_indices]
|
||||||
|
ax.scatter(max_indices, max_y, color=self.max_color, s=self.max_size,
|
||||||
|
marker=self.max_marker, label='Local Maxima', zorder=self.marker_zorder)
|
||||||
|
|
||||||
|
if analysis_results:
|
||||||
|
x_vals = np.arange(len(df))
|
||||||
|
# Minima regression line (support)
|
||||||
|
min_slope = analysis_results['linear_regression']['min']['slope']
|
||||||
|
min_intercept = analysis_results['linear_regression']['min']['intercept']
|
||||||
|
min_line = min_slope * x_vals + min_intercept
|
||||||
|
ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width,
|
||||||
|
label='Minima Regression')
|
||||||
|
|
||||||
|
# Maxima regression line (resistance)
|
||||||
|
max_slope = analysis_results['linear_regression']['max']['slope']
|
||||||
|
max_intercept = analysis_results['linear_regression']['max']['intercept']
|
||||||
|
max_line = max_slope * x_vals + max_intercept
|
||||||
|
ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width,
|
||||||
|
label='Maxima Regression')
|
||||||
|
|
||||||
|
# SMA-7 line
|
||||||
|
sma_7 = analysis_results['sma']['7']
|
||||||
|
ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width,
|
||||||
|
label='SMA-7')
|
||||||
|
|
||||||
|
# SMA-15 line
|
||||||
|
sma_15 = analysis_results['sma']['15']
|
||||||
|
valid_idx_15 = ~np.isnan(sma_15)
|
||||||
|
ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style,
|
||||||
|
linewidth=self.line_width, label='SMA-15')
|
||||||
|
|
||||||
|
# Set title and labels
|
||||||
|
ax.set_title('Price Candlestick Chart with Local Minima and Maxima',
|
||||||
|
fontsize=self.title_size, color=self.title_color)
|
||||||
|
ax.set_xlabel('Date', fontsize=self.axis_label_size, color=self.axis_label_color)
|
||||||
|
ax.set_ylabel('Price', fontsize=self.axis_label_size, color=self.axis_label_color)
|
||||||
|
|
||||||
|
# Set appropriate x-axis limits
|
||||||
|
ax.set_xlim(-0.5, len(df) - 0.5)
|
||||||
|
|
||||||
|
# Add a legend
|
||||||
|
ax.legend(loc=self.legend_loc, facecolor=self.legend_bg_color)
|
||||||
|
|
||||||
|
# Adjust layout
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Show the plot
|
||||||
|
plt.show()
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user