952 lines
37 KiB
Python
952 lines
37 KiB
Python
"""
|
|
Base Chart Layer Components
|
|
|
|
This module contains the foundational layer classes that serve as building blocks
|
|
for all chart components including candlestick charts, indicators, and signals.
|
|
"""
|
|
|
|
import plotly.graph_objects as go
|
|
from plotly.subplots import make_subplots
|
|
import pandas as pd
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, Optional, List, Union
|
|
from dataclasses import dataclass
|
|
|
|
from utils.logger import get_logger
|
|
from ..error_handling import (
|
|
ChartErrorHandler, ChartError, ErrorSeverity,
|
|
InsufficientDataError, DataValidationError, IndicatorCalculationError,
|
|
create_error_annotation, get_error_message
|
|
)
|
|
|
|
# Initialize logger
|
|
logger = get_logger()
|
|
|
|
|
|
@dataclass
|
|
class LayerConfig:
|
|
"""Configuration for chart layers"""
|
|
name: str
|
|
enabled: bool = True
|
|
color: Optional[str] = None
|
|
style: Dict[str, Any] = None
|
|
subplot_row: Optional[int] = None # None = main chart, 1+ = subplot row
|
|
|
|
def __post_init__(self):
|
|
if self.style is None:
|
|
self.style = {}
|
|
|
|
|
|
class BaseLayer:
|
|
"""
|
|
Base class for all chart layers providing common functionality
|
|
for data validation, error handling, and trace management.
|
|
"""
|
|
|
|
def __init__(self, config: LayerConfig):
|
|
self.config = config
|
|
self.logger = get_logger()
|
|
self.error_handler = ChartErrorHandler()
|
|
self.traces = []
|
|
self._is_valid = False
|
|
self._error_message = None
|
|
|
|
def validate_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]]) -> bool:
|
|
"""
|
|
Validate input data for layer requirements.
|
|
|
|
Args:
|
|
data: Input data to validate
|
|
|
|
Returns:
|
|
True if data is valid, False otherwise
|
|
"""
|
|
try:
|
|
self.error_handler.clear_errors()
|
|
|
|
# Check data type
|
|
if not isinstance(data, (pd.DataFrame, list)):
|
|
error = ChartError(
|
|
code='INVALID_DATA_TYPE',
|
|
message=f'Invalid data type for {self.__class__.__name__}: {type(data)}',
|
|
severity=ErrorSeverity.ERROR,
|
|
context={'layer': self.__class__.__name__, 'data_type': str(type(data))},
|
|
recovery_suggestion='Provide data as pandas DataFrame or list of dictionaries'
|
|
)
|
|
self.error_handler.errors.append(error)
|
|
return False
|
|
|
|
# Check data sufficiency
|
|
is_sufficient = self.error_handler.validate_data_sufficiency(
|
|
data,
|
|
chart_type='candlestick', # Default chart type since LayerConfig doesn't have layer_type
|
|
indicators=[{'type': 'candlestick', 'parameters': {}}] # Default indicator type
|
|
)
|
|
|
|
self._is_valid = is_sufficient
|
|
if not is_sufficient:
|
|
self._error_message = self.error_handler.get_user_friendly_message()
|
|
|
|
return is_sufficient
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Base layer: Data validation error in {self.__class__.__name__}: {e}")
|
|
error = ChartError(
|
|
code='VALIDATION_EXCEPTION',
|
|
message=f'Validation error: {str(e)}',
|
|
severity=ErrorSeverity.ERROR,
|
|
context={'layer': self.__class__.__name__, 'exception': str(e)}
|
|
)
|
|
self.error_handler.errors.append(error)
|
|
self._is_valid = False
|
|
self._error_message = str(e)
|
|
return False
|
|
|
|
def get_error_info(self) -> Dict[str, Any]:
|
|
"""Get error information for this layer"""
|
|
return {
|
|
'is_valid': self._is_valid,
|
|
'error_message': self._error_message,
|
|
'error_summary': self.error_handler.get_error_summary(),
|
|
'can_proceed': len(self.error_handler.errors) == 0
|
|
}
|
|
|
|
def create_error_trace(self, error_message: str) -> go.Scatter:
|
|
"""Create an error display trace"""
|
|
return go.Scatter(
|
|
x=[],
|
|
y=[],
|
|
mode='text',
|
|
text=[error_message],
|
|
textposition='middle center',
|
|
textfont={'size': 14, 'color': '#e74c3c'},
|
|
showlegend=False,
|
|
name=f"{self.__class__.__name__} Error"
|
|
)
|
|
|
|
|
|
class BaseChartLayer(ABC):
|
|
"""
|
|
Abstract base class for all chart layers.
|
|
|
|
This defines the interface that all chart layers must implement,
|
|
whether they are candlestick charts, indicators, or signal overlays.
|
|
"""
|
|
|
|
def __init__(self, config: LayerConfig):
|
|
"""
|
|
Initialize the base layer.
|
|
|
|
Args:
|
|
config: Layer configuration
|
|
"""
|
|
self.config = config
|
|
self.logger = logger
|
|
|
|
@abstractmethod
|
|
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
|
"""
|
|
Render the layer onto the provided figure.
|
|
|
|
Args:
|
|
fig: Plotly figure to render onto
|
|
data: Chart data (OHLCV format)
|
|
**kwargs: Additional rendering parameters
|
|
|
|
Returns:
|
|
Updated figure with layer rendered
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def validate_data(self, data: pd.DataFrame) -> bool:
|
|
"""
|
|
Validate that the data is suitable for this layer.
|
|
|
|
Args:
|
|
data: Chart data to validate
|
|
|
|
Returns:
|
|
True if data is valid, False otherwise
|
|
"""
|
|
pass
|
|
|
|
def is_enabled(self) -> bool:
|
|
"""Check if the layer is enabled."""
|
|
return self.config.enabled
|
|
|
|
def get_subplot_row(self) -> Optional[int]:
|
|
"""Get the subplot row for this layer."""
|
|
return self.config.subplot_row
|
|
|
|
def is_overlay(self) -> bool:
|
|
"""Check if this layer is an overlay (main chart) or subplot."""
|
|
return self.config.subplot_row is None
|
|
|
|
|
|
class CandlestickLayer(BaseLayer):
|
|
"""
|
|
Candlestick chart layer implementation with enhanced error handling.
|
|
|
|
This layer renders OHLC data as candlesticks on the main chart.
|
|
"""
|
|
|
|
def __init__(self, config: LayerConfig = None):
|
|
"""
|
|
Initialize candlestick layer.
|
|
|
|
Args:
|
|
config: Layer configuration (optional, uses defaults)
|
|
"""
|
|
if config is None:
|
|
config = LayerConfig(
|
|
name="candlestick",
|
|
enabled=True,
|
|
style={
|
|
'increasing_color': '#00C851', # Green for bullish
|
|
'decreasing_color': '#FF4444', # Red for bearish
|
|
'line_width': 1
|
|
}
|
|
)
|
|
|
|
super().__init__(config)
|
|
|
|
def is_enabled(self) -> bool:
|
|
"""Check if the layer is enabled."""
|
|
return self.config.enabled
|
|
|
|
def is_overlay(self) -> bool:
|
|
"""Check if this layer is an overlay (main chart) or subplot."""
|
|
return self.config.subplot_row is None
|
|
|
|
def get_subplot_row(self) -> Optional[int]:
|
|
"""Get the subplot row for this layer."""
|
|
return self.config.subplot_row
|
|
|
|
def validate_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]]) -> bool:
|
|
"""Enhanced validation with comprehensive error handling"""
|
|
try:
|
|
# Use parent class error handling for comprehensive validation
|
|
parent_valid = super().validate_data(data)
|
|
|
|
# Convert to DataFrame if needed for local validation
|
|
if isinstance(data, list):
|
|
df = pd.DataFrame(data)
|
|
else:
|
|
df = data.copy()
|
|
|
|
# Additional candlestick-specific validation
|
|
required_columns = ['timestamp', 'open', 'high', 'low', 'close']
|
|
|
|
if not all(col in df.columns for col in required_columns):
|
|
missing = [col for col in required_columns if col not in df.columns]
|
|
error = ChartError(
|
|
code='MISSING_OHLC_COLUMNS',
|
|
message=f'Missing required OHLC columns: {missing}',
|
|
severity=ErrorSeverity.ERROR,
|
|
context={'missing_columns': missing, 'available_columns': list(df.columns)},
|
|
recovery_suggestion='Ensure data contains timestamp, open, high, low, close columns'
|
|
)
|
|
self.error_handler.errors.append(error)
|
|
return False
|
|
|
|
if len(df) == 0:
|
|
error = ChartError(
|
|
code='EMPTY_CANDLESTICK_DATA',
|
|
message='No candlestick data available',
|
|
severity=ErrorSeverity.ERROR,
|
|
context={'data_count': 0},
|
|
recovery_suggestion='Check data source or time range'
|
|
)
|
|
self.error_handler.errors.append(error)
|
|
return False
|
|
|
|
# Check for price data validity
|
|
invalid_prices = df[
|
|
(df['high'] < df['low']) |
|
|
(df['open'] < 0) | (df['close'] < 0) |
|
|
(df['high'] < 0) | (df['low'] < 0) |
|
|
pd.isna(df[['open', 'high', 'low', 'close']]).any(axis=1)
|
|
]
|
|
|
|
if len(invalid_prices) > len(df) * 0.5: # More than 50% invalid
|
|
error = ChartError(
|
|
code='EXCESSIVE_INVALID_PRICES',
|
|
message=f'Too many invalid price records: {len(invalid_prices)}/{len(df)}',
|
|
severity=ErrorSeverity.ERROR,
|
|
context={'invalid_count': len(invalid_prices), 'total_count': len(df)},
|
|
recovery_suggestion='Check data quality and price data sources'
|
|
)
|
|
self.error_handler.errors.append(error)
|
|
return False
|
|
elif len(invalid_prices) > 0:
|
|
# Warning for some invalid data
|
|
error = ChartError(
|
|
code='SOME_INVALID_PRICES',
|
|
message=f'Found {len(invalid_prices)} invalid price records (will be filtered)',
|
|
severity=ErrorSeverity.WARNING,
|
|
context={'invalid_count': len(invalid_prices), 'total_count': len(df)},
|
|
recovery_suggestion='Invalid records will be automatically removed'
|
|
)
|
|
self.error_handler.warnings.append(error)
|
|
|
|
return parent_valid and len(self.error_handler.errors) == 0
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Candlestick layer: Error validating candlestick data: {e}")
|
|
error = ChartError(
|
|
code='CANDLESTICK_VALIDATION_ERROR',
|
|
message=f'Candlestick validation failed: {str(e)}',
|
|
severity=ErrorSeverity.ERROR,
|
|
context={'exception': str(e)}
|
|
)
|
|
self.error_handler.errors.append(error)
|
|
return False
|
|
|
|
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
|
"""
|
|
Render candlestick chart with error handling and recovery.
|
|
|
|
Args:
|
|
fig: Target figure
|
|
data: OHLCV data
|
|
**kwargs: Additional parameters (row, col for subplots)
|
|
|
|
Returns:
|
|
Figure with candlestick trace added or error display
|
|
"""
|
|
try:
|
|
# Validate data
|
|
if not self.validate_data(data):
|
|
self.logger.error("Candlestick layer: Invalid data for candlestick layer")
|
|
|
|
# Add error annotation to figure
|
|
if self.error_handler.errors:
|
|
error_msg = self.error_handler.errors[0].message
|
|
fig.add_annotation(create_error_annotation(
|
|
f"Candlestick Error: {error_msg}",
|
|
position='center'
|
|
))
|
|
return fig
|
|
|
|
# Clean and prepare data
|
|
clean_data = self._clean_candlestick_data(data)
|
|
if clean_data.empty:
|
|
fig.add_annotation(create_error_annotation(
|
|
"No valid candlestick data after cleaning",
|
|
position='center'
|
|
))
|
|
return fig
|
|
|
|
# Extract styling
|
|
style = self.config.style
|
|
increasing_color = style.get('increasing_color', '#00C851')
|
|
decreasing_color = style.get('decreasing_color', '#FF4444')
|
|
|
|
# Create candlestick trace
|
|
candlestick = go.Candlestick(
|
|
x=clean_data['timestamp'],
|
|
open=clean_data['open'],
|
|
high=clean_data['high'],
|
|
low=clean_data['low'],
|
|
close=clean_data['close'],
|
|
name=self.config.name,
|
|
increasing_line_color=increasing_color,
|
|
decreasing_line_color=decreasing_color,
|
|
showlegend=False
|
|
)
|
|
|
|
# Add to figure
|
|
row = kwargs.get('row', 1)
|
|
col = kwargs.get('col', 1)
|
|
|
|
try:
|
|
if hasattr(fig, 'add_trace') and row == 1 and col == 1:
|
|
# Simple figure without subplots
|
|
fig.add_trace(candlestick)
|
|
elif hasattr(fig, 'add_trace'):
|
|
# Subplot figure
|
|
fig.add_trace(candlestick, row=row, col=col)
|
|
else:
|
|
# Fallback
|
|
fig.add_trace(candlestick)
|
|
except Exception as trace_error:
|
|
# If subplot call fails, try simple add_trace
|
|
try:
|
|
fig.add_trace(candlestick)
|
|
except Exception as fallback_error:
|
|
self.logger.error(f"Candlestick layer: Failed to add candlestick trace: {fallback_error}")
|
|
fig.add_annotation(create_error_annotation(
|
|
f"Failed to add candlestick trace: {str(fallback_error)}",
|
|
position='center'
|
|
))
|
|
return fig
|
|
|
|
# Add warning annotations if needed
|
|
if self.error_handler.warnings:
|
|
warning_msg = f"⚠️ {self.error_handler.warnings[0].message}"
|
|
fig.add_annotation({
|
|
'text': warning_msg,
|
|
'xref': 'paper', 'yref': 'paper',
|
|
'x': 0.02, 'y': 0.98,
|
|
'xanchor': 'left', 'yanchor': 'top',
|
|
'showarrow': False,
|
|
'font': {'size': 10, 'color': '#f39c12'},
|
|
'bgcolor': 'rgba(255,255,255,0.8)'
|
|
})
|
|
|
|
self.logger.debug(f"Rendered candlestick layer with {len(clean_data)} candles")
|
|
return fig
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Candlestick layer: Error rendering candlestick layer: {e}")
|
|
fig.add_annotation(create_error_annotation(
|
|
f"Candlestick render error: {str(e)}",
|
|
position='center'
|
|
))
|
|
return fig
|
|
|
|
def _clean_candlestick_data(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
"""Clean and validate candlestick data"""
|
|
try:
|
|
clean_data = data.copy()
|
|
|
|
# Remove rows with invalid prices
|
|
invalid_mask = (
|
|
(clean_data['high'] < clean_data['low']) |
|
|
(clean_data['open'] < 0) | (clean_data['close'] < 0) |
|
|
(clean_data['high'] < 0) | (clean_data['low'] < 0) |
|
|
pd.isna(clean_data[['open', 'high', 'low', 'close']]).any(axis=1)
|
|
)
|
|
|
|
initial_count = len(clean_data)
|
|
clean_data = clean_data[~invalid_mask]
|
|
|
|
if len(clean_data) < initial_count:
|
|
removed_count = initial_count - len(clean_data)
|
|
self.logger.info(f"Removed {removed_count} invalid candlestick records")
|
|
|
|
# Ensure timestamp is properly formatted
|
|
if not pd.api.types.is_datetime64_any_dtype(clean_data['timestamp']):
|
|
clean_data['timestamp'] = pd.to_datetime(clean_data['timestamp'])
|
|
|
|
# Sort by timestamp
|
|
clean_data = clean_data.sort_values('timestamp')
|
|
|
|
return clean_data
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Candlestick layer: Error cleaning candlestick data: {e}")
|
|
return pd.DataFrame()
|
|
|
|
|
|
class VolumeLayer(BaseLayer):
|
|
"""
|
|
Volume subplot layer implementation with enhanced error handling.
|
|
|
|
This layer renders volume data as a bar chart in a separate subplot,
|
|
with bars colored based on price movement.
|
|
"""
|
|
|
|
def __init__(self, config: LayerConfig = None):
|
|
"""
|
|
Initialize volume layer.
|
|
|
|
Args:
|
|
config: Layer configuration (optional, uses defaults)
|
|
"""
|
|
if config is None:
|
|
config = LayerConfig(
|
|
name="volume",
|
|
enabled=True,
|
|
subplot_row=2, # Volume goes in second row by default
|
|
style={
|
|
'bullish_color': '#00C851',
|
|
'bearish_color': '#FF4444',
|
|
'opacity': 0.7
|
|
}
|
|
)
|
|
|
|
super().__init__(config)
|
|
|
|
def is_enabled(self) -> bool:
|
|
"""Check if the layer is enabled."""
|
|
return self.config.enabled
|
|
|
|
def is_overlay(self) -> bool:
|
|
"""Check if this layer is an overlay (main chart) or subplot."""
|
|
return self.config.subplot_row is None
|
|
|
|
def get_subplot_row(self) -> Optional[int]:
|
|
"""Get the subplot row for this layer."""
|
|
return self.config.subplot_row
|
|
|
|
def validate_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]]) -> bool:
|
|
"""Enhanced validation with comprehensive error handling"""
|
|
try:
|
|
# Use parent class error handling
|
|
parent_valid = super().validate_data(data)
|
|
|
|
# Convert to DataFrame if needed
|
|
if isinstance(data, list):
|
|
df = pd.DataFrame(data)
|
|
else:
|
|
df = data.copy()
|
|
|
|
# Volume-specific validation
|
|
required_columns = ['timestamp', 'open', 'close', 'volume']
|
|
|
|
if not all(col in df.columns for col in required_columns):
|
|
missing = [col for col in required_columns if col not in df.columns]
|
|
error = ChartError(
|
|
code='MISSING_VOLUME_COLUMNS',
|
|
message=f'Missing required volume columns: {missing}',
|
|
severity=ErrorSeverity.ERROR,
|
|
context={'missing_columns': missing, 'available_columns': list(df.columns)},
|
|
recovery_suggestion='Ensure data contains timestamp, open, close, volume columns'
|
|
)
|
|
self.error_handler.errors.append(error)
|
|
return False
|
|
|
|
if len(df) == 0:
|
|
error = ChartError(
|
|
code='EMPTY_VOLUME_DATA',
|
|
message='No volume data available',
|
|
severity=ErrorSeverity.ERROR,
|
|
context={'data_count': 0},
|
|
recovery_suggestion='Check data source or time range'
|
|
)
|
|
self.error_handler.errors.append(error)
|
|
return False
|
|
|
|
# Check if volume data exists and is valid
|
|
valid_volume_mask = (df['volume'] >= 0) & pd.notna(df['volume'])
|
|
valid_volume_count = valid_volume_mask.sum()
|
|
|
|
if valid_volume_count == 0:
|
|
error = ChartError(
|
|
code='NO_VALID_VOLUME',
|
|
message='No valid volume data found',
|
|
severity=ErrorSeverity.WARNING,
|
|
context={'total_records': len(df), 'valid_volume': 0},
|
|
recovery_suggestion='Volume chart will be skipped'
|
|
)
|
|
self.error_handler.warnings.append(error)
|
|
|
|
elif valid_volume_count < len(df) * 0.5: # Less than 50% valid
|
|
error = ChartError(
|
|
code='MOSTLY_INVALID_VOLUME',
|
|
message=f'Most volume data is invalid: {valid_volume_count}/{len(df)} valid',
|
|
severity=ErrorSeverity.WARNING,
|
|
context={'total_records': len(df), 'valid_volume': valid_volume_count},
|
|
recovery_suggestion='Invalid volume records will be filtered out'
|
|
)
|
|
self.error_handler.warnings.append(error)
|
|
|
|
elif df['volume'].sum() <= 0:
|
|
error = ChartError(
|
|
code='ZERO_VOLUME_TOTAL',
|
|
message='Total volume is zero or negative',
|
|
severity=ErrorSeverity.WARNING,
|
|
context={'volume_sum': float(df['volume'].sum())},
|
|
recovery_suggestion='Volume chart may not be meaningful'
|
|
)
|
|
self.error_handler.warnings.append(error)
|
|
|
|
return parent_valid and valid_volume_count > 0
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Volume layer: Error validating volume data: {e}")
|
|
error = ChartError(
|
|
code='VOLUME_VALIDATION_ERROR',
|
|
message=f'Volume validation failed: {str(e)}',
|
|
severity=ErrorSeverity.ERROR,
|
|
context={'exception': str(e)}
|
|
)
|
|
self.error_handler.errors.append(error)
|
|
return False
|
|
|
|
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
|
"""
|
|
Render volume bars with error handling and recovery.
|
|
|
|
Args:
|
|
fig: Target figure (must be subplot figure)
|
|
data: OHLCV data
|
|
**kwargs: Additional parameters (row, col for subplots)
|
|
|
|
Returns:
|
|
Figure with volume trace added or error handling
|
|
"""
|
|
try:
|
|
# Validate data
|
|
if not self.validate_data(data):
|
|
# Check if we can skip gracefully (warnings only)
|
|
if not self.error_handler.errors and self.error_handler.warnings:
|
|
self.logger.debug("Skipping volume layer due to warnings")
|
|
return fig
|
|
else:
|
|
self.logger.error("Volume layer: Invalid data for volume layer")
|
|
return fig
|
|
|
|
# Clean and prepare data
|
|
clean_data = self._clean_volume_data(data)
|
|
if clean_data.empty:
|
|
self.logger.debug("No valid volume data after cleaning")
|
|
return fig
|
|
|
|
# Calculate bar colors based on price movement
|
|
style = self.config.style
|
|
bullish_color = style.get('bullish_color', '#00C851')
|
|
bearish_color = style.get('bearish_color', '#FF4444')
|
|
opacity = style.get('opacity', 0.7)
|
|
|
|
colors = [
|
|
bullish_color if close >= open_price else bearish_color
|
|
for close, open_price in zip(clean_data['close'], clean_data['open'])
|
|
]
|
|
|
|
# Create volume bar trace
|
|
volume_bars = go.Bar(
|
|
x=clean_data['timestamp'],
|
|
y=clean_data['volume'],
|
|
name='Volume',
|
|
marker_color=colors,
|
|
opacity=opacity,
|
|
showlegend=False
|
|
)
|
|
|
|
# Add to figure
|
|
row = kwargs.get('row', 2) # Default to row 2 for volume
|
|
col = kwargs.get('col', 1)
|
|
|
|
fig.add_trace(volume_bars, row=row, col=col)
|
|
|
|
self.logger.debug(f"Volume layer: Rendered volume layer with {len(clean_data)} bars")
|
|
return fig
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Volume layer: Error rendering volume layer: {e}")
|
|
return fig
|
|
|
|
def _clean_volume_data(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
"""Clean and validate volume data"""
|
|
try:
|
|
clean_data = data.copy()
|
|
|
|
# Remove rows with invalid volume
|
|
valid_mask = (clean_data['volume'] >= 0) & pd.notna(clean_data['volume'])
|
|
initial_count = len(clean_data)
|
|
clean_data = clean_data[valid_mask]
|
|
|
|
if len(clean_data) < initial_count:
|
|
removed_count = initial_count - len(clean_data)
|
|
self.logger.info(f"Removed {removed_count} invalid volume records")
|
|
|
|
# Ensure timestamp is properly formatted
|
|
if not pd.api.types.is_datetime64_any_dtype(clean_data['timestamp']):
|
|
clean_data['timestamp'] = pd.to_datetime(clean_data['timestamp'])
|
|
|
|
# Sort by timestamp
|
|
clean_data = clean_data.sort_values('timestamp')
|
|
|
|
return clean_data
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Volume layer: Error cleaning volume data: {e}")
|
|
return pd.DataFrame()
|
|
|
|
|
|
class LayerManager:
|
|
"""
|
|
Manager class for coordinating multiple chart layers.
|
|
|
|
This class handles the orchestration of multiple layers, including
|
|
setting up subplots and rendering layers in the correct order.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the layer manager."""
|
|
self.layers: List[BaseLayer] = []
|
|
self.logger = logger
|
|
|
|
def add_layer(self, layer: BaseLayer) -> None:
|
|
"""
|
|
Add a layer to the manager.
|
|
|
|
Args:
|
|
layer: Chart layer to add
|
|
"""
|
|
self.layers.append(layer)
|
|
self.logger.debug(f"Added layer: {layer.config.name}")
|
|
|
|
def remove_layer(self, layer_name: str) -> bool:
|
|
"""
|
|
Remove a layer by name.
|
|
|
|
Args:
|
|
layer_name: Name of layer to remove
|
|
|
|
Returns:
|
|
True if layer was removed, False if not found
|
|
"""
|
|
for i, layer in enumerate(self.layers):
|
|
if layer.config.name == layer_name:
|
|
self.layers.pop(i)
|
|
self.logger.debug(f"Removed layer: {layer_name}")
|
|
return True
|
|
|
|
self.logger.warning(f"Layer not found for removal: {layer_name}")
|
|
return False
|
|
|
|
def get_enabled_layers(self) -> List[BaseLayer]:
|
|
"""Get list of enabled layers."""
|
|
return [layer for layer in self.layers if layer.is_enabled()]
|
|
|
|
def get_overlay_layers(self) -> List[BaseLayer]:
|
|
"""Get layers that render on the main chart."""
|
|
return [layer for layer in self.get_enabled_layers() if layer.is_overlay()]
|
|
|
|
def get_subplot_layers(self) -> Dict[int, List[BaseLayer]]:
|
|
"""Get layers grouped by subplot row."""
|
|
subplot_layers = {}
|
|
|
|
for layer in self.get_enabled_layers():
|
|
if not layer.is_overlay():
|
|
row = layer.get_subplot_row()
|
|
if row not in subplot_layers:
|
|
subplot_layers[row] = []
|
|
subplot_layers[row].append(layer)
|
|
|
|
return subplot_layers
|
|
|
|
def calculate_subplot_layout(self) -> Dict[str, Any]:
|
|
"""
|
|
Calculate subplot configuration based on layers.
|
|
|
|
Returns:
|
|
Dict with subplot configuration parameters
|
|
"""
|
|
subplot_layers = self.get_subplot_layers()
|
|
|
|
if not subplot_layers:
|
|
# No subplots needed
|
|
return {
|
|
'rows': 1,
|
|
'cols': 1,
|
|
'subplot_titles': None,
|
|
'row_heights': None
|
|
}
|
|
|
|
# Reassign subplot rows dynamically to ensure proper ordering
|
|
self._reassign_subplot_rows()
|
|
|
|
# Recalculate after reassignment
|
|
subplot_layers = self.get_subplot_layers()
|
|
|
|
# Calculate number of rows (main chart + subplots)
|
|
max_subplot_row = max(subplot_layers.keys()) if subplot_layers else 0
|
|
total_rows = max(1, max_subplot_row) # Row numbers are 1-indexed, so max_subplot_row is the total rows needed
|
|
|
|
# Create subplot titles
|
|
subplot_titles = ['Price'] # Main chart
|
|
for row in range(2, total_rows + 1):
|
|
if row in subplot_layers:
|
|
# Use the first layer's name as the subtitle
|
|
layer_names = [layer.config.name for layer in subplot_layers[row]]
|
|
subplot_titles.append(' / '.join(layer_names).title())
|
|
else:
|
|
subplot_titles.append(f'Subplot {row}')
|
|
|
|
# Calculate row heights based on subplot height ratios
|
|
row_heights = self._calculate_dynamic_row_heights(subplot_layers, total_rows)
|
|
|
|
return {
|
|
'rows': total_rows,
|
|
'cols': 1,
|
|
'subplot_titles': subplot_titles,
|
|
'row_heights': row_heights,
|
|
'shared_xaxes': True,
|
|
'vertical_spacing': 0.03
|
|
}
|
|
|
|
def _reassign_subplot_rows(self) -> None:
|
|
"""
|
|
Reassign subplot rows to ensure proper sequential ordering.
|
|
|
|
This method dynamically assigns subplot rows starting from row 2,
|
|
ensuring no gaps in the subplot layout.
|
|
"""
|
|
subplot_layers = []
|
|
|
|
# Collect all subplot layers
|
|
for layer in self.get_enabled_layers():
|
|
if not layer.is_overlay():
|
|
subplot_layers.append(layer)
|
|
|
|
# Sort by priority: volume first, then by current subplot row
|
|
def layer_priority(layer):
|
|
# Volume gets highest priority (0), then by current row
|
|
if hasattr(layer, 'config') and layer.config.name == 'volume':
|
|
return (0, layer.get_subplot_row() or 999)
|
|
else:
|
|
return (1, layer.get_subplot_row() or 999)
|
|
|
|
subplot_layers.sort(key=layer_priority)
|
|
|
|
# Reassign rows starting from 2
|
|
for i, layer in enumerate(subplot_layers):
|
|
new_row = i + 2 # Start from row 2 (row 1 is main chart)
|
|
layer.config.subplot_row = new_row
|
|
self.logger.debug(f"Assigned {layer.config.name} to subplot row {new_row}")
|
|
|
|
def _calculate_dynamic_row_heights(self, subplot_layers: Dict[int, List], total_rows: int) -> List[float]:
|
|
"""
|
|
Calculate row heights based on subplot height ratios.
|
|
|
|
Args:
|
|
subplot_layers: Dictionary of subplot layers by row
|
|
total_rows: Total number of rows
|
|
|
|
Returns:
|
|
List of height ratios for each row
|
|
"""
|
|
if total_rows == 1:
|
|
return [1.0] # Single row gets full height
|
|
|
|
# Calculate total requested subplot height
|
|
total_subplot_ratio = 0.0
|
|
subplot_ratios = {}
|
|
|
|
for row in range(2, total_rows + 1):
|
|
if row in subplot_layers:
|
|
# Get height ratio from first layer in the row
|
|
layer = subplot_layers[row][0]
|
|
if hasattr(layer, 'get_subplot_height_ratio'):
|
|
ratio = layer.get_subplot_height_ratio()
|
|
else:
|
|
ratio = 0.25 # Default ratio
|
|
subplot_ratios[row] = ratio
|
|
total_subplot_ratio += ratio
|
|
else:
|
|
subplot_ratios[row] = 0.25 # Default for empty rows
|
|
total_subplot_ratio += 0.25
|
|
|
|
# Ensure total doesn't exceed reasonable limits
|
|
max_subplot_ratio = 0.6 # Maximum 60% for all subplots
|
|
if total_subplot_ratio > max_subplot_ratio:
|
|
# Scale down proportionally
|
|
scale_factor = max_subplot_ratio / total_subplot_ratio
|
|
for row in subplot_ratios:
|
|
subplot_ratios[row] *= scale_factor
|
|
total_subplot_ratio = max_subplot_ratio
|
|
|
|
# Main chart gets remaining space
|
|
main_chart_ratio = 1.0 - total_subplot_ratio
|
|
|
|
# Build final height list
|
|
row_heights = [main_chart_ratio] # Main chart
|
|
for row in range(2, total_rows + 1):
|
|
row_heights.append(subplot_ratios.get(row, 0.25))
|
|
|
|
return row_heights
|
|
|
|
def render_all_layers(self, data: pd.DataFrame, **kwargs) -> go.Figure:
|
|
"""
|
|
Render all enabled layers onto a new figure.
|
|
|
|
Args:
|
|
data: Chart data (OHLCV format)
|
|
**kwargs: Additional rendering parameters
|
|
|
|
Returns:
|
|
Complete figure with all layers rendered
|
|
"""
|
|
try:
|
|
# Calculate subplot layout
|
|
layout_config = self.calculate_subplot_layout()
|
|
|
|
# Create figure with subplots if needed
|
|
if layout_config['rows'] > 1:
|
|
fig = make_subplots(**layout_config)
|
|
else:
|
|
fig = go.Figure()
|
|
|
|
# Render overlay layers (main chart)
|
|
overlay_layers = self.get_overlay_layers()
|
|
for layer in overlay_layers:
|
|
fig = layer.render(fig, data, row=1, col=1, **kwargs)
|
|
|
|
# Render subplot layers
|
|
subplot_layers = self.get_subplot_layers()
|
|
for row, layers in subplot_layers.items():
|
|
for layer in layers:
|
|
fig = layer.render(fig, data, row=row, col=1, **kwargs)
|
|
|
|
# Update layout styling
|
|
self._apply_layout_styling(fig, layout_config)
|
|
|
|
self.logger.debug(f"Rendered {len(self.get_enabled_layers())} layers")
|
|
return fig
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Layer manager: Error rendering layers: {e}")
|
|
# Return empty figure on error
|
|
return go.Figure()
|
|
|
|
def _apply_layout_styling(self, fig: go.Figure, layout_config: Dict[str, Any]) -> None:
|
|
"""Apply consistent styling to the figure layout."""
|
|
try:
|
|
# Basic layout settings
|
|
fig.update_layout(
|
|
template="plotly_white",
|
|
showlegend=False,
|
|
hovermode='x unified',
|
|
xaxis_rangeslider_visible=False
|
|
)
|
|
|
|
# Update axes for subplots
|
|
if layout_config['rows'] > 1:
|
|
# Update main chart axes
|
|
fig.update_yaxes(title_text="Price (USDT)", row=1, col=1)
|
|
fig.update_xaxes(showticklabels=False, row=1, col=1)
|
|
|
|
# Update subplot axes
|
|
subplot_layers = self.get_subplot_layers()
|
|
for row in range(2, layout_config['rows'] + 1):
|
|
if row in subplot_layers:
|
|
# Set y-axis title and range based on layer type
|
|
layers_in_row = subplot_layers[row]
|
|
layer = layers_in_row[0] # Use first layer for configuration
|
|
|
|
# Set y-axis title
|
|
if hasattr(layer, 'config') and hasattr(layer.config, 'indicator_type'):
|
|
indicator_type = layer.config.indicator_type
|
|
if indicator_type == 'rsi':
|
|
fig.update_yaxes(title_text="RSI", row=row, col=1)
|
|
elif indicator_type == 'macd':
|
|
fig.update_yaxes(title_text="MACD", row=row, col=1)
|
|
else:
|
|
layer_names = [l.config.name for l in layers_in_row]
|
|
fig.update_yaxes(title_text=' / '.join(layer_names), row=row, col=1)
|
|
|
|
# Set fixed y-axis range if specified
|
|
if hasattr(layer, 'has_fixed_range') and layer.has_fixed_range():
|
|
y_range = layer.get_y_axis_range()
|
|
if y_range:
|
|
fig.update_yaxes(range=list(y_range), row=row, col=1)
|
|
|
|
# Only show x-axis labels on the bottom subplot
|
|
if row == layout_config['rows']:
|
|
fig.update_xaxes(title_text="Time", row=row, col=1)
|
|
else:
|
|
fig.update_xaxes(showticklabels=False, row=row, col=1)
|
|
else:
|
|
# Single chart
|
|
fig.update_layout(
|
|
xaxis_title="Time",
|
|
yaxis_title="Price (USDT)"
|
|
)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Layer manager: Error applying layout styling: {e}") |