2025-06-12 13:27:30 +08:00

425 lines
15 KiB
Python

"""
Subplot Chart Layers
This module contains subplot layer implementations for indicators that render
in separate subplots below the main price chart, such as RSI, MACD, and other
oscillators and momentum indicators.
"""
import plotly.graph_objects as go
import pandas as pd
from decimal import Decimal
from typing import Dict, Any, Optional, List, Union, Tuple
from dataclasses import dataclass
from .base import BaseChartLayer, LayerConfig
from .indicators import BaseIndicatorLayer, IndicatorLayerConfig
from data.common.indicators import TechnicalIndicators, IndicatorResult
from data.common.data_types import OHLCVCandle
from components.charts.utils import get_indicator_colors
from utils.logger import get_logger
from ..error_handling import (
ChartErrorHandler, ChartError, ErrorSeverity, DataRequirements,
InsufficientDataError, DataValidationError, IndicatorCalculationError,
ErrorRecoveryStrategies, create_error_annotation, get_error_message
)
# Initialize logger
logger = get_logger()
@dataclass
class SubplotLayerConfig(IndicatorLayerConfig):
"""Extended configuration for subplot indicator layers"""
subplot_height_ratio: float = 0.25 # Height ratio for subplot (0.25 = 25% of total height)
y_axis_range: Optional[Tuple[float, float]] = None # Fixed y-axis range (min, max)
show_zero_line: bool = False # Show horizontal line at y=0
reference_lines: List[float] = None # Additional horizontal reference lines
def __post_init__(self):
super().__post_init__()
if self.reference_lines is None:
self.reference_lines = []
class BaseSubplotLayer(BaseIndicatorLayer):
"""
Base class for all subplot indicator layers.
Provides common functionality for indicators that render in separate subplots
with their own y-axis scaling and reference lines.
"""
def __init__(self, config: SubplotLayerConfig):
"""
Initialize base subplot layer.
Args:
config: Subplot layer configuration
"""
super().__init__(config)
self.subplot_config = config
def get_subplot_height_ratio(self) -> float:
"""Get the height ratio for this subplot."""
return self.subplot_config.subplot_height_ratio
def has_fixed_range(self) -> bool:
"""Check if this subplot has a fixed y-axis range."""
return self.subplot_config.y_axis_range is not None
def get_y_axis_range(self) -> Optional[Tuple[float, float]]:
"""Get the fixed y-axis range if defined."""
return self.subplot_config.y_axis_range
def should_show_zero_line(self) -> bool:
"""Check if zero line should be shown."""
return self.subplot_config.show_zero_line
def get_reference_lines(self) -> List[float]:
"""Get additional reference lines to draw."""
return self.subplot_config.reference_lines
def add_reference_lines(self, fig: go.Figure, row: int, col: int = 1) -> None:
"""
Add reference lines to the subplot.
Args:
fig: Target figure
row: Subplot row
col: Subplot column
"""
try:
# Add zero line if enabled
if self.should_show_zero_line():
fig.add_hline(
y=0,
line=dict(color='gray', width=1, dash='dash'),
row=row,
col=col
)
# Add additional reference lines
for ref_value in self.get_reference_lines():
fig.add_hline(
y=ref_value,
line=dict(color='lightgray', width=1, dash='dot'),
row=row,
col=col
)
except Exception as e:
self.logger.warning(f"Subplot layers: Could not add reference lines: {e}")
class RSILayer(BaseSubplotLayer):
"""
Relative Strength Index (RSI) subplot layer.
Renders RSI oscillator in a separate subplot with standard overbought (70)
and oversold (30) reference lines.
"""
def __init__(self, period: int = 14, color: str = None, name: str = None):
"""
Initialize RSI layer.
Args:
period: RSI period (default: 14)
color: Line color (optional, uses default)
name: Layer name (optional, auto-generated)
"""
# Use default color if not specified
if color is None:
colors = get_indicator_colors()
color = colors.get('rsi', '#20c997')
# Generate name if not specified
if name is None:
name = f"RSI({period})"
# Find next available subplot row (will be managed by LayerManager)
subplot_row = 2 # Default to row 2 (first subplot after main chart)
config = SubplotLayerConfig(
name=name,
indicator_type="rsi",
color=color,
parameters={'period': period},
subplot_row=subplot_row,
subplot_height_ratio=0.25,
y_axis_range=(0, 100), # RSI ranges from 0 to 100
reference_lines=[30, 70], # Oversold and overbought levels
style={
'line_color': color,
'line_width': 2,
'opacity': 1.0
}
)
super().__init__(config)
self.period = period
def _calculate_rsi(self, data: pd.DataFrame, period: int) -> pd.DataFrame:
"""Calculate RSI with validation and error handling"""
try:
result_df = data.copy()
# Calculate price changes
result_df['price_change'] = result_df['close'].diff()
# Separate gains and losses
result_df['gain'] = result_df['price_change'].clip(lower=0)
result_df['loss'] = -result_df['price_change'].clip(upper=0)
# Calculate average gains and losses using Wilder's smoothing
result_df['avg_gain'] = result_df['gain'].ewm(alpha=1/period, adjust=False).mean()
result_df['avg_loss'] = result_df['loss'].ewm(alpha=1/period, adjust=False).mean()
# Calculate RS and RSI
result_df['rs'] = result_df['avg_gain'] / result_df['avg_loss']
result_df['rsi'] = 100 - (100 / (1 + result_df['rs']))
# Remove rows where RSI cannot be calculated
result_df = result_df.iloc[period:].copy()
# Remove NaN values and invalid RSI values
result_df = result_df.dropna(subset=['rsi'])
result_df = result_df[
(result_df['rsi'] >= 0) &
(result_df['rsi'] <= 100) &
pd.notna(result_df['rsi'])
]
if result_df.empty:
raise Exception(f'RSI calculation produced no values (period={period}, data_length={len(data)})')
return result_df[['timestamp', 'rsi']]
except Exception as e:
raise Exception(f'RSI calculation failed: {str(e)}')
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
"""Render RSI layer for compatibility with base interface"""
try:
# Calculate RSI
rsi_data = self._calculate_rsi(data, self.period)
if rsi_data.empty:
return fig
# Create RSI trace
rsi_trace = go.Scatter(
x=rsi_data['timestamp'],
y=rsi_data['rsi'],
mode='lines',
name=self.config.name,
line=dict(
color=self.config.color,
width=2
),
showlegend=True
)
# Add trace
row = kwargs.get('row', self.config.subplot_row or 2)
col = kwargs.get('col', 1)
if hasattr(fig, 'add_trace'):
fig.add_trace(rsi_trace, row=row, col=col)
else:
fig.add_trace(rsi_trace)
# Add reference lines
self.add_reference_lines(fig, row, col)
return fig
except Exception as e:
self.logger.error(f"Subplot layers: Error rendering RSI layer: {e}")
return fig
class MACDLayer(BaseSubplotLayer):
"""MACD (Moving Average Convergence Divergence) subplot layer with enhanced error handling"""
def __init__(self, fast_period: int = 12, slow_period: int = 26, signal_period: int = 9,
color: str = None, name: str = None):
"""Initialize MACD layer with custom parameters"""
# Use default color if not specified
if color is None:
colors = get_indicator_colors()
color = colors.get('macd', '#fd7e14')
# Generate name if not specified
if name is None:
name = f"MACD({fast_period},{slow_period},{signal_period})"
config = SubplotLayerConfig(
name=name,
indicator_type="macd",
color=color,
parameters={
'fast_period': fast_period,
'slow_period': slow_period,
'signal_period': signal_period
},
subplot_row=3, # Will be managed by LayerManager
subplot_height_ratio=0.3,
show_zero_line=True,
style={
'line_color': color,
'line_width': 2,
'opacity': 1.0
}
)
super().__init__(config)
self.fast_period = fast_period
self.slow_period = slow_period
self.signal_period = signal_period
def _calculate_macd(self, data: pd.DataFrame, fast_period: int,
slow_period: int, signal_period: int) -> pd.DataFrame:
"""Calculate MACD with validation and error handling"""
try:
result_df = data.copy()
# Validate periods
if fast_period >= slow_period:
raise Exception(f'Fast period ({fast_period}) must be less than slow period ({slow_period})')
# Calculate EMAs
result_df['ema_fast'] = result_df['close'].ewm(span=fast_period, adjust=False).mean()
result_df['ema_slow'] = result_df['close'].ewm(span=slow_period, adjust=False).mean()
# Calculate MACD line
result_df['macd'] = result_df['ema_fast'] - result_df['ema_slow']
# Calculate signal line
result_df['signal'] = result_df['macd'].ewm(span=signal_period, adjust=False).mean()
# Calculate histogram
result_df['histogram'] = result_df['macd'] - result_df['signal']
# Remove rows where MACD cannot be calculated reliably
warmup_period = slow_period + signal_period
result_df = result_df.iloc[warmup_period:].copy()
# Remove NaN values
result_df = result_df.dropna(subset=['macd', 'signal', 'histogram'])
if result_df.empty:
raise Exception(f'MACD calculation produced no values (fast={fast_period}, slow={slow_period}, signal={signal_period})')
return result_df[['timestamp', 'macd', 'signal', 'histogram']]
except Exception as e:
raise Exception(f'MACD calculation failed: {str(e)}')
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
"""Render MACD layer for compatibility with base interface"""
try:
# Calculate MACD
macd_data = self._calculate_macd(data, self.fast_period, self.slow_period, self.signal_period)
if macd_data.empty:
return fig
row = kwargs.get('row', self.config.subplot_row or 3)
col = kwargs.get('col', 1)
# Create MACD line trace
macd_trace = go.Scatter(
x=macd_data['timestamp'],
y=macd_data['macd'],
mode='lines',
name=f'{self.config.name} Line',
line=dict(color=self.config.color, width=2),
showlegend=True
)
# Create signal line trace
signal_trace = go.Scatter(
x=macd_data['timestamp'],
y=macd_data['signal'],
mode='lines',
name=f'{self.config.name} Signal',
line=dict(color='#FF9800', width=2),
showlegend=True
)
# Create histogram
histogram_colors = ['green' if h >= 0 else 'red' for h in macd_data['histogram']]
histogram_trace = go.Bar(
x=macd_data['timestamp'],
y=macd_data['histogram'],
name=f'{self.config.name} Histogram',
marker_color=histogram_colors,
opacity=0.6,
showlegend=True
)
# Add traces
if hasattr(fig, 'add_trace'):
fig.add_trace(macd_trace, row=row, col=col)
fig.add_trace(signal_trace, row=row, col=col)
fig.add_trace(histogram_trace, row=row, col=col)
else:
fig.add_trace(macd_trace)
fig.add_trace(signal_trace)
fig.add_trace(histogram_trace)
# Add zero line
self.add_reference_lines(fig, row, col)
return fig
except Exception as e:
self.logger.error(f"Subplot layers: Error rendering MACD layer: {e}")
return fig
def create_rsi_layer(period: int = 14, **kwargs) -> 'RSILayer':
"""
Convenience function to create an RSI layer.
Args:
period: RSI period (default: 14)
**kwargs: Additional configuration options
Returns:
Configured RSI layer
"""
return RSILayer(period=period, **kwargs)
def create_macd_layer(fast_period: int = 12, slow_period: int = 26,
signal_period: int = 9, **kwargs) -> 'MACDLayer':
"""
Convenience function to create a MACD layer.
Args:
fast_period: Fast EMA period (default: 12)
slow_period: Slow EMA period (default: 26)
signal_period: Signal line period (default: 9)
**kwargs: Additional configuration options
Returns:
Configured MACD layer
"""
return MACDLayer(
fast_period=fast_period,
slow_period=slow_period,
signal_period=signal_period,
**kwargs
)
def create_common_subplot_indicators() -> List[BaseSubplotLayer]:
"""
Create commonly used subplot indicators.
Returns:
List of configured subplot indicator layers (RSI, MACD)
"""
return [
RSILayer(period=14),
MACDLayer(fast_period=12, slow_period=26, signal_period=9)
]