3.4 -2.0 Indicator Layer System Implementation
Implement modular chart layers and error handling for Crypto Trading Bot Dashboard - Introduced a comprehensive chart layer system in `components/charts/layers/` to support various technical indicators and subplots. - Added base layer components including `BaseLayer`, `CandlestickLayer`, and `VolumeLayer` for flexible chart rendering. - Implemented overlay indicators such as `SMALayer`, `EMALayer`, and `BollingerBandsLayer` with robust error handling. - Created subplot layers for indicators like `RSILayer` and `MACDLayer`, enhancing visualization capabilities. - Developed a `MarketDataIntegrator` for seamless data fetching and validation, improving data quality assurance. - Enhanced error handling utilities in `components/charts/error_handling.py` to manage insufficient data scenarios effectively. - Updated documentation to reflect the new chart layer architecture and usage guidelines. - Added unit tests for all chart layer components to ensure functionality and reliability.
This commit is contained in:
@@ -1,13 +1,89 @@
|
||||
"""
|
||||
Chart Layers Package
|
||||
|
||||
This package contains the modular chart layer system for rendering different
|
||||
chart components including candlesticks, indicators, and signals.
|
||||
This package contains the modular layer system for building complex charts
|
||||
with multiple indicators, signals, and subplots.
|
||||
|
||||
Components:
|
||||
- BaseChartLayer: Abstract base class for all layers
|
||||
- CandlestickLayer: OHLC price chart layer
|
||||
- VolumeLayer: Volume subplot layer
|
||||
- LayerManager: Orchestrates multiple layers
|
||||
- SMALayer: Simple Moving Average indicator overlay
|
||||
- EMALayer: Exponential Moving Average indicator overlay
|
||||
- BollingerBandsLayer: Bollinger Bands overlay with fill area
|
||||
- RSILayer: RSI oscillator subplot
|
||||
- MACDLayer: MACD lines and histogram subplot
|
||||
"""
|
||||
|
||||
# Package metadata
|
||||
from .base import (
|
||||
BaseChartLayer,
|
||||
CandlestickLayer,
|
||||
VolumeLayer,
|
||||
LayerManager,
|
||||
LayerConfig
|
||||
)
|
||||
|
||||
from .indicators import (
|
||||
BaseIndicatorLayer,
|
||||
IndicatorLayerConfig,
|
||||
SMALayer,
|
||||
EMALayer,
|
||||
BollingerBandsLayer,
|
||||
create_sma_layer,
|
||||
create_ema_layer,
|
||||
create_bollinger_bands_layer,
|
||||
create_common_ma_layers,
|
||||
create_common_overlay_indicators
|
||||
)
|
||||
|
||||
from .subplots import (
|
||||
BaseSubplotLayer,
|
||||
SubplotLayerConfig,
|
||||
RSILayer,
|
||||
MACDLayer,
|
||||
create_rsi_layer,
|
||||
create_macd_layer,
|
||||
create_common_subplot_indicators
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base layers
|
||||
'BaseChartLayer',
|
||||
'CandlestickLayer',
|
||||
'VolumeLayer',
|
||||
'LayerManager',
|
||||
'LayerConfig',
|
||||
|
||||
# Indicator layers (overlays)
|
||||
'BaseIndicatorLayer',
|
||||
'IndicatorLayerConfig',
|
||||
'SMALayer',
|
||||
'EMALayer',
|
||||
'BollingerBandsLayer',
|
||||
|
||||
# Subplot layers
|
||||
'BaseSubplotLayer',
|
||||
'SubplotLayerConfig',
|
||||
'RSILayer',
|
||||
'MACDLayer',
|
||||
|
||||
# Convenience functions
|
||||
'create_sma_layer',
|
||||
'create_ema_layer',
|
||||
'create_bollinger_bands_layer',
|
||||
'create_common_ma_layers',
|
||||
'create_common_overlay_indicators',
|
||||
'create_rsi_layer',
|
||||
'create_macd_layer',
|
||||
'create_common_subplot_indicators'
|
||||
]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__package_name__ = "layers"
|
||||
|
||||
# Package metadata
|
||||
# __version__ = "0.1.0"
|
||||
# __package_name__ = "layers"
|
||||
|
||||
# Layers will be imported once they are created
|
||||
# from .base import BaseCandlestickLayer
|
||||
@@ -16,9 +92,9 @@ __package_name__ = "layers"
|
||||
# from .signals import SignalLayer
|
||||
|
||||
# Public exports (will be populated as layers are implemented)
|
||||
__all__ = [
|
||||
# "BaseCandlestickLayer",
|
||||
# "IndicatorLayer",
|
||||
# "SubplotManager",
|
||||
# "SignalLayer"
|
||||
]
|
||||
# __all__ = [
|
||||
# # "BaseCandlestickLayer",
|
||||
# # "IndicatorLayer",
|
||||
# # "SubplotManager",
|
||||
# # "SignalLayer"
|
||||
# ]
|
||||
952
components/charts/layers/base.py
Normal file
952
components/charts/layers/base.py
Normal file
@@ -0,0 +1,952 @@
|
||||
"""
|
||||
Base Chart Layer Components
|
||||
|
||||
This module contains the foundational layer classes that serve as building blocks
|
||||
for all chart components including candlestick charts, indicators, and signals.
|
||||
"""
|
||||
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from utils.logger import get_logger
|
||||
from ..error_handling import (
|
||||
ChartErrorHandler, ChartError, ErrorSeverity,
|
||||
InsufficientDataError, DataValidationError, IndicatorCalculationError,
|
||||
create_error_annotation, get_error_message
|
||||
)
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger("chart_layers")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerConfig:
|
||||
"""Configuration for chart layers"""
|
||||
name: str
|
||||
enabled: bool = True
|
||||
color: Optional[str] = None
|
||||
style: Dict[str, Any] = None
|
||||
subplot_row: Optional[int] = None # None = main chart, 1+ = subplot row
|
||||
|
||||
def __post_init__(self):
|
||||
if self.style is None:
|
||||
self.style = {}
|
||||
|
||||
|
||||
class BaseLayer:
|
||||
"""
|
||||
Base class for all chart layers providing common functionality
|
||||
for data validation, error handling, and trace management.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LayerConfig):
|
||||
self.config = config
|
||||
self.logger = get_logger(f"chart_layer_{self.__class__.__name__.lower()}")
|
||||
self.error_handler = ChartErrorHandler()
|
||||
self.traces = []
|
||||
self._is_valid = False
|
||||
self._error_message = None
|
||||
|
||||
def validate_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]]) -> bool:
|
||||
"""
|
||||
Validate input data for layer requirements.
|
||||
|
||||
Args:
|
||||
data: Input data to validate
|
||||
|
||||
Returns:
|
||||
True if data is valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.error_handler.clear_errors()
|
||||
|
||||
# Check data type
|
||||
if not isinstance(data, (pd.DataFrame, list)):
|
||||
error = ChartError(
|
||||
code='INVALID_DATA_TYPE',
|
||||
message=f'Invalid data type for {self.__class__.__name__}: {type(data)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'layer': self.__class__.__name__, 'data_type': str(type(data))},
|
||||
recovery_suggestion='Provide data as pandas DataFrame or list of dictionaries'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
# Check data sufficiency
|
||||
is_sufficient = self.error_handler.validate_data_sufficiency(
|
||||
data,
|
||||
chart_type='candlestick', # Default chart type since LayerConfig doesn't have layer_type
|
||||
indicators=[{'type': 'candlestick', 'parameters': {}}] # Default indicator type
|
||||
)
|
||||
|
||||
self._is_valid = is_sufficient
|
||||
if not is_sufficient:
|
||||
self._error_message = self.error_handler.get_user_friendly_message()
|
||||
|
||||
return is_sufficient
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Data validation error in {self.__class__.__name__}: {e}")
|
||||
error = ChartError(
|
||||
code='VALIDATION_EXCEPTION',
|
||||
message=f'Validation error: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'layer': self.__class__.__name__, 'exception': str(e)}
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
self._is_valid = False
|
||||
self._error_message = str(e)
|
||||
return False
|
||||
|
||||
def get_error_info(self) -> Dict[str, Any]:
|
||||
"""Get error information for this layer"""
|
||||
return {
|
||||
'is_valid': self._is_valid,
|
||||
'error_message': self._error_message,
|
||||
'error_summary': self.error_handler.get_error_summary(),
|
||||
'can_proceed': len(self.error_handler.errors) == 0
|
||||
}
|
||||
|
||||
def create_error_trace(self, error_message: str) -> go.Scatter:
|
||||
"""Create an error display trace"""
|
||||
return go.Scatter(
|
||||
x=[],
|
||||
y=[],
|
||||
mode='text',
|
||||
text=[error_message],
|
||||
textposition='middle center',
|
||||
textfont={'size': 14, 'color': '#e74c3c'},
|
||||
showlegend=False,
|
||||
name=f"{self.__class__.__name__} Error"
|
||||
)
|
||||
|
||||
|
||||
class BaseChartLayer(ABC):
|
||||
"""
|
||||
Abstract base class for all chart layers.
|
||||
|
||||
This defines the interface that all chart layers must implement,
|
||||
whether they are candlestick charts, indicators, or signal overlays.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LayerConfig):
|
||||
"""
|
||||
Initialize the base layer.
|
||||
|
||||
Args:
|
||||
config: Layer configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
|
||||
@abstractmethod
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""
|
||||
Render the layer onto the provided figure.
|
||||
|
||||
Args:
|
||||
fig: Plotly figure to render onto
|
||||
data: Chart data (OHLCV format)
|
||||
**kwargs: Additional rendering parameters
|
||||
|
||||
Returns:
|
||||
Updated figure with layer rendered
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_data(self, data: pd.DataFrame) -> bool:
|
||||
"""
|
||||
Validate that the data is suitable for this layer.
|
||||
|
||||
Args:
|
||||
data: Chart data to validate
|
||||
|
||||
Returns:
|
||||
True if data is valid, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if the layer is enabled."""
|
||||
return self.config.enabled
|
||||
|
||||
def get_subplot_row(self) -> Optional[int]:
|
||||
"""Get the subplot row for this layer."""
|
||||
return self.config.subplot_row
|
||||
|
||||
def is_overlay(self) -> bool:
|
||||
"""Check if this layer is an overlay (main chart) or subplot."""
|
||||
return self.config.subplot_row is None
|
||||
|
||||
|
||||
class CandlestickLayer(BaseLayer):
|
||||
"""
|
||||
Candlestick chart layer implementation with enhanced error handling.
|
||||
|
||||
This layer renders OHLC data as candlesticks on the main chart.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LayerConfig = None):
|
||||
"""
|
||||
Initialize candlestick layer.
|
||||
|
||||
Args:
|
||||
config: Layer configuration (optional, uses defaults)
|
||||
"""
|
||||
if config is None:
|
||||
config = LayerConfig(
|
||||
name="candlestick",
|
||||
enabled=True,
|
||||
style={
|
||||
'increasing_color': '#00C851', # Green for bullish
|
||||
'decreasing_color': '#FF4444', # Red for bearish
|
||||
'line_width': 1
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if the layer is enabled."""
|
||||
return self.config.enabled
|
||||
|
||||
def is_overlay(self) -> bool:
|
||||
"""Check if this layer is an overlay (main chart) or subplot."""
|
||||
return self.config.subplot_row is None
|
||||
|
||||
def get_subplot_row(self) -> Optional[int]:
|
||||
"""Get the subplot row for this layer."""
|
||||
return self.config.subplot_row
|
||||
|
||||
def validate_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]]) -> bool:
|
||||
"""Enhanced validation with comprehensive error handling"""
|
||||
try:
|
||||
# Use parent class error handling for comprehensive validation
|
||||
parent_valid = super().validate_data(data)
|
||||
|
||||
# Convert to DataFrame if needed for local validation
|
||||
if isinstance(data, list):
|
||||
df = pd.DataFrame(data)
|
||||
else:
|
||||
df = data.copy()
|
||||
|
||||
# Additional candlestick-specific validation
|
||||
required_columns = ['timestamp', 'open', 'high', 'low', 'close']
|
||||
|
||||
if not all(col in df.columns for col in required_columns):
|
||||
missing = [col for col in required_columns if col not in df.columns]
|
||||
error = ChartError(
|
||||
code='MISSING_OHLC_COLUMNS',
|
||||
message=f'Missing required OHLC columns: {missing}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'missing_columns': missing, 'available_columns': list(df.columns)},
|
||||
recovery_suggestion='Ensure data contains timestamp, open, high, low, close columns'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
if len(df) == 0:
|
||||
error = ChartError(
|
||||
code='EMPTY_CANDLESTICK_DATA',
|
||||
message='No candlestick data available',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'data_count': 0},
|
||||
recovery_suggestion='Check data source or time range'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
# Check for price data validity
|
||||
invalid_prices = df[
|
||||
(df['high'] < df['low']) |
|
||||
(df['open'] < 0) | (df['close'] < 0) |
|
||||
(df['high'] < 0) | (df['low'] < 0) |
|
||||
pd.isna(df[['open', 'high', 'low', 'close']]).any(axis=1)
|
||||
]
|
||||
|
||||
if len(invalid_prices) > len(df) * 0.5: # More than 50% invalid
|
||||
error = ChartError(
|
||||
code='EXCESSIVE_INVALID_PRICES',
|
||||
message=f'Too many invalid price records: {len(invalid_prices)}/{len(df)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'invalid_count': len(invalid_prices), 'total_count': len(df)},
|
||||
recovery_suggestion='Check data quality and price data sources'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
elif len(invalid_prices) > 0:
|
||||
# Warning for some invalid data
|
||||
error = ChartError(
|
||||
code='SOME_INVALID_PRICES',
|
||||
message=f'Found {len(invalid_prices)} invalid price records (will be filtered)',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'invalid_count': len(invalid_prices), 'total_count': len(df)},
|
||||
recovery_suggestion='Invalid records will be automatically removed'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
|
||||
return parent_valid and len(self.error_handler.errors) == 0
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error validating candlestick data: {e}")
|
||||
error = ChartError(
|
||||
code='CANDLESTICK_VALIDATION_ERROR',
|
||||
message=f'Candlestick validation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'exception': str(e)}
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""
|
||||
Render candlestick chart with error handling and recovery.
|
||||
|
||||
Args:
|
||||
fig: Target figure
|
||||
data: OHLCV data
|
||||
**kwargs: Additional parameters (row, col for subplots)
|
||||
|
||||
Returns:
|
||||
Figure with candlestick trace added or error display
|
||||
"""
|
||||
try:
|
||||
# Validate data
|
||||
if not self.validate_data(data):
|
||||
self.logger.error("Invalid data for candlestick layer")
|
||||
|
||||
# Add error annotation to figure
|
||||
if self.error_handler.errors:
|
||||
error_msg = self.error_handler.errors[0].message
|
||||
fig.add_annotation(create_error_annotation(
|
||||
f"Candlestick Error: {error_msg}",
|
||||
position='center'
|
||||
))
|
||||
return fig
|
||||
|
||||
# Clean and prepare data
|
||||
clean_data = self._clean_candlestick_data(data)
|
||||
if clean_data.empty:
|
||||
fig.add_annotation(create_error_annotation(
|
||||
"No valid candlestick data after cleaning",
|
||||
position='center'
|
||||
))
|
||||
return fig
|
||||
|
||||
# Extract styling
|
||||
style = self.config.style
|
||||
increasing_color = style.get('increasing_color', '#00C851')
|
||||
decreasing_color = style.get('decreasing_color', '#FF4444')
|
||||
|
||||
# Create candlestick trace
|
||||
candlestick = go.Candlestick(
|
||||
x=clean_data['timestamp'],
|
||||
open=clean_data['open'],
|
||||
high=clean_data['high'],
|
||||
low=clean_data['low'],
|
||||
close=clean_data['close'],
|
||||
name=self.config.name,
|
||||
increasing_line_color=increasing_color,
|
||||
decreasing_line_color=decreasing_color,
|
||||
showlegend=False
|
||||
)
|
||||
|
||||
# Add to figure
|
||||
row = kwargs.get('row', 1)
|
||||
col = kwargs.get('col', 1)
|
||||
|
||||
try:
|
||||
if hasattr(fig, 'add_trace') and row == 1 and col == 1:
|
||||
# Simple figure without subplots
|
||||
fig.add_trace(candlestick)
|
||||
elif hasattr(fig, 'add_trace'):
|
||||
# Subplot figure
|
||||
fig.add_trace(candlestick, row=row, col=col)
|
||||
else:
|
||||
# Fallback
|
||||
fig.add_trace(candlestick)
|
||||
except Exception as trace_error:
|
||||
# If subplot call fails, try simple add_trace
|
||||
try:
|
||||
fig.add_trace(candlestick)
|
||||
except Exception as fallback_error:
|
||||
self.logger.error(f"Failed to add candlestick trace: {fallback_error}")
|
||||
fig.add_annotation(create_error_annotation(
|
||||
f"Failed to add candlestick trace: {str(fallback_error)}",
|
||||
position='center'
|
||||
))
|
||||
return fig
|
||||
|
||||
# Add warning annotations if needed
|
||||
if self.error_handler.warnings:
|
||||
warning_msg = f"⚠️ {self.error_handler.warnings[0].message}"
|
||||
fig.add_annotation({
|
||||
'text': warning_msg,
|
||||
'xref': 'paper', 'yref': 'paper',
|
||||
'x': 0.02, 'y': 0.98,
|
||||
'xanchor': 'left', 'yanchor': 'top',
|
||||
'showarrow': False,
|
||||
'font': {'size': 10, 'color': '#f39c12'},
|
||||
'bgcolor': 'rgba(255,255,255,0.8)'
|
||||
})
|
||||
|
||||
self.logger.debug(f"Rendered candlestick layer with {len(clean_data)} candles")
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering candlestick layer: {e}")
|
||||
fig.add_annotation(create_error_annotation(
|
||||
f"Candlestick render error: {str(e)}",
|
||||
position='center'
|
||||
))
|
||||
return fig
|
||||
|
||||
def _clean_candlestick_data(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Clean and validate candlestick data"""
|
||||
try:
|
||||
clean_data = data.copy()
|
||||
|
||||
# Remove rows with invalid prices
|
||||
invalid_mask = (
|
||||
(clean_data['high'] < clean_data['low']) |
|
||||
(clean_data['open'] < 0) | (clean_data['close'] < 0) |
|
||||
(clean_data['high'] < 0) | (clean_data['low'] < 0) |
|
||||
pd.isna(clean_data[['open', 'high', 'low', 'close']]).any(axis=1)
|
||||
)
|
||||
|
||||
initial_count = len(clean_data)
|
||||
clean_data = clean_data[~invalid_mask]
|
||||
|
||||
if len(clean_data) < initial_count:
|
||||
removed_count = initial_count - len(clean_data)
|
||||
self.logger.info(f"Removed {removed_count} invalid candlestick records")
|
||||
|
||||
# Ensure timestamp is properly formatted
|
||||
if not pd.api.types.is_datetime64_any_dtype(clean_data['timestamp']):
|
||||
clean_data['timestamp'] = pd.to_datetime(clean_data['timestamp'])
|
||||
|
||||
# Sort by timestamp
|
||||
clean_data = clean_data.sort_values('timestamp')
|
||||
|
||||
return clean_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning candlestick data: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
class VolumeLayer(BaseLayer):
|
||||
"""
|
||||
Volume subplot layer implementation with enhanced error handling.
|
||||
|
||||
This layer renders volume data as a bar chart in a separate subplot,
|
||||
with bars colored based on price movement.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LayerConfig = None):
|
||||
"""
|
||||
Initialize volume layer.
|
||||
|
||||
Args:
|
||||
config: Layer configuration (optional, uses defaults)
|
||||
"""
|
||||
if config is None:
|
||||
config = LayerConfig(
|
||||
name="volume",
|
||||
enabled=True,
|
||||
subplot_row=2, # Volume goes in second row by default
|
||||
style={
|
||||
'bullish_color': '#00C851',
|
||||
'bearish_color': '#FF4444',
|
||||
'opacity': 0.7
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if the layer is enabled."""
|
||||
return self.config.enabled
|
||||
|
||||
def is_overlay(self) -> bool:
|
||||
"""Check if this layer is an overlay (main chart) or subplot."""
|
||||
return self.config.subplot_row is None
|
||||
|
||||
def get_subplot_row(self) -> Optional[int]:
|
||||
"""Get the subplot row for this layer."""
|
||||
return self.config.subplot_row
|
||||
|
||||
def validate_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]]) -> bool:
|
||||
"""Enhanced validation with comprehensive error handling"""
|
||||
try:
|
||||
# Use parent class error handling
|
||||
parent_valid = super().validate_data(data)
|
||||
|
||||
# Convert to DataFrame if needed
|
||||
if isinstance(data, list):
|
||||
df = pd.DataFrame(data)
|
||||
else:
|
||||
df = data.copy()
|
||||
|
||||
# Volume-specific validation
|
||||
required_columns = ['timestamp', 'open', 'close', 'volume']
|
||||
|
||||
if not all(col in df.columns for col in required_columns):
|
||||
missing = [col for col in required_columns if col not in df.columns]
|
||||
error = ChartError(
|
||||
code='MISSING_VOLUME_COLUMNS',
|
||||
message=f'Missing required volume columns: {missing}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'missing_columns': missing, 'available_columns': list(df.columns)},
|
||||
recovery_suggestion='Ensure data contains timestamp, open, close, volume columns'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
if len(df) == 0:
|
||||
error = ChartError(
|
||||
code='EMPTY_VOLUME_DATA',
|
||||
message='No volume data available',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'data_count': 0},
|
||||
recovery_suggestion='Check data source or time range'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
# Check if volume data exists and is valid
|
||||
valid_volume_mask = (df['volume'] >= 0) & pd.notna(df['volume'])
|
||||
valid_volume_count = valid_volume_mask.sum()
|
||||
|
||||
if valid_volume_count == 0:
|
||||
error = ChartError(
|
||||
code='NO_VALID_VOLUME',
|
||||
message='No valid volume data found',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'total_records': len(df), 'valid_volume': 0},
|
||||
recovery_suggestion='Volume chart will be skipped'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
|
||||
elif valid_volume_count < len(df) * 0.5: # Less than 50% valid
|
||||
error = ChartError(
|
||||
code='MOSTLY_INVALID_VOLUME',
|
||||
message=f'Most volume data is invalid: {valid_volume_count}/{len(df)} valid',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'total_records': len(df), 'valid_volume': valid_volume_count},
|
||||
recovery_suggestion='Invalid volume records will be filtered out'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
|
||||
elif df['volume'].sum() <= 0:
|
||||
error = ChartError(
|
||||
code='ZERO_VOLUME_TOTAL',
|
||||
message='Total volume is zero or negative',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'volume_sum': float(df['volume'].sum())},
|
||||
recovery_suggestion='Volume chart may not be meaningful'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
|
||||
return parent_valid and valid_volume_count > 0
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error validating volume data: {e}")
|
||||
error = ChartError(
|
||||
code='VOLUME_VALIDATION_ERROR',
|
||||
message=f'Volume validation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'exception': str(e)}
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""
|
||||
Render volume bars with error handling and recovery.
|
||||
|
||||
Args:
|
||||
fig: Target figure (must be subplot figure)
|
||||
data: OHLCV data
|
||||
**kwargs: Additional parameters (row, col for subplots)
|
||||
|
||||
Returns:
|
||||
Figure with volume trace added or error handling
|
||||
"""
|
||||
try:
|
||||
# Validate data
|
||||
if not self.validate_data(data):
|
||||
# Check if we can skip gracefully (warnings only)
|
||||
if not self.error_handler.errors and self.error_handler.warnings:
|
||||
self.logger.debug("Skipping volume layer due to warnings")
|
||||
return fig
|
||||
else:
|
||||
self.logger.error("Invalid data for volume layer")
|
||||
return fig
|
||||
|
||||
# Clean and prepare data
|
||||
clean_data = self._clean_volume_data(data)
|
||||
if clean_data.empty:
|
||||
self.logger.debug("No valid volume data after cleaning")
|
||||
return fig
|
||||
|
||||
# Calculate bar colors based on price movement
|
||||
style = self.config.style
|
||||
bullish_color = style.get('bullish_color', '#00C851')
|
||||
bearish_color = style.get('bearish_color', '#FF4444')
|
||||
opacity = style.get('opacity', 0.7)
|
||||
|
||||
colors = [
|
||||
bullish_color if close >= open_price else bearish_color
|
||||
for close, open_price in zip(clean_data['close'], clean_data['open'])
|
||||
]
|
||||
|
||||
# Create volume bar trace
|
||||
volume_bars = go.Bar(
|
||||
x=clean_data['timestamp'],
|
||||
y=clean_data['volume'],
|
||||
name='Volume',
|
||||
marker_color=colors,
|
||||
opacity=opacity,
|
||||
showlegend=False
|
||||
)
|
||||
|
||||
# Add to figure
|
||||
row = kwargs.get('row', 2) # Default to row 2 for volume
|
||||
col = kwargs.get('col', 1)
|
||||
|
||||
fig.add_trace(volume_bars, row=row, col=col)
|
||||
|
||||
self.logger.debug(f"Rendered volume layer with {len(clean_data)} bars")
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering volume layer: {e}")
|
||||
return fig
|
||||
|
||||
def _clean_volume_data(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Clean and validate volume data"""
|
||||
try:
|
||||
clean_data = data.copy()
|
||||
|
||||
# Remove rows with invalid volume
|
||||
valid_mask = (clean_data['volume'] >= 0) & pd.notna(clean_data['volume'])
|
||||
initial_count = len(clean_data)
|
||||
clean_data = clean_data[valid_mask]
|
||||
|
||||
if len(clean_data) < initial_count:
|
||||
removed_count = initial_count - len(clean_data)
|
||||
self.logger.info(f"Removed {removed_count} invalid volume records")
|
||||
|
||||
# Ensure timestamp is properly formatted
|
||||
if not pd.api.types.is_datetime64_any_dtype(clean_data['timestamp']):
|
||||
clean_data['timestamp'] = pd.to_datetime(clean_data['timestamp'])
|
||||
|
||||
# Sort by timestamp
|
||||
clean_data = clean_data.sort_values('timestamp')
|
||||
|
||||
return clean_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning volume data: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
class LayerManager:
|
||||
"""
|
||||
Manager class for coordinating multiple chart layers.
|
||||
|
||||
This class handles the orchestration of multiple layers, including
|
||||
setting up subplots and rendering layers in the correct order.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the layer manager."""
|
||||
self.layers: List[BaseLayer] = []
|
||||
self.logger = logger
|
||||
|
||||
def add_layer(self, layer: BaseLayer) -> None:
|
||||
"""
|
||||
Add a layer to the manager.
|
||||
|
||||
Args:
|
||||
layer: Chart layer to add
|
||||
"""
|
||||
self.layers.append(layer)
|
||||
self.logger.debug(f"Added layer: {layer.config.name}")
|
||||
|
||||
def remove_layer(self, layer_name: str) -> bool:
|
||||
"""
|
||||
Remove a layer by name.
|
||||
|
||||
Args:
|
||||
layer_name: Name of layer to remove
|
||||
|
||||
Returns:
|
||||
True if layer was removed, False if not found
|
||||
"""
|
||||
for i, layer in enumerate(self.layers):
|
||||
if layer.config.name == layer_name:
|
||||
self.layers.pop(i)
|
||||
self.logger.debug(f"Removed layer: {layer_name}")
|
||||
return True
|
||||
|
||||
self.logger.warning(f"Layer not found for removal: {layer_name}")
|
||||
return False
|
||||
|
||||
def get_enabled_layers(self) -> List[BaseLayer]:
|
||||
"""Get list of enabled layers."""
|
||||
return [layer for layer in self.layers if layer.is_enabled()]
|
||||
|
||||
def get_overlay_layers(self) -> List[BaseLayer]:
|
||||
"""Get layers that render on the main chart."""
|
||||
return [layer for layer in self.get_enabled_layers() if layer.is_overlay()]
|
||||
|
||||
def get_subplot_layers(self) -> Dict[int, List[BaseLayer]]:
|
||||
"""Get layers grouped by subplot row."""
|
||||
subplot_layers = {}
|
||||
|
||||
for layer in self.get_enabled_layers():
|
||||
if not layer.is_overlay():
|
||||
row = layer.get_subplot_row()
|
||||
if row not in subplot_layers:
|
||||
subplot_layers[row] = []
|
||||
subplot_layers[row].append(layer)
|
||||
|
||||
return subplot_layers
|
||||
|
||||
def calculate_subplot_layout(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate subplot configuration based on layers.
|
||||
|
||||
Returns:
|
||||
Dict with subplot configuration parameters
|
||||
"""
|
||||
subplot_layers = self.get_subplot_layers()
|
||||
|
||||
if not subplot_layers:
|
||||
# No subplots needed
|
||||
return {
|
||||
'rows': 1,
|
||||
'cols': 1,
|
||||
'subplot_titles': None,
|
||||
'row_heights': None
|
||||
}
|
||||
|
||||
# Reassign subplot rows dynamically to ensure proper ordering
|
||||
self._reassign_subplot_rows()
|
||||
|
||||
# Recalculate after reassignment
|
||||
subplot_layers = self.get_subplot_layers()
|
||||
|
||||
# Calculate number of rows (main chart + subplots)
|
||||
max_subplot_row = max(subplot_layers.keys()) if subplot_layers else 0
|
||||
total_rows = max(1, max_subplot_row) # Row numbers are 1-indexed, so max_subplot_row is the total rows needed
|
||||
|
||||
# Create subplot titles
|
||||
subplot_titles = ['Price'] # Main chart
|
||||
for row in range(2, total_rows + 1):
|
||||
if row in subplot_layers:
|
||||
# Use the first layer's name as the subtitle
|
||||
layer_names = [layer.config.name for layer in subplot_layers[row]]
|
||||
subplot_titles.append(' / '.join(layer_names).title())
|
||||
else:
|
||||
subplot_titles.append(f'Subplot {row}')
|
||||
|
||||
# Calculate row heights based on subplot height ratios
|
||||
row_heights = self._calculate_dynamic_row_heights(subplot_layers, total_rows)
|
||||
|
||||
return {
|
||||
'rows': total_rows,
|
||||
'cols': 1,
|
||||
'subplot_titles': subplot_titles,
|
||||
'row_heights': row_heights,
|
||||
'shared_xaxes': True,
|
||||
'vertical_spacing': 0.03
|
||||
}
|
||||
|
||||
def _reassign_subplot_rows(self) -> None:
|
||||
"""
|
||||
Reassign subplot rows to ensure proper sequential ordering.
|
||||
|
||||
This method dynamically assigns subplot rows starting from row 2,
|
||||
ensuring no gaps in the subplot layout.
|
||||
"""
|
||||
subplot_layers = []
|
||||
|
||||
# Collect all subplot layers
|
||||
for layer in self.get_enabled_layers():
|
||||
if not layer.is_overlay():
|
||||
subplot_layers.append(layer)
|
||||
|
||||
# Sort by priority: volume first, then by current subplot row
|
||||
def layer_priority(layer):
|
||||
# Volume gets highest priority (0), then by current row
|
||||
if hasattr(layer, 'config') and layer.config.name == 'volume':
|
||||
return (0, layer.get_subplot_row() or 999)
|
||||
else:
|
||||
return (1, layer.get_subplot_row() or 999)
|
||||
|
||||
subplot_layers.sort(key=layer_priority)
|
||||
|
||||
# Reassign rows starting from 2
|
||||
for i, layer in enumerate(subplot_layers):
|
||||
new_row = i + 2 # Start from row 2 (row 1 is main chart)
|
||||
layer.config.subplot_row = new_row
|
||||
self.logger.debug(f"Assigned {layer.config.name} to subplot row {new_row}")
|
||||
|
||||
def _calculate_dynamic_row_heights(self, subplot_layers: Dict[int, List], total_rows: int) -> List[float]:
|
||||
"""
|
||||
Calculate row heights based on subplot height ratios.
|
||||
|
||||
Args:
|
||||
subplot_layers: Dictionary of subplot layers by row
|
||||
total_rows: Total number of rows
|
||||
|
||||
Returns:
|
||||
List of height ratios for each row
|
||||
"""
|
||||
if total_rows == 1:
|
||||
return [1.0] # Single row gets full height
|
||||
|
||||
# Calculate total requested subplot height
|
||||
total_subplot_ratio = 0.0
|
||||
subplot_ratios = {}
|
||||
|
||||
for row in range(2, total_rows + 1):
|
||||
if row in subplot_layers:
|
||||
# Get height ratio from first layer in the row
|
||||
layer = subplot_layers[row][0]
|
||||
if hasattr(layer, 'get_subplot_height_ratio'):
|
||||
ratio = layer.get_subplot_height_ratio()
|
||||
else:
|
||||
ratio = 0.25 # Default ratio
|
||||
subplot_ratios[row] = ratio
|
||||
total_subplot_ratio += ratio
|
||||
else:
|
||||
subplot_ratios[row] = 0.25 # Default for empty rows
|
||||
total_subplot_ratio += 0.25
|
||||
|
||||
# Ensure total doesn't exceed reasonable limits
|
||||
max_subplot_ratio = 0.6 # Maximum 60% for all subplots
|
||||
if total_subplot_ratio > max_subplot_ratio:
|
||||
# Scale down proportionally
|
||||
scale_factor = max_subplot_ratio / total_subplot_ratio
|
||||
for row in subplot_ratios:
|
||||
subplot_ratios[row] *= scale_factor
|
||||
total_subplot_ratio = max_subplot_ratio
|
||||
|
||||
# Main chart gets remaining space
|
||||
main_chart_ratio = 1.0 - total_subplot_ratio
|
||||
|
||||
# Build final height list
|
||||
row_heights = [main_chart_ratio] # Main chart
|
||||
for row in range(2, total_rows + 1):
|
||||
row_heights.append(subplot_ratios.get(row, 0.25))
|
||||
|
||||
return row_heights
|
||||
|
||||
def render_all_layers(self, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""
|
||||
Render all enabled layers onto a new figure.
|
||||
|
||||
Args:
|
||||
data: Chart data (OHLCV format)
|
||||
**kwargs: Additional rendering parameters
|
||||
|
||||
Returns:
|
||||
Complete figure with all layers rendered
|
||||
"""
|
||||
try:
|
||||
# Calculate subplot layout
|
||||
layout_config = self.calculate_subplot_layout()
|
||||
|
||||
# Create figure with subplots if needed
|
||||
if layout_config['rows'] > 1:
|
||||
fig = make_subplots(**layout_config)
|
||||
else:
|
||||
fig = go.Figure()
|
||||
|
||||
# Render overlay layers (main chart)
|
||||
overlay_layers = self.get_overlay_layers()
|
||||
for layer in overlay_layers:
|
||||
fig = layer.render(fig, data, row=1, col=1, **kwargs)
|
||||
|
||||
# Render subplot layers
|
||||
subplot_layers = self.get_subplot_layers()
|
||||
for row, layers in subplot_layers.items():
|
||||
for layer in layers:
|
||||
fig = layer.render(fig, data, row=row, col=1, **kwargs)
|
||||
|
||||
# Update layout styling
|
||||
self._apply_layout_styling(fig, layout_config)
|
||||
|
||||
self.logger.debug(f"Rendered {len(self.get_enabled_layers())} layers")
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering layers: {e}")
|
||||
# Return empty figure on error
|
||||
return go.Figure()
|
||||
|
||||
def _apply_layout_styling(self, fig: go.Figure, layout_config: Dict[str, Any]) -> None:
|
||||
"""Apply consistent styling to the figure layout."""
|
||||
try:
|
||||
# Basic layout settings
|
||||
fig.update_layout(
|
||||
template="plotly_white",
|
||||
showlegend=False,
|
||||
hovermode='x unified',
|
||||
xaxis_rangeslider_visible=False
|
||||
)
|
||||
|
||||
# Update axes for subplots
|
||||
if layout_config['rows'] > 1:
|
||||
# Update main chart axes
|
||||
fig.update_yaxes(title_text="Price (USDT)", row=1, col=1)
|
||||
fig.update_xaxes(showticklabels=False, row=1, col=1)
|
||||
|
||||
# Update subplot axes
|
||||
subplot_layers = self.get_subplot_layers()
|
||||
for row in range(2, layout_config['rows'] + 1):
|
||||
if row in subplot_layers:
|
||||
# Set y-axis title and range based on layer type
|
||||
layers_in_row = subplot_layers[row]
|
||||
layer = layers_in_row[0] # Use first layer for configuration
|
||||
|
||||
# Set y-axis title
|
||||
if hasattr(layer, 'config') and hasattr(layer.config, 'indicator_type'):
|
||||
indicator_type = layer.config.indicator_type
|
||||
if indicator_type == 'rsi':
|
||||
fig.update_yaxes(title_text="RSI", row=row, col=1)
|
||||
elif indicator_type == 'macd':
|
||||
fig.update_yaxes(title_text="MACD", row=row, col=1)
|
||||
else:
|
||||
layer_names = [l.config.name for l in layers_in_row]
|
||||
fig.update_yaxes(title_text=' / '.join(layer_names), row=row, col=1)
|
||||
|
||||
# Set fixed y-axis range if specified
|
||||
if hasattr(layer, 'has_fixed_range') and layer.has_fixed_range():
|
||||
y_range = layer.get_y_axis_range()
|
||||
if y_range:
|
||||
fig.update_yaxes(range=list(y_range), row=row, col=1)
|
||||
|
||||
# Only show x-axis labels on the bottom subplot
|
||||
if row == layout_config['rows']:
|
||||
fig.update_xaxes(title_text="Time", row=row, col=1)
|
||||
else:
|
||||
fig.update_xaxes(showticklabels=False, row=row, col=1)
|
||||
else:
|
||||
# Single chart
|
||||
fig.update_layout(
|
||||
xaxis_title="Time",
|
||||
yaxis_title="Price (USDT)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error applying layout styling: {e}")
|
||||
720
components/charts/layers/indicators.py
Normal file
720
components/charts/layers/indicators.py
Normal file
@@ -0,0 +1,720 @@
|
||||
"""
|
||||
Technical Indicator Chart Layers
|
||||
|
||||
This module implements overlay indicator layers for technical analysis visualization
|
||||
including SMA, EMA, and Bollinger Bands with comprehensive error handling.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
from typing import Dict, Any, Optional, List, Union, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..error_handling import (
|
||||
ChartErrorHandler, ChartError, ErrorSeverity, DataRequirements,
|
||||
InsufficientDataError, DataValidationError, IndicatorCalculationError,
|
||||
ErrorRecoveryStrategies, create_error_annotation, get_error_message
|
||||
)
|
||||
|
||||
from .base import BaseLayer, LayerConfig
|
||||
from data.common.indicators import TechnicalIndicators, OHLCVCandle
|
||||
from components.charts.utils import get_indicator_colors
|
||||
from utils.logger import get_logger
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger("chart_indicators")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndicatorLayerConfig(LayerConfig):
|
||||
"""Extended configuration for indicator layers"""
|
||||
indicator_type: str = "" # e.g., 'sma', 'ema', 'rsi'
|
||||
parameters: Dict[str, Any] = None # Indicator-specific parameters
|
||||
line_width: int = 2
|
||||
opacity: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.parameters is None:
|
||||
self.parameters = {}
|
||||
|
||||
|
||||
class BaseIndicatorLayer(BaseLayer):
|
||||
"""
|
||||
Enhanced base class for all indicator layers with comprehensive error handling.
|
||||
"""
|
||||
|
||||
def __init__(self, config: IndicatorLayerConfig):
|
||||
"""
|
||||
Initialize base indicator layer.
|
||||
|
||||
Args:
|
||||
config: Indicator layer configuration
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.indicators = TechnicalIndicators()
|
||||
self.colors = get_indicator_colors()
|
||||
self.calculated_data = None
|
||||
self.calculation_errors = []
|
||||
|
||||
def prepare_indicator_data(self, data: pd.DataFrame) -> List[OHLCVCandle]:
|
||||
"""
|
||||
Convert DataFrame to OHLCVCandle format for indicator calculations.
|
||||
|
||||
Args:
|
||||
data: Chart data (OHLCV format)
|
||||
|
||||
Returns:
|
||||
List of OHLCVCandle objects
|
||||
"""
|
||||
try:
|
||||
candles = []
|
||||
for _, row in data.iterrows():
|
||||
# Calculate start_time (assuming 1-minute candles for now)
|
||||
start_time = row['timestamp']
|
||||
end_time = row['timestamp']
|
||||
|
||||
candle = OHLCVCandle(
|
||||
symbol="BTCUSDT", # Default symbol for testing
|
||||
timeframe="1m", # Default timeframe
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
open=Decimal(str(row['open'])),
|
||||
high=Decimal(str(row['high'])),
|
||||
low=Decimal(str(row['low'])),
|
||||
close=Decimal(str(row['close'])),
|
||||
volume=Decimal(str(row.get('volume', 0))),
|
||||
trade_count=1, # Default trade count
|
||||
exchange="test", # Test exchange
|
||||
is_complete=True # Mark as complete for testing
|
||||
)
|
||||
candles.append(candle)
|
||||
|
||||
return candles
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error preparing indicator data: {e}")
|
||||
return []
|
||||
|
||||
def validate_indicator_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]],
|
||||
required_columns: List[str] = None) -> bool:
|
||||
"""
|
||||
Validate data specifically for indicator calculations.
|
||||
|
||||
Args:
|
||||
data: Input data
|
||||
required_columns: Required columns for this indicator
|
||||
|
||||
Returns:
|
||||
True if data is valid for indicator calculation
|
||||
"""
|
||||
try:
|
||||
# Use parent validation first
|
||||
if not super().validate_data(data):
|
||||
return False
|
||||
|
||||
# Convert to DataFrame if needed
|
||||
if isinstance(data, list):
|
||||
df = pd.DataFrame(data)
|
||||
else:
|
||||
df = data.copy()
|
||||
|
||||
# Check required columns for indicator
|
||||
if required_columns:
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
if missing_columns:
|
||||
error = ChartError(
|
||||
code='MISSING_INDICATOR_COLUMNS',
|
||||
message=f'Missing columns for {self.config.indicator_type}: {missing_columns}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={
|
||||
'indicator_type': self.config.indicator_type,
|
||||
'missing_columns': missing_columns,
|
||||
'available_columns': list(df.columns)
|
||||
},
|
||||
recovery_suggestion=f'Ensure data contains required columns: {required_columns}'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
# Check data sufficiency for indicator
|
||||
indicator_config = {
|
||||
'type': self.config.indicator_type,
|
||||
'parameters': self.config.parameters or {}
|
||||
}
|
||||
|
||||
indicator_error = DataRequirements.check_indicator_requirements(
|
||||
self.config.indicator_type,
|
||||
len(df),
|
||||
self.config.parameters or {}
|
||||
)
|
||||
|
||||
if indicator_error.severity == ErrorSeverity.WARNING:
|
||||
self.error_handler.warnings.append(indicator_error)
|
||||
elif indicator_error.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL]:
|
||||
self.error_handler.errors.append(indicator_error)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error validating indicator data: {e}")
|
||||
error = ChartError(
|
||||
code='INDICATOR_VALIDATION_ERROR',
|
||||
message=f'Indicator validation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'exception': str(e), 'indicator_type': self.config.indicator_type}
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
def safe_calculate_indicator(self, data: pd.DataFrame,
|
||||
calculation_func: Callable,
|
||||
**kwargs) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Safely calculate indicator with error handling.
|
||||
|
||||
Args:
|
||||
data: Input data
|
||||
calculation_func: Function to calculate indicator
|
||||
**kwargs: Additional arguments for calculation
|
||||
|
||||
Returns:
|
||||
Calculated indicator data or None if failed
|
||||
"""
|
||||
try:
|
||||
# Validate data first
|
||||
if not self.validate_indicator_data(data):
|
||||
return None
|
||||
|
||||
# Try calculation with recovery strategies
|
||||
result = calculation_func(data, **kwargs)
|
||||
|
||||
# Validate result
|
||||
if result is None or (isinstance(result, pd.DataFrame) and result.empty):
|
||||
error = ChartError(
|
||||
code='EMPTY_INDICATOR_RESULT',
|
||||
message=f'Indicator calculation returned no data: {self.config.indicator_type}',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'indicator_type': self.config.indicator_type, 'input_length': len(data)},
|
||||
recovery_suggestion='Check calculation parameters or input data range'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
return None
|
||||
|
||||
# Check for sufficient calculated data
|
||||
if isinstance(result, pd.DataFrame) and len(result) < len(data) * 0.1:
|
||||
error = ChartError(
|
||||
code='INSUFFICIENT_INDICATOR_OUTPUT',
|
||||
message=f'Very few indicator values calculated: {len(result)}/{len(data)}',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={
|
||||
'indicator_type': self.config.indicator_type,
|
||||
'output_length': len(result),
|
||||
'input_length': len(data)
|
||||
},
|
||||
recovery_suggestion='Consider adjusting indicator parameters'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
|
||||
self.calculated_data = result
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error calculating {self.config.indicator_type}: {e}")
|
||||
|
||||
# Try to apply error recovery
|
||||
recovery_strategy = ErrorRecoveryStrategies.handle_insufficient_data(
|
||||
ChartError(
|
||||
code='INDICATOR_CALCULATION_ERROR',
|
||||
message=f'Calculation failed for {self.config.indicator_type}: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'exception': str(e), 'indicator_type': self.config.indicator_type}
|
||||
),
|
||||
fallback_options={'data_length': len(data)}
|
||||
)
|
||||
|
||||
if recovery_strategy['can_proceed'] and recovery_strategy['fallback_action'] == 'adjust_parameters':
|
||||
# Try with adjusted parameters
|
||||
try:
|
||||
modified_config = recovery_strategy.get('modified_config', {})
|
||||
self.logger.info(f"Retrying indicator calculation with adjusted parameters: {modified_config}")
|
||||
|
||||
# Update parameters temporarily
|
||||
original_params = self.config.parameters.copy() if self.config.parameters else {}
|
||||
self.config.parameters.update(modified_config)
|
||||
|
||||
# Retry calculation
|
||||
result = calculation_func(data, **kwargs)
|
||||
|
||||
# Restore original parameters
|
||||
self.config.parameters = original_params
|
||||
|
||||
if result is not None and not (isinstance(result, pd.DataFrame) and result.empty):
|
||||
# Add warning about parameter adjustment
|
||||
warning = ChartError(
|
||||
code='INDICATOR_PARAMETERS_ADJUSTED',
|
||||
message=recovery_strategy['user_message'],
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'original_params': original_params, 'adjusted_params': modified_config}
|
||||
)
|
||||
self.error_handler.warnings.append(warning)
|
||||
self.calculated_data = result
|
||||
return result
|
||||
|
||||
except Exception as retry_error:
|
||||
self.logger.error(f"Retry with adjusted parameters also failed: {retry_error}")
|
||||
|
||||
# Final error if all recovery attempts fail
|
||||
error = ChartError(
|
||||
code='INDICATOR_CALCULATION_FAILED',
|
||||
message=f'Failed to calculate {self.config.indicator_type}: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'exception': str(e), 'indicator_type': self.config.indicator_type}
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return None
|
||||
|
||||
def create_indicator_traces(self, data: pd.DataFrame, subplot_row: int = 1) -> List[go.Scatter]:
|
||||
"""
|
||||
Create indicator traces with error handling.
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement create_indicator_traces")
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if the layer is enabled."""
|
||||
return self.config.enabled
|
||||
|
||||
def is_overlay(self) -> bool:
|
||||
"""Check if this layer is an overlay (main chart) or subplot."""
|
||||
return self.config.subplot_row is None
|
||||
|
||||
def get_subplot_row(self) -> Optional[int]:
|
||||
"""Get the subplot row for this layer."""
|
||||
return self.config.subplot_row
|
||||
|
||||
|
||||
class SMALayer(BaseIndicatorLayer):
|
||||
"""Simple Moving Average layer with enhanced error handling"""
|
||||
|
||||
def __init__(self, config: IndicatorLayerConfig = None):
|
||||
"""Initialize SMA layer"""
|
||||
if config is None:
|
||||
config = IndicatorLayerConfig(
|
||||
indicator_type='sma',
|
||||
parameters={'period': 20}
|
||||
)
|
||||
super().__init__(config)
|
||||
|
||||
def create_traces(self, data: List[Dict[str, Any]], subplot_row: int = 1) -> List[go.Scatter]:
|
||||
"""Create SMA traces with comprehensive error handling"""
|
||||
try:
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(data) if isinstance(data, list) else data.copy()
|
||||
|
||||
# Validate data
|
||||
if not self.validate_indicator_data(df, required_columns=['close', 'timestamp']):
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"SMA Error: {self._error_message}")]
|
||||
|
||||
# Calculate SMA with error handling
|
||||
period = self.config.parameters.get('period', 20)
|
||||
sma_data = self.safe_calculate_indicator(
|
||||
df,
|
||||
self._calculate_sma,
|
||||
period=period
|
||||
)
|
||||
|
||||
if sma_data is None:
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"SMA calculation failed")]
|
||||
else:
|
||||
return [] # Skip layer gracefully
|
||||
|
||||
# Create trace
|
||||
sma_trace = go.Scatter(
|
||||
x=sma_data['timestamp'],
|
||||
y=sma_data['sma'],
|
||||
mode='lines',
|
||||
name=f'SMA({period})',
|
||||
line=dict(
|
||||
color=self.config.color or '#2196F3',
|
||||
width=self.config.line_width
|
||||
),
|
||||
row=subplot_row,
|
||||
col=1
|
||||
)
|
||||
|
||||
self.traces = [sma_trace]
|
||||
return self.traces
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error creating SMA traces: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
return [self.create_error_trace(error_msg)]
|
||||
|
||||
def _calculate_sma(self, data: pd.DataFrame, period: int) -> pd.DataFrame:
|
||||
"""Calculate SMA with validation"""
|
||||
try:
|
||||
result_df = data.copy()
|
||||
result_df['sma'] = result_df['close'].rolling(window=period, min_periods=period).mean()
|
||||
|
||||
# Remove NaN values
|
||||
result_df = result_df.dropna(subset=['sma'])
|
||||
|
||||
if result_df.empty:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='SMA_NO_VALUES',
|
||||
message=f'SMA calculation produced no values (period={period}, data_length={len(data)})',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'data_length': len(data)}
|
||||
))
|
||||
|
||||
return result_df[['timestamp', 'sma']]
|
||||
|
||||
except Exception as e:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='SMA_CALCULATION_ERROR',
|
||||
message=f'SMA calculation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'data_length': len(data), 'exception': str(e)}
|
||||
))
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""Render SMA layer for compatibility with base interface"""
|
||||
try:
|
||||
traces = self.create_traces(data.to_dict('records'), **kwargs)
|
||||
for trace in traces:
|
||||
if hasattr(fig, 'add_trace'):
|
||||
fig.add_trace(trace, **kwargs)
|
||||
else:
|
||||
fig.add_trace(trace)
|
||||
return fig
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering SMA layer: {e}")
|
||||
return fig
|
||||
|
||||
|
||||
class EMALayer(BaseIndicatorLayer):
|
||||
"""Exponential Moving Average layer with enhanced error handling"""
|
||||
|
||||
def __init__(self, config: IndicatorLayerConfig = None):
|
||||
"""Initialize EMA layer"""
|
||||
if config is None:
|
||||
config = IndicatorLayerConfig(
|
||||
indicator_type='ema',
|
||||
parameters={'period': 20}
|
||||
)
|
||||
super().__init__(config)
|
||||
|
||||
def create_traces(self, data: List[Dict[str, Any]], subplot_row: int = 1) -> List[go.Scatter]:
|
||||
"""Create EMA traces with comprehensive error handling"""
|
||||
try:
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(data) if isinstance(data, list) else data.copy()
|
||||
|
||||
# Validate data
|
||||
if not self.validate_indicator_data(df, required_columns=['close', 'timestamp']):
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"EMA Error: {self._error_message}")]
|
||||
|
||||
# Calculate EMA with error handling
|
||||
period = self.config.parameters.get('period', 20)
|
||||
ema_data = self.safe_calculate_indicator(
|
||||
df,
|
||||
self._calculate_ema,
|
||||
period=period
|
||||
)
|
||||
|
||||
if ema_data is None:
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"EMA calculation failed")]
|
||||
else:
|
||||
return [] # Skip layer gracefully
|
||||
|
||||
# Create trace
|
||||
ema_trace = go.Scatter(
|
||||
x=ema_data['timestamp'],
|
||||
y=ema_data['ema'],
|
||||
mode='lines',
|
||||
name=f'EMA({period})',
|
||||
line=dict(
|
||||
color=self.config.color or '#FF9800',
|
||||
width=self.config.line_width
|
||||
),
|
||||
row=subplot_row,
|
||||
col=1
|
||||
)
|
||||
|
||||
self.traces = [ema_trace]
|
||||
return self.traces
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error creating EMA traces: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
return [self.create_error_trace(error_msg)]
|
||||
|
||||
def _calculate_ema(self, data: pd.DataFrame, period: int) -> pd.DataFrame:
|
||||
"""Calculate EMA with validation"""
|
||||
try:
|
||||
result_df = data.copy()
|
||||
result_df['ema'] = result_df['close'].ewm(span=period, adjust=False).mean()
|
||||
|
||||
# For EMA, we can start from the first value, but remove obvious outliers
|
||||
# Skip first few values for stability
|
||||
warmup_period = max(1, period // 4)
|
||||
result_df = result_df.iloc[warmup_period:]
|
||||
|
||||
if result_df.empty:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='EMA_NO_VALUES',
|
||||
message=f'EMA calculation produced no values (period={period}, data_length={len(data)})',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'data_length': len(data)}
|
||||
))
|
||||
|
||||
return result_df[['timestamp', 'ema']]
|
||||
|
||||
except Exception as e:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='EMA_CALCULATION_ERROR',
|
||||
message=f'EMA calculation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'data_length': len(data), 'exception': str(e)}
|
||||
))
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""Render EMA layer for compatibility with base interface"""
|
||||
try:
|
||||
traces = self.create_traces(data.to_dict('records'), **kwargs)
|
||||
for trace in traces:
|
||||
if hasattr(fig, 'add_trace'):
|
||||
fig.add_trace(trace, **kwargs)
|
||||
else:
|
||||
fig.add_trace(trace)
|
||||
return fig
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering EMA layer: {e}")
|
||||
return fig
|
||||
|
||||
|
||||
class BollingerBandsLayer(BaseIndicatorLayer):
|
||||
"""Bollinger Bands layer with enhanced error handling"""
|
||||
|
||||
def __init__(self, config: IndicatorLayerConfig = None):
|
||||
"""Initialize Bollinger Bands layer"""
|
||||
if config is None:
|
||||
config = IndicatorLayerConfig(
|
||||
indicator_type='bollinger_bands',
|
||||
parameters={'period': 20, 'std_dev': 2},
|
||||
show_middle_line=True
|
||||
)
|
||||
super().__init__(config)
|
||||
|
||||
def create_traces(self, data: List[Dict[str, Any]], subplot_row: int = 1) -> List[go.Scatter]:
|
||||
"""Create Bollinger Bands traces with comprehensive error handling"""
|
||||
try:
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(data) if isinstance(data, list) else data.copy()
|
||||
|
||||
# Validate data
|
||||
if not self.validate_indicator_data(df, required_columns=['close', 'timestamp']):
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"Bollinger Bands Error: {self._error_message}")]
|
||||
|
||||
# Calculate Bollinger Bands with error handling
|
||||
period = self.config.parameters.get('period', 20)
|
||||
std_dev = self.config.parameters.get('std_dev', 2)
|
||||
|
||||
bb_data = self.safe_calculate_indicator(
|
||||
df,
|
||||
self._calculate_bollinger_bands,
|
||||
period=period,
|
||||
std_dev=std_dev
|
||||
)
|
||||
|
||||
if bb_data is None:
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"Bollinger Bands calculation failed")]
|
||||
else:
|
||||
return [] # Skip layer gracefully
|
||||
|
||||
# Create traces
|
||||
traces = []
|
||||
|
||||
# Upper band
|
||||
upper_trace = go.Scatter(
|
||||
x=bb_data['timestamp'],
|
||||
y=bb_data['upper_band'],
|
||||
mode='lines',
|
||||
name=f'BB Upper({period})',
|
||||
line=dict(color=self.config.color or '#9C27B0', width=1),
|
||||
row=subplot_row,
|
||||
col=1,
|
||||
showlegend=True
|
||||
)
|
||||
traces.append(upper_trace)
|
||||
|
||||
# Lower band with fill
|
||||
lower_trace = go.Scatter(
|
||||
x=bb_data['timestamp'],
|
||||
y=bb_data['lower_band'],
|
||||
mode='lines',
|
||||
name=f'BB Lower({period})',
|
||||
line=dict(color=self.config.color or '#9C27B0', width=1),
|
||||
fill='tonexty',
|
||||
fillcolor='rgba(156, 39, 176, 0.1)',
|
||||
row=subplot_row,
|
||||
col=1,
|
||||
showlegend=True
|
||||
)
|
||||
traces.append(lower_trace)
|
||||
|
||||
# Middle line (SMA)
|
||||
if self.config.show_middle_line:
|
||||
middle_trace = go.Scatter(
|
||||
x=bb_data['timestamp'],
|
||||
y=bb_data['middle_band'],
|
||||
mode='lines',
|
||||
name=f'BB Middle({period})',
|
||||
line=dict(color=self.config.color or '#9C27B0', width=1, dash='dash'),
|
||||
row=subplot_row,
|
||||
col=1,
|
||||
showlegend=True
|
||||
)
|
||||
traces.append(middle_trace)
|
||||
|
||||
self.traces = traces
|
||||
return self.traces
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error creating Bollinger Bands traces: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
return [self.create_error_trace(error_msg)]
|
||||
|
||||
def _calculate_bollinger_bands(self, data: pd.DataFrame, period: int, std_dev: float) -> pd.DataFrame:
|
||||
"""Calculate Bollinger Bands with validation"""
|
||||
try:
|
||||
result_df = data.copy()
|
||||
|
||||
# Calculate middle band (SMA)
|
||||
result_df['middle_band'] = result_df['close'].rolling(window=period, min_periods=period).mean()
|
||||
|
||||
# Calculate standard deviation
|
||||
result_df['std'] = result_df['close'].rolling(window=period, min_periods=period).std()
|
||||
|
||||
# Calculate upper and lower bands
|
||||
result_df['upper_band'] = result_df['middle_band'] + (result_df['std'] * std_dev)
|
||||
result_df['lower_band'] = result_df['middle_band'] - (result_df['std'] * std_dev)
|
||||
|
||||
# Remove NaN values
|
||||
result_df = result_df.dropna(subset=['middle_band', 'upper_band', 'lower_band'])
|
||||
|
||||
if result_df.empty:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='BB_NO_VALUES',
|
||||
message=f'Bollinger Bands calculation produced no values (period={period}, data_length={len(data)})',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'std_dev': std_dev, 'data_length': len(data)}
|
||||
))
|
||||
|
||||
return result_df[['timestamp', 'upper_band', 'middle_band', 'lower_band']]
|
||||
|
||||
except Exception as e:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='BB_CALCULATION_ERROR',
|
||||
message=f'Bollinger Bands calculation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'std_dev': std_dev, 'data_length': len(data), 'exception': str(e)}
|
||||
))
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""Render Bollinger Bands layer for compatibility with base interface"""
|
||||
try:
|
||||
traces = self.create_traces(data.to_dict('records'), **kwargs)
|
||||
for trace in traces:
|
||||
if hasattr(fig, 'add_trace'):
|
||||
fig.add_trace(trace, **kwargs)
|
||||
else:
|
||||
fig.add_trace(trace)
|
||||
return fig
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering Bollinger Bands layer: {e}")
|
||||
return fig
|
||||
|
||||
|
||||
def create_sma_layer(period: int = 20, **kwargs) -> SMALayer:
|
||||
"""
|
||||
Convenience function to create an SMA layer.
|
||||
|
||||
Args:
|
||||
period: SMA period
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configured SMA layer
|
||||
"""
|
||||
return SMALayer(period=period, **kwargs)
|
||||
|
||||
|
||||
def create_ema_layer(period: int = 12, **kwargs) -> EMALayer:
|
||||
"""
|
||||
Convenience function to create an EMA layer.
|
||||
|
||||
Args:
|
||||
period: EMA period
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configured EMA layer
|
||||
"""
|
||||
return EMALayer(period=period, **kwargs)
|
||||
|
||||
|
||||
def create_bollinger_bands_layer(period: int = 20, std_dev: float = 2.0, **kwargs) -> BollingerBandsLayer:
|
||||
"""
|
||||
Convenience function to create a Bollinger Bands layer.
|
||||
|
||||
Args:
|
||||
period: BB period (default: 20)
|
||||
std_dev: Standard deviation multiplier (default: 2.0)
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configured Bollinger Bands layer
|
||||
"""
|
||||
return BollingerBandsLayer(period=period, std_dev=std_dev, **kwargs)
|
||||
|
||||
|
||||
def create_common_ma_layers() -> List[BaseIndicatorLayer]:
|
||||
"""
|
||||
Create commonly used moving average layers.
|
||||
|
||||
Returns:
|
||||
List of configured MA layers (SMA 20, SMA 50, EMA 12, EMA 26)
|
||||
"""
|
||||
colors = get_indicator_colors()
|
||||
|
||||
return [
|
||||
SMALayer(20, color=colors.get('sma', '#007bff'), name="SMA(20)"),
|
||||
SMALayer(50, color='#6c757d', name="SMA(50)"), # Gray for longer SMA
|
||||
EMALayer(12, color=colors.get('ema', '#ff6b35'), name="EMA(12)"),
|
||||
EMALayer(26, color='#28a745', name="EMA(26)") # Green for longer EMA
|
||||
]
|
||||
|
||||
|
||||
def create_common_overlay_indicators() -> List[BaseIndicatorLayer]:
|
||||
"""
|
||||
Create commonly used overlay indicators including moving averages and Bollinger Bands.
|
||||
|
||||
Returns:
|
||||
List of configured overlay indicator layers
|
||||
"""
|
||||
colors = get_indicator_colors()
|
||||
|
||||
return [
|
||||
SMALayer(20, color=colors.get('sma', '#007bff'), name="SMA(20)"),
|
||||
EMALayer(12, color=colors.get('ema', '#ff6b35'), name="EMA(12)"),
|
||||
BollingerBandsLayer(20, 2.0, color=colors.get('bb_upper', '#6f42c1'), name="BB(20,2)")
|
||||
]
|
||||
424
components/charts/layers/subplots.py
Normal file
424
components/charts/layers/subplots.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
Subplot Chart Layers
|
||||
|
||||
This module contains subplot layer implementations for indicators that render
|
||||
in separate subplots below the main price chart, such as RSI, MACD, and other
|
||||
oscillators and momentum indicators.
|
||||
"""
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Any, Optional, List, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .base import BaseChartLayer, LayerConfig
|
||||
from .indicators import BaseIndicatorLayer, IndicatorLayerConfig
|
||||
from data.common.indicators import TechnicalIndicators, IndicatorResult, OHLCVCandle
|
||||
from components.charts.utils import get_indicator_colors
|
||||
from utils.logger import get_logger
|
||||
from ..error_handling import (
|
||||
ChartErrorHandler, ChartError, ErrorSeverity, DataRequirements,
|
||||
InsufficientDataError, DataValidationError, IndicatorCalculationError,
|
||||
ErrorRecoveryStrategies, create_error_annotation, get_error_message
|
||||
)
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger("subplot_layers")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubplotLayerConfig(IndicatorLayerConfig):
|
||||
"""Extended configuration for subplot indicator layers"""
|
||||
subplot_height_ratio: float = 0.25 # Height ratio for subplot (0.25 = 25% of total height)
|
||||
y_axis_range: Optional[Tuple[float, float]] = None # Fixed y-axis range (min, max)
|
||||
show_zero_line: bool = False # Show horizontal line at y=0
|
||||
reference_lines: List[float] = None # Additional horizontal reference lines
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.reference_lines is None:
|
||||
self.reference_lines = []
|
||||
|
||||
|
||||
class BaseSubplotLayer(BaseIndicatorLayer):
|
||||
"""
|
||||
Base class for all subplot indicator layers.
|
||||
|
||||
Provides common functionality for indicators that render in separate subplots
|
||||
with their own y-axis scaling and reference lines.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SubplotLayerConfig):
|
||||
"""
|
||||
Initialize base subplot layer.
|
||||
|
||||
Args:
|
||||
config: Subplot layer configuration
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.subplot_config = config
|
||||
|
||||
def get_subplot_height_ratio(self) -> float:
|
||||
"""Get the height ratio for this subplot."""
|
||||
return self.subplot_config.subplot_height_ratio
|
||||
|
||||
def has_fixed_range(self) -> bool:
|
||||
"""Check if this subplot has a fixed y-axis range."""
|
||||
return self.subplot_config.y_axis_range is not None
|
||||
|
||||
def get_y_axis_range(self) -> Optional[Tuple[float, float]]:
|
||||
"""Get the fixed y-axis range if defined."""
|
||||
return self.subplot_config.y_axis_range
|
||||
|
||||
def should_show_zero_line(self) -> bool:
|
||||
"""Check if zero line should be shown."""
|
||||
return self.subplot_config.show_zero_line
|
||||
|
||||
def get_reference_lines(self) -> List[float]:
|
||||
"""Get additional reference lines to draw."""
|
||||
return self.subplot_config.reference_lines
|
||||
|
||||
def add_reference_lines(self, fig: go.Figure, row: int, col: int = 1) -> None:
|
||||
"""
|
||||
Add reference lines to the subplot.
|
||||
|
||||
Args:
|
||||
fig: Target figure
|
||||
row: Subplot row
|
||||
col: Subplot column
|
||||
"""
|
||||
try:
|
||||
# Add zero line if enabled
|
||||
if self.should_show_zero_line():
|
||||
fig.add_hline(
|
||||
y=0,
|
||||
line=dict(color='gray', width=1, dash='dash'),
|
||||
row=row,
|
||||
col=col
|
||||
)
|
||||
|
||||
# Add additional reference lines
|
||||
for ref_value in self.get_reference_lines():
|
||||
fig.add_hline(
|
||||
y=ref_value,
|
||||
line=dict(color='lightgray', width=1, dash='dot'),
|
||||
row=row,
|
||||
col=col
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not add reference lines: {e}")
|
||||
|
||||
|
||||
class RSILayer(BaseSubplotLayer):
|
||||
"""
|
||||
Relative Strength Index (RSI) subplot layer.
|
||||
|
||||
Renders RSI oscillator in a separate subplot with standard overbought (70)
|
||||
and oversold (30) reference lines.
|
||||
"""
|
||||
|
||||
def __init__(self, period: int = 14, color: str = None, name: str = None):
|
||||
"""
|
||||
Initialize RSI layer.
|
||||
|
||||
Args:
|
||||
period: RSI period (default: 14)
|
||||
color: Line color (optional, uses default)
|
||||
name: Layer name (optional, auto-generated)
|
||||
"""
|
||||
# Use default color if not specified
|
||||
if color is None:
|
||||
colors = get_indicator_colors()
|
||||
color = colors.get('rsi', '#20c997')
|
||||
|
||||
# Generate name if not specified
|
||||
if name is None:
|
||||
name = f"RSI({period})"
|
||||
|
||||
# Find next available subplot row (will be managed by LayerManager)
|
||||
subplot_row = 2 # Default to row 2 (first subplot after main chart)
|
||||
|
||||
config = SubplotLayerConfig(
|
||||
name=name,
|
||||
indicator_type="rsi",
|
||||
color=color,
|
||||
parameters={'period': period},
|
||||
subplot_row=subplot_row,
|
||||
subplot_height_ratio=0.25,
|
||||
y_axis_range=(0, 100), # RSI ranges from 0 to 100
|
||||
reference_lines=[30, 70], # Oversold and overbought levels
|
||||
style={
|
||||
'line_color': color,
|
||||
'line_width': 2,
|
||||
'opacity': 1.0
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
self.period = period
|
||||
|
||||
def _calculate_rsi(self, data: pd.DataFrame, period: int) -> pd.DataFrame:
|
||||
"""Calculate RSI with validation and error handling"""
|
||||
try:
|
||||
result_df = data.copy()
|
||||
|
||||
# Calculate price changes
|
||||
result_df['price_change'] = result_df['close'].diff()
|
||||
|
||||
# Separate gains and losses
|
||||
result_df['gain'] = result_df['price_change'].clip(lower=0)
|
||||
result_df['loss'] = -result_df['price_change'].clip(upper=0)
|
||||
|
||||
# Calculate average gains and losses using Wilder's smoothing
|
||||
result_df['avg_gain'] = result_df['gain'].ewm(alpha=1/period, adjust=False).mean()
|
||||
result_df['avg_loss'] = result_df['loss'].ewm(alpha=1/period, adjust=False).mean()
|
||||
|
||||
# Calculate RS and RSI
|
||||
result_df['rs'] = result_df['avg_gain'] / result_df['avg_loss']
|
||||
result_df['rsi'] = 100 - (100 / (1 + result_df['rs']))
|
||||
|
||||
# Remove rows where RSI cannot be calculated
|
||||
result_df = result_df.iloc[period:].copy()
|
||||
|
||||
# Remove NaN values and invalid RSI values
|
||||
result_df = result_df.dropna(subset=['rsi'])
|
||||
result_df = result_df[
|
||||
(result_df['rsi'] >= 0) &
|
||||
(result_df['rsi'] <= 100) &
|
||||
pd.notna(result_df['rsi'])
|
||||
]
|
||||
|
||||
if result_df.empty:
|
||||
raise Exception(f'RSI calculation produced no values (period={period}, data_length={len(data)})')
|
||||
|
||||
return result_df[['timestamp', 'rsi']]
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f'RSI calculation failed: {str(e)}')
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""Render RSI layer for compatibility with base interface"""
|
||||
try:
|
||||
# Calculate RSI
|
||||
rsi_data = self._calculate_rsi(data, self.period)
|
||||
if rsi_data.empty:
|
||||
return fig
|
||||
|
||||
# Create RSI trace
|
||||
rsi_trace = go.Scatter(
|
||||
x=rsi_data['timestamp'],
|
||||
y=rsi_data['rsi'],
|
||||
mode='lines',
|
||||
name=self.config.name,
|
||||
line=dict(
|
||||
color=self.config.color,
|
||||
width=2
|
||||
),
|
||||
showlegend=True
|
||||
)
|
||||
|
||||
# Add trace
|
||||
row = kwargs.get('row', self.config.subplot_row or 2)
|
||||
col = kwargs.get('col', 1)
|
||||
|
||||
if hasattr(fig, 'add_trace'):
|
||||
fig.add_trace(rsi_trace, row=row, col=col)
|
||||
else:
|
||||
fig.add_trace(rsi_trace)
|
||||
|
||||
# Add reference lines
|
||||
self.add_reference_lines(fig, row, col)
|
||||
|
||||
return fig
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering RSI layer: {e}")
|
||||
return fig
|
||||
|
||||
|
||||
class MACDLayer(BaseSubplotLayer):
|
||||
"""MACD (Moving Average Convergence Divergence) subplot layer with enhanced error handling"""
|
||||
|
||||
def __init__(self, fast_period: int = 12, slow_period: int = 26, signal_period: int = 9,
|
||||
color: str = None, name: str = None):
|
||||
"""Initialize MACD layer with custom parameters"""
|
||||
# Use default color if not specified
|
||||
if color is None:
|
||||
colors = get_indicator_colors()
|
||||
color = colors.get('macd', '#fd7e14')
|
||||
|
||||
# Generate name if not specified
|
||||
if name is None:
|
||||
name = f"MACD({fast_period},{slow_period},{signal_period})"
|
||||
|
||||
config = SubplotLayerConfig(
|
||||
name=name,
|
||||
indicator_type="macd",
|
||||
color=color,
|
||||
parameters={
|
||||
'fast_period': fast_period,
|
||||
'slow_period': slow_period,
|
||||
'signal_period': signal_period
|
||||
},
|
||||
subplot_row=3, # Will be managed by LayerManager
|
||||
subplot_height_ratio=0.3,
|
||||
show_zero_line=True,
|
||||
style={
|
||||
'line_color': color,
|
||||
'line_width': 2,
|
||||
'opacity': 1.0
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
self.fast_period = fast_period
|
||||
self.slow_period = slow_period
|
||||
self.signal_period = signal_period
|
||||
|
||||
def _calculate_macd(self, data: pd.DataFrame, fast_period: int,
|
||||
slow_period: int, signal_period: int) -> pd.DataFrame:
|
||||
"""Calculate MACD with validation and error handling"""
|
||||
try:
|
||||
result_df = data.copy()
|
||||
|
||||
# Validate periods
|
||||
if fast_period >= slow_period:
|
||||
raise Exception(f'Fast period ({fast_period}) must be less than slow period ({slow_period})')
|
||||
|
||||
# Calculate EMAs
|
||||
result_df['ema_fast'] = result_df['close'].ewm(span=fast_period, adjust=False).mean()
|
||||
result_df['ema_slow'] = result_df['close'].ewm(span=slow_period, adjust=False).mean()
|
||||
|
||||
# Calculate MACD line
|
||||
result_df['macd'] = result_df['ema_fast'] - result_df['ema_slow']
|
||||
|
||||
# Calculate signal line
|
||||
result_df['signal'] = result_df['macd'].ewm(span=signal_period, adjust=False).mean()
|
||||
|
||||
# Calculate histogram
|
||||
result_df['histogram'] = result_df['macd'] - result_df['signal']
|
||||
|
||||
# Remove rows where MACD cannot be calculated reliably
|
||||
warmup_period = slow_period + signal_period
|
||||
result_df = result_df.iloc[warmup_period:].copy()
|
||||
|
||||
# Remove NaN values
|
||||
result_df = result_df.dropna(subset=['macd', 'signal', 'histogram'])
|
||||
|
||||
if result_df.empty:
|
||||
raise Exception(f'MACD calculation produced no values (fast={fast_period}, slow={slow_period}, signal={signal_period})')
|
||||
|
||||
return result_df[['timestamp', 'macd', 'signal', 'histogram']]
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f'MACD calculation failed: {str(e)}')
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""Render MACD layer for compatibility with base interface"""
|
||||
try:
|
||||
# Calculate MACD
|
||||
macd_data = self._calculate_macd(data, self.fast_period, self.slow_period, self.signal_period)
|
||||
if macd_data.empty:
|
||||
return fig
|
||||
|
||||
row = kwargs.get('row', self.config.subplot_row or 3)
|
||||
col = kwargs.get('col', 1)
|
||||
|
||||
# Create MACD line trace
|
||||
macd_trace = go.Scatter(
|
||||
x=macd_data['timestamp'],
|
||||
y=macd_data['macd'],
|
||||
mode='lines',
|
||||
name=f'{self.config.name} Line',
|
||||
line=dict(color=self.config.color, width=2),
|
||||
showlegend=True
|
||||
)
|
||||
|
||||
# Create signal line trace
|
||||
signal_trace = go.Scatter(
|
||||
x=macd_data['timestamp'],
|
||||
y=macd_data['signal'],
|
||||
mode='lines',
|
||||
name=f'{self.config.name} Signal',
|
||||
line=dict(color='#FF9800', width=2),
|
||||
showlegend=True
|
||||
)
|
||||
|
||||
# Create histogram
|
||||
histogram_colors = ['green' if h >= 0 else 'red' for h in macd_data['histogram']]
|
||||
histogram_trace = go.Bar(
|
||||
x=macd_data['timestamp'],
|
||||
y=macd_data['histogram'],
|
||||
name=f'{self.config.name} Histogram',
|
||||
marker_color=histogram_colors,
|
||||
opacity=0.6,
|
||||
showlegend=True
|
||||
)
|
||||
|
||||
# Add traces
|
||||
if hasattr(fig, 'add_trace'):
|
||||
fig.add_trace(macd_trace, row=row, col=col)
|
||||
fig.add_trace(signal_trace, row=row, col=col)
|
||||
fig.add_trace(histogram_trace, row=row, col=col)
|
||||
else:
|
||||
fig.add_trace(macd_trace)
|
||||
fig.add_trace(signal_trace)
|
||||
fig.add_trace(histogram_trace)
|
||||
|
||||
# Add zero line
|
||||
self.add_reference_lines(fig, row, col)
|
||||
|
||||
return fig
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering MACD layer: {e}")
|
||||
return fig
|
||||
|
||||
|
||||
def create_rsi_layer(period: int = 14, **kwargs) -> 'RSILayer':
|
||||
"""
|
||||
Convenience function to create an RSI layer.
|
||||
|
||||
Args:
|
||||
period: RSI period (default: 14)
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configured RSI layer
|
||||
"""
|
||||
return RSILayer(period=period, **kwargs)
|
||||
|
||||
|
||||
def create_macd_layer(fast_period: int = 12, slow_period: int = 26,
|
||||
signal_period: int = 9, **kwargs) -> 'MACDLayer':
|
||||
"""
|
||||
Convenience function to create a MACD layer.
|
||||
|
||||
Args:
|
||||
fast_period: Fast EMA period (default: 12)
|
||||
slow_period: Slow EMA period (default: 26)
|
||||
signal_period: Signal line period (default: 9)
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configured MACD layer
|
||||
"""
|
||||
return MACDLayer(
|
||||
fast_period=fast_period,
|
||||
slow_period=slow_period,
|
||||
signal_period=signal_period,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def create_common_subplot_indicators() -> List[BaseSubplotLayer]:
|
||||
"""
|
||||
Create commonly used subplot indicators.
|
||||
|
||||
Returns:
|
||||
List of configured subplot indicator layers (RSI, MACD)
|
||||
"""
|
||||
return [
|
||||
RSILayer(period=14),
|
||||
MACDLayer(fast_period=12, slow_period=26, signal_period=9)
|
||||
]
|
||||
Reference in New Issue
Block a user