3.4 -2.0 Indicator Layer System Implementation
Implement modular chart layers and error handling for Crypto Trading Bot Dashboard - Introduced a comprehensive chart layer system in `components/charts/layers/` to support various technical indicators and subplots. - Added base layer components including `BaseLayer`, `CandlestickLayer`, and `VolumeLayer` for flexible chart rendering. - Implemented overlay indicators such as `SMALayer`, `EMALayer`, and `BollingerBandsLayer` with robust error handling. - Created subplot layers for indicators like `RSILayer` and `MACDLayer`, enhancing visualization capabilities. - Developed a `MarketDataIntegrator` for seamless data fetching and validation, improving data quality assurance. - Enhanced error handling utilities in `components/charts/error_handling.py` to manage insufficient data scenarios effectively. - Updated documentation to reflect the new chart layer architecture and usage guidelines. - Added unit tests for all chart layer components to ensure functionality and reliability.
This commit is contained in:
parent
371c0a4591
commit
a969defe1f
110
app.py
110
app.py
@ -39,60 +39,55 @@ from components.charts import (
|
||||
# Initialize logger
|
||||
logger = get_logger("dashboard_app")
|
||||
|
||||
def create_app():
|
||||
"""Create and configure the Dash application."""
|
||||
# Create the app instance at module level
|
||||
app = dash.Dash(
|
||||
__name__,
|
||||
title="Crypto Trading Bot Dashboard",
|
||||
update_title="Loading...",
|
||||
suppress_callback_exceptions=True
|
||||
)
|
||||
|
||||
# Configure app
|
||||
app.server.secret_key = "crypto-bot-dashboard-secret-key-2024"
|
||||
|
||||
logger.info("Initializing Crypto Trading Bot Dashboard")
|
||||
|
||||
# Define basic layout
|
||||
app.layout = html.Div([
|
||||
# Header
|
||||
html.Div([
|
||||
html.H1("🚀 Crypto Trading Bot Dashboard",
|
||||
style={'margin': '0', 'color': '#2c3e50'}),
|
||||
html.P("Real-time monitoring and bot management",
|
||||
style={'margin': '5px 0 0 0', 'color': '#7f8c8d'})
|
||||
], style={
|
||||
'padding': '20px',
|
||||
'background-color': '#ecf0f1',
|
||||
'border-bottom': '2px solid #bdc3c7'
|
||||
}),
|
||||
|
||||
# Initialize Dash app
|
||||
app = dash.Dash(
|
||||
__name__,
|
||||
title="Crypto Trading Bot Dashboard",
|
||||
update_title="Loading...",
|
||||
suppress_callback_exceptions=True
|
||||
)
|
||||
# Navigation tabs
|
||||
dcc.Tabs(id="main-tabs", value='market-data', children=[
|
||||
dcc.Tab(label='📊 Market Data', value='market-data'),
|
||||
dcc.Tab(label='🤖 Bot Management', value='bot-management'),
|
||||
dcc.Tab(label='📈 Performance', value='performance'),
|
||||
dcc.Tab(label='⚙️ System Health', value='system-health'),
|
||||
], style={'margin': '10px 20px'}),
|
||||
|
||||
# Configure app
|
||||
app.server.secret_key = "crypto-bot-dashboard-secret-key-2024"
|
||||
# Main content area
|
||||
html.Div(id='tab-content', style={'padding': '20px'}),
|
||||
|
||||
logger.info("Initializing Crypto Trading Bot Dashboard")
|
||||
# Auto-refresh interval for real-time updates
|
||||
dcc.Interval(
|
||||
id='interval-component',
|
||||
interval=5000, # Update every 5 seconds
|
||||
n_intervals=0
|
||||
),
|
||||
|
||||
# Define basic layout
|
||||
app.layout = html.Div([
|
||||
# Header
|
||||
html.Div([
|
||||
html.H1("🚀 Crypto Trading Bot Dashboard",
|
||||
style={'margin': '0', 'color': '#2c3e50'}),
|
||||
html.P("Real-time monitoring and bot management",
|
||||
style={'margin': '5px 0 0 0', 'color': '#7f8c8d'})
|
||||
], style={
|
||||
'padding': '20px',
|
||||
'background-color': '#ecf0f1',
|
||||
'border-bottom': '2px solid #bdc3c7'
|
||||
}),
|
||||
|
||||
# Navigation tabs
|
||||
dcc.Tabs(id="main-tabs", value='market-data', children=[
|
||||
dcc.Tab(label='📊 Market Data', value='market-data'),
|
||||
dcc.Tab(label='🤖 Bot Management', value='bot-management'),
|
||||
dcc.Tab(label='📈 Performance', value='performance'),
|
||||
dcc.Tab(label='⚙️ System Health', value='system-health'),
|
||||
], style={'margin': '10px 20px'}),
|
||||
|
||||
# Main content area
|
||||
html.Div(id='tab-content', style={'padding': '20px'}),
|
||||
|
||||
# Auto-refresh interval for real-time updates
|
||||
dcc.Interval(
|
||||
id='interval-component',
|
||||
interval=5000, # Update every 5 seconds
|
||||
n_intervals=0
|
||||
),
|
||||
|
||||
# Store components for data sharing between callbacks
|
||||
dcc.Store(id='market-data-store'),
|
||||
dcc.Store(id='bot-status-store'),
|
||||
])
|
||||
|
||||
return app
|
||||
# Store components for data sharing between callbacks
|
||||
dcc.Store(id='market-data-store'),
|
||||
dcc.Store(id='bot-status-store'),
|
||||
])
|
||||
|
||||
def get_market_data_layout():
|
||||
"""Create the market data visualization layout."""
|
||||
@ -209,11 +204,8 @@ def get_system_health_layout():
|
||||
], style={'margin': '20px 0'})
|
||||
])
|
||||
|
||||
# Create the app instance
|
||||
app = create_app()
|
||||
|
||||
# Tab switching callback
|
||||
@callback(
|
||||
@app.callback(
|
||||
Output('tab-content', 'children'),
|
||||
Input('main-tabs', 'value')
|
||||
)
|
||||
@ -231,7 +223,7 @@ def render_tab_content(active_tab):
|
||||
return html.Div("Tab not found")
|
||||
|
||||
# Market data chart callback
|
||||
@callback(
|
||||
@app.callback(
|
||||
Output('price-chart', 'figure'),
|
||||
[Input('symbol-dropdown', 'value'),
|
||||
Input('timeframe-dropdown', 'value'),
|
||||
@ -253,7 +245,7 @@ def update_price_chart(symbol, timeframe, n_intervals):
|
||||
return create_error_chart(f"Error loading chart: {str(e)}")
|
||||
|
||||
# Market statistics callback
|
||||
@callback(
|
||||
@app.callback(
|
||||
Output('market-stats', 'children'),
|
||||
[Input('symbol-dropdown', 'value'),
|
||||
Input('interval-component', 'n_intervals')]
|
||||
@ -279,7 +271,7 @@ def update_market_stats(symbol, n_intervals):
|
||||
return html.Div("Error loading market statistics")
|
||||
|
||||
# System health callbacks
|
||||
@callback(
|
||||
@app.callback(
|
||||
Output('database-status', 'children'),
|
||||
Input('interval-component', 'n_intervals')
|
||||
)
|
||||
@ -311,7 +303,7 @@ def update_database_status(n_intervals):
|
||||
html.P(f"Error: {str(e)}", style={'color': '#7f8c8d', 'font-size': '12px'})
|
||||
])
|
||||
|
||||
@callback(
|
||||
@app.callback(
|
||||
Output('data-status', 'children'),
|
||||
[Input('symbol-dropdown', 'value'),
|
||||
Input('timeframe-dropdown', 'value'),
|
||||
@ -362,5 +354,5 @@ def main():
|
||||
logger.error(f"Failed to start dashboard: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -13,12 +13,63 @@ Main Components:
|
||||
- Configuration System: Strategy-driven chart configs
|
||||
"""
|
||||
|
||||
import plotly.graph_objects as go
|
||||
from .builder import ChartBuilder
|
||||
from .utils import (
|
||||
validate_market_data,
|
||||
prepare_chart_data,
|
||||
get_indicator_colors
|
||||
)
|
||||
from .config import (
|
||||
get_available_indicators,
|
||||
calculate_indicators,
|
||||
get_overlay_indicators,
|
||||
get_subplot_indicators,
|
||||
get_indicator_display_config
|
||||
)
|
||||
from .data_integration import (
|
||||
MarketDataIntegrator,
|
||||
DataIntegrationConfig,
|
||||
get_market_data_integrator,
|
||||
fetch_indicator_data,
|
||||
check_symbol_data_quality
|
||||
)
|
||||
from .error_handling import (
|
||||
ChartErrorHandler,
|
||||
ChartError,
|
||||
ErrorSeverity,
|
||||
InsufficientDataError,
|
||||
DataValidationError,
|
||||
IndicatorCalculationError,
|
||||
DataConnectionError,
|
||||
check_data_sufficiency,
|
||||
get_error_message,
|
||||
create_error_annotation
|
||||
)
|
||||
|
||||
# Layer imports with error handling
|
||||
from .layers.base import (
|
||||
LayerConfig,
|
||||
BaseLayer,
|
||||
CandlestickLayer,
|
||||
VolumeLayer,
|
||||
LayerManager
|
||||
)
|
||||
|
||||
from .layers.indicators import (
|
||||
IndicatorLayerConfig,
|
||||
BaseIndicatorLayer,
|
||||
SMALayer,
|
||||
EMALayer,
|
||||
BollingerBandsLayer
|
||||
)
|
||||
|
||||
from .layers.subplots import (
|
||||
SubplotLayerConfig,
|
||||
BaseSubplotLayer,
|
||||
RSILayer,
|
||||
MACDLayer
|
||||
)
|
||||
|
||||
# Version information
|
||||
__version__ = "0.1.0"
|
||||
@ -26,35 +77,130 @@ __package_name__ = "charts"
|
||||
|
||||
# Public API exports
|
||||
__all__ = [
|
||||
# Core components
|
||||
"ChartBuilder",
|
||||
"validate_market_data",
|
||||
"prepare_chart_data",
|
||||
"get_indicator_colors",
|
||||
|
||||
# Chart creation functions
|
||||
"create_candlestick_chart",
|
||||
"create_strategy_chart",
|
||||
"get_supported_symbols",
|
||||
"get_supported_timeframes",
|
||||
"create_empty_chart",
|
||||
"create_error_chart",
|
||||
|
||||
# Data integration
|
||||
"MarketDataIntegrator",
|
||||
"DataIntegrationConfig",
|
||||
"get_market_data_integrator",
|
||||
"fetch_indicator_data",
|
||||
"check_symbol_data_quality",
|
||||
|
||||
# Error handling
|
||||
"ChartErrorHandler",
|
||||
"ChartError",
|
||||
"ErrorSeverity",
|
||||
"InsufficientDataError",
|
||||
"DataValidationError",
|
||||
"IndicatorCalculationError",
|
||||
"DataConnectionError",
|
||||
"check_data_sufficiency",
|
||||
"get_error_message",
|
||||
"create_error_annotation",
|
||||
|
||||
# Utility functions
|
||||
"get_supported_symbols",
|
||||
"get_supported_timeframes",
|
||||
"get_market_statistics",
|
||||
"check_data_availability",
|
||||
"create_data_status_indicator",
|
||||
"create_error_chart"
|
||||
|
||||
# Base layers
|
||||
"LayerConfig",
|
||||
"BaseLayer",
|
||||
"CandlestickLayer",
|
||||
"VolumeLayer",
|
||||
"LayerManager",
|
||||
|
||||
# Indicator layers
|
||||
"IndicatorLayerConfig",
|
||||
"BaseIndicatorLayer",
|
||||
"SMALayer",
|
||||
"EMALayer",
|
||||
"BollingerBandsLayer",
|
||||
|
||||
# Subplot layers
|
||||
"SubplotLayerConfig",
|
||||
"BaseSubplotLayer",
|
||||
"RSILayer",
|
||||
"MACDLayer",
|
||||
|
||||
# Convenience functions
|
||||
"create_basic_chart",
|
||||
"create_indicator_chart"
|
||||
]
|
||||
|
||||
def create_candlestick_chart(symbol: str, timeframe: str, days_back: int = 7, **kwargs):
|
||||
# Initialize logger
|
||||
from utils.logger import get_logger
|
||||
logger = get_logger("charts")
|
||||
|
||||
def create_candlestick_chart(symbol: str, timeframe: str, days_back: int = 7, **kwargs) -> go.Figure:
|
||||
"""
|
||||
Convenience function to create a basic candlestick chart.
|
||||
Create a candlestick chart with enhanced data integration.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair (e.g., 'BTC-USDT')
|
||||
timeframe: Timeframe (e.g., '1h', '1d')
|
||||
days_back: Number of days to look back
|
||||
**kwargs: Additional parameters for chart customization
|
||||
**kwargs: Additional chart parameters
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
Plotly figure with candlestick chart
|
||||
"""
|
||||
builder = ChartBuilder()
|
||||
return builder.create_candlestick_chart(symbol, timeframe, days_back, **kwargs)
|
||||
|
||||
# Check data quality first
|
||||
data_quality = builder.check_data_quality(symbol, timeframe)
|
||||
if not data_quality['available']:
|
||||
logger.warning(f"Data not available for {symbol} {timeframe}: {data_quality['message']}")
|
||||
return builder._create_error_chart(f"No data available: {data_quality['message']}")
|
||||
|
||||
if not data_quality['sufficient_for_indicators']:
|
||||
logger.warning(f"Insufficient data for indicators: {symbol} {timeframe}")
|
||||
|
||||
# Use enhanced data fetching
|
||||
try:
|
||||
candles = builder.fetch_market_data_enhanced(symbol, timeframe, days_back)
|
||||
if not candles:
|
||||
return builder._create_error_chart(f"No market data found for {symbol} {timeframe}")
|
||||
|
||||
# Prepare data for charting
|
||||
df = prepare_chart_data(candles)
|
||||
if df.empty:
|
||||
return builder._create_error_chart("Failed to prepare chart data")
|
||||
|
||||
# Create chart with data quality info
|
||||
fig = builder._create_candlestick_with_volume(df, symbol, timeframe)
|
||||
|
||||
# Add data quality annotation if data is stale
|
||||
if not data_quality['is_recent']:
|
||||
age_hours = data_quality['data_age_minutes'] / 60
|
||||
fig.add_annotation(
|
||||
text=f"⚠️ Data is {age_hours:.1f}h old",
|
||||
xref="paper", yref="paper",
|
||||
x=0.02, y=0.98,
|
||||
showarrow=False,
|
||||
bgcolor="rgba(255,193,7,0.8)",
|
||||
bordercolor="orange",
|
||||
borderwidth=1
|
||||
)
|
||||
|
||||
logger.debug(f"Created enhanced candlestick chart for {symbol} {timeframe} with {len(candles)} candles")
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating enhanced candlestick chart: {e}")
|
||||
return builder._create_error_chart(f"Chart creation failed: {str(e)}")
|
||||
|
||||
def create_strategy_chart(symbol: str, timeframe: str, strategy_name: str, **kwargs):
|
||||
"""
|
||||
@ -197,4 +343,108 @@ def create_data_status_indicator(symbol: str, timeframe: str):
|
||||
def create_error_chart(error_message: str):
|
||||
"""Create an error chart with error message."""
|
||||
builder = ChartBuilder()
|
||||
return builder._create_error_chart(error_message)
|
||||
return builder._create_error_chart(error_message)
|
||||
|
||||
def create_basic_chart(symbol: str, data: list,
|
||||
indicators: list = None,
|
||||
error_handling: bool = True) -> 'go.Figure':
|
||||
"""
|
||||
Create a basic chart with error handling.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
data: OHLCV data as list of dictionaries
|
||||
indicators: List of indicator configurations
|
||||
error_handling: Whether to use comprehensive error handling
|
||||
|
||||
Returns:
|
||||
Plotly figure with chart or error display
|
||||
"""
|
||||
try:
|
||||
from plotly import graph_objects as go
|
||||
|
||||
# Initialize chart builder
|
||||
builder = ChartBuilder()
|
||||
|
||||
if error_handling:
|
||||
# Use error-aware chart creation
|
||||
error_handler = ChartErrorHandler()
|
||||
is_valid = error_handler.validate_data_sufficiency(data, indicators=indicators or [])
|
||||
|
||||
if not is_valid:
|
||||
# Create error chart
|
||||
fig = go.Figure()
|
||||
error_msg = error_handler.get_user_friendly_message()
|
||||
fig.add_annotation(create_error_annotation(error_msg, position='center'))
|
||||
fig.update_layout(
|
||||
title=f"Chart Error - {symbol}",
|
||||
xaxis={'visible': False},
|
||||
yaxis={'visible': False},
|
||||
template='plotly_white',
|
||||
height=400
|
||||
)
|
||||
return fig
|
||||
|
||||
# Create chart normally
|
||||
return builder.create_candlestick_chart(data, symbol=symbol, indicators=indicators or [])
|
||||
|
||||
except Exception as e:
|
||||
# Fallback error chart
|
||||
from plotly import graph_objects as go
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(create_error_annotation(
|
||||
f"Chart creation failed: {str(e)}",
|
||||
position='center'
|
||||
))
|
||||
fig.update_layout(
|
||||
title=f"Chart Error - {symbol}",
|
||||
template='plotly_white',
|
||||
height=400
|
||||
)
|
||||
return fig
|
||||
|
||||
def create_indicator_chart(symbol: str, data: list,
|
||||
indicator_type: str, **params) -> 'go.Figure':
|
||||
"""
|
||||
Create a chart focused on a specific indicator.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
data: OHLCV data
|
||||
indicator_type: Type of indicator ('sma', 'ema', 'bollinger_bands', 'rsi', 'macd')
|
||||
**params: Indicator parameters
|
||||
|
||||
Returns:
|
||||
Plotly figure with indicator chart
|
||||
"""
|
||||
try:
|
||||
# Map indicator types to configurations
|
||||
indicator_map = {
|
||||
'sma': {'type': 'sma', 'parameters': {'period': params.get('period', 20)}},
|
||||
'ema': {'type': 'ema', 'parameters': {'period': params.get('period', 20)}},
|
||||
'bollinger_bands': {
|
||||
'type': 'bollinger_bands',
|
||||
'parameters': {
|
||||
'period': params.get('period', 20),
|
||||
'std_dev': params.get('std_dev', 2)
|
||||
}
|
||||
},
|
||||
'rsi': {'type': 'rsi', 'parameters': {'period': params.get('period', 14)}},
|
||||
'macd': {
|
||||
'type': 'macd',
|
||||
'parameters': {
|
||||
'fast_period': params.get('fast_period', 12),
|
||||
'slow_period': params.get('slow_period', 26),
|
||||
'signal_period': params.get('signal_period', 9)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if indicator_type not in indicator_map:
|
||||
raise ValueError(f"Unknown indicator type: {indicator_type}")
|
||||
|
||||
indicator_config = indicator_map[indicator_type]
|
||||
return create_basic_chart(symbol, data, indicators=[indicator_config])
|
||||
|
||||
except Exception as e:
|
||||
return create_basic_chart(symbol, data, indicators=[]) # Fallback to basic chart
|
||||
@ -38,6 +38,10 @@ class ChartBuilder:
|
||||
self.logger = logger_instance or logger
|
||||
self.db_ops = get_database_operations(self.logger)
|
||||
|
||||
# Initialize market data integrator
|
||||
from .data_integration import get_market_data_integrator
|
||||
self.data_integrator = get_market_data_integrator()
|
||||
|
||||
# Chart styling defaults
|
||||
self.default_colors = get_indicator_colors()
|
||||
self.default_height = 600
|
||||
@ -81,6 +85,38 @@ class ChartBuilder:
|
||||
self.logger.error(f"Unexpected error fetching market data: {e}")
|
||||
return []
|
||||
|
||||
def fetch_market_data_enhanced(self, symbol: str, timeframe: str,
|
||||
days_back: int = 7, exchange: str = "okx") -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Enhanced market data fetching with validation and caching.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair (e.g., 'BTC-USDT')
|
||||
timeframe: Timeframe (e.g., '1h', '1d')
|
||||
days_back: Number of days to look back
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
List of validated candle data dictionaries
|
||||
"""
|
||||
try:
|
||||
# Use the data integrator for enhanced data handling
|
||||
raw_candles, ohlcv_candles = self.data_integrator.get_market_data_for_indicators(
|
||||
symbol, timeframe, days_back, exchange
|
||||
)
|
||||
|
||||
if not raw_candles:
|
||||
self.logger.warning(f"No market data available for {symbol} {timeframe}")
|
||||
return []
|
||||
|
||||
self.logger.debug(f"Enhanced fetch: {len(raw_candles)} candles for {symbol} {timeframe}")
|
||||
return raw_candles
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in enhanced market data fetch: {e}")
|
||||
# Fallback to original method
|
||||
return self.fetch_market_data(symbol, timeframe, days_back, exchange)
|
||||
|
||||
def create_candlestick_chart(self, symbol: str, timeframe: str,
|
||||
days_back: int = 7, **kwargs) -> go.Figure:
|
||||
"""
|
||||
@ -288,4 +324,29 @@ class ChartBuilder:
|
||||
# For now, return a basic candlestick chart
|
||||
# This will be enhanced in later tasks with strategy configurations
|
||||
self.logger.info(f"Creating strategy chart for {strategy_name} (basic implementation)")
|
||||
return self.create_candlestick_chart(symbol, timeframe, **kwargs)
|
||||
return self.create_candlestick_chart(symbol, timeframe, **kwargs)
|
||||
|
||||
def check_data_quality(self, symbol: str, timeframe: str,
|
||||
exchange: str = "okx") -> Dict[str, Any]:
|
||||
"""
|
||||
Check data quality and availability for chart creation.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframe: Timeframe
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Dictionary with data quality information
|
||||
"""
|
||||
try:
|
||||
return self.data_integrator.check_data_availability(symbol, timeframe, exchange)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking data quality: {e}")
|
||||
return {
|
||||
'available': False,
|
||||
'latest_timestamp': None,
|
||||
'data_age_minutes': None,
|
||||
'sufficient_for_indicators': False,
|
||||
'message': f"Error checking data: {str(e)}"
|
||||
}
|
||||
513
components/charts/data_integration.py
Normal file
513
components/charts/data_integration.py
Normal file
@ -0,0 +1,513 @@
|
||||
"""
|
||||
Market Data Integration for Chart Layers
|
||||
|
||||
This module provides seamless integration between database market data and
|
||||
indicator layer calculations, handling data format conversions, validation,
|
||||
and optimization for real-time chart updates.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Any, Optional, Union, Tuple
|
||||
from decimal import Decimal
|
||||
from dataclasses import dataclass
|
||||
|
||||
from database.operations import get_database_operations, DatabaseOperationError
|
||||
from data.common.data_types import OHLCVCandle
|
||||
from data.common.indicators import TechnicalIndicators, IndicatorResult
|
||||
from components.charts.config.indicator_defs import convert_database_candles_to_ohlcv
|
||||
from utils.logger import get_logger
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger("data_integration")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataIntegrationConfig:
|
||||
"""Configuration for market data integration"""
|
||||
default_days_back: int = 7
|
||||
min_candles_required: int = 50
|
||||
max_candles_limit: int = 1000
|
||||
cache_timeout_minutes: int = 5
|
||||
enable_data_validation: bool = True
|
||||
enable_sparse_data_handling: bool = True
|
||||
|
||||
|
||||
class MarketDataIntegrator:
|
||||
"""
|
||||
Integrates market data from database with indicator calculations.
|
||||
|
||||
This class handles:
|
||||
- Fetching market data from database
|
||||
- Converting to indicator-compatible formats
|
||||
- Caching for performance
|
||||
- Data validation and error handling
|
||||
- Sparse data handling (gaps in time series)
|
||||
"""
|
||||
|
||||
def __init__(self, config: DataIntegrationConfig = None):
|
||||
"""
|
||||
Initialize market data integrator.
|
||||
|
||||
Args:
|
||||
config: Integration configuration
|
||||
"""
|
||||
self.config = config or DataIntegrationConfig()
|
||||
self.logger = logger
|
||||
self.db_ops = get_database_operations(self.logger)
|
||||
self.indicators = TechnicalIndicators()
|
||||
|
||||
# Simple in-memory cache for recent data
|
||||
self._cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def get_market_data_for_indicators(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
days_back: Optional[int] = None,
|
||||
exchange: str = "okx"
|
||||
) -> Tuple[List[Dict[str, Any]], List[OHLCVCandle]]:
|
||||
"""
|
||||
Fetch and prepare market data for indicator calculations.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair (e.g., 'BTC-USDT')
|
||||
timeframe: Timeframe (e.g., '1h', '1d')
|
||||
days_back: Number of days to look back
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Tuple of (raw_candles, ohlcv_candles) for different use cases
|
||||
"""
|
||||
try:
|
||||
# Use default or provided days_back
|
||||
days_back = days_back or self.config.default_days_back
|
||||
|
||||
# Check cache first
|
||||
cache_key = f"{symbol}_{timeframe}_{days_back}_{exchange}"
|
||||
cached_data = self._get_cached_data(cache_key)
|
||||
if cached_data:
|
||||
self.logger.debug(f"Using cached data for {cache_key}")
|
||||
return cached_data['raw_candles'], cached_data['ohlcv_candles']
|
||||
|
||||
# Fetch from database
|
||||
end_time = datetime.now(timezone.utc)
|
||||
start_time = end_time - timedelta(days=days_back)
|
||||
|
||||
raw_candles = self.db_ops.market_data.get_candles(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
exchange=exchange
|
||||
)
|
||||
|
||||
if not raw_candles:
|
||||
self.logger.warning(f"No market data found for {symbol} {timeframe}")
|
||||
return [], []
|
||||
|
||||
# Validate data if enabled
|
||||
if self.config.enable_data_validation:
|
||||
raw_candles = self._validate_and_clean_data(raw_candles)
|
||||
|
||||
# Handle sparse data if enabled
|
||||
if self.config.enable_sparse_data_handling:
|
||||
raw_candles = self._handle_sparse_data(raw_candles, timeframe)
|
||||
|
||||
# Convert to OHLCV format for indicators
|
||||
ohlcv_candles = convert_database_candles_to_ohlcv(raw_candles)
|
||||
|
||||
# Cache the results
|
||||
self._cache_data(cache_key, {
|
||||
'raw_candles': raw_candles,
|
||||
'ohlcv_candles': ohlcv_candles,
|
||||
'timestamp': datetime.now(timezone.utc)
|
||||
})
|
||||
|
||||
self.logger.debug(f"Fetched {len(raw_candles)} candles for {symbol} {timeframe}")
|
||||
return raw_candles, ohlcv_candles
|
||||
|
||||
except DatabaseOperationError as e:
|
||||
self.logger.error(f"Database error fetching market data: {e}")
|
||||
return [], []
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error fetching market data: {e}")
|
||||
return [], []
|
||||
|
||||
def calculate_indicators_for_symbol(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
indicator_configs: List[Dict[str, Any]],
|
||||
days_back: Optional[int] = None,
|
||||
exchange: str = "okx"
|
||||
) -> Dict[str, List[IndicatorResult]]:
|
||||
"""
|
||||
Calculate multiple indicators for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframe: Timeframe
|
||||
indicator_configs: List of indicator configurations
|
||||
days_back: Number of days to look back
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Dictionary mapping indicator names to their results
|
||||
"""
|
||||
try:
|
||||
# Get market data
|
||||
raw_candles, ohlcv_candles = self.get_market_data_for_indicators(
|
||||
symbol, timeframe, days_back, exchange
|
||||
)
|
||||
|
||||
if not ohlcv_candles:
|
||||
self.logger.warning(f"No data available for indicator calculations: {symbol} {timeframe}")
|
||||
return {}
|
||||
|
||||
# Check minimum data requirements
|
||||
if len(ohlcv_candles) < self.config.min_candles_required:
|
||||
self.logger.warning(
|
||||
f"Insufficient data for reliable indicators: {len(ohlcv_candles)} < {self.config.min_candles_required}"
|
||||
)
|
||||
|
||||
# Calculate indicators
|
||||
results = {}
|
||||
for config in indicator_configs:
|
||||
indicator_name = config.get('name', 'unknown')
|
||||
indicator_type = config.get('type', 'unknown')
|
||||
parameters = config.get('parameters', {})
|
||||
|
||||
try:
|
||||
indicator_results = self._calculate_single_indicator(
|
||||
indicator_type, ohlcv_candles, parameters
|
||||
)
|
||||
if indicator_results:
|
||||
results[indicator_name] = indicator_results
|
||||
self.logger.debug(f"Calculated {indicator_name}: {len(indicator_results)} points")
|
||||
else:
|
||||
self.logger.warning(f"No results for indicator {indicator_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error calculating indicator {indicator_name}: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error calculating indicators for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_latest_market_data(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
limit: int = 100,
|
||||
exchange: str = "okx"
|
||||
) -> Tuple[List[Dict[str, Any]], List[OHLCVCandle]]:
|
||||
"""
|
||||
Get the most recent market data for real-time updates.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframe: Timeframe
|
||||
limit: Maximum number of candles to fetch
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Tuple of (raw_candles, ohlcv_candles)
|
||||
"""
|
||||
try:
|
||||
# Calculate time range based on limit and timeframe
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
# Estimate time range based on timeframe
|
||||
timeframe_minutes = self._parse_timeframe_to_minutes(timeframe)
|
||||
start_time = end_time - timedelta(minutes=timeframe_minutes * limit * 2) # Buffer for sparse data
|
||||
|
||||
raw_candles = self.db_ops.market_data.get_candles(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
exchange=exchange
|
||||
)
|
||||
|
||||
# Limit to most recent candles
|
||||
if len(raw_candles) > limit:
|
||||
raw_candles = raw_candles[-limit:]
|
||||
|
||||
# Convert to OHLCV format
|
||||
ohlcv_candles = convert_database_candles_to_ohlcv(raw_candles)
|
||||
|
||||
self.logger.debug(f"Fetched latest {len(raw_candles)} candles for {symbol} {timeframe}")
|
||||
return raw_candles, ohlcv_candles
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error fetching latest market data: {e}")
|
||||
return [], []
|
||||
|
||||
def check_data_availability(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
exchange: str = "okx"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Check data availability and quality for a symbol/timeframe.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframe: Timeframe
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Dictionary with availability information
|
||||
"""
|
||||
try:
|
||||
# Get latest candle
|
||||
latest_candle = self.db_ops.market_data.get_latest_candle(symbol, timeframe, exchange)
|
||||
|
||||
if not latest_candle:
|
||||
return {
|
||||
'available': False,
|
||||
'latest_timestamp': None,
|
||||
'data_age_minutes': None,
|
||||
'sufficient_for_indicators': False,
|
||||
'message': f"No data available for {symbol} {timeframe}"
|
||||
}
|
||||
|
||||
# Calculate data age
|
||||
latest_time = latest_candle['timestamp']
|
||||
if latest_time.tzinfo is None:
|
||||
latest_time = latest_time.replace(tzinfo=timezone.utc)
|
||||
|
||||
data_age = datetime.now(timezone.utc) - latest_time
|
||||
data_age_minutes = data_age.total_seconds() / 60
|
||||
|
||||
# Check if we have sufficient data for indicators
|
||||
end_time = datetime.now(timezone.utc)
|
||||
start_time = end_time - timedelta(days=1) # Check last day
|
||||
|
||||
recent_candles = self.db_ops.market_data.get_candles(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
exchange=exchange
|
||||
)
|
||||
|
||||
sufficient_data = len(recent_candles) >= self.config.min_candles_required
|
||||
|
||||
return {
|
||||
'available': True,
|
||||
'latest_timestamp': latest_time,
|
||||
'data_age_minutes': data_age_minutes,
|
||||
'recent_candle_count': len(recent_candles),
|
||||
'sufficient_for_indicators': sufficient_data,
|
||||
'is_recent': data_age_minutes < 60, # Less than 1 hour old
|
||||
'message': f"Latest: {latest_time.strftime('%Y-%m-%d %H:%M:%S UTC')}, {len(recent_candles)} recent candles"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking data availability: {e}")
|
||||
return {
|
||||
'available': False,
|
||||
'latest_timestamp': None,
|
||||
'data_age_minutes': None,
|
||||
'sufficient_for_indicators': False,
|
||||
'message': f"Error checking data: {str(e)}"
|
||||
}
|
||||
|
||||
def _calculate_single_indicator(
|
||||
self,
|
||||
indicator_type: str,
|
||||
candles: List[OHLCVCandle],
|
||||
parameters: Dict[str, Any]
|
||||
) -> List[IndicatorResult]:
|
||||
"""Calculate a single indicator with given parameters."""
|
||||
try:
|
||||
if indicator_type == 'sma':
|
||||
period = parameters.get('period', 20)
|
||||
return self.indicators.sma(candles, period)
|
||||
|
||||
elif indicator_type == 'ema':
|
||||
period = parameters.get('period', 20)
|
||||
return self.indicators.ema(candles, period)
|
||||
|
||||
elif indicator_type == 'rsi':
|
||||
period = parameters.get('period', 14)
|
||||
return self.indicators.rsi(candles, period)
|
||||
|
||||
elif indicator_type == 'macd':
|
||||
fast = parameters.get('fast_period', 12)
|
||||
slow = parameters.get('slow_period', 26)
|
||||
signal = parameters.get('signal_period', 9)
|
||||
return self.indicators.macd(candles, fast, slow, signal)
|
||||
|
||||
elif indicator_type == 'bollinger_bands':
|
||||
period = parameters.get('period', 20)
|
||||
std_dev = parameters.get('std_dev', 2)
|
||||
return self.indicators.bollinger_bands(candles, period, std_dev)
|
||||
|
||||
else:
|
||||
self.logger.warning(f"Unknown indicator type: {indicator_type}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error calculating {indicator_type}: {e}")
|
||||
return []
|
||||
|
||||
def _validate_and_clean_data(self, candles: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Validate and clean market data."""
|
||||
cleaned_candles = []
|
||||
|
||||
for i, candle in enumerate(candles):
|
||||
try:
|
||||
# Check required fields
|
||||
required_fields = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
if not all(field in candle for field in required_fields):
|
||||
self.logger.warning(f"Missing fields in candle {i}")
|
||||
continue
|
||||
|
||||
# Validate OHLC relationships
|
||||
o, h, l, c = float(candle['open']), float(candle['high']), float(candle['low']), float(candle['close'])
|
||||
if not (h >= max(o, c) and l <= min(o, c)):
|
||||
self.logger.warning(f"Invalid OHLC relationship in candle {i}")
|
||||
continue
|
||||
|
||||
# Validate positive values
|
||||
if any(val <= 0 for val in [o, h, l, c]):
|
||||
self.logger.warning(f"Non-positive price in candle {i}")
|
||||
continue
|
||||
|
||||
cleaned_candles.append(candle)
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Error validating candle {i}: {e}")
|
||||
continue
|
||||
|
||||
removed_count = len(candles) - len(cleaned_candles)
|
||||
if removed_count > 0:
|
||||
self.logger.info(f"Removed {removed_count} invalid candles during validation")
|
||||
|
||||
return cleaned_candles
|
||||
|
||||
def _handle_sparse_data(self, candles: List[Dict[str, Any]], timeframe: str) -> List[Dict[str, Any]]:
|
||||
"""Handle sparse data by detecting and logging gaps."""
|
||||
if len(candles) < 2:
|
||||
return candles
|
||||
|
||||
# Calculate expected interval
|
||||
timeframe_minutes = self._parse_timeframe_to_minutes(timeframe)
|
||||
expected_interval = timedelta(minutes=timeframe_minutes)
|
||||
|
||||
gaps_detected = 0
|
||||
for i in range(1, len(candles)):
|
||||
prev_time = candles[i-1]['timestamp']
|
||||
curr_time = candles[i]['timestamp']
|
||||
|
||||
if isinstance(prev_time, str):
|
||||
prev_time = datetime.fromisoformat(prev_time.replace('Z', '+00:00'))
|
||||
if isinstance(curr_time, str):
|
||||
curr_time = datetime.fromisoformat(curr_time.replace('Z', '+00:00'))
|
||||
|
||||
actual_interval = curr_time - prev_time
|
||||
if actual_interval > expected_interval * 1.5: # Allow 50% tolerance
|
||||
gaps_detected += 1
|
||||
|
||||
if gaps_detected > 0:
|
||||
self.logger.info(f"Detected {gaps_detected} gaps in {timeframe} data (normal for sparse aggregation)")
|
||||
|
||||
return candles
|
||||
|
||||
def _parse_timeframe_to_minutes(self, timeframe: str) -> int:
|
||||
"""Parse timeframe string to minutes."""
|
||||
timeframe_map = {
|
||||
'1s': 1/60, '5s': 5/60, '10s': 10/60, '15s': 15/60, '30s': 30/60,
|
||||
'1m': 1, '5m': 5, '15m': 15, '30m': 30,
|
||||
'1h': 60, '2h': 120, '4h': 240, '6h': 360, '12h': 720,
|
||||
'1d': 1440, '3d': 4320, '1w': 10080
|
||||
}
|
||||
return timeframe_map.get(timeframe, 60) # Default to 1 hour
|
||||
|
||||
def _get_cached_data(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get data from cache if still valid."""
|
||||
if cache_key not in self._cache:
|
||||
return None
|
||||
|
||||
cached_item = self._cache[cache_key]
|
||||
cache_age = datetime.now(timezone.utc) - cached_item['timestamp']
|
||||
|
||||
if cache_age.total_seconds() > self.config.cache_timeout_minutes * 60:
|
||||
del self._cache[cache_key]
|
||||
return None
|
||||
|
||||
return cached_item
|
||||
|
||||
def _cache_data(self, cache_key: str, data: Dict[str, Any]) -> None:
|
||||
"""Cache data with timestamp."""
|
||||
# Simple cache size management
|
||||
if len(self._cache) > 50: # Limit cache size
|
||||
# Remove oldest entries
|
||||
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k]['timestamp'])
|
||||
del self._cache[oldest_key]
|
||||
|
||||
self._cache[cache_key] = data
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the data cache."""
|
||||
self._cache.clear()
|
||||
self.logger.debug("Data cache cleared")
|
||||
|
||||
|
||||
# Convenience functions for common operations
|
||||
def get_market_data_integrator(config: DataIntegrationConfig = None) -> MarketDataIntegrator:
|
||||
"""Get a configured market data integrator instance."""
|
||||
return MarketDataIntegrator(config)
|
||||
|
||||
|
||||
def fetch_indicator_data(
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
indicator_configs: List[Dict[str, Any]],
|
||||
days_back: int = 7,
|
||||
exchange: str = "okx"
|
||||
) -> Dict[str, List[IndicatorResult]]:
|
||||
"""
|
||||
Convenience function to fetch and calculate indicators.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframe: Timeframe
|
||||
indicator_configs: List of indicator configurations
|
||||
days_back: Number of days to look back
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Dictionary mapping indicator names to results
|
||||
"""
|
||||
integrator = get_market_data_integrator()
|
||||
return integrator.calculate_indicators_for_symbol(
|
||||
symbol, timeframe, indicator_configs, days_back, exchange
|
||||
)
|
||||
|
||||
|
||||
def check_symbol_data_quality(
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
exchange: str = "okx"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convenience function to check data quality for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframe: Timeframe
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Data quality information
|
||||
"""
|
||||
integrator = get_market_data_integrator()
|
||||
return integrator.check_data_availability(symbol, timeframe, exchange)
|
||||
462
components/charts/error_handling.py
Normal file
462
components/charts/error_handling.py
Normal file
@ -0,0 +1,462 @@
|
||||
"""
|
||||
Error Handling Utilities for Chart Layers
|
||||
|
||||
This module provides comprehensive error handling for chart creation,
|
||||
including custom exceptions, error recovery strategies, and user-friendly
|
||||
error messaging for various insufficient data scenarios.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Any, Optional, Union, Tuple, Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from utils.logger import get_logger
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger("chart_error_handling")
|
||||
|
||||
|
||||
class ErrorSeverity(Enum):
|
||||
"""Error severity levels for chart operations"""
|
||||
INFO = "info" # Informational, chart can proceed
|
||||
WARNING = "warning" # Warning, chart proceeds with limitations
|
||||
ERROR = "error" # Error, chart creation may fail
|
||||
CRITICAL = "critical" # Critical error, chart creation impossible
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChartError:
|
||||
"""Container for chart error information"""
|
||||
code: str
|
||||
message: str
|
||||
severity: ErrorSeverity
|
||||
context: Dict[str, Any]
|
||||
recovery_suggestion: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert error to dictionary for logging/serialization"""
|
||||
return {
|
||||
'code': self.code,
|
||||
'message': self.message,
|
||||
'severity': self.severity.value,
|
||||
'context': self.context,
|
||||
'recovery_suggestion': self.recovery_suggestion
|
||||
}
|
||||
|
||||
|
||||
class ChartDataError(Exception):
|
||||
"""Base exception for chart data-related errors"""
|
||||
def __init__(self, error: ChartError):
|
||||
self.error = error
|
||||
super().__init__(error.message)
|
||||
|
||||
|
||||
class InsufficientDataError(ChartDataError):
|
||||
"""Raised when there's insufficient data for chart/indicator calculations"""
|
||||
pass
|
||||
|
||||
|
||||
class DataValidationError(ChartDataError):
|
||||
"""Raised when data validation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class IndicatorCalculationError(ChartDataError):
|
||||
"""Raised when indicator calculations fail"""
|
||||
pass
|
||||
|
||||
|
||||
class DataConnectionError(ChartDataError):
|
||||
"""Raised when database/data source connection fails"""
|
||||
pass
|
||||
|
||||
|
||||
class DataRequirements:
|
||||
"""Data requirements checker for charts and indicators"""
|
||||
|
||||
# Minimum data requirements for different indicators
|
||||
INDICATOR_MIN_PERIODS = {
|
||||
'sma': lambda period: period + 5, # SMA needs period + buffer
|
||||
'ema': lambda period: period * 2, # EMA needs 2x period for stability
|
||||
'rsi': lambda period: period + 10, # RSI needs period + warmup
|
||||
'macd': lambda fast, slow, signal: slow + signal + 10, # MACD most demanding
|
||||
'bollinger_bands': lambda period: period + 5, # BB needs period + buffer
|
||||
'candlestick': lambda: 10, # Basic candlestick minimum
|
||||
'volume': lambda: 5 # Volume minimum
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def check_candlestick_requirements(cls, data_count: int) -> ChartError:
|
||||
"""Check if we have enough data for basic candlestick chart"""
|
||||
min_required = cls.INDICATOR_MIN_PERIODS['candlestick']()
|
||||
|
||||
if data_count == 0:
|
||||
return ChartError(
|
||||
code='NO_DATA',
|
||||
message='No market data available',
|
||||
severity=ErrorSeverity.CRITICAL,
|
||||
context={'data_count': data_count, 'required': min_required},
|
||||
recovery_suggestion='Check data collection service or select different symbol/timeframe'
|
||||
)
|
||||
elif data_count < min_required:
|
||||
return ChartError(
|
||||
code='INSUFFICIENT_CANDLESTICK_DATA',
|
||||
message=f'Insufficient data for candlestick chart: {data_count} candles (need {min_required})',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'data_count': data_count, 'required': min_required},
|
||||
recovery_suggestion='Chart will display with limited data - consider longer time range'
|
||||
)
|
||||
else:
|
||||
return ChartError(
|
||||
code='SUFFICIENT_DATA',
|
||||
message='Sufficient data for candlestick chart',
|
||||
severity=ErrorSeverity.INFO,
|
||||
context={'data_count': data_count, 'required': min_required}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_indicator_requirements(cls, indicator_type: str, data_count: int,
|
||||
parameters: Dict[str, Any]) -> ChartError:
|
||||
"""Check if we have enough data for specific indicator"""
|
||||
if indicator_type not in cls.INDICATOR_MIN_PERIODS:
|
||||
return ChartError(
|
||||
code='UNKNOWN_INDICATOR',
|
||||
message=f'Unknown indicator type: {indicator_type}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'indicator_type': indicator_type, 'data_count': data_count},
|
||||
recovery_suggestion='Check indicator type spelling or implementation'
|
||||
)
|
||||
|
||||
# Calculate minimum required data
|
||||
try:
|
||||
if indicator_type in ['sma', 'ema', 'rsi', 'bollinger_bands']:
|
||||
period = parameters.get('period', 20)
|
||||
min_required = cls.INDICATOR_MIN_PERIODS[indicator_type](period)
|
||||
elif indicator_type == 'macd':
|
||||
fast = parameters.get('fast_period', 12)
|
||||
slow = parameters.get('slow_period', 26)
|
||||
signal = parameters.get('signal_period', 9)
|
||||
min_required = cls.INDICATOR_MIN_PERIODS[indicator_type](fast, slow, signal)
|
||||
else:
|
||||
min_required = cls.INDICATOR_MIN_PERIODS[indicator_type]()
|
||||
except Exception as e:
|
||||
return ChartError(
|
||||
code='PARAMETER_ERROR',
|
||||
message=f'Invalid parameters for {indicator_type}: {e}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'indicator_type': indicator_type, 'parameters': parameters},
|
||||
recovery_suggestion='Check indicator parameters for valid values'
|
||||
)
|
||||
|
||||
if data_count < min_required:
|
||||
# Determine severity based on how insufficient the data is
|
||||
if data_count < min_required // 2:
|
||||
# Severely insufficient - less than half the required data
|
||||
severity = ErrorSeverity.ERROR
|
||||
else:
|
||||
# Slightly insufficient - can potentially adjust parameters
|
||||
severity = ErrorSeverity.WARNING
|
||||
|
||||
return ChartError(
|
||||
code='INSUFFICIENT_INDICATOR_DATA',
|
||||
message=f'Insufficient data for {indicator_type}: {data_count} candles (need {min_required})',
|
||||
severity=severity,
|
||||
context={
|
||||
'indicator_type': indicator_type,
|
||||
'data_count': data_count,
|
||||
'required': min_required,
|
||||
'parameters': parameters
|
||||
},
|
||||
recovery_suggestion=f'Increase data range to at least {min_required} candles or adjust {indicator_type} parameters'
|
||||
)
|
||||
else:
|
||||
return ChartError(
|
||||
code='SUFFICIENT_INDICATOR_DATA',
|
||||
message=f'Sufficient data for {indicator_type}',
|
||||
severity=ErrorSeverity.INFO,
|
||||
context={
|
||||
'indicator_type': indicator_type,
|
||||
'data_count': data_count,
|
||||
'required': min_required
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ErrorRecoveryStrategies:
|
||||
"""Error recovery strategies for different chart scenarios"""
|
||||
|
||||
@staticmethod
|
||||
def handle_insufficient_data(error: ChartError, fallback_options: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Handle insufficient data by providing fallback strategies"""
|
||||
strategy = {
|
||||
'can_proceed': False,
|
||||
'fallback_action': None,
|
||||
'modified_config': None,
|
||||
'user_message': error.message
|
||||
}
|
||||
|
||||
if error.code == 'INSUFFICIENT_CANDLESTICK_DATA':
|
||||
# For candlestick, we can proceed with warnings
|
||||
strategy.update({
|
||||
'can_proceed': True,
|
||||
'fallback_action': 'display_with_warning',
|
||||
'user_message': f"{error.message}. Chart will display available data."
|
||||
})
|
||||
|
||||
elif error.code == 'INSUFFICIENT_INDICATOR_DATA':
|
||||
# For indicators, try to adjust parameters or skip
|
||||
indicator_type = error.context.get('indicator_type')
|
||||
data_count = error.context.get('data_count', 0)
|
||||
|
||||
if indicator_type in ['sma', 'ema', 'bollinger_bands']:
|
||||
# Try reducing period to fit available data
|
||||
max_period = max(5, data_count // 2) # Conservative estimate
|
||||
strategy.update({
|
||||
'can_proceed': True,
|
||||
'fallback_action': 'adjust_parameters',
|
||||
'modified_config': {'period': max_period},
|
||||
'user_message': f"Adjusted {indicator_type} period to {max_period} due to limited data"
|
||||
})
|
||||
|
||||
elif indicator_type == 'rsi':
|
||||
# RSI can work with reduced period
|
||||
max_period = max(7, data_count // 3)
|
||||
strategy.update({
|
||||
'can_proceed': True,
|
||||
'fallback_action': 'adjust_parameters',
|
||||
'modified_config': {'period': max_period},
|
||||
'user_message': f"Adjusted RSI period to {max_period} due to limited data"
|
||||
})
|
||||
|
||||
else:
|
||||
# Skip the indicator entirely
|
||||
strategy.update({
|
||||
'can_proceed': True,
|
||||
'fallback_action': 'skip_indicator',
|
||||
'user_message': f"Skipped {indicator_type} due to insufficient data"
|
||||
})
|
||||
|
||||
return strategy
|
||||
|
||||
@staticmethod
|
||||
def handle_data_validation_error(error: ChartError) -> Dict[str, Any]:
|
||||
"""Handle data validation errors"""
|
||||
return {
|
||||
'can_proceed': False,
|
||||
'fallback_action': 'show_error',
|
||||
'user_message': f"Data validation failed: {error.message}",
|
||||
'recovery_suggestion': error.recovery_suggestion
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def handle_connection_error(error: ChartError) -> Dict[str, Any]:
|
||||
"""Handle database/connection errors"""
|
||||
return {
|
||||
'can_proceed': False,
|
||||
'fallback_action': 'show_error',
|
||||
'user_message': "Unable to connect to data source",
|
||||
'recovery_suggestion': "Check database connection or try again later"
|
||||
}
|
||||
|
||||
|
||||
class ChartErrorHandler:
|
||||
"""Main error handler for chart operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logger
|
||||
self.errors: List[ChartError] = []
|
||||
self.warnings: List[ChartError] = []
|
||||
|
||||
def clear_errors(self):
|
||||
"""Clear accumulated errors and warnings"""
|
||||
self.errors.clear()
|
||||
self.warnings.clear()
|
||||
|
||||
def validate_data_sufficiency(self, data: Union[pd.DataFrame, List[Dict[str, Any]]],
|
||||
chart_type: str = 'candlestick',
|
||||
indicators: List[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Validate if data is sufficient for chart and indicator requirements.
|
||||
|
||||
Args:
|
||||
data: Chart data (DataFrame or list of candle dicts)
|
||||
chart_type: Type of chart being created
|
||||
indicators: List of indicator configurations
|
||||
|
||||
Returns:
|
||||
True if data is sufficient, False otherwise
|
||||
"""
|
||||
self.clear_errors()
|
||||
|
||||
# Get data count
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data_count = len(data)
|
||||
elif isinstance(data, list):
|
||||
data_count = len(data)
|
||||
else:
|
||||
self.errors.append(ChartError(
|
||||
code='INVALID_DATA_TYPE',
|
||||
message=f'Invalid data type: {type(data)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'data_type': str(type(data))}
|
||||
))
|
||||
return False
|
||||
|
||||
# Check basic chart requirements
|
||||
chart_error = DataRequirements.check_candlestick_requirements(data_count)
|
||||
if chart_error.severity in [ErrorSeverity.WARNING]:
|
||||
self.warnings.append(chart_error)
|
||||
elif chart_error.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL]:
|
||||
self.errors.append(chart_error)
|
||||
return False
|
||||
|
||||
# Check indicator requirements
|
||||
if indicators:
|
||||
for indicator_config in indicators:
|
||||
indicator_type = indicator_config.get('type', 'unknown')
|
||||
parameters = indicator_config.get('parameters', {})
|
||||
|
||||
indicator_error = DataRequirements.check_indicator_requirements(
|
||||
indicator_type, data_count, parameters
|
||||
)
|
||||
|
||||
if indicator_error.severity == ErrorSeverity.WARNING:
|
||||
self.warnings.append(indicator_error)
|
||||
elif indicator_error.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL]:
|
||||
self.errors.append(indicator_error)
|
||||
|
||||
# Return True if no critical errors
|
||||
return len(self.errors) == 0
|
||||
|
||||
def get_error_summary(self) -> Dict[str, Any]:
|
||||
"""Get summary of all errors and warnings"""
|
||||
return {
|
||||
'has_errors': len(self.errors) > 0,
|
||||
'has_warnings': len(self.warnings) > 0,
|
||||
'error_count': len(self.errors),
|
||||
'warning_count': len(self.warnings),
|
||||
'errors': [error.to_dict() for error in self.errors],
|
||||
'warnings': [warning.to_dict() for warning in self.warnings],
|
||||
'can_proceed': len(self.errors) == 0
|
||||
}
|
||||
|
||||
def get_user_friendly_message(self) -> str:
|
||||
"""Get a user-friendly message summarizing errors and warnings"""
|
||||
if not self.errors and not self.warnings:
|
||||
return "Chart data is ready"
|
||||
|
||||
messages = []
|
||||
|
||||
if self.errors:
|
||||
error_msg = f"❌ {len(self.errors)} error(s) prevent chart creation"
|
||||
messages.append(error_msg)
|
||||
|
||||
# Add most relevant error message
|
||||
if self.errors:
|
||||
main_error = self.errors[0] # Show first error
|
||||
messages.append(f"• {main_error.message}")
|
||||
if main_error.recovery_suggestion:
|
||||
messages.append(f" 💡 {main_error.recovery_suggestion}")
|
||||
|
||||
if self.warnings:
|
||||
warning_msg = f"⚠️ {len(self.warnings)} warning(s)"
|
||||
messages.append(warning_msg)
|
||||
|
||||
# Add most relevant warning
|
||||
if self.warnings:
|
||||
main_warning = self.warnings[0]
|
||||
messages.append(f"• {main_warning.message}")
|
||||
|
||||
return "\n".join(messages)
|
||||
|
||||
def apply_error_recovery(self, error: ChartError,
|
||||
fallback_options: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Apply error recovery strategy for a specific error"""
|
||||
fallback_options = fallback_options or {}
|
||||
|
||||
if error.code.startswith('INSUFFICIENT'):
|
||||
return ErrorRecoveryStrategies.handle_insufficient_data(error, fallback_options)
|
||||
elif 'VALIDATION' in error.code:
|
||||
return ErrorRecoveryStrategies.handle_data_validation_error(error)
|
||||
elif 'CONNECTION' in error.code:
|
||||
return ErrorRecoveryStrategies.handle_connection_error(error)
|
||||
else:
|
||||
# Default recovery strategy
|
||||
return {
|
||||
'can_proceed': False,
|
||||
'fallback_action': 'show_error',
|
||||
'user_message': error.message,
|
||||
'recovery_suggestion': error.recovery_suggestion
|
||||
}
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def check_data_sufficiency(data: Union[pd.DataFrame, List[Dict[str, Any]]],
|
||||
indicators: List[Dict[str, Any]] = None) -> Tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
Convenience function to check data sufficiency.
|
||||
|
||||
Args:
|
||||
data: Chart data
|
||||
indicators: List of indicator configurations
|
||||
|
||||
Returns:
|
||||
Tuple of (is_sufficient, error_summary)
|
||||
"""
|
||||
handler = ChartErrorHandler()
|
||||
is_sufficient = handler.validate_data_sufficiency(data, indicators=indicators)
|
||||
return is_sufficient, handler.get_error_summary()
|
||||
|
||||
|
||||
def get_error_message(data: Union[pd.DataFrame, List[Dict[str, Any]]],
|
||||
indicators: List[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Get user-friendly error message for data issues.
|
||||
|
||||
Args:
|
||||
data: Chart data
|
||||
indicators: List of indicator configurations
|
||||
|
||||
Returns:
|
||||
User-friendly error message
|
||||
"""
|
||||
handler = ChartErrorHandler()
|
||||
handler.validate_data_sufficiency(data, indicators=indicators)
|
||||
return handler.get_user_friendly_message()
|
||||
|
||||
|
||||
def create_error_annotation(error_message: str, position: str = "top") -> Dict[str, Any]:
|
||||
"""
|
||||
Create a Plotly annotation for error display.
|
||||
|
||||
Args:
|
||||
error_message: Error message to display
|
||||
position: Position of annotation ('top', 'center', 'bottom')
|
||||
|
||||
Returns:
|
||||
Plotly annotation configuration
|
||||
"""
|
||||
positions = {
|
||||
'top': {'x': 0.5, 'y': 0.9},
|
||||
'center': {'x': 0.5, 'y': 0.5},
|
||||
'bottom': {'x': 0.5, 'y': 0.1}
|
||||
}
|
||||
|
||||
pos = positions.get(position, positions['center'])
|
||||
|
||||
return {
|
||||
'text': error_message,
|
||||
'xref': 'paper',
|
||||
'yref': 'paper',
|
||||
'x': pos['x'],
|
||||
'y': pos['y'],
|
||||
'xanchor': 'center',
|
||||
'yanchor': 'middle',
|
||||
'showarrow': False,
|
||||
'font': {'size': 14, 'color': '#e74c3c'},
|
||||
'bgcolor': 'rgba(255,255,255,0.8)',
|
||||
'bordercolor': '#e74c3c',
|
||||
'borderwidth': 1
|
||||
}
|
||||
@ -1,13 +1,89 @@
|
||||
"""
|
||||
Chart Layers Package
|
||||
|
||||
This package contains the modular chart layer system for rendering different
|
||||
chart components including candlesticks, indicators, and signals.
|
||||
This package contains the modular layer system for building complex charts
|
||||
with multiple indicators, signals, and subplots.
|
||||
|
||||
Components:
|
||||
- BaseChartLayer: Abstract base class for all layers
|
||||
- CandlestickLayer: OHLC price chart layer
|
||||
- VolumeLayer: Volume subplot layer
|
||||
- LayerManager: Orchestrates multiple layers
|
||||
- SMALayer: Simple Moving Average indicator overlay
|
||||
- EMALayer: Exponential Moving Average indicator overlay
|
||||
- BollingerBandsLayer: Bollinger Bands overlay with fill area
|
||||
- RSILayer: RSI oscillator subplot
|
||||
- MACDLayer: MACD lines and histogram subplot
|
||||
"""
|
||||
|
||||
# Package metadata
|
||||
from .base import (
|
||||
BaseChartLayer,
|
||||
CandlestickLayer,
|
||||
VolumeLayer,
|
||||
LayerManager,
|
||||
LayerConfig
|
||||
)
|
||||
|
||||
from .indicators import (
|
||||
BaseIndicatorLayer,
|
||||
IndicatorLayerConfig,
|
||||
SMALayer,
|
||||
EMALayer,
|
||||
BollingerBandsLayer,
|
||||
create_sma_layer,
|
||||
create_ema_layer,
|
||||
create_bollinger_bands_layer,
|
||||
create_common_ma_layers,
|
||||
create_common_overlay_indicators
|
||||
)
|
||||
|
||||
from .subplots import (
|
||||
BaseSubplotLayer,
|
||||
SubplotLayerConfig,
|
||||
RSILayer,
|
||||
MACDLayer,
|
||||
create_rsi_layer,
|
||||
create_macd_layer,
|
||||
create_common_subplot_indicators
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base layers
|
||||
'BaseChartLayer',
|
||||
'CandlestickLayer',
|
||||
'VolumeLayer',
|
||||
'LayerManager',
|
||||
'LayerConfig',
|
||||
|
||||
# Indicator layers (overlays)
|
||||
'BaseIndicatorLayer',
|
||||
'IndicatorLayerConfig',
|
||||
'SMALayer',
|
||||
'EMALayer',
|
||||
'BollingerBandsLayer',
|
||||
|
||||
# Subplot layers
|
||||
'BaseSubplotLayer',
|
||||
'SubplotLayerConfig',
|
||||
'RSILayer',
|
||||
'MACDLayer',
|
||||
|
||||
# Convenience functions
|
||||
'create_sma_layer',
|
||||
'create_ema_layer',
|
||||
'create_bollinger_bands_layer',
|
||||
'create_common_ma_layers',
|
||||
'create_common_overlay_indicators',
|
||||
'create_rsi_layer',
|
||||
'create_macd_layer',
|
||||
'create_common_subplot_indicators'
|
||||
]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__package_name__ = "layers"
|
||||
|
||||
# Package metadata
|
||||
# __version__ = "0.1.0"
|
||||
# __package_name__ = "layers"
|
||||
|
||||
# Layers will be imported once they are created
|
||||
# from .base import BaseCandlestickLayer
|
||||
@ -16,9 +92,9 @@ __package_name__ = "layers"
|
||||
# from .signals import SignalLayer
|
||||
|
||||
# Public exports (will be populated as layers are implemented)
|
||||
__all__ = [
|
||||
# "BaseCandlestickLayer",
|
||||
# "IndicatorLayer",
|
||||
# "SubplotManager",
|
||||
# "SignalLayer"
|
||||
]
|
||||
# __all__ = [
|
||||
# # "BaseCandlestickLayer",
|
||||
# # "IndicatorLayer",
|
||||
# # "SubplotManager",
|
||||
# # "SignalLayer"
|
||||
# ]
|
||||
952
components/charts/layers/base.py
Normal file
952
components/charts/layers/base.py
Normal file
@ -0,0 +1,952 @@
|
||||
"""
|
||||
Base Chart Layer Components
|
||||
|
||||
This module contains the foundational layer classes that serve as building blocks
|
||||
for all chart components including candlestick charts, indicators, and signals.
|
||||
"""
|
||||
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from utils.logger import get_logger
|
||||
from ..error_handling import (
|
||||
ChartErrorHandler, ChartError, ErrorSeverity,
|
||||
InsufficientDataError, DataValidationError, IndicatorCalculationError,
|
||||
create_error_annotation, get_error_message
|
||||
)
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger("chart_layers")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerConfig:
|
||||
"""Configuration for chart layers"""
|
||||
name: str
|
||||
enabled: bool = True
|
||||
color: Optional[str] = None
|
||||
style: Dict[str, Any] = None
|
||||
subplot_row: Optional[int] = None # None = main chart, 1+ = subplot row
|
||||
|
||||
def __post_init__(self):
|
||||
if self.style is None:
|
||||
self.style = {}
|
||||
|
||||
|
||||
class BaseLayer:
|
||||
"""
|
||||
Base class for all chart layers providing common functionality
|
||||
for data validation, error handling, and trace management.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LayerConfig):
|
||||
self.config = config
|
||||
self.logger = get_logger(f"chart_layer_{self.__class__.__name__.lower()}")
|
||||
self.error_handler = ChartErrorHandler()
|
||||
self.traces = []
|
||||
self._is_valid = False
|
||||
self._error_message = None
|
||||
|
||||
def validate_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]]) -> bool:
|
||||
"""
|
||||
Validate input data for layer requirements.
|
||||
|
||||
Args:
|
||||
data: Input data to validate
|
||||
|
||||
Returns:
|
||||
True if data is valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.error_handler.clear_errors()
|
||||
|
||||
# Check data type
|
||||
if not isinstance(data, (pd.DataFrame, list)):
|
||||
error = ChartError(
|
||||
code='INVALID_DATA_TYPE',
|
||||
message=f'Invalid data type for {self.__class__.__name__}: {type(data)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'layer': self.__class__.__name__, 'data_type': str(type(data))},
|
||||
recovery_suggestion='Provide data as pandas DataFrame or list of dictionaries'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
# Check data sufficiency
|
||||
is_sufficient = self.error_handler.validate_data_sufficiency(
|
||||
data,
|
||||
chart_type='candlestick', # Default chart type since LayerConfig doesn't have layer_type
|
||||
indicators=[{'type': 'candlestick', 'parameters': {}}] # Default indicator type
|
||||
)
|
||||
|
||||
self._is_valid = is_sufficient
|
||||
if not is_sufficient:
|
||||
self._error_message = self.error_handler.get_user_friendly_message()
|
||||
|
||||
return is_sufficient
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Data validation error in {self.__class__.__name__}: {e}")
|
||||
error = ChartError(
|
||||
code='VALIDATION_EXCEPTION',
|
||||
message=f'Validation error: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'layer': self.__class__.__name__, 'exception': str(e)}
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
self._is_valid = False
|
||||
self._error_message = str(e)
|
||||
return False
|
||||
|
||||
def get_error_info(self) -> Dict[str, Any]:
|
||||
"""Get error information for this layer"""
|
||||
return {
|
||||
'is_valid': self._is_valid,
|
||||
'error_message': self._error_message,
|
||||
'error_summary': self.error_handler.get_error_summary(),
|
||||
'can_proceed': len(self.error_handler.errors) == 0
|
||||
}
|
||||
|
||||
def create_error_trace(self, error_message: str) -> go.Scatter:
|
||||
"""Create an error display trace"""
|
||||
return go.Scatter(
|
||||
x=[],
|
||||
y=[],
|
||||
mode='text',
|
||||
text=[error_message],
|
||||
textposition='middle center',
|
||||
textfont={'size': 14, 'color': '#e74c3c'},
|
||||
showlegend=False,
|
||||
name=f"{self.__class__.__name__} Error"
|
||||
)
|
||||
|
||||
|
||||
class BaseChartLayer(ABC):
|
||||
"""
|
||||
Abstract base class for all chart layers.
|
||||
|
||||
This defines the interface that all chart layers must implement,
|
||||
whether they are candlestick charts, indicators, or signal overlays.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LayerConfig):
|
||||
"""
|
||||
Initialize the base layer.
|
||||
|
||||
Args:
|
||||
config: Layer configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
|
||||
@abstractmethod
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""
|
||||
Render the layer onto the provided figure.
|
||||
|
||||
Args:
|
||||
fig: Plotly figure to render onto
|
||||
data: Chart data (OHLCV format)
|
||||
**kwargs: Additional rendering parameters
|
||||
|
||||
Returns:
|
||||
Updated figure with layer rendered
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_data(self, data: pd.DataFrame) -> bool:
|
||||
"""
|
||||
Validate that the data is suitable for this layer.
|
||||
|
||||
Args:
|
||||
data: Chart data to validate
|
||||
|
||||
Returns:
|
||||
True if data is valid, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if the layer is enabled."""
|
||||
return self.config.enabled
|
||||
|
||||
def get_subplot_row(self) -> Optional[int]:
|
||||
"""Get the subplot row for this layer."""
|
||||
return self.config.subplot_row
|
||||
|
||||
def is_overlay(self) -> bool:
|
||||
"""Check if this layer is an overlay (main chart) or subplot."""
|
||||
return self.config.subplot_row is None
|
||||
|
||||
|
||||
class CandlestickLayer(BaseLayer):
|
||||
"""
|
||||
Candlestick chart layer implementation with enhanced error handling.
|
||||
|
||||
This layer renders OHLC data as candlesticks on the main chart.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LayerConfig = None):
|
||||
"""
|
||||
Initialize candlestick layer.
|
||||
|
||||
Args:
|
||||
config: Layer configuration (optional, uses defaults)
|
||||
"""
|
||||
if config is None:
|
||||
config = LayerConfig(
|
||||
name="candlestick",
|
||||
enabled=True,
|
||||
style={
|
||||
'increasing_color': '#00C851', # Green for bullish
|
||||
'decreasing_color': '#FF4444', # Red for bearish
|
||||
'line_width': 1
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if the layer is enabled."""
|
||||
return self.config.enabled
|
||||
|
||||
def is_overlay(self) -> bool:
|
||||
"""Check if this layer is an overlay (main chart) or subplot."""
|
||||
return self.config.subplot_row is None
|
||||
|
||||
def get_subplot_row(self) -> Optional[int]:
|
||||
"""Get the subplot row for this layer."""
|
||||
return self.config.subplot_row
|
||||
|
||||
def validate_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]]) -> bool:
|
||||
"""Enhanced validation with comprehensive error handling"""
|
||||
try:
|
||||
# Use parent class error handling for comprehensive validation
|
||||
parent_valid = super().validate_data(data)
|
||||
|
||||
# Convert to DataFrame if needed for local validation
|
||||
if isinstance(data, list):
|
||||
df = pd.DataFrame(data)
|
||||
else:
|
||||
df = data.copy()
|
||||
|
||||
# Additional candlestick-specific validation
|
||||
required_columns = ['timestamp', 'open', 'high', 'low', 'close']
|
||||
|
||||
if not all(col in df.columns for col in required_columns):
|
||||
missing = [col for col in required_columns if col not in df.columns]
|
||||
error = ChartError(
|
||||
code='MISSING_OHLC_COLUMNS',
|
||||
message=f'Missing required OHLC columns: {missing}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'missing_columns': missing, 'available_columns': list(df.columns)},
|
||||
recovery_suggestion='Ensure data contains timestamp, open, high, low, close columns'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
if len(df) == 0:
|
||||
error = ChartError(
|
||||
code='EMPTY_CANDLESTICK_DATA',
|
||||
message='No candlestick data available',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'data_count': 0},
|
||||
recovery_suggestion='Check data source or time range'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
# Check for price data validity
|
||||
invalid_prices = df[
|
||||
(df['high'] < df['low']) |
|
||||
(df['open'] < 0) | (df['close'] < 0) |
|
||||
(df['high'] < 0) | (df['low'] < 0) |
|
||||
pd.isna(df[['open', 'high', 'low', 'close']]).any(axis=1)
|
||||
]
|
||||
|
||||
if len(invalid_prices) > len(df) * 0.5: # More than 50% invalid
|
||||
error = ChartError(
|
||||
code='EXCESSIVE_INVALID_PRICES',
|
||||
message=f'Too many invalid price records: {len(invalid_prices)}/{len(df)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'invalid_count': len(invalid_prices), 'total_count': len(df)},
|
||||
recovery_suggestion='Check data quality and price data sources'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
elif len(invalid_prices) > 0:
|
||||
# Warning for some invalid data
|
||||
error = ChartError(
|
||||
code='SOME_INVALID_PRICES',
|
||||
message=f'Found {len(invalid_prices)} invalid price records (will be filtered)',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'invalid_count': len(invalid_prices), 'total_count': len(df)},
|
||||
recovery_suggestion='Invalid records will be automatically removed'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
|
||||
return parent_valid and len(self.error_handler.errors) == 0
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error validating candlestick data: {e}")
|
||||
error = ChartError(
|
||||
code='CANDLESTICK_VALIDATION_ERROR',
|
||||
message=f'Candlestick validation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'exception': str(e)}
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""
|
||||
Render candlestick chart with error handling and recovery.
|
||||
|
||||
Args:
|
||||
fig: Target figure
|
||||
data: OHLCV data
|
||||
**kwargs: Additional parameters (row, col for subplots)
|
||||
|
||||
Returns:
|
||||
Figure with candlestick trace added or error display
|
||||
"""
|
||||
try:
|
||||
# Validate data
|
||||
if not self.validate_data(data):
|
||||
self.logger.error("Invalid data for candlestick layer")
|
||||
|
||||
# Add error annotation to figure
|
||||
if self.error_handler.errors:
|
||||
error_msg = self.error_handler.errors[0].message
|
||||
fig.add_annotation(create_error_annotation(
|
||||
f"Candlestick Error: {error_msg}",
|
||||
position='center'
|
||||
))
|
||||
return fig
|
||||
|
||||
# Clean and prepare data
|
||||
clean_data = self._clean_candlestick_data(data)
|
||||
if clean_data.empty:
|
||||
fig.add_annotation(create_error_annotation(
|
||||
"No valid candlestick data after cleaning",
|
||||
position='center'
|
||||
))
|
||||
return fig
|
||||
|
||||
# Extract styling
|
||||
style = self.config.style
|
||||
increasing_color = style.get('increasing_color', '#00C851')
|
||||
decreasing_color = style.get('decreasing_color', '#FF4444')
|
||||
|
||||
# Create candlestick trace
|
||||
candlestick = go.Candlestick(
|
||||
x=clean_data['timestamp'],
|
||||
open=clean_data['open'],
|
||||
high=clean_data['high'],
|
||||
low=clean_data['low'],
|
||||
close=clean_data['close'],
|
||||
name=self.config.name,
|
||||
increasing_line_color=increasing_color,
|
||||
decreasing_line_color=decreasing_color,
|
||||
showlegend=False
|
||||
)
|
||||
|
||||
# Add to figure
|
||||
row = kwargs.get('row', 1)
|
||||
col = kwargs.get('col', 1)
|
||||
|
||||
try:
|
||||
if hasattr(fig, 'add_trace') and row == 1 and col == 1:
|
||||
# Simple figure without subplots
|
||||
fig.add_trace(candlestick)
|
||||
elif hasattr(fig, 'add_trace'):
|
||||
# Subplot figure
|
||||
fig.add_trace(candlestick, row=row, col=col)
|
||||
else:
|
||||
# Fallback
|
||||
fig.add_trace(candlestick)
|
||||
except Exception as trace_error:
|
||||
# If subplot call fails, try simple add_trace
|
||||
try:
|
||||
fig.add_trace(candlestick)
|
||||
except Exception as fallback_error:
|
||||
self.logger.error(f"Failed to add candlestick trace: {fallback_error}")
|
||||
fig.add_annotation(create_error_annotation(
|
||||
f"Failed to add candlestick trace: {str(fallback_error)}",
|
||||
position='center'
|
||||
))
|
||||
return fig
|
||||
|
||||
# Add warning annotations if needed
|
||||
if self.error_handler.warnings:
|
||||
warning_msg = f"⚠️ {self.error_handler.warnings[0].message}"
|
||||
fig.add_annotation({
|
||||
'text': warning_msg,
|
||||
'xref': 'paper', 'yref': 'paper',
|
||||
'x': 0.02, 'y': 0.98,
|
||||
'xanchor': 'left', 'yanchor': 'top',
|
||||
'showarrow': False,
|
||||
'font': {'size': 10, 'color': '#f39c12'},
|
||||
'bgcolor': 'rgba(255,255,255,0.8)'
|
||||
})
|
||||
|
||||
self.logger.debug(f"Rendered candlestick layer with {len(clean_data)} candles")
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering candlestick layer: {e}")
|
||||
fig.add_annotation(create_error_annotation(
|
||||
f"Candlestick render error: {str(e)}",
|
||||
position='center'
|
||||
))
|
||||
return fig
|
||||
|
||||
def _clean_candlestick_data(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Clean and validate candlestick data"""
|
||||
try:
|
||||
clean_data = data.copy()
|
||||
|
||||
# Remove rows with invalid prices
|
||||
invalid_mask = (
|
||||
(clean_data['high'] < clean_data['low']) |
|
||||
(clean_data['open'] < 0) | (clean_data['close'] < 0) |
|
||||
(clean_data['high'] < 0) | (clean_data['low'] < 0) |
|
||||
pd.isna(clean_data[['open', 'high', 'low', 'close']]).any(axis=1)
|
||||
)
|
||||
|
||||
initial_count = len(clean_data)
|
||||
clean_data = clean_data[~invalid_mask]
|
||||
|
||||
if len(clean_data) < initial_count:
|
||||
removed_count = initial_count - len(clean_data)
|
||||
self.logger.info(f"Removed {removed_count} invalid candlestick records")
|
||||
|
||||
# Ensure timestamp is properly formatted
|
||||
if not pd.api.types.is_datetime64_any_dtype(clean_data['timestamp']):
|
||||
clean_data['timestamp'] = pd.to_datetime(clean_data['timestamp'])
|
||||
|
||||
# Sort by timestamp
|
||||
clean_data = clean_data.sort_values('timestamp')
|
||||
|
||||
return clean_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning candlestick data: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
class VolumeLayer(BaseLayer):
|
||||
"""
|
||||
Volume subplot layer implementation with enhanced error handling.
|
||||
|
||||
This layer renders volume data as a bar chart in a separate subplot,
|
||||
with bars colored based on price movement.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LayerConfig = None):
|
||||
"""
|
||||
Initialize volume layer.
|
||||
|
||||
Args:
|
||||
config: Layer configuration (optional, uses defaults)
|
||||
"""
|
||||
if config is None:
|
||||
config = LayerConfig(
|
||||
name="volume",
|
||||
enabled=True,
|
||||
subplot_row=2, # Volume goes in second row by default
|
||||
style={
|
||||
'bullish_color': '#00C851',
|
||||
'bearish_color': '#FF4444',
|
||||
'opacity': 0.7
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if the layer is enabled."""
|
||||
return self.config.enabled
|
||||
|
||||
def is_overlay(self) -> bool:
|
||||
"""Check if this layer is an overlay (main chart) or subplot."""
|
||||
return self.config.subplot_row is None
|
||||
|
||||
def get_subplot_row(self) -> Optional[int]:
|
||||
"""Get the subplot row for this layer."""
|
||||
return self.config.subplot_row
|
||||
|
||||
def validate_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]]) -> bool:
|
||||
"""Enhanced validation with comprehensive error handling"""
|
||||
try:
|
||||
# Use parent class error handling
|
||||
parent_valid = super().validate_data(data)
|
||||
|
||||
# Convert to DataFrame if needed
|
||||
if isinstance(data, list):
|
||||
df = pd.DataFrame(data)
|
||||
else:
|
||||
df = data.copy()
|
||||
|
||||
# Volume-specific validation
|
||||
required_columns = ['timestamp', 'open', 'close', 'volume']
|
||||
|
||||
if not all(col in df.columns for col in required_columns):
|
||||
missing = [col for col in required_columns if col not in df.columns]
|
||||
error = ChartError(
|
||||
code='MISSING_VOLUME_COLUMNS',
|
||||
message=f'Missing required volume columns: {missing}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'missing_columns': missing, 'available_columns': list(df.columns)},
|
||||
recovery_suggestion='Ensure data contains timestamp, open, close, volume columns'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
if len(df) == 0:
|
||||
error = ChartError(
|
||||
code='EMPTY_VOLUME_DATA',
|
||||
message='No volume data available',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'data_count': 0},
|
||||
recovery_suggestion='Check data source or time range'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
# Check if volume data exists and is valid
|
||||
valid_volume_mask = (df['volume'] >= 0) & pd.notna(df['volume'])
|
||||
valid_volume_count = valid_volume_mask.sum()
|
||||
|
||||
if valid_volume_count == 0:
|
||||
error = ChartError(
|
||||
code='NO_VALID_VOLUME',
|
||||
message='No valid volume data found',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'total_records': len(df), 'valid_volume': 0},
|
||||
recovery_suggestion='Volume chart will be skipped'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
|
||||
elif valid_volume_count < len(df) * 0.5: # Less than 50% valid
|
||||
error = ChartError(
|
||||
code='MOSTLY_INVALID_VOLUME',
|
||||
message=f'Most volume data is invalid: {valid_volume_count}/{len(df)} valid',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'total_records': len(df), 'valid_volume': valid_volume_count},
|
||||
recovery_suggestion='Invalid volume records will be filtered out'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
|
||||
elif df['volume'].sum() <= 0:
|
||||
error = ChartError(
|
||||
code='ZERO_VOLUME_TOTAL',
|
||||
message='Total volume is zero or negative',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'volume_sum': float(df['volume'].sum())},
|
||||
recovery_suggestion='Volume chart may not be meaningful'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
|
||||
return parent_valid and valid_volume_count > 0
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error validating volume data: {e}")
|
||||
error = ChartError(
|
||||
code='VOLUME_VALIDATION_ERROR',
|
||||
message=f'Volume validation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'exception': str(e)}
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""
|
||||
Render volume bars with error handling and recovery.
|
||||
|
||||
Args:
|
||||
fig: Target figure (must be subplot figure)
|
||||
data: OHLCV data
|
||||
**kwargs: Additional parameters (row, col for subplots)
|
||||
|
||||
Returns:
|
||||
Figure with volume trace added or error handling
|
||||
"""
|
||||
try:
|
||||
# Validate data
|
||||
if not self.validate_data(data):
|
||||
# Check if we can skip gracefully (warnings only)
|
||||
if not self.error_handler.errors and self.error_handler.warnings:
|
||||
self.logger.debug("Skipping volume layer due to warnings")
|
||||
return fig
|
||||
else:
|
||||
self.logger.error("Invalid data for volume layer")
|
||||
return fig
|
||||
|
||||
# Clean and prepare data
|
||||
clean_data = self._clean_volume_data(data)
|
||||
if clean_data.empty:
|
||||
self.logger.debug("No valid volume data after cleaning")
|
||||
return fig
|
||||
|
||||
# Calculate bar colors based on price movement
|
||||
style = self.config.style
|
||||
bullish_color = style.get('bullish_color', '#00C851')
|
||||
bearish_color = style.get('bearish_color', '#FF4444')
|
||||
opacity = style.get('opacity', 0.7)
|
||||
|
||||
colors = [
|
||||
bullish_color if close >= open_price else bearish_color
|
||||
for close, open_price in zip(clean_data['close'], clean_data['open'])
|
||||
]
|
||||
|
||||
# Create volume bar trace
|
||||
volume_bars = go.Bar(
|
||||
x=clean_data['timestamp'],
|
||||
y=clean_data['volume'],
|
||||
name='Volume',
|
||||
marker_color=colors,
|
||||
opacity=opacity,
|
||||
showlegend=False
|
||||
)
|
||||
|
||||
# Add to figure
|
||||
row = kwargs.get('row', 2) # Default to row 2 for volume
|
||||
col = kwargs.get('col', 1)
|
||||
|
||||
fig.add_trace(volume_bars, row=row, col=col)
|
||||
|
||||
self.logger.debug(f"Rendered volume layer with {len(clean_data)} bars")
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering volume layer: {e}")
|
||||
return fig
|
||||
|
||||
def _clean_volume_data(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Clean and validate volume data"""
|
||||
try:
|
||||
clean_data = data.copy()
|
||||
|
||||
# Remove rows with invalid volume
|
||||
valid_mask = (clean_data['volume'] >= 0) & pd.notna(clean_data['volume'])
|
||||
initial_count = len(clean_data)
|
||||
clean_data = clean_data[valid_mask]
|
||||
|
||||
if len(clean_data) < initial_count:
|
||||
removed_count = initial_count - len(clean_data)
|
||||
self.logger.info(f"Removed {removed_count} invalid volume records")
|
||||
|
||||
# Ensure timestamp is properly formatted
|
||||
if not pd.api.types.is_datetime64_any_dtype(clean_data['timestamp']):
|
||||
clean_data['timestamp'] = pd.to_datetime(clean_data['timestamp'])
|
||||
|
||||
# Sort by timestamp
|
||||
clean_data = clean_data.sort_values('timestamp')
|
||||
|
||||
return clean_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning volume data: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
class LayerManager:
|
||||
"""
|
||||
Manager class for coordinating multiple chart layers.
|
||||
|
||||
This class handles the orchestration of multiple layers, including
|
||||
setting up subplots and rendering layers in the correct order.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the layer manager."""
|
||||
self.layers: List[BaseLayer] = []
|
||||
self.logger = logger
|
||||
|
||||
def add_layer(self, layer: BaseLayer) -> None:
|
||||
"""
|
||||
Add a layer to the manager.
|
||||
|
||||
Args:
|
||||
layer: Chart layer to add
|
||||
"""
|
||||
self.layers.append(layer)
|
||||
self.logger.debug(f"Added layer: {layer.config.name}")
|
||||
|
||||
def remove_layer(self, layer_name: str) -> bool:
|
||||
"""
|
||||
Remove a layer by name.
|
||||
|
||||
Args:
|
||||
layer_name: Name of layer to remove
|
||||
|
||||
Returns:
|
||||
True if layer was removed, False if not found
|
||||
"""
|
||||
for i, layer in enumerate(self.layers):
|
||||
if layer.config.name == layer_name:
|
||||
self.layers.pop(i)
|
||||
self.logger.debug(f"Removed layer: {layer_name}")
|
||||
return True
|
||||
|
||||
self.logger.warning(f"Layer not found for removal: {layer_name}")
|
||||
return False
|
||||
|
||||
def get_enabled_layers(self) -> List[BaseLayer]:
|
||||
"""Get list of enabled layers."""
|
||||
return [layer for layer in self.layers if layer.is_enabled()]
|
||||
|
||||
def get_overlay_layers(self) -> List[BaseLayer]:
|
||||
"""Get layers that render on the main chart."""
|
||||
return [layer for layer in self.get_enabled_layers() if layer.is_overlay()]
|
||||
|
||||
def get_subplot_layers(self) -> Dict[int, List[BaseLayer]]:
|
||||
"""Get layers grouped by subplot row."""
|
||||
subplot_layers = {}
|
||||
|
||||
for layer in self.get_enabled_layers():
|
||||
if not layer.is_overlay():
|
||||
row = layer.get_subplot_row()
|
||||
if row not in subplot_layers:
|
||||
subplot_layers[row] = []
|
||||
subplot_layers[row].append(layer)
|
||||
|
||||
return subplot_layers
|
||||
|
||||
def calculate_subplot_layout(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate subplot configuration based on layers.
|
||||
|
||||
Returns:
|
||||
Dict with subplot configuration parameters
|
||||
"""
|
||||
subplot_layers = self.get_subplot_layers()
|
||||
|
||||
if not subplot_layers:
|
||||
# No subplots needed
|
||||
return {
|
||||
'rows': 1,
|
||||
'cols': 1,
|
||||
'subplot_titles': None,
|
||||
'row_heights': None
|
||||
}
|
||||
|
||||
# Reassign subplot rows dynamically to ensure proper ordering
|
||||
self._reassign_subplot_rows()
|
||||
|
||||
# Recalculate after reassignment
|
||||
subplot_layers = self.get_subplot_layers()
|
||||
|
||||
# Calculate number of rows (main chart + subplots)
|
||||
max_subplot_row = max(subplot_layers.keys()) if subplot_layers else 0
|
||||
total_rows = max(1, max_subplot_row) # Row numbers are 1-indexed, so max_subplot_row is the total rows needed
|
||||
|
||||
# Create subplot titles
|
||||
subplot_titles = ['Price'] # Main chart
|
||||
for row in range(2, total_rows + 1):
|
||||
if row in subplot_layers:
|
||||
# Use the first layer's name as the subtitle
|
||||
layer_names = [layer.config.name for layer in subplot_layers[row]]
|
||||
subplot_titles.append(' / '.join(layer_names).title())
|
||||
else:
|
||||
subplot_titles.append(f'Subplot {row}')
|
||||
|
||||
# Calculate row heights based on subplot height ratios
|
||||
row_heights = self._calculate_dynamic_row_heights(subplot_layers, total_rows)
|
||||
|
||||
return {
|
||||
'rows': total_rows,
|
||||
'cols': 1,
|
||||
'subplot_titles': subplot_titles,
|
||||
'row_heights': row_heights,
|
||||
'shared_xaxes': True,
|
||||
'vertical_spacing': 0.03
|
||||
}
|
||||
|
||||
def _reassign_subplot_rows(self) -> None:
|
||||
"""
|
||||
Reassign subplot rows to ensure proper sequential ordering.
|
||||
|
||||
This method dynamically assigns subplot rows starting from row 2,
|
||||
ensuring no gaps in the subplot layout.
|
||||
"""
|
||||
subplot_layers = []
|
||||
|
||||
# Collect all subplot layers
|
||||
for layer in self.get_enabled_layers():
|
||||
if not layer.is_overlay():
|
||||
subplot_layers.append(layer)
|
||||
|
||||
# Sort by priority: volume first, then by current subplot row
|
||||
def layer_priority(layer):
|
||||
# Volume gets highest priority (0), then by current row
|
||||
if hasattr(layer, 'config') and layer.config.name == 'volume':
|
||||
return (0, layer.get_subplot_row() or 999)
|
||||
else:
|
||||
return (1, layer.get_subplot_row() or 999)
|
||||
|
||||
subplot_layers.sort(key=layer_priority)
|
||||
|
||||
# Reassign rows starting from 2
|
||||
for i, layer in enumerate(subplot_layers):
|
||||
new_row = i + 2 # Start from row 2 (row 1 is main chart)
|
||||
layer.config.subplot_row = new_row
|
||||
self.logger.debug(f"Assigned {layer.config.name} to subplot row {new_row}")
|
||||
|
||||
def _calculate_dynamic_row_heights(self, subplot_layers: Dict[int, List], total_rows: int) -> List[float]:
|
||||
"""
|
||||
Calculate row heights based on subplot height ratios.
|
||||
|
||||
Args:
|
||||
subplot_layers: Dictionary of subplot layers by row
|
||||
total_rows: Total number of rows
|
||||
|
||||
Returns:
|
||||
List of height ratios for each row
|
||||
"""
|
||||
if total_rows == 1:
|
||||
return [1.0] # Single row gets full height
|
||||
|
||||
# Calculate total requested subplot height
|
||||
total_subplot_ratio = 0.0
|
||||
subplot_ratios = {}
|
||||
|
||||
for row in range(2, total_rows + 1):
|
||||
if row in subplot_layers:
|
||||
# Get height ratio from first layer in the row
|
||||
layer = subplot_layers[row][0]
|
||||
if hasattr(layer, 'get_subplot_height_ratio'):
|
||||
ratio = layer.get_subplot_height_ratio()
|
||||
else:
|
||||
ratio = 0.25 # Default ratio
|
||||
subplot_ratios[row] = ratio
|
||||
total_subplot_ratio += ratio
|
||||
else:
|
||||
subplot_ratios[row] = 0.25 # Default for empty rows
|
||||
total_subplot_ratio += 0.25
|
||||
|
||||
# Ensure total doesn't exceed reasonable limits
|
||||
max_subplot_ratio = 0.6 # Maximum 60% for all subplots
|
||||
if total_subplot_ratio > max_subplot_ratio:
|
||||
# Scale down proportionally
|
||||
scale_factor = max_subplot_ratio / total_subplot_ratio
|
||||
for row in subplot_ratios:
|
||||
subplot_ratios[row] *= scale_factor
|
||||
total_subplot_ratio = max_subplot_ratio
|
||||
|
||||
# Main chart gets remaining space
|
||||
main_chart_ratio = 1.0 - total_subplot_ratio
|
||||
|
||||
# Build final height list
|
||||
row_heights = [main_chart_ratio] # Main chart
|
||||
for row in range(2, total_rows + 1):
|
||||
row_heights.append(subplot_ratios.get(row, 0.25))
|
||||
|
||||
return row_heights
|
||||
|
||||
def render_all_layers(self, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""
|
||||
Render all enabled layers onto a new figure.
|
||||
|
||||
Args:
|
||||
data: Chart data (OHLCV format)
|
||||
**kwargs: Additional rendering parameters
|
||||
|
||||
Returns:
|
||||
Complete figure with all layers rendered
|
||||
"""
|
||||
try:
|
||||
# Calculate subplot layout
|
||||
layout_config = self.calculate_subplot_layout()
|
||||
|
||||
# Create figure with subplots if needed
|
||||
if layout_config['rows'] > 1:
|
||||
fig = make_subplots(**layout_config)
|
||||
else:
|
||||
fig = go.Figure()
|
||||
|
||||
# Render overlay layers (main chart)
|
||||
overlay_layers = self.get_overlay_layers()
|
||||
for layer in overlay_layers:
|
||||
fig = layer.render(fig, data, row=1, col=1, **kwargs)
|
||||
|
||||
# Render subplot layers
|
||||
subplot_layers = self.get_subplot_layers()
|
||||
for row, layers in subplot_layers.items():
|
||||
for layer in layers:
|
||||
fig = layer.render(fig, data, row=row, col=1, **kwargs)
|
||||
|
||||
# Update layout styling
|
||||
self._apply_layout_styling(fig, layout_config)
|
||||
|
||||
self.logger.debug(f"Rendered {len(self.get_enabled_layers())} layers")
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering layers: {e}")
|
||||
# Return empty figure on error
|
||||
return go.Figure()
|
||||
|
||||
def _apply_layout_styling(self, fig: go.Figure, layout_config: Dict[str, Any]) -> None:
|
||||
"""Apply consistent styling to the figure layout."""
|
||||
try:
|
||||
# Basic layout settings
|
||||
fig.update_layout(
|
||||
template="plotly_white",
|
||||
showlegend=False,
|
||||
hovermode='x unified',
|
||||
xaxis_rangeslider_visible=False
|
||||
)
|
||||
|
||||
# Update axes for subplots
|
||||
if layout_config['rows'] > 1:
|
||||
# Update main chart axes
|
||||
fig.update_yaxes(title_text="Price (USDT)", row=1, col=1)
|
||||
fig.update_xaxes(showticklabels=False, row=1, col=1)
|
||||
|
||||
# Update subplot axes
|
||||
subplot_layers = self.get_subplot_layers()
|
||||
for row in range(2, layout_config['rows'] + 1):
|
||||
if row in subplot_layers:
|
||||
# Set y-axis title and range based on layer type
|
||||
layers_in_row = subplot_layers[row]
|
||||
layer = layers_in_row[0] # Use first layer for configuration
|
||||
|
||||
# Set y-axis title
|
||||
if hasattr(layer, 'config') and hasattr(layer.config, 'indicator_type'):
|
||||
indicator_type = layer.config.indicator_type
|
||||
if indicator_type == 'rsi':
|
||||
fig.update_yaxes(title_text="RSI", row=row, col=1)
|
||||
elif indicator_type == 'macd':
|
||||
fig.update_yaxes(title_text="MACD", row=row, col=1)
|
||||
else:
|
||||
layer_names = [l.config.name for l in layers_in_row]
|
||||
fig.update_yaxes(title_text=' / '.join(layer_names), row=row, col=1)
|
||||
|
||||
# Set fixed y-axis range if specified
|
||||
if hasattr(layer, 'has_fixed_range') and layer.has_fixed_range():
|
||||
y_range = layer.get_y_axis_range()
|
||||
if y_range:
|
||||
fig.update_yaxes(range=list(y_range), row=row, col=1)
|
||||
|
||||
# Only show x-axis labels on the bottom subplot
|
||||
if row == layout_config['rows']:
|
||||
fig.update_xaxes(title_text="Time", row=row, col=1)
|
||||
else:
|
||||
fig.update_xaxes(showticklabels=False, row=row, col=1)
|
||||
else:
|
||||
# Single chart
|
||||
fig.update_layout(
|
||||
xaxis_title="Time",
|
||||
yaxis_title="Price (USDT)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error applying layout styling: {e}")
|
||||
720
components/charts/layers/indicators.py
Normal file
720
components/charts/layers/indicators.py
Normal file
@ -0,0 +1,720 @@
|
||||
"""
|
||||
Technical Indicator Chart Layers
|
||||
|
||||
This module implements overlay indicator layers for technical analysis visualization
|
||||
including SMA, EMA, and Bollinger Bands with comprehensive error handling.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
from typing import Dict, Any, Optional, List, Union, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..error_handling import (
|
||||
ChartErrorHandler, ChartError, ErrorSeverity, DataRequirements,
|
||||
InsufficientDataError, DataValidationError, IndicatorCalculationError,
|
||||
ErrorRecoveryStrategies, create_error_annotation, get_error_message
|
||||
)
|
||||
|
||||
from .base import BaseLayer, LayerConfig
|
||||
from data.common.indicators import TechnicalIndicators, OHLCVCandle
|
||||
from components.charts.utils import get_indicator_colors
|
||||
from utils.logger import get_logger
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger("chart_indicators")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndicatorLayerConfig(LayerConfig):
|
||||
"""Extended configuration for indicator layers"""
|
||||
indicator_type: str = "" # e.g., 'sma', 'ema', 'rsi'
|
||||
parameters: Dict[str, Any] = None # Indicator-specific parameters
|
||||
line_width: int = 2
|
||||
opacity: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.parameters is None:
|
||||
self.parameters = {}
|
||||
|
||||
|
||||
class BaseIndicatorLayer(BaseLayer):
|
||||
"""
|
||||
Enhanced base class for all indicator layers with comprehensive error handling.
|
||||
"""
|
||||
|
||||
def __init__(self, config: IndicatorLayerConfig):
|
||||
"""
|
||||
Initialize base indicator layer.
|
||||
|
||||
Args:
|
||||
config: Indicator layer configuration
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.indicators = TechnicalIndicators()
|
||||
self.colors = get_indicator_colors()
|
||||
self.calculated_data = None
|
||||
self.calculation_errors = []
|
||||
|
||||
def prepare_indicator_data(self, data: pd.DataFrame) -> List[OHLCVCandle]:
|
||||
"""
|
||||
Convert DataFrame to OHLCVCandle format for indicator calculations.
|
||||
|
||||
Args:
|
||||
data: Chart data (OHLCV format)
|
||||
|
||||
Returns:
|
||||
List of OHLCVCandle objects
|
||||
"""
|
||||
try:
|
||||
candles = []
|
||||
for _, row in data.iterrows():
|
||||
# Calculate start_time (assuming 1-minute candles for now)
|
||||
start_time = row['timestamp']
|
||||
end_time = row['timestamp']
|
||||
|
||||
candle = OHLCVCandle(
|
||||
symbol="BTCUSDT", # Default symbol for testing
|
||||
timeframe="1m", # Default timeframe
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
open=Decimal(str(row['open'])),
|
||||
high=Decimal(str(row['high'])),
|
||||
low=Decimal(str(row['low'])),
|
||||
close=Decimal(str(row['close'])),
|
||||
volume=Decimal(str(row.get('volume', 0))),
|
||||
trade_count=1, # Default trade count
|
||||
exchange="test", # Test exchange
|
||||
is_complete=True # Mark as complete for testing
|
||||
)
|
||||
candles.append(candle)
|
||||
|
||||
return candles
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error preparing indicator data: {e}")
|
||||
return []
|
||||
|
||||
def validate_indicator_data(self, data: Union[pd.DataFrame, List[Dict[str, Any]]],
|
||||
required_columns: List[str] = None) -> bool:
|
||||
"""
|
||||
Validate data specifically for indicator calculations.
|
||||
|
||||
Args:
|
||||
data: Input data
|
||||
required_columns: Required columns for this indicator
|
||||
|
||||
Returns:
|
||||
True if data is valid for indicator calculation
|
||||
"""
|
||||
try:
|
||||
# Use parent validation first
|
||||
if not super().validate_data(data):
|
||||
return False
|
||||
|
||||
# Convert to DataFrame if needed
|
||||
if isinstance(data, list):
|
||||
df = pd.DataFrame(data)
|
||||
else:
|
||||
df = data.copy()
|
||||
|
||||
# Check required columns for indicator
|
||||
if required_columns:
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
if missing_columns:
|
||||
error = ChartError(
|
||||
code='MISSING_INDICATOR_COLUMNS',
|
||||
message=f'Missing columns for {self.config.indicator_type}: {missing_columns}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={
|
||||
'indicator_type': self.config.indicator_type,
|
||||
'missing_columns': missing_columns,
|
||||
'available_columns': list(df.columns)
|
||||
},
|
||||
recovery_suggestion=f'Ensure data contains required columns: {required_columns}'
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
# Check data sufficiency for indicator
|
||||
indicator_config = {
|
||||
'type': self.config.indicator_type,
|
||||
'parameters': self.config.parameters or {}
|
||||
}
|
||||
|
||||
indicator_error = DataRequirements.check_indicator_requirements(
|
||||
self.config.indicator_type,
|
||||
len(df),
|
||||
self.config.parameters or {}
|
||||
)
|
||||
|
||||
if indicator_error.severity == ErrorSeverity.WARNING:
|
||||
self.error_handler.warnings.append(indicator_error)
|
||||
elif indicator_error.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL]:
|
||||
self.error_handler.errors.append(indicator_error)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error validating indicator data: {e}")
|
||||
error = ChartError(
|
||||
code='INDICATOR_VALIDATION_ERROR',
|
||||
message=f'Indicator validation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'exception': str(e), 'indicator_type': self.config.indicator_type}
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return False
|
||||
|
||||
def safe_calculate_indicator(self, data: pd.DataFrame,
|
||||
calculation_func: Callable,
|
||||
**kwargs) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Safely calculate indicator with error handling.
|
||||
|
||||
Args:
|
||||
data: Input data
|
||||
calculation_func: Function to calculate indicator
|
||||
**kwargs: Additional arguments for calculation
|
||||
|
||||
Returns:
|
||||
Calculated indicator data or None if failed
|
||||
"""
|
||||
try:
|
||||
# Validate data first
|
||||
if not self.validate_indicator_data(data):
|
||||
return None
|
||||
|
||||
# Try calculation with recovery strategies
|
||||
result = calculation_func(data, **kwargs)
|
||||
|
||||
# Validate result
|
||||
if result is None or (isinstance(result, pd.DataFrame) and result.empty):
|
||||
error = ChartError(
|
||||
code='EMPTY_INDICATOR_RESULT',
|
||||
message=f'Indicator calculation returned no data: {self.config.indicator_type}',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'indicator_type': self.config.indicator_type, 'input_length': len(data)},
|
||||
recovery_suggestion='Check calculation parameters or input data range'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
return None
|
||||
|
||||
# Check for sufficient calculated data
|
||||
if isinstance(result, pd.DataFrame) and len(result) < len(data) * 0.1:
|
||||
error = ChartError(
|
||||
code='INSUFFICIENT_INDICATOR_OUTPUT',
|
||||
message=f'Very few indicator values calculated: {len(result)}/{len(data)}',
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={
|
||||
'indicator_type': self.config.indicator_type,
|
||||
'output_length': len(result),
|
||||
'input_length': len(data)
|
||||
},
|
||||
recovery_suggestion='Consider adjusting indicator parameters'
|
||||
)
|
||||
self.error_handler.warnings.append(error)
|
||||
|
||||
self.calculated_data = result
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error calculating {self.config.indicator_type}: {e}")
|
||||
|
||||
# Try to apply error recovery
|
||||
recovery_strategy = ErrorRecoveryStrategies.handle_insufficient_data(
|
||||
ChartError(
|
||||
code='INDICATOR_CALCULATION_ERROR',
|
||||
message=f'Calculation failed for {self.config.indicator_type}: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'exception': str(e), 'indicator_type': self.config.indicator_type}
|
||||
),
|
||||
fallback_options={'data_length': len(data)}
|
||||
)
|
||||
|
||||
if recovery_strategy['can_proceed'] and recovery_strategy['fallback_action'] == 'adjust_parameters':
|
||||
# Try with adjusted parameters
|
||||
try:
|
||||
modified_config = recovery_strategy.get('modified_config', {})
|
||||
self.logger.info(f"Retrying indicator calculation with adjusted parameters: {modified_config}")
|
||||
|
||||
# Update parameters temporarily
|
||||
original_params = self.config.parameters.copy() if self.config.parameters else {}
|
||||
self.config.parameters.update(modified_config)
|
||||
|
||||
# Retry calculation
|
||||
result = calculation_func(data, **kwargs)
|
||||
|
||||
# Restore original parameters
|
||||
self.config.parameters = original_params
|
||||
|
||||
if result is not None and not (isinstance(result, pd.DataFrame) and result.empty):
|
||||
# Add warning about parameter adjustment
|
||||
warning = ChartError(
|
||||
code='INDICATOR_PARAMETERS_ADJUSTED',
|
||||
message=recovery_strategy['user_message'],
|
||||
severity=ErrorSeverity.WARNING,
|
||||
context={'original_params': original_params, 'adjusted_params': modified_config}
|
||||
)
|
||||
self.error_handler.warnings.append(warning)
|
||||
self.calculated_data = result
|
||||
return result
|
||||
|
||||
except Exception as retry_error:
|
||||
self.logger.error(f"Retry with adjusted parameters also failed: {retry_error}")
|
||||
|
||||
# Final error if all recovery attempts fail
|
||||
error = ChartError(
|
||||
code='INDICATOR_CALCULATION_FAILED',
|
||||
message=f'Failed to calculate {self.config.indicator_type}: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'exception': str(e), 'indicator_type': self.config.indicator_type}
|
||||
)
|
||||
self.error_handler.errors.append(error)
|
||||
return None
|
||||
|
||||
def create_indicator_traces(self, data: pd.DataFrame, subplot_row: int = 1) -> List[go.Scatter]:
|
||||
"""
|
||||
Create indicator traces with error handling.
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement create_indicator_traces")
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if the layer is enabled."""
|
||||
return self.config.enabled
|
||||
|
||||
def is_overlay(self) -> bool:
|
||||
"""Check if this layer is an overlay (main chart) or subplot."""
|
||||
return self.config.subplot_row is None
|
||||
|
||||
def get_subplot_row(self) -> Optional[int]:
|
||||
"""Get the subplot row for this layer."""
|
||||
return self.config.subplot_row
|
||||
|
||||
|
||||
class SMALayer(BaseIndicatorLayer):
|
||||
"""Simple Moving Average layer with enhanced error handling"""
|
||||
|
||||
def __init__(self, config: IndicatorLayerConfig = None):
|
||||
"""Initialize SMA layer"""
|
||||
if config is None:
|
||||
config = IndicatorLayerConfig(
|
||||
indicator_type='sma',
|
||||
parameters={'period': 20}
|
||||
)
|
||||
super().__init__(config)
|
||||
|
||||
def create_traces(self, data: List[Dict[str, Any]], subplot_row: int = 1) -> List[go.Scatter]:
|
||||
"""Create SMA traces with comprehensive error handling"""
|
||||
try:
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(data) if isinstance(data, list) else data.copy()
|
||||
|
||||
# Validate data
|
||||
if not self.validate_indicator_data(df, required_columns=['close', 'timestamp']):
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"SMA Error: {self._error_message}")]
|
||||
|
||||
# Calculate SMA with error handling
|
||||
period = self.config.parameters.get('period', 20)
|
||||
sma_data = self.safe_calculate_indicator(
|
||||
df,
|
||||
self._calculate_sma,
|
||||
period=period
|
||||
)
|
||||
|
||||
if sma_data is None:
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"SMA calculation failed")]
|
||||
else:
|
||||
return [] # Skip layer gracefully
|
||||
|
||||
# Create trace
|
||||
sma_trace = go.Scatter(
|
||||
x=sma_data['timestamp'],
|
||||
y=sma_data['sma'],
|
||||
mode='lines',
|
||||
name=f'SMA({period})',
|
||||
line=dict(
|
||||
color=self.config.color or '#2196F3',
|
||||
width=self.config.line_width
|
||||
),
|
||||
row=subplot_row,
|
||||
col=1
|
||||
)
|
||||
|
||||
self.traces = [sma_trace]
|
||||
return self.traces
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error creating SMA traces: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
return [self.create_error_trace(error_msg)]
|
||||
|
||||
def _calculate_sma(self, data: pd.DataFrame, period: int) -> pd.DataFrame:
|
||||
"""Calculate SMA with validation"""
|
||||
try:
|
||||
result_df = data.copy()
|
||||
result_df['sma'] = result_df['close'].rolling(window=period, min_periods=period).mean()
|
||||
|
||||
# Remove NaN values
|
||||
result_df = result_df.dropna(subset=['sma'])
|
||||
|
||||
if result_df.empty:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='SMA_NO_VALUES',
|
||||
message=f'SMA calculation produced no values (period={period}, data_length={len(data)})',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'data_length': len(data)}
|
||||
))
|
||||
|
||||
return result_df[['timestamp', 'sma']]
|
||||
|
||||
except Exception as e:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='SMA_CALCULATION_ERROR',
|
||||
message=f'SMA calculation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'data_length': len(data), 'exception': str(e)}
|
||||
))
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""Render SMA layer for compatibility with base interface"""
|
||||
try:
|
||||
traces = self.create_traces(data.to_dict('records'), **kwargs)
|
||||
for trace in traces:
|
||||
if hasattr(fig, 'add_trace'):
|
||||
fig.add_trace(trace, **kwargs)
|
||||
else:
|
||||
fig.add_trace(trace)
|
||||
return fig
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering SMA layer: {e}")
|
||||
return fig
|
||||
|
||||
|
||||
class EMALayer(BaseIndicatorLayer):
|
||||
"""Exponential Moving Average layer with enhanced error handling"""
|
||||
|
||||
def __init__(self, config: IndicatorLayerConfig = None):
|
||||
"""Initialize EMA layer"""
|
||||
if config is None:
|
||||
config = IndicatorLayerConfig(
|
||||
indicator_type='ema',
|
||||
parameters={'period': 20}
|
||||
)
|
||||
super().__init__(config)
|
||||
|
||||
def create_traces(self, data: List[Dict[str, Any]], subplot_row: int = 1) -> List[go.Scatter]:
|
||||
"""Create EMA traces with comprehensive error handling"""
|
||||
try:
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(data) if isinstance(data, list) else data.copy()
|
||||
|
||||
# Validate data
|
||||
if not self.validate_indicator_data(df, required_columns=['close', 'timestamp']):
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"EMA Error: {self._error_message}")]
|
||||
|
||||
# Calculate EMA with error handling
|
||||
period = self.config.parameters.get('period', 20)
|
||||
ema_data = self.safe_calculate_indicator(
|
||||
df,
|
||||
self._calculate_ema,
|
||||
period=period
|
||||
)
|
||||
|
||||
if ema_data is None:
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"EMA calculation failed")]
|
||||
else:
|
||||
return [] # Skip layer gracefully
|
||||
|
||||
# Create trace
|
||||
ema_trace = go.Scatter(
|
||||
x=ema_data['timestamp'],
|
||||
y=ema_data['ema'],
|
||||
mode='lines',
|
||||
name=f'EMA({period})',
|
||||
line=dict(
|
||||
color=self.config.color or '#FF9800',
|
||||
width=self.config.line_width
|
||||
),
|
||||
row=subplot_row,
|
||||
col=1
|
||||
)
|
||||
|
||||
self.traces = [ema_trace]
|
||||
return self.traces
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error creating EMA traces: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
return [self.create_error_trace(error_msg)]
|
||||
|
||||
def _calculate_ema(self, data: pd.DataFrame, period: int) -> pd.DataFrame:
|
||||
"""Calculate EMA with validation"""
|
||||
try:
|
||||
result_df = data.copy()
|
||||
result_df['ema'] = result_df['close'].ewm(span=period, adjust=False).mean()
|
||||
|
||||
# For EMA, we can start from the first value, but remove obvious outliers
|
||||
# Skip first few values for stability
|
||||
warmup_period = max(1, period // 4)
|
||||
result_df = result_df.iloc[warmup_period:]
|
||||
|
||||
if result_df.empty:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='EMA_NO_VALUES',
|
||||
message=f'EMA calculation produced no values (period={period}, data_length={len(data)})',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'data_length': len(data)}
|
||||
))
|
||||
|
||||
return result_df[['timestamp', 'ema']]
|
||||
|
||||
except Exception as e:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='EMA_CALCULATION_ERROR',
|
||||
message=f'EMA calculation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'data_length': len(data), 'exception': str(e)}
|
||||
))
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""Render EMA layer for compatibility with base interface"""
|
||||
try:
|
||||
traces = self.create_traces(data.to_dict('records'), **kwargs)
|
||||
for trace in traces:
|
||||
if hasattr(fig, 'add_trace'):
|
||||
fig.add_trace(trace, **kwargs)
|
||||
else:
|
||||
fig.add_trace(trace)
|
||||
return fig
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering EMA layer: {e}")
|
||||
return fig
|
||||
|
||||
|
||||
class BollingerBandsLayer(BaseIndicatorLayer):
|
||||
"""Bollinger Bands layer with enhanced error handling"""
|
||||
|
||||
def __init__(self, config: IndicatorLayerConfig = None):
|
||||
"""Initialize Bollinger Bands layer"""
|
||||
if config is None:
|
||||
config = IndicatorLayerConfig(
|
||||
indicator_type='bollinger_bands',
|
||||
parameters={'period': 20, 'std_dev': 2},
|
||||
show_middle_line=True
|
||||
)
|
||||
super().__init__(config)
|
||||
|
||||
def create_traces(self, data: List[Dict[str, Any]], subplot_row: int = 1) -> List[go.Scatter]:
|
||||
"""Create Bollinger Bands traces with comprehensive error handling"""
|
||||
try:
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(data) if isinstance(data, list) else data.copy()
|
||||
|
||||
# Validate data
|
||||
if not self.validate_indicator_data(df, required_columns=['close', 'timestamp']):
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"Bollinger Bands Error: {self._error_message}")]
|
||||
|
||||
# Calculate Bollinger Bands with error handling
|
||||
period = self.config.parameters.get('period', 20)
|
||||
std_dev = self.config.parameters.get('std_dev', 2)
|
||||
|
||||
bb_data = self.safe_calculate_indicator(
|
||||
df,
|
||||
self._calculate_bollinger_bands,
|
||||
period=period,
|
||||
std_dev=std_dev
|
||||
)
|
||||
|
||||
if bb_data is None:
|
||||
if self.error_handler.errors:
|
||||
return [self.create_error_trace(f"Bollinger Bands calculation failed")]
|
||||
else:
|
||||
return [] # Skip layer gracefully
|
||||
|
||||
# Create traces
|
||||
traces = []
|
||||
|
||||
# Upper band
|
||||
upper_trace = go.Scatter(
|
||||
x=bb_data['timestamp'],
|
||||
y=bb_data['upper_band'],
|
||||
mode='lines',
|
||||
name=f'BB Upper({period})',
|
||||
line=dict(color=self.config.color or '#9C27B0', width=1),
|
||||
row=subplot_row,
|
||||
col=1,
|
||||
showlegend=True
|
||||
)
|
||||
traces.append(upper_trace)
|
||||
|
||||
# Lower band with fill
|
||||
lower_trace = go.Scatter(
|
||||
x=bb_data['timestamp'],
|
||||
y=bb_data['lower_band'],
|
||||
mode='lines',
|
||||
name=f'BB Lower({period})',
|
||||
line=dict(color=self.config.color or '#9C27B0', width=1),
|
||||
fill='tonexty',
|
||||
fillcolor='rgba(156, 39, 176, 0.1)',
|
||||
row=subplot_row,
|
||||
col=1,
|
||||
showlegend=True
|
||||
)
|
||||
traces.append(lower_trace)
|
||||
|
||||
# Middle line (SMA)
|
||||
if self.config.show_middle_line:
|
||||
middle_trace = go.Scatter(
|
||||
x=bb_data['timestamp'],
|
||||
y=bb_data['middle_band'],
|
||||
mode='lines',
|
||||
name=f'BB Middle({period})',
|
||||
line=dict(color=self.config.color or '#9C27B0', width=1, dash='dash'),
|
||||
row=subplot_row,
|
||||
col=1,
|
||||
showlegend=True
|
||||
)
|
||||
traces.append(middle_trace)
|
||||
|
||||
self.traces = traces
|
||||
return self.traces
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error creating Bollinger Bands traces: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
return [self.create_error_trace(error_msg)]
|
||||
|
||||
def _calculate_bollinger_bands(self, data: pd.DataFrame, period: int, std_dev: float) -> pd.DataFrame:
|
||||
"""Calculate Bollinger Bands with validation"""
|
||||
try:
|
||||
result_df = data.copy()
|
||||
|
||||
# Calculate middle band (SMA)
|
||||
result_df['middle_band'] = result_df['close'].rolling(window=period, min_periods=period).mean()
|
||||
|
||||
# Calculate standard deviation
|
||||
result_df['std'] = result_df['close'].rolling(window=period, min_periods=period).std()
|
||||
|
||||
# Calculate upper and lower bands
|
||||
result_df['upper_band'] = result_df['middle_band'] + (result_df['std'] * std_dev)
|
||||
result_df['lower_band'] = result_df['middle_band'] - (result_df['std'] * std_dev)
|
||||
|
||||
# Remove NaN values
|
||||
result_df = result_df.dropna(subset=['middle_band', 'upper_band', 'lower_band'])
|
||||
|
||||
if result_df.empty:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='BB_NO_VALUES',
|
||||
message=f'Bollinger Bands calculation produced no values (period={period}, data_length={len(data)})',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'std_dev': std_dev, 'data_length': len(data)}
|
||||
))
|
||||
|
||||
return result_df[['timestamp', 'upper_band', 'middle_band', 'lower_band']]
|
||||
|
||||
except Exception as e:
|
||||
raise IndicatorCalculationError(ChartError(
|
||||
code='BB_CALCULATION_ERROR',
|
||||
message=f'Bollinger Bands calculation failed: {str(e)}',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'period': period, 'std_dev': std_dev, 'data_length': len(data), 'exception': str(e)}
|
||||
))
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""Render Bollinger Bands layer for compatibility with base interface"""
|
||||
try:
|
||||
traces = self.create_traces(data.to_dict('records'), **kwargs)
|
||||
for trace in traces:
|
||||
if hasattr(fig, 'add_trace'):
|
||||
fig.add_trace(trace, **kwargs)
|
||||
else:
|
||||
fig.add_trace(trace)
|
||||
return fig
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering Bollinger Bands layer: {e}")
|
||||
return fig
|
||||
|
||||
|
||||
def create_sma_layer(period: int = 20, **kwargs) -> SMALayer:
|
||||
"""
|
||||
Convenience function to create an SMA layer.
|
||||
|
||||
Args:
|
||||
period: SMA period
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configured SMA layer
|
||||
"""
|
||||
return SMALayer(period=period, **kwargs)
|
||||
|
||||
|
||||
def create_ema_layer(period: int = 12, **kwargs) -> EMALayer:
|
||||
"""
|
||||
Convenience function to create an EMA layer.
|
||||
|
||||
Args:
|
||||
period: EMA period
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configured EMA layer
|
||||
"""
|
||||
return EMALayer(period=period, **kwargs)
|
||||
|
||||
|
||||
def create_bollinger_bands_layer(period: int = 20, std_dev: float = 2.0, **kwargs) -> BollingerBandsLayer:
|
||||
"""
|
||||
Convenience function to create a Bollinger Bands layer.
|
||||
|
||||
Args:
|
||||
period: BB period (default: 20)
|
||||
std_dev: Standard deviation multiplier (default: 2.0)
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configured Bollinger Bands layer
|
||||
"""
|
||||
return BollingerBandsLayer(period=period, std_dev=std_dev, **kwargs)
|
||||
|
||||
|
||||
def create_common_ma_layers() -> List[BaseIndicatorLayer]:
|
||||
"""
|
||||
Create commonly used moving average layers.
|
||||
|
||||
Returns:
|
||||
List of configured MA layers (SMA 20, SMA 50, EMA 12, EMA 26)
|
||||
"""
|
||||
colors = get_indicator_colors()
|
||||
|
||||
return [
|
||||
SMALayer(20, color=colors.get('sma', '#007bff'), name="SMA(20)"),
|
||||
SMALayer(50, color='#6c757d', name="SMA(50)"), # Gray for longer SMA
|
||||
EMALayer(12, color=colors.get('ema', '#ff6b35'), name="EMA(12)"),
|
||||
EMALayer(26, color='#28a745', name="EMA(26)") # Green for longer EMA
|
||||
]
|
||||
|
||||
|
||||
def create_common_overlay_indicators() -> List[BaseIndicatorLayer]:
|
||||
"""
|
||||
Create commonly used overlay indicators including moving averages and Bollinger Bands.
|
||||
|
||||
Returns:
|
||||
List of configured overlay indicator layers
|
||||
"""
|
||||
colors = get_indicator_colors()
|
||||
|
||||
return [
|
||||
SMALayer(20, color=colors.get('sma', '#007bff'), name="SMA(20)"),
|
||||
EMALayer(12, color=colors.get('ema', '#ff6b35'), name="EMA(12)"),
|
||||
BollingerBandsLayer(20, 2.0, color=colors.get('bb_upper', '#6f42c1'), name="BB(20,2)")
|
||||
]
|
||||
424
components/charts/layers/subplots.py
Normal file
424
components/charts/layers/subplots.py
Normal file
@ -0,0 +1,424 @@
|
||||
"""
|
||||
Subplot Chart Layers
|
||||
|
||||
This module contains subplot layer implementations for indicators that render
|
||||
in separate subplots below the main price chart, such as RSI, MACD, and other
|
||||
oscillators and momentum indicators.
|
||||
"""
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Any, Optional, List, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .base import BaseChartLayer, LayerConfig
|
||||
from .indicators import BaseIndicatorLayer, IndicatorLayerConfig
|
||||
from data.common.indicators import TechnicalIndicators, IndicatorResult, OHLCVCandle
|
||||
from components.charts.utils import get_indicator_colors
|
||||
from utils.logger import get_logger
|
||||
from ..error_handling import (
|
||||
ChartErrorHandler, ChartError, ErrorSeverity, DataRequirements,
|
||||
InsufficientDataError, DataValidationError, IndicatorCalculationError,
|
||||
ErrorRecoveryStrategies, create_error_annotation, get_error_message
|
||||
)
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger("subplot_layers")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubplotLayerConfig(IndicatorLayerConfig):
|
||||
"""Extended configuration for subplot indicator layers"""
|
||||
subplot_height_ratio: float = 0.25 # Height ratio for subplot (0.25 = 25% of total height)
|
||||
y_axis_range: Optional[Tuple[float, float]] = None # Fixed y-axis range (min, max)
|
||||
show_zero_line: bool = False # Show horizontal line at y=0
|
||||
reference_lines: List[float] = None # Additional horizontal reference lines
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.reference_lines is None:
|
||||
self.reference_lines = []
|
||||
|
||||
|
||||
class BaseSubplotLayer(BaseIndicatorLayer):
|
||||
"""
|
||||
Base class for all subplot indicator layers.
|
||||
|
||||
Provides common functionality for indicators that render in separate subplots
|
||||
with their own y-axis scaling and reference lines.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SubplotLayerConfig):
|
||||
"""
|
||||
Initialize base subplot layer.
|
||||
|
||||
Args:
|
||||
config: Subplot layer configuration
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.subplot_config = config
|
||||
|
||||
def get_subplot_height_ratio(self) -> float:
|
||||
"""Get the height ratio for this subplot."""
|
||||
return self.subplot_config.subplot_height_ratio
|
||||
|
||||
def has_fixed_range(self) -> bool:
|
||||
"""Check if this subplot has a fixed y-axis range."""
|
||||
return self.subplot_config.y_axis_range is not None
|
||||
|
||||
def get_y_axis_range(self) -> Optional[Tuple[float, float]]:
|
||||
"""Get the fixed y-axis range if defined."""
|
||||
return self.subplot_config.y_axis_range
|
||||
|
||||
def should_show_zero_line(self) -> bool:
|
||||
"""Check if zero line should be shown."""
|
||||
return self.subplot_config.show_zero_line
|
||||
|
||||
def get_reference_lines(self) -> List[float]:
|
||||
"""Get additional reference lines to draw."""
|
||||
return self.subplot_config.reference_lines
|
||||
|
||||
def add_reference_lines(self, fig: go.Figure, row: int, col: int = 1) -> None:
|
||||
"""
|
||||
Add reference lines to the subplot.
|
||||
|
||||
Args:
|
||||
fig: Target figure
|
||||
row: Subplot row
|
||||
col: Subplot column
|
||||
"""
|
||||
try:
|
||||
# Add zero line if enabled
|
||||
if self.should_show_zero_line():
|
||||
fig.add_hline(
|
||||
y=0,
|
||||
line=dict(color='gray', width=1, dash='dash'),
|
||||
row=row,
|
||||
col=col
|
||||
)
|
||||
|
||||
# Add additional reference lines
|
||||
for ref_value in self.get_reference_lines():
|
||||
fig.add_hline(
|
||||
y=ref_value,
|
||||
line=dict(color='lightgray', width=1, dash='dot'),
|
||||
row=row,
|
||||
col=col
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not add reference lines: {e}")
|
||||
|
||||
|
||||
class RSILayer(BaseSubplotLayer):
|
||||
"""
|
||||
Relative Strength Index (RSI) subplot layer.
|
||||
|
||||
Renders RSI oscillator in a separate subplot with standard overbought (70)
|
||||
and oversold (30) reference lines.
|
||||
"""
|
||||
|
||||
def __init__(self, period: int = 14, color: str = None, name: str = None):
|
||||
"""
|
||||
Initialize RSI layer.
|
||||
|
||||
Args:
|
||||
period: RSI period (default: 14)
|
||||
color: Line color (optional, uses default)
|
||||
name: Layer name (optional, auto-generated)
|
||||
"""
|
||||
# Use default color if not specified
|
||||
if color is None:
|
||||
colors = get_indicator_colors()
|
||||
color = colors.get('rsi', '#20c997')
|
||||
|
||||
# Generate name if not specified
|
||||
if name is None:
|
||||
name = f"RSI({period})"
|
||||
|
||||
# Find next available subplot row (will be managed by LayerManager)
|
||||
subplot_row = 2 # Default to row 2 (first subplot after main chart)
|
||||
|
||||
config = SubplotLayerConfig(
|
||||
name=name,
|
||||
indicator_type="rsi",
|
||||
color=color,
|
||||
parameters={'period': period},
|
||||
subplot_row=subplot_row,
|
||||
subplot_height_ratio=0.25,
|
||||
y_axis_range=(0, 100), # RSI ranges from 0 to 100
|
||||
reference_lines=[30, 70], # Oversold and overbought levels
|
||||
style={
|
||||
'line_color': color,
|
||||
'line_width': 2,
|
||||
'opacity': 1.0
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
self.period = period
|
||||
|
||||
def _calculate_rsi(self, data: pd.DataFrame, period: int) -> pd.DataFrame:
|
||||
"""Calculate RSI with validation and error handling"""
|
||||
try:
|
||||
result_df = data.copy()
|
||||
|
||||
# Calculate price changes
|
||||
result_df['price_change'] = result_df['close'].diff()
|
||||
|
||||
# Separate gains and losses
|
||||
result_df['gain'] = result_df['price_change'].clip(lower=0)
|
||||
result_df['loss'] = -result_df['price_change'].clip(upper=0)
|
||||
|
||||
# Calculate average gains and losses using Wilder's smoothing
|
||||
result_df['avg_gain'] = result_df['gain'].ewm(alpha=1/period, adjust=False).mean()
|
||||
result_df['avg_loss'] = result_df['loss'].ewm(alpha=1/period, adjust=False).mean()
|
||||
|
||||
# Calculate RS and RSI
|
||||
result_df['rs'] = result_df['avg_gain'] / result_df['avg_loss']
|
||||
result_df['rsi'] = 100 - (100 / (1 + result_df['rs']))
|
||||
|
||||
# Remove rows where RSI cannot be calculated
|
||||
result_df = result_df.iloc[period:].copy()
|
||||
|
||||
# Remove NaN values and invalid RSI values
|
||||
result_df = result_df.dropna(subset=['rsi'])
|
||||
result_df = result_df[
|
||||
(result_df['rsi'] >= 0) &
|
||||
(result_df['rsi'] <= 100) &
|
||||
pd.notna(result_df['rsi'])
|
||||
]
|
||||
|
||||
if result_df.empty:
|
||||
raise Exception(f'RSI calculation produced no values (period={period}, data_length={len(data)})')
|
||||
|
||||
return result_df[['timestamp', 'rsi']]
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f'RSI calculation failed: {str(e)}')
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""Render RSI layer for compatibility with base interface"""
|
||||
try:
|
||||
# Calculate RSI
|
||||
rsi_data = self._calculate_rsi(data, self.period)
|
||||
if rsi_data.empty:
|
||||
return fig
|
||||
|
||||
# Create RSI trace
|
||||
rsi_trace = go.Scatter(
|
||||
x=rsi_data['timestamp'],
|
||||
y=rsi_data['rsi'],
|
||||
mode='lines',
|
||||
name=self.config.name,
|
||||
line=dict(
|
||||
color=self.config.color,
|
||||
width=2
|
||||
),
|
||||
showlegend=True
|
||||
)
|
||||
|
||||
# Add trace
|
||||
row = kwargs.get('row', self.config.subplot_row or 2)
|
||||
col = kwargs.get('col', 1)
|
||||
|
||||
if hasattr(fig, 'add_trace'):
|
||||
fig.add_trace(rsi_trace, row=row, col=col)
|
||||
else:
|
||||
fig.add_trace(rsi_trace)
|
||||
|
||||
# Add reference lines
|
||||
self.add_reference_lines(fig, row, col)
|
||||
|
||||
return fig
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering RSI layer: {e}")
|
||||
return fig
|
||||
|
||||
|
||||
class MACDLayer(BaseSubplotLayer):
|
||||
"""MACD (Moving Average Convergence Divergence) subplot layer with enhanced error handling"""
|
||||
|
||||
def __init__(self, fast_period: int = 12, slow_period: int = 26, signal_period: int = 9,
|
||||
color: str = None, name: str = None):
|
||||
"""Initialize MACD layer with custom parameters"""
|
||||
# Use default color if not specified
|
||||
if color is None:
|
||||
colors = get_indicator_colors()
|
||||
color = colors.get('macd', '#fd7e14')
|
||||
|
||||
# Generate name if not specified
|
||||
if name is None:
|
||||
name = f"MACD({fast_period},{slow_period},{signal_period})"
|
||||
|
||||
config = SubplotLayerConfig(
|
||||
name=name,
|
||||
indicator_type="macd",
|
||||
color=color,
|
||||
parameters={
|
||||
'fast_period': fast_period,
|
||||
'slow_period': slow_period,
|
||||
'signal_period': signal_period
|
||||
},
|
||||
subplot_row=3, # Will be managed by LayerManager
|
||||
subplot_height_ratio=0.3,
|
||||
show_zero_line=True,
|
||||
style={
|
||||
'line_color': color,
|
||||
'line_width': 2,
|
||||
'opacity': 1.0
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
self.fast_period = fast_period
|
||||
self.slow_period = slow_period
|
||||
self.signal_period = signal_period
|
||||
|
||||
def _calculate_macd(self, data: pd.DataFrame, fast_period: int,
|
||||
slow_period: int, signal_period: int) -> pd.DataFrame:
|
||||
"""Calculate MACD with validation and error handling"""
|
||||
try:
|
||||
result_df = data.copy()
|
||||
|
||||
# Validate periods
|
||||
if fast_period >= slow_period:
|
||||
raise Exception(f'Fast period ({fast_period}) must be less than slow period ({slow_period})')
|
||||
|
||||
# Calculate EMAs
|
||||
result_df['ema_fast'] = result_df['close'].ewm(span=fast_period, adjust=False).mean()
|
||||
result_df['ema_slow'] = result_df['close'].ewm(span=slow_period, adjust=False).mean()
|
||||
|
||||
# Calculate MACD line
|
||||
result_df['macd'] = result_df['ema_fast'] - result_df['ema_slow']
|
||||
|
||||
# Calculate signal line
|
||||
result_df['signal'] = result_df['macd'].ewm(span=signal_period, adjust=False).mean()
|
||||
|
||||
# Calculate histogram
|
||||
result_df['histogram'] = result_df['macd'] - result_df['signal']
|
||||
|
||||
# Remove rows where MACD cannot be calculated reliably
|
||||
warmup_period = slow_period + signal_period
|
||||
result_df = result_df.iloc[warmup_period:].copy()
|
||||
|
||||
# Remove NaN values
|
||||
result_df = result_df.dropna(subset=['macd', 'signal', 'histogram'])
|
||||
|
||||
if result_df.empty:
|
||||
raise Exception(f'MACD calculation produced no values (fast={fast_period}, slow={slow_period}, signal={signal_period})')
|
||||
|
||||
return result_df[['timestamp', 'macd', 'signal', 'histogram']]
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f'MACD calculation failed: {str(e)}')
|
||||
|
||||
def render(self, fig: go.Figure, data: pd.DataFrame, **kwargs) -> go.Figure:
|
||||
"""Render MACD layer for compatibility with base interface"""
|
||||
try:
|
||||
# Calculate MACD
|
||||
macd_data = self._calculate_macd(data, self.fast_period, self.slow_period, self.signal_period)
|
||||
if macd_data.empty:
|
||||
return fig
|
||||
|
||||
row = kwargs.get('row', self.config.subplot_row or 3)
|
||||
col = kwargs.get('col', 1)
|
||||
|
||||
# Create MACD line trace
|
||||
macd_trace = go.Scatter(
|
||||
x=macd_data['timestamp'],
|
||||
y=macd_data['macd'],
|
||||
mode='lines',
|
||||
name=f'{self.config.name} Line',
|
||||
line=dict(color=self.config.color, width=2),
|
||||
showlegend=True
|
||||
)
|
||||
|
||||
# Create signal line trace
|
||||
signal_trace = go.Scatter(
|
||||
x=macd_data['timestamp'],
|
||||
y=macd_data['signal'],
|
||||
mode='lines',
|
||||
name=f'{self.config.name} Signal',
|
||||
line=dict(color='#FF9800', width=2),
|
||||
showlegend=True
|
||||
)
|
||||
|
||||
# Create histogram
|
||||
histogram_colors = ['green' if h >= 0 else 'red' for h in macd_data['histogram']]
|
||||
histogram_trace = go.Bar(
|
||||
x=macd_data['timestamp'],
|
||||
y=macd_data['histogram'],
|
||||
name=f'{self.config.name} Histogram',
|
||||
marker_color=histogram_colors,
|
||||
opacity=0.6,
|
||||
showlegend=True
|
||||
)
|
||||
|
||||
# Add traces
|
||||
if hasattr(fig, 'add_trace'):
|
||||
fig.add_trace(macd_trace, row=row, col=col)
|
||||
fig.add_trace(signal_trace, row=row, col=col)
|
||||
fig.add_trace(histogram_trace, row=row, col=col)
|
||||
else:
|
||||
fig.add_trace(macd_trace)
|
||||
fig.add_trace(signal_trace)
|
||||
fig.add_trace(histogram_trace)
|
||||
|
||||
# Add zero line
|
||||
self.add_reference_lines(fig, row, col)
|
||||
|
||||
return fig
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error rendering MACD layer: {e}")
|
||||
return fig
|
||||
|
||||
|
||||
def create_rsi_layer(period: int = 14, **kwargs) -> 'RSILayer':
|
||||
"""
|
||||
Convenience function to create an RSI layer.
|
||||
|
||||
Args:
|
||||
period: RSI period (default: 14)
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configured RSI layer
|
||||
"""
|
||||
return RSILayer(period=period, **kwargs)
|
||||
|
||||
|
||||
def create_macd_layer(fast_period: int = 12, slow_period: int = 26,
|
||||
signal_period: int = 9, **kwargs) -> 'MACDLayer':
|
||||
"""
|
||||
Convenience function to create a MACD layer.
|
||||
|
||||
Args:
|
||||
fast_period: Fast EMA period (default: 12)
|
||||
slow_period: Slow EMA period (default: 26)
|
||||
signal_period: Signal line period (default: 9)
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configured MACD layer
|
||||
"""
|
||||
return MACDLayer(
|
||||
fast_period=fast_period,
|
||||
slow_period=slow_period,
|
||||
signal_period=signal_period,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def create_common_subplot_indicators() -> List[BaseSubplotLayer]:
|
||||
"""
|
||||
Create commonly used subplot indicators.
|
||||
|
||||
Returns:
|
||||
List of configured subplot indicator layers (RSI, MACD)
|
||||
"""
|
||||
return [
|
||||
RSILayer(period=14),
|
||||
MACDLayer(fast_period=12, slow_period=26, signal_period=9)
|
||||
]
|
||||
@ -12,8 +12,8 @@ Implementation of a flexible, strategy-driven chart system that supports technic
|
||||
- `components/charts/config/indicator_defs.py` - Base indicator definitions, schemas, and default parameters
|
||||
- `components/charts/config/strategy_charts.py` - Strategy-specific chart configurations and presets
|
||||
- `components/charts/config/defaults.py` - Default chart configurations and fallback settings
|
||||
- `components/charts/layers/__init__.py` - Chart layers package initialization
|
||||
- `components/charts/layers/base.py` - Base candlestick chart layer implementation
|
||||
- `components/charts/layers/__init__.py` - Chart layers package initialization with base layer exports
|
||||
- `components/charts/layers/base.py` - Base layer system with CandlestickLayer, VolumeLayer, and LayerManager
|
||||
- `components/charts/layers/indicators.py` - Indicator overlay rendering (SMA, EMA, Bollinger Bands)
|
||||
- `components/charts/layers/subplots.py` - Subplot management for indicators like RSI and MACD
|
||||
- `components/charts/layers/signals.py` - Strategy signal overlays and trade markers (future bot integration)
|
||||
@ -42,15 +42,15 @@ Implementation of a flexible, strategy-driven chart system that supports technic
|
||||
- [x] 1.5 Setup backward compatibility with existing components/charts.py API
|
||||
- [x] 1.6 Create basic unit tests for ChartBuilder class
|
||||
|
||||
- [ ] 2.0 Indicator Layer System Implementation
|
||||
- [ ] 2.1 Create base candlestick chart layer with volume subplot
|
||||
- [ ] 2.2 Implement overlay indicator rendering (SMA, EMA)
|
||||
- [ ] 2.3 Add Bollinger Bands overlay functionality
|
||||
- [ ] 2.4 Create subplot management system for secondary indicators
|
||||
- [ ] 2.5 Implement RSI subplot with proper scaling and styling
|
||||
- [ ] 2.6 Add MACD subplot with signal line and histogram
|
||||
- [ ] 2.7 Create indicator calculation integration with market data
|
||||
- [ ] 2.8 Add error handling for insufficient data scenarios
|
||||
- [x] 2.0 Indicator Layer System Implementation
|
||||
- [x] 2.1 Create base candlestick chart layer with volume subplot
|
||||
- [x] 2.2 Implement overlay indicator rendering (SMA, EMA)
|
||||
- [x] 2.3 Add Bollinger Bands overlay functionality
|
||||
- [x] 2.4 Create subplot management system for secondary indicators
|
||||
- [x] 2.5 Implement RSI subplot with proper scaling and styling
|
||||
- [x] 2.6 Add MACD subplot with signal line and histogram
|
||||
- [x] 2.7 Create indicator calculation integration with market data
|
||||
- [x] 2.8 Add comprehensive error handling for insufficient data scenarios
|
||||
- [ ] 2.9 Unit test all indicator layer components
|
||||
|
||||
- [ ] 3.0 Strategy Configuration System
|
||||
|
||||
711
tests/test_chart_layers.py
Normal file
711
tests/test_chart_layers.py
Normal file
@ -0,0 +1,711 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Unit Tests for Chart Layer Components
|
||||
|
||||
Tests for all chart layer functionality including:
|
||||
- Error handling system
|
||||
- Base layer components (CandlestickLayer, VolumeLayer, LayerManager)
|
||||
- Indicator layers (SMA, EMA, Bollinger Bands)
|
||||
- Subplot layers (RSI, MACD)
|
||||
- Integration and error recovery
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from typing import List, Dict, Any
|
||||
from decimal import Decimal
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Import components to test
|
||||
from components.charts.error_handling import (
|
||||
ChartErrorHandler, ChartError, ErrorSeverity, DataRequirements,
|
||||
ErrorRecoveryStrategies, check_data_sufficiency, get_error_message
|
||||
)
|
||||
|
||||
from components.charts.layers.base import (
|
||||
LayerConfig, BaseLayer, CandlestickLayer, VolumeLayer, LayerManager
|
||||
)
|
||||
|
||||
from components.charts.layers.indicators import (
|
||||
IndicatorLayerConfig, BaseIndicatorLayer, SMALayer, EMALayer, BollingerBandsLayer
|
||||
)
|
||||
|
||||
from components.charts.layers.subplots import (
|
||||
SubplotLayerConfig, BaseSubplotLayer, RSILayer, MACDLayer
|
||||
)
|
||||
|
||||
|
||||
class TestErrorHandlingSystem:
|
||||
"""Test suite for chart error handling system"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Sample market data for testing"""
|
||||
base_time = datetime.now(timezone.utc) - timedelta(hours=24)
|
||||
return [
|
||||
{
|
||||
'timestamp': base_time + timedelta(minutes=i),
|
||||
'open': 50000 + i * 10,
|
||||
'high': 50100 + i * 10,
|
||||
'low': 49900 + i * 10,
|
||||
'close': 50050 + i * 10,
|
||||
'volume': 1000 + i * 5
|
||||
}
|
||||
for i in range(50) # 50 data points
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def insufficient_data(self):
|
||||
"""Insufficient market data for testing"""
|
||||
base_time = datetime.now(timezone.utc)
|
||||
return [
|
||||
{
|
||||
'timestamp': base_time + timedelta(minutes=i),
|
||||
'open': 50000,
|
||||
'high': 50100,
|
||||
'low': 49900,
|
||||
'close': 50050,
|
||||
'volume': 1000
|
||||
}
|
||||
for i in range(5) # Only 5 data points
|
||||
]
|
||||
|
||||
def test_chart_error_creation(self):
|
||||
"""Test ChartError dataclass creation"""
|
||||
error = ChartError(
|
||||
code='TEST_ERROR',
|
||||
message='Test error message',
|
||||
severity=ErrorSeverity.ERROR,
|
||||
context={'test': 'value'},
|
||||
recovery_suggestion='Fix the test'
|
||||
)
|
||||
|
||||
assert error.code == 'TEST_ERROR'
|
||||
assert error.message == 'Test error message'
|
||||
assert error.severity == ErrorSeverity.ERROR
|
||||
assert error.context == {'test': 'value'}
|
||||
assert error.recovery_suggestion == 'Fix the test'
|
||||
|
||||
# Test dict conversion
|
||||
error_dict = error.to_dict()
|
||||
assert error_dict['code'] == 'TEST_ERROR'
|
||||
assert error_dict['severity'] == 'error'
|
||||
|
||||
def test_data_requirements_candlestick(self):
|
||||
"""Test data requirements checking for candlestick charts"""
|
||||
# Test sufficient data
|
||||
error = DataRequirements.check_candlestick_requirements(50)
|
||||
assert error.severity == ErrorSeverity.INFO
|
||||
assert error.code == 'SUFFICIENT_DATA'
|
||||
|
||||
# Test insufficient data
|
||||
error = DataRequirements.check_candlestick_requirements(5)
|
||||
assert error.severity == ErrorSeverity.WARNING
|
||||
assert error.code == 'INSUFFICIENT_CANDLESTICK_DATA'
|
||||
|
||||
# Test no data
|
||||
error = DataRequirements.check_candlestick_requirements(0)
|
||||
assert error.severity == ErrorSeverity.CRITICAL
|
||||
assert error.code == 'NO_DATA'
|
||||
|
||||
def test_data_requirements_indicators(self):
|
||||
"""Test data requirements checking for indicators"""
|
||||
# Test SMA with sufficient data
|
||||
error = DataRequirements.check_indicator_requirements('sma', 50, {'period': 20})
|
||||
assert error.severity == ErrorSeverity.INFO
|
||||
|
||||
# Test SMA with insufficient data
|
||||
error = DataRequirements.check_indicator_requirements('sma', 15, {'period': 20})
|
||||
assert error.severity == ErrorSeverity.WARNING
|
||||
assert error.code == 'INSUFFICIENT_INDICATOR_DATA'
|
||||
|
||||
# Test unknown indicator
|
||||
error = DataRequirements.check_indicator_requirements('unknown', 50, {})
|
||||
assert error.severity == ErrorSeverity.ERROR
|
||||
assert error.code == 'UNKNOWN_INDICATOR'
|
||||
|
||||
def test_chart_error_handler(self, sample_data, insufficient_data):
|
||||
"""Test ChartErrorHandler functionality"""
|
||||
handler = ChartErrorHandler()
|
||||
|
||||
# Test with sufficient data
|
||||
is_valid = handler.validate_data_sufficiency(sample_data)
|
||||
assert is_valid == True
|
||||
assert len(handler.errors) == 0
|
||||
|
||||
# Test with insufficient data and indicators
|
||||
indicators = [{'type': 'sma', 'parameters': {'period': 30}}]
|
||||
is_valid = handler.validate_data_sufficiency(insufficient_data, indicators=indicators)
|
||||
assert is_valid == False
|
||||
assert len(handler.errors) > 0 or len(handler.warnings) > 0
|
||||
|
||||
# Test error summary
|
||||
summary = handler.get_error_summary()
|
||||
assert 'has_errors' in summary
|
||||
assert 'can_proceed' in summary
|
||||
|
||||
def test_convenience_functions(self, sample_data, insufficient_data):
|
||||
"""Test convenience functions for error handling"""
|
||||
# Test check_data_sufficiency
|
||||
is_sufficient, summary = check_data_sufficiency(sample_data)
|
||||
assert is_sufficient == True
|
||||
assert summary['can_proceed'] == True
|
||||
|
||||
# Test with insufficient data
|
||||
indicators = [{'type': 'sma', 'parameters': {'period': 100}}]
|
||||
is_sufficient, summary = check_data_sufficiency(insufficient_data, indicators)
|
||||
assert is_sufficient == False
|
||||
|
||||
# Test get_error_message
|
||||
error_msg = get_error_message(insufficient_data, indicators)
|
||||
assert isinstance(error_msg, str)
|
||||
assert len(error_msg) > 0
|
||||
|
||||
|
||||
class TestBaseLayerSystem:
|
||||
"""Test suite for base layer components"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_df(self):
|
||||
"""Sample DataFrame for testing"""
|
||||
base_time = datetime.now(timezone.utc) - timedelta(hours=24)
|
||||
data = []
|
||||
for i in range(100):
|
||||
data.append({
|
||||
'timestamp': base_time + timedelta(minutes=i),
|
||||
'open': 50000 + i * 10,
|
||||
'high': 50100 + i * 10,
|
||||
'low': 49900 + i * 10,
|
||||
'close': 50050 + i * 10,
|
||||
'volume': 1000 + i * 5
|
||||
})
|
||||
return pd.DataFrame(data)
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_df(self):
|
||||
"""Invalid DataFrame for testing error handling"""
|
||||
return pd.DataFrame([
|
||||
{'timestamp': datetime.now(), 'open': -100, 'high': 50, 'low': 60, 'close': 40, 'volume': -50},
|
||||
{'timestamp': datetime.now(), 'open': None, 'high': None, 'low': None, 'close': None, 'volume': None}
|
||||
])
|
||||
|
||||
def test_layer_config(self):
|
||||
"""Test LayerConfig creation"""
|
||||
config = LayerConfig(name="test", enabled=True, color="#FF0000")
|
||||
assert config.name == "test"
|
||||
assert config.enabled == True
|
||||
assert config.color == "#FF0000"
|
||||
assert config.style == {}
|
||||
assert config.subplot_row is None
|
||||
|
||||
def test_base_layer(self):
|
||||
"""Test BaseLayer functionality"""
|
||||
config = LayerConfig(name="test_layer")
|
||||
layer = BaseLayer(config)
|
||||
|
||||
assert layer.config.name == "test_layer"
|
||||
assert hasattr(layer, 'error_handler')
|
||||
assert hasattr(layer, 'logger')
|
||||
|
||||
def test_candlestick_layer_validation(self, sample_df, invalid_df):
|
||||
"""Test CandlestickLayer data validation"""
|
||||
layer = CandlestickLayer()
|
||||
|
||||
# Test valid data
|
||||
is_valid = layer.validate_data(sample_df)
|
||||
assert is_valid == True
|
||||
|
||||
# Test invalid data
|
||||
is_valid = layer.validate_data(invalid_df)
|
||||
assert is_valid == False
|
||||
assert len(layer.error_handler.errors) > 0
|
||||
|
||||
def test_candlestick_layer_render(self, sample_df):
|
||||
"""Test CandlestickLayer rendering"""
|
||||
layer = CandlestickLayer()
|
||||
fig = go.Figure()
|
||||
|
||||
result_fig = layer.render(fig, sample_df)
|
||||
assert result_fig is not None
|
||||
assert len(result_fig.data) >= 1 # Should have candlestick trace
|
||||
|
||||
def test_volume_layer_validation(self, sample_df, invalid_df):
|
||||
"""Test VolumeLayer data validation"""
|
||||
layer = VolumeLayer()
|
||||
|
||||
# Test valid data
|
||||
is_valid = layer.validate_data(sample_df)
|
||||
assert is_valid == True
|
||||
|
||||
# Test invalid data (some volume issues)
|
||||
is_valid = layer.validate_data(invalid_df)
|
||||
# Volume layer should handle invalid data gracefully
|
||||
assert len(layer.error_handler.warnings) >= 0 # May have warnings
|
||||
|
||||
def test_volume_layer_render(self, sample_df):
|
||||
"""Test VolumeLayer rendering"""
|
||||
layer = VolumeLayer()
|
||||
fig = go.Figure()
|
||||
|
||||
result_fig = layer.render(fig, sample_df)
|
||||
assert result_fig is not None
|
||||
|
||||
def test_layer_manager(self, sample_df):
|
||||
"""Test LayerManager functionality"""
|
||||
manager = LayerManager()
|
||||
|
||||
# Add layers
|
||||
candlestick_layer = CandlestickLayer()
|
||||
volume_layer = VolumeLayer()
|
||||
manager.add_layer(candlestick_layer)
|
||||
manager.add_layer(volume_layer)
|
||||
|
||||
assert len(manager.layers) == 2
|
||||
|
||||
# Test enabled layers
|
||||
enabled = manager.get_enabled_layers()
|
||||
assert len(enabled) == 2
|
||||
|
||||
# Test overlay vs subplot layers
|
||||
overlays = manager.get_overlay_layers()
|
||||
subplots = manager.get_subplot_layers()
|
||||
|
||||
assert len(overlays) == 1 # Candlestick is overlay
|
||||
assert len(subplots) >= 1 # Volume is subplot
|
||||
|
||||
# Test layout calculation
|
||||
layout_config = manager.calculate_subplot_layout()
|
||||
assert 'rows' in layout_config
|
||||
assert 'cols' in layout_config
|
||||
assert layout_config['rows'] >= 2 # Main chart + volume subplot
|
||||
|
||||
# Test rendering all layers
|
||||
fig = manager.render_all_layers(sample_df)
|
||||
assert fig is not None
|
||||
assert len(fig.data) >= 2 # Candlestick + volume
|
||||
|
||||
|
||||
class TestIndicatorLayers:
|
||||
"""Test suite for indicator layer components"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_df(self):
|
||||
"""Sample DataFrame with trend for indicator testing"""
|
||||
base_time = datetime.now(timezone.utc) - timedelta(hours=24)
|
||||
data = []
|
||||
for i in range(100):
|
||||
# Create trending data for better indicator calculation
|
||||
trend = i * 0.1
|
||||
base_price = 50000 + trend
|
||||
data.append({
|
||||
'timestamp': base_time + timedelta(minutes=i),
|
||||
'open': base_price + (i % 3) * 10,
|
||||
'high': base_price + 50 + (i % 3) * 10,
|
||||
'low': base_price - 50 + (i % 3) * 10,
|
||||
'close': base_price + (i % 2) * 10,
|
||||
'volume': 1000 + i * 5
|
||||
})
|
||||
return pd.DataFrame(data)
|
||||
|
||||
@pytest.fixture
|
||||
def insufficient_df(self):
|
||||
"""Insufficient data for indicator testing"""
|
||||
base_time = datetime.now(timezone.utc)
|
||||
data = []
|
||||
for i in range(10): # Only 10 data points
|
||||
data.append({
|
||||
'timestamp': base_time + timedelta(minutes=i),
|
||||
'open': 50000,
|
||||
'high': 50100,
|
||||
'low': 49900,
|
||||
'close': 50050,
|
||||
'volume': 1000
|
||||
})
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def test_indicator_layer_config(self):
|
||||
"""Test IndicatorLayerConfig creation"""
|
||||
config = IndicatorLayerConfig(
|
||||
name="test_indicator",
|
||||
indicator_type="sma",
|
||||
parameters={'period': 20}
|
||||
)
|
||||
|
||||
assert config.name == "test_indicator"
|
||||
assert config.indicator_type == "sma"
|
||||
assert config.parameters == {'period': 20}
|
||||
assert config.line_width == 2
|
||||
assert config.opacity == 1.0
|
||||
|
||||
def test_sma_layer(self, sample_df, insufficient_df):
|
||||
"""Test SMALayer functionality"""
|
||||
config = IndicatorLayerConfig(
|
||||
name="SMA(20)",
|
||||
indicator_type='sma',
|
||||
parameters={'period': 20}
|
||||
)
|
||||
layer = SMALayer(config)
|
||||
|
||||
# Test with sufficient data
|
||||
is_valid = layer.validate_indicator_data(sample_df, required_columns=['close', 'timestamp'])
|
||||
assert is_valid == True
|
||||
|
||||
# Test calculation
|
||||
sma_data = layer._calculate_sma(sample_df, 20)
|
||||
assert sma_data is not None
|
||||
assert 'sma' in sma_data.columns
|
||||
assert len(sma_data) > 0
|
||||
|
||||
# Test with insufficient data
|
||||
is_valid = layer.validate_indicator_data(insufficient_df, required_columns=['close', 'timestamp'])
|
||||
# Should have warnings but may still be valid for short periods
|
||||
assert len(layer.error_handler.warnings) >= 0
|
||||
|
||||
def test_ema_layer(self, sample_df):
|
||||
"""Test EMALayer functionality"""
|
||||
config = IndicatorLayerConfig(
|
||||
name="EMA(12)",
|
||||
indicator_type='ema',
|
||||
parameters={'period': 12}
|
||||
)
|
||||
layer = EMALayer(config)
|
||||
|
||||
# Test validation
|
||||
is_valid = layer.validate_indicator_data(sample_df, required_columns=['close', 'timestamp'])
|
||||
assert is_valid == True
|
||||
|
||||
# Test calculation
|
||||
ema_data = layer._calculate_ema(sample_df, 12)
|
||||
assert ema_data is not None
|
||||
assert 'ema' in ema_data.columns
|
||||
assert len(ema_data) > 0
|
||||
|
||||
def test_bollinger_bands_layer(self, sample_df):
|
||||
"""Test BollingerBandsLayer functionality"""
|
||||
config = IndicatorLayerConfig(
|
||||
name="BB(20,2)",
|
||||
indicator_type='bollinger_bands',
|
||||
parameters={'period': 20, 'std_dev': 2}
|
||||
)
|
||||
layer = BollingerBandsLayer(config)
|
||||
|
||||
# Test validation
|
||||
is_valid = layer.validate_indicator_data(sample_df, required_columns=['close', 'timestamp'])
|
||||
assert is_valid == True
|
||||
|
||||
# Test calculation
|
||||
bb_data = layer._calculate_bollinger_bands(sample_df, 20, 2)
|
||||
assert bb_data is not None
|
||||
assert 'upper_band' in bb_data.columns
|
||||
assert 'middle_band' in bb_data.columns
|
||||
assert 'lower_band' in bb_data.columns
|
||||
assert len(bb_data) > 0
|
||||
|
||||
def test_safe_calculate_indicator(self, sample_df, insufficient_df):
|
||||
"""Test safe indicator calculation with error handling"""
|
||||
config = IndicatorLayerConfig(
|
||||
name="SMA(20)",
|
||||
indicator_type='sma',
|
||||
parameters={'period': 20}
|
||||
)
|
||||
layer = SMALayer(config)
|
||||
|
||||
# Test successful calculation
|
||||
result = layer.safe_calculate_indicator(
|
||||
sample_df,
|
||||
layer._calculate_sma,
|
||||
period=20
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
# Test with insufficient data - should attempt recovery
|
||||
result = layer.safe_calculate_indicator(
|
||||
insufficient_df,
|
||||
layer._calculate_sma,
|
||||
period=50 # Too large for data
|
||||
)
|
||||
# Should either return adjusted result or None
|
||||
assert result is None or len(result) > 0
|
||||
|
||||
|
||||
class TestSubplotLayers:
|
||||
"""Test suite for subplot layer components"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_df(self):
|
||||
"""Sample DataFrame for RSI/MACD testing"""
|
||||
base_time = datetime.now(timezone.utc) - timedelta(hours=24)
|
||||
data = []
|
||||
|
||||
# Create more realistic price data for RSI/MACD
|
||||
prices = [50000]
|
||||
for i in range(100):
|
||||
# Random walk with trend
|
||||
change = (i % 7 - 3) * 50 # Some volatility
|
||||
new_price = prices[-1] + change
|
||||
prices.append(new_price)
|
||||
|
||||
data.append({
|
||||
'timestamp': base_time + timedelta(minutes=i),
|
||||
'open': prices[i],
|
||||
'high': prices[i] + abs(change) + 20,
|
||||
'low': prices[i] - abs(change) - 20,
|
||||
'close': prices[i+1],
|
||||
'volume': 1000 + i * 5
|
||||
})
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def test_subplot_layer_config(self):
|
||||
"""Test SubplotLayerConfig creation"""
|
||||
config = SubplotLayerConfig(
|
||||
name="RSI(14)",
|
||||
indicator_type="rsi",
|
||||
parameters={'period': 14},
|
||||
subplot_height_ratio=0.25,
|
||||
y_axis_range=(0, 100),
|
||||
reference_lines=[30, 70]
|
||||
)
|
||||
|
||||
assert config.name == "RSI(14)"
|
||||
assert config.indicator_type == "rsi"
|
||||
assert config.subplot_height_ratio == 0.25
|
||||
assert config.y_axis_range == (0, 100)
|
||||
assert config.reference_lines == [30, 70]
|
||||
|
||||
def test_rsi_layer(self, sample_df):
|
||||
"""Test RSILayer functionality"""
|
||||
layer = RSILayer(period=14)
|
||||
|
||||
# Test validation
|
||||
is_valid = layer.validate_indicator_data(sample_df, required_columns=['close', 'timestamp'])
|
||||
assert is_valid == True
|
||||
|
||||
# Test RSI calculation
|
||||
rsi_data = layer._calculate_rsi(sample_df, 14)
|
||||
assert rsi_data is not None
|
||||
assert 'rsi' in rsi_data.columns
|
||||
assert len(rsi_data) > 0
|
||||
|
||||
# Validate RSI values are in correct range
|
||||
assert (rsi_data['rsi'] >= 0).all()
|
||||
assert (rsi_data['rsi'] <= 100).all()
|
||||
|
||||
# Test subplot properties
|
||||
assert layer.has_fixed_range() == True
|
||||
assert layer.get_y_axis_range() == (0, 100)
|
||||
assert 30 in layer.get_reference_lines()
|
||||
assert 70 in layer.get_reference_lines()
|
||||
|
||||
def test_macd_layer(self, sample_df):
|
||||
"""Test MACDLayer functionality"""
|
||||
layer = MACDLayer(fast_period=12, slow_period=26, signal_period=9)
|
||||
|
||||
# Test validation
|
||||
is_valid = layer.validate_indicator_data(sample_df, required_columns=['close', 'timestamp'])
|
||||
assert is_valid == True
|
||||
|
||||
# Test MACD calculation
|
||||
macd_data = layer._calculate_macd(sample_df, 12, 26, 9)
|
||||
assert macd_data is not None
|
||||
assert 'macd' in macd_data.columns
|
||||
assert 'signal' in macd_data.columns
|
||||
assert 'histogram' in macd_data.columns
|
||||
assert len(macd_data) > 0
|
||||
|
||||
# Test subplot properties
|
||||
assert layer.should_show_zero_line() == True
|
||||
assert layer.get_subplot_height_ratio() == 0.3
|
||||
|
||||
def test_rsi_calculation_edge_cases(self, sample_df):
|
||||
"""Test RSI calculation with edge cases"""
|
||||
layer = RSILayer(period=14)
|
||||
|
||||
# Test with very short period
|
||||
short_data = sample_df.head(20)
|
||||
rsi_data = layer._calculate_rsi(short_data, 5) # Short period
|
||||
assert rsi_data is not None
|
||||
assert len(rsi_data) > 0
|
||||
|
||||
# Test with period too large for data
|
||||
try:
|
||||
layer._calculate_rsi(sample_df.head(10), 20) # Period larger than data
|
||||
assert False, "Should have raised an error"
|
||||
except Exception:
|
||||
pass # Expected to fail
|
||||
|
||||
def test_macd_calculation_edge_cases(self, sample_df):
|
||||
"""Test MACD calculation with edge cases"""
|
||||
layer = MACDLayer(fast_period=12, slow_period=26, signal_period=9)
|
||||
|
||||
# Test with invalid periods (fast >= slow)
|
||||
try:
|
||||
layer._calculate_macd(sample_df, 26, 12, 9) # fast >= slow
|
||||
assert False, "Should have raised an error"
|
||||
except Exception:
|
||||
pass # Expected to fail
|
||||
|
||||
|
||||
class TestLayerIntegration:
|
||||
"""Test suite for layer integration and complex scenarios"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_df(self):
|
||||
"""Sample DataFrame for integration testing"""
|
||||
base_time = datetime.now(timezone.utc) - timedelta(hours=24)
|
||||
data = []
|
||||
for i in range(150): # Enough data for all indicators
|
||||
trend = i * 0.1
|
||||
base_price = 50000 + trend
|
||||
volatility = (i % 10) * 20
|
||||
|
||||
data.append({
|
||||
'timestamp': base_time + timedelta(minutes=i),
|
||||
'open': base_price + volatility,
|
||||
'high': base_price + volatility + 50,
|
||||
'low': base_price + volatility - 50,
|
||||
'close': base_price + volatility + (i % 3 - 1) * 10,
|
||||
'volume': 1000 + i * 5
|
||||
})
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def test_full_chart_creation(self, sample_df):
|
||||
"""Test creating a full chart with multiple layers"""
|
||||
manager = LayerManager()
|
||||
|
||||
# Add base layers
|
||||
manager.add_layer(CandlestickLayer())
|
||||
manager.add_layer(VolumeLayer())
|
||||
|
||||
# Add indicator layers
|
||||
manager.add_layer(SMALayer(IndicatorLayerConfig(
|
||||
name="SMA(20)",
|
||||
indicator_type='sma',
|
||||
parameters={'period': 20}
|
||||
)))
|
||||
manager.add_layer(EMALayer(IndicatorLayerConfig(
|
||||
name="EMA(12)",
|
||||
indicator_type='ema',
|
||||
parameters={'period': 12}
|
||||
)))
|
||||
|
||||
# Add subplot layers
|
||||
manager.add_layer(RSILayer(period=14))
|
||||
manager.add_layer(MACDLayer(fast_period=12, slow_period=26, signal_period=9))
|
||||
|
||||
# Calculate layout
|
||||
layout_config = manager.calculate_subplot_layout()
|
||||
assert layout_config['rows'] >= 4 # Main + volume + RSI + MACD
|
||||
|
||||
# Render all layers
|
||||
fig = manager.render_all_layers(sample_df)
|
||||
assert fig is not None
|
||||
assert len(fig.data) >= 6 # Candlestick + volume + SMA + EMA + RSI + MACD components
|
||||
|
||||
def test_error_recovery_integration(self):
|
||||
"""Test error recovery with insufficient data"""
|
||||
manager = LayerManager()
|
||||
|
||||
# Create insufficient data
|
||||
base_time = datetime.now(timezone.utc)
|
||||
insufficient_data = pd.DataFrame([
|
||||
{
|
||||
'timestamp': base_time + timedelta(minutes=i),
|
||||
'open': 50000,
|
||||
'high': 50100,
|
||||
'low': 49900,
|
||||
'close': 50050,
|
||||
'volume': 1000
|
||||
}
|
||||
for i in range(15) # Only 15 data points
|
||||
])
|
||||
|
||||
# Add layers that require more data
|
||||
manager.add_layer(CandlestickLayer())
|
||||
manager.add_layer(SMALayer(IndicatorLayerConfig(
|
||||
name="SMA(50)", # Requires too much data
|
||||
indicator_type='sma',
|
||||
parameters={'period': 50}
|
||||
)))
|
||||
|
||||
# Should still create a chart (graceful degradation)
|
||||
fig = manager.render_all_layers(insufficient_data)
|
||||
assert fig is not None
|
||||
# Should have at least candlestick layer
|
||||
assert len(fig.data) >= 1
|
||||
|
||||
def test_mixed_valid_invalid_data(self):
|
||||
"""Test handling mixed valid and invalid data"""
|
||||
# Create data with some invalid entries
|
||||
base_time = datetime.now(timezone.utc)
|
||||
mixed_data = []
|
||||
|
||||
for i in range(50):
|
||||
if i % 10 == 0: # Every 10th entry is invalid
|
||||
data_point = {
|
||||
'timestamp': base_time + timedelta(minutes=i),
|
||||
'open': -100, # Invalid negative price
|
||||
'high': None, # Missing data
|
||||
'low': None,
|
||||
'close': None,
|
||||
'volume': -50 # Invalid negative volume
|
||||
}
|
||||
else:
|
||||
data_point = {
|
||||
'timestamp': base_time + timedelta(minutes=i),
|
||||
'open': 50000 + i * 10,
|
||||
'high': 50100 + i * 10,
|
||||
'low': 49900 + i * 10,
|
||||
'close': 50050 + i * 10,
|
||||
'volume': 1000 + i * 5
|
||||
}
|
||||
mixed_data.append(data_point)
|
||||
|
||||
df = pd.DataFrame(mixed_data)
|
||||
|
||||
# Test candlestick layer with mixed data
|
||||
candlestick_layer = CandlestickLayer()
|
||||
is_valid = candlestick_layer.validate_data(df)
|
||||
|
||||
# Should handle mixed data gracefully
|
||||
if not is_valid:
|
||||
# Should have warnings but possibly still proceed
|
||||
assert len(candlestick_layer.error_handler.warnings) > 0
|
||||
|
||||
def test_layer_manager_dynamic_layout(self):
|
||||
"""Test LayerManager dynamic layout calculation"""
|
||||
manager = LayerManager()
|
||||
|
||||
# Test with no subplots
|
||||
manager.add_layer(CandlestickLayer())
|
||||
layout = manager.calculate_subplot_layout()
|
||||
assert layout['rows'] == 1
|
||||
|
||||
# Add one subplot
|
||||
manager.add_layer(VolumeLayer())
|
||||
layout = manager.calculate_subplot_layout()
|
||||
assert layout['rows'] == 2
|
||||
|
||||
# Add more subplots
|
||||
manager.add_layer(RSILayer(period=14))
|
||||
manager.add_layer(MACDLayer(fast_period=12, slow_period=26, signal_period=9))
|
||||
layout = manager.calculate_subplot_layout()
|
||||
assert layout['rows'] == 4 # Main + volume + RSI + MACD
|
||||
assert layout['cols'] == 1
|
||||
assert len(layout['subplot_titles']) == 4
|
||||
assert len(layout['row_heights']) == 4
|
||||
|
||||
# Test row height calculation
|
||||
total_height = sum(layout['row_heights'])
|
||||
assert abs(total_height - 1.0) < 0.01 # Should sum to approximately 1.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Loading…
x
Reference in New Issue
Block a user