4.0 - 4.0 Implement real-time strategy execution and data integration features
- Added `realtime_execution.py` for real-time strategy execution, enabling live signal generation and integration with the dashboard's chart refresh cycle. - Introduced `data_integration.py` to manage market data orchestration, caching, and technical indicator calculations for strategy signal generation. - Implemented `validation.py` for comprehensive validation and quality assessment of strategy-generated signals, ensuring reliability and consistency. - Developed `batch_processing.py` to facilitate efficient backtesting of multiple strategies across large datasets with memory management and performance optimization. - Updated `__init__.py` files to include new modules and ensure proper exports, enhancing modularity and maintainability. - Enhanced unit tests for the new features, ensuring robust functionality and adherence to project standards. These changes establish a solid foundation for real-time strategy execution and data integration, aligning with project goals for modularity, performance, and maintainability.
This commit is contained in:
parent
f09864d61b
commit
8c23489ff0
@ -6,10 +6,12 @@ from .navigation import register_navigation_callbacks
|
||||
from .charts import register_chart_callbacks
|
||||
from .indicators import register_indicator_callbacks
|
||||
from .system_health import register_system_health_callbacks
|
||||
from .realtime_strategies import register_realtime_strategy_callbacks
|
||||
|
||||
__all__ = [
|
||||
'register_navigation_callbacks',
|
||||
'register_chart_callbacks',
|
||||
'register_indicator_callbacks',
|
||||
'register_system_health_callbacks'
|
||||
'register_system_health_callbacks',
|
||||
'register_realtime_strategy_callbacks'
|
||||
]
|
||||
291
dashboard/callbacks/realtime_strategies.py
Normal file
291
dashboard/callbacks/realtime_strategies.py
Normal file
@ -0,0 +1,291 @@
|
||||
"""
|
||||
Real-time Strategy Callbacks
|
||||
|
||||
This module provides callbacks for integrating real-time strategy execution
|
||||
with the dashboard chart refresh cycle and user interactions.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dash import Output, Input, State, Patch, ctx, html, no_update, dcc, callback
|
||||
import dash_bootstrap_components as dbc
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from utils.logger import get_logger
|
||||
from strategies.realtime_execution import (
|
||||
get_realtime_strategy_processor,
|
||||
initialize_realtime_strategy_system,
|
||||
RealTimeConfig,
|
||||
RealTimeSignal
|
||||
)
|
||||
from strategies.manager import StrategyManager
|
||||
from config.strategies.config_utils import StrategyConfigurationManager
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Global processor instance
|
||||
_processor = None
|
||||
|
||||
|
||||
def get_processor():
|
||||
"""Get or initialize the real-time strategy processor."""
|
||||
global _processor
|
||||
if _processor is None:
|
||||
config = RealTimeConfig(
|
||||
refresh_interval_seconds=30,
|
||||
max_strategies_concurrent=3,
|
||||
incremental_calculation=True,
|
||||
signal_batch_size=50,
|
||||
enable_signal_broadcasting=True
|
||||
)
|
||||
_processor = initialize_realtime_strategy_system(config)
|
||||
return _processor
|
||||
|
||||
|
||||
def register_realtime_strategy_callbacks(app):
|
||||
"""Register real-time strategy callbacks."""
|
||||
|
||||
@app.callback(
|
||||
Output('realtime-strategies-store', 'data'),
|
||||
[Input('realtime-strategy-toggle', 'value'),
|
||||
Input('symbol-dropdown', 'value'),
|
||||
Input('timeframe-dropdown', 'value'),
|
||||
Input('strategy-dropdown', 'value')],
|
||||
[State('realtime-strategies-store', 'data')],
|
||||
prevent_initial_call=True
|
||||
)
|
||||
def manage_realtime_strategies(enable_realtime, symbol, timeframe, strategy_name, current_data):
|
||||
"""
|
||||
Manage real-time strategy registration based on user selections.
|
||||
|
||||
This callback handles enabling/disabling real-time strategy execution
|
||||
and registers strategies based on current chart selections.
|
||||
"""
|
||||
try:
|
||||
current_data = current_data or {'active_strategies': [], 'enabled': False}
|
||||
processor = get_processor()
|
||||
|
||||
if not enable_realtime:
|
||||
# Disable all strategies
|
||||
for context_id in current_data.get('active_strategies', []):
|
||||
processor.unregister_strategy(context_id)
|
||||
logger.info(f"Unregistered real-time strategy: {context_id}")
|
||||
|
||||
return {'active_strategies': [], 'enabled': False}
|
||||
|
||||
# Enable real-time strategies
|
||||
if symbol and timeframe and strategy_name and strategy_name != 'basic':
|
||||
# Load strategy configuration
|
||||
try:
|
||||
config_manager = StrategyConfigurationManager()
|
||||
strategy_config = config_manager.load_user_strategy_config(strategy_name)
|
||||
|
||||
if not strategy_config:
|
||||
# Load from templates if user config doesn't exist
|
||||
strategy_config = config_manager.load_strategy_template(strategy_name)
|
||||
|
||||
if strategy_config:
|
||||
# Register strategy for real-time execution
|
||||
context_id = processor.register_strategy(
|
||||
strategy_name=strategy_name,
|
||||
strategy_config=strategy_config,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe
|
||||
)
|
||||
|
||||
active_strategies = [context_id]
|
||||
logger.info(f"Registered real-time strategy: {context_id}")
|
||||
|
||||
return {
|
||||
'active_strategies': active_strategies,
|
||||
'enabled': True,
|
||||
'current_symbol': symbol,
|
||||
'current_timeframe': timeframe,
|
||||
'current_strategy': strategy_name
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading strategy configuration for {strategy_name}: {e}")
|
||||
return current_data
|
||||
|
||||
return current_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error managing real-time strategies: {e}")
|
||||
return current_data or {'active_strategies': [], 'enabled': False}
|
||||
|
||||
@app.callback(
|
||||
Output('realtime-strategy-status', 'children'),
|
||||
[Input('realtime-strategies-store', 'data'),
|
||||
Input('interval-component', 'n_intervals')],
|
||||
prevent_initial_call=True
|
||||
)
|
||||
def update_realtime_status(strategy_data, n_intervals):
|
||||
"""
|
||||
Update real-time strategy status display.
|
||||
|
||||
Shows current status of real-time strategy execution including
|
||||
active strategies and performance metrics.
|
||||
"""
|
||||
try:
|
||||
if not strategy_data or not strategy_data.get('enabled'):
|
||||
return dbc.Alert("Real-time strategy execution is disabled", color="secondary", className="mb-2")
|
||||
|
||||
processor = get_processor()
|
||||
active_strategies = processor.get_active_strategies()
|
||||
perf_stats = processor.get_performance_stats()
|
||||
|
||||
if not active_strategies:
|
||||
return dbc.Alert("No active real-time strategies", color="warning", className="mb-2")
|
||||
|
||||
# Build status display
|
||||
status_items = []
|
||||
|
||||
# Active strategies
|
||||
for context_id, context in active_strategies.items():
|
||||
status_items.append(
|
||||
html.Li([
|
||||
html.Strong(f"{context.strategy_name}: "),
|
||||
f"{context.symbol} {context.timeframe}",
|
||||
html.Span(
|
||||
" ✓" if context.is_active else " ⚠️",
|
||||
style={'color': 'green' if context.is_active else 'orange'}
|
||||
)
|
||||
])
|
||||
)
|
||||
|
||||
# Performance metrics
|
||||
success_rate = 0
|
||||
if perf_stats['total_calculations'] > 0:
|
||||
success_rate = (perf_stats['successful_calculations'] / perf_stats['total_calculations']) * 100
|
||||
|
||||
metrics_text = f"Calculations: {perf_stats['total_calculations']} | " \
|
||||
f"Success Rate: {success_rate:.1f}% | " \
|
||||
f"Signals Generated: {perf_stats['signals_generated']}"
|
||||
|
||||
return dbc.Card([
|
||||
dbc.CardHeader("Real-time Strategy Status"),
|
||||
dbc.CardBody([
|
||||
html.H6("Active Strategies:", className="mb-2"),
|
||||
html.Ul(status_items, className="mb-3"),
|
||||
html.P(metrics_text, className="small mb-0")
|
||||
])
|
||||
], className="mb-2")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating real-time status: {e}")
|
||||
return dbc.Alert(f"Error updating status: {str(e)}", color="danger", className="mb-2")
|
||||
|
||||
# Integration with chart refresh cycle
|
||||
@app.callback(
|
||||
Output('realtime-execution-trigger', 'data'),
|
||||
[Input('interval-component', 'n_intervals')],
|
||||
[State('symbol-dropdown', 'value'),
|
||||
State('timeframe-dropdown', 'value'),
|
||||
State('realtime-strategies-store', 'data'),
|
||||
State('analysis-mode-toggle', 'value')],
|
||||
prevent_initial_call=True
|
||||
)
|
||||
def trigger_realtime_execution(n_intervals, symbol, timeframe, strategy_data, analysis_mode):
|
||||
"""
|
||||
Trigger real-time strategy execution when new data is available.
|
||||
|
||||
This callback integrates with the existing chart refresh cycle to
|
||||
execute real-time strategies when new candle data arrives.
|
||||
"""
|
||||
try:
|
||||
# Only execute in live mode
|
||||
if analysis_mode == 'locked':
|
||||
return no_update
|
||||
|
||||
# Only execute if real-time strategies are enabled
|
||||
if not strategy_data or not strategy_data.get('enabled'):
|
||||
return no_update
|
||||
|
||||
# Only execute if we have symbol and timeframe
|
||||
if not symbol or not timeframe:
|
||||
return no_update
|
||||
|
||||
processor = get_processor()
|
||||
|
||||
# Execute real-time strategy update
|
||||
signals = processor.execute_realtime_update(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
exchange="okx"
|
||||
)
|
||||
|
||||
if signals:
|
||||
logger.info(f"Real-time execution generated {len(signals)} signals for {symbol} {timeframe}")
|
||||
return {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'signals_generated': len(signals),
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe
|
||||
}
|
||||
|
||||
return no_update
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time strategy execution: {e}")
|
||||
return no_update
|
||||
|
||||
|
||||
def add_realtime_strategy_components():
|
||||
"""
|
||||
Add real-time strategy components to the dashboard layout.
|
||||
|
||||
Returns:
|
||||
List of Dash components for real-time strategy controls
|
||||
"""
|
||||
return [
|
||||
# Real-time strategy toggle
|
||||
dbc.Row([
|
||||
dbc.Col([
|
||||
dbc.Label("Real-time Strategy Execution", className="fw-bold"),
|
||||
dbc.Switch(
|
||||
id="realtime-strategy-toggle",
|
||||
label="Enable Real-time Execution",
|
||||
value=False,
|
||||
className="mb-2"
|
||||
),
|
||||
], width=12)
|
||||
], className="mb-3"),
|
||||
|
||||
# Status display
|
||||
html.Div(id="realtime-strategy-status"),
|
||||
|
||||
# Hidden stores for state management
|
||||
dcc.Store(id="realtime-strategies-store", data={'active_strategies': [], 'enabled': False}),
|
||||
dcc.Store(id="realtime-execution-trigger", data={}),
|
||||
]
|
||||
|
||||
|
||||
def setup_chart_update_callback():
|
||||
"""
|
||||
Setup chart update callback for real-time signals.
|
||||
|
||||
This function configures the real-time processor to trigger
|
||||
chart updates when new signals are generated.
|
||||
"""
|
||||
def chart_update_callback(signal: RealTimeSignal):
|
||||
"""Handle chart updates for real-time signals."""
|
||||
try:
|
||||
# This would trigger chart refresh for the specific symbol/timeframe
|
||||
# For now, we'll log the signal and let the regular refresh cycle handle it
|
||||
logger.debug(
|
||||
f"Chart update requested for signal: {signal.context.strategy_name} "
|
||||
f"on {signal.context.symbol} {signal.context.timeframe}"
|
||||
)
|
||||
|
||||
# Future enhancement: Could trigger specific chart layer updates here
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chart update callback: {e}")
|
||||
|
||||
processor = get_processor()
|
||||
processor.set_chart_update_callback(chart_update_callback)
|
||||
|
||||
|
||||
# Initialize the chart update callback when module is imported
|
||||
setup_chart_update_callback()
|
||||
@ -181,6 +181,43 @@ class StrategyRepository(BaseRepository):
|
||||
self.log_error(f"Error retrieving strategy signals: {e}")
|
||||
raise DatabaseOperationError(f"Failed to retrieve strategy signals: {e}")
|
||||
|
||||
def store_signals_batch(self, signal_data_list: List[Dict[str, Any]]) -> int:
|
||||
"""
|
||||
Store a batch of real-time strategy signals.
|
||||
|
||||
Args:
|
||||
signal_data_list: List of signal data dictionaries
|
||||
|
||||
Returns:
|
||||
Number of signals stored
|
||||
"""
|
||||
try:
|
||||
signals_stored = 0
|
||||
with self.get_session() as session:
|
||||
for signal_data in signal_data_list:
|
||||
strategy_signal = StrategySignal(
|
||||
run_id=None, # Real-time signals don't have a run_id
|
||||
strategy_name=signal_data.get('strategy_name'),
|
||||
strategy_config=signal_data.get('strategy_config'),
|
||||
symbol=signal_data.get('symbol'),
|
||||
timeframe=signal_data.get('timeframe'),
|
||||
timestamp=signal_data.get('timestamp'),
|
||||
signal_type=signal_data.get('signal_type', 'HOLD'),
|
||||
price=Decimal(str(signal_data.get('price'))) if signal_data.get('price') else None,
|
||||
confidence=Decimal(str(signal_data.get('confidence', 0.0))),
|
||||
signal_metadata=signal_data.get('signal_metadata', {})
|
||||
)
|
||||
session.add(strategy_signal)
|
||||
signals_stored += 1
|
||||
|
||||
session.commit()
|
||||
self.log_info(f"Stored batch of {signals_stored} real-time strategy signals")
|
||||
return signals_stored
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error storing signals batch: {e}")
|
||||
raise DatabaseOperationError(f"Failed to store signals batch: {e}")
|
||||
|
||||
def get_strategy_signal_stats(self, run_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get statistics about strategy signals."""
|
||||
try:
|
||||
|
||||
@ -16,6 +16,10 @@ from .base import BaseStrategy
|
||||
from .factory import StrategyFactory
|
||||
from .data_types import StrategySignal, SignalType, StrategyResult
|
||||
from .manager import StrategyManager, StrategyConfig, StrategyType, StrategyCategory, get_strategy_manager
|
||||
from .data_integration import StrategyDataIntegrator, StrategyDataIntegrationConfig, get_strategy_data_integrator
|
||||
from .validation import StrategySignalValidator, ValidationConfig
|
||||
from .batch_processing import BacktestingBatchProcessor, BatchProcessingConfig
|
||||
from .realtime_execution import RealTimeStrategyProcessor, RealTimeConfig, get_realtime_strategy_processor
|
||||
|
||||
__all__ = [
|
||||
'BaseStrategy',
|
||||
@ -27,5 +31,15 @@ __all__ = [
|
||||
'StrategyConfig',
|
||||
'StrategyType',
|
||||
'StrategyCategory',
|
||||
'get_strategy_manager'
|
||||
'get_strategy_manager',
|
||||
'StrategyDataIntegrator',
|
||||
'StrategyDataIntegrationConfig',
|
||||
'get_strategy_data_integrator',
|
||||
'StrategySignalValidator',
|
||||
'ValidationConfig',
|
||||
'BacktestingBatchProcessor',
|
||||
'BatchProcessingConfig',
|
||||
'RealTimeStrategyProcessor',
|
||||
'RealTimeConfig',
|
||||
'get_realtime_strategy_processor'
|
||||
]
|
||||
1059
strategies/batch_processing.py
Normal file
1059
strategies/batch_processing.py
Normal file
File diff suppressed because it is too large
Load Diff
1060
strategies/data_integration.py
Normal file
1060
strategies/data_integration.py
Normal file
File diff suppressed because it is too large
Load Diff
649
strategies/realtime_execution.py
Normal file
649
strategies/realtime_execution.py
Normal file
@ -0,0 +1,649 @@
|
||||
"""
|
||||
Real-time Strategy Execution Pipeline
|
||||
|
||||
This module provides real-time strategy execution capabilities that integrate
|
||||
with the existing chart data refresh cycle. It handles incremental strategy
|
||||
calculations, real-time signal generation, and live chart updates.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Any, Optional, Callable, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread, Event, Lock
|
||||
from queue import Queue, Empty
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
|
||||
from database.operations import get_database_operations, DatabaseOperationError
|
||||
from data.common.data_types import OHLCVCandle
|
||||
from components.charts.data_integration import MarketDataIntegrator
|
||||
from .data_integration import StrategyDataIntegrator, StrategyDataIntegrationConfig
|
||||
from .factory import StrategyFactory
|
||||
from .data_types import StrategyResult, StrategySignal
|
||||
from utils.logger import get_logger
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealTimeConfig:
|
||||
"""Configuration for real-time strategy execution"""
|
||||
refresh_interval_seconds: int = 30 # How often to check for new data
|
||||
max_strategies_concurrent: int = 5 # Maximum concurrent strategy calculations
|
||||
incremental_calculation: bool = True # Use incremental vs full recalculation
|
||||
signal_batch_size: int = 100 # Batch size for signal storage
|
||||
enable_signal_broadcasting: bool = True # Enable real-time signal broadcasting
|
||||
max_signal_queue_size: int = 1000 # Maximum signals in queue before dropping
|
||||
chart_update_throttle_ms: int = 1000 # Minimum time between chart updates
|
||||
error_retry_attempts: int = 3 # Number of retries on calculation errors
|
||||
error_retry_delay_seconds: int = 5 # Delay between retry attempts
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyExecutionContext:
|
||||
"""Context for strategy execution"""
|
||||
strategy_name: str
|
||||
strategy_config: Dict[str, Any]
|
||||
symbol: str
|
||||
timeframe: str
|
||||
exchange: str = "okx"
|
||||
last_calculation_time: Optional[datetime] = None
|
||||
consecutive_errors: int = 0
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealTimeSignal:
|
||||
"""Real-time signal with metadata"""
|
||||
strategy_result: StrategyResult
|
||||
context: StrategyExecutionContext
|
||||
generation_time: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
chart_update_required: bool = True
|
||||
|
||||
|
||||
class StrategySignalBroadcaster:
|
||||
"""
|
||||
Handles real-time signal broadcasting and distribution.
|
||||
|
||||
Manages signal queues, chart updates, and database storage
|
||||
for real-time strategy signals.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RealTimeConfig):
|
||||
"""Initialize signal broadcaster."""
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.db_ops = get_database_operations(self.logger)
|
||||
|
||||
# Signal queues
|
||||
self._signal_queue: Queue[RealTimeSignal] = Queue(maxsize=self.config.max_signal_queue_size)
|
||||
self._chart_update_queue: Queue[RealTimeSignal] = Queue()
|
||||
|
||||
# Chart update throttling
|
||||
self._last_chart_update = {} # symbol_timeframe -> timestamp
|
||||
self._chart_update_lock = Lock()
|
||||
|
||||
# Background processing
|
||||
self._processing_thread: Optional[Thread] = None
|
||||
self._stop_event = Event()
|
||||
self._is_running = False
|
||||
|
||||
# Callback for chart updates
|
||||
self._chart_update_callback: Optional[Callable] = None
|
||||
|
||||
if self.logger:
|
||||
self.logger.info("StrategySignalBroadcaster: Initialized")
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the signal broadcasting service."""
|
||||
if self._is_running:
|
||||
return
|
||||
|
||||
self._is_running = True
|
||||
self._stop_event.clear()
|
||||
|
||||
# Start background processing thread
|
||||
self._processing_thread = Thread(
|
||||
target=self._process_signals_loop,
|
||||
name="StrategySignalProcessor",
|
||||
daemon=True
|
||||
)
|
||||
self._processing_thread.start()
|
||||
|
||||
if self.logger:
|
||||
self.logger.info("StrategySignalBroadcaster: Started signal processing")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the signal broadcasting service."""
|
||||
if not self._is_running:
|
||||
return
|
||||
|
||||
self._is_running = False
|
||||
self._stop_event.set()
|
||||
|
||||
if self._processing_thread and self._processing_thread.is_alive():
|
||||
self._processing_thread.join(timeout=5.0)
|
||||
|
||||
if self.logger:
|
||||
self.logger.info("StrategySignalBroadcaster: Stopped signal processing")
|
||||
|
||||
def broadcast_signal(self, signal: RealTimeSignal) -> bool:
|
||||
"""
|
||||
Broadcast a real-time signal.
|
||||
|
||||
Args:
|
||||
signal: Real-time signal to broadcast
|
||||
|
||||
Returns:
|
||||
True if signal was queued successfully, False if queue is full
|
||||
"""
|
||||
try:
|
||||
self._signal_queue.put_nowait(signal)
|
||||
return True
|
||||
except:
|
||||
# Queue is full, drop the signal
|
||||
if self.logger:
|
||||
self.logger.warning(f"Signal queue full, dropping signal for {signal.context.symbol}")
|
||||
return False
|
||||
|
||||
def set_chart_update_callback(self, callback: Callable[[RealTimeSignal], None]) -> None:
|
||||
"""Set callback for chart updates."""
|
||||
self._chart_update_callback = callback
|
||||
|
||||
def _process_signals_loop(self) -> None:
|
||||
"""Main signal processing loop."""
|
||||
batch_signals = []
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
# Collect signals in batches
|
||||
try:
|
||||
signal = self._signal_queue.get(timeout=1.0)
|
||||
batch_signals.append(signal)
|
||||
|
||||
# Collect more signals if available (up to batch size)
|
||||
while len(batch_signals) < self.config.signal_batch_size:
|
||||
try:
|
||||
signal = self._signal_queue.get_nowait()
|
||||
batch_signals.append(signal)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# Process the batch
|
||||
if batch_signals:
|
||||
self._process_signal_batch(batch_signals)
|
||||
batch_signals.clear()
|
||||
|
||||
except Empty:
|
||||
# No signals to process, continue
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error in signal processing loop: {e}")
|
||||
time.sleep(1.0) # Brief pause on error
|
||||
|
||||
def _process_signal_batch(self, signals: List[RealTimeSignal]) -> None:
|
||||
"""Process a batch of signals."""
|
||||
try:
|
||||
# Store signals in database
|
||||
self._store_signals_batch(signals)
|
||||
|
||||
# Process chart updates
|
||||
self._process_chart_updates(signals)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error processing signal batch: {e}")
|
||||
|
||||
def _store_signals_batch(self, signals: List[RealTimeSignal]) -> None:
|
||||
"""Store signals in database."""
|
||||
try:
|
||||
signal_data = []
|
||||
for signal in signals:
|
||||
result = signal.strategy_result
|
||||
context = signal.context
|
||||
|
||||
signal_data.append({
|
||||
'strategy_name': context.strategy_name,
|
||||
'strategy_config': context.strategy_config,
|
||||
'symbol': context.symbol,
|
||||
'timeframe': context.timeframe,
|
||||
'exchange': context.exchange,
|
||||
'timestamp': result.timestamp,
|
||||
'signal_type': result.signal.signal_type.value if result.signal else 'HOLD',
|
||||
'price': float(result.price) if result.price else None,
|
||||
'confidence': result.confidence,
|
||||
'signal_metadata': result.metadata or {},
|
||||
'generation_time': signal.generation_time
|
||||
})
|
||||
|
||||
# Batch insert into database
|
||||
self.db_ops.strategy.store_signals_batch(signal_data)
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"Stored batch of {len(signals)} real-time signals")
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error storing signal batch: {e}")
|
||||
|
||||
def _process_chart_updates(self, signals: List[RealTimeSignal]) -> None:
|
||||
"""Process chart updates for signals."""
|
||||
if not self._chart_update_callback:
|
||||
return
|
||||
|
||||
# Group signals by symbol/timeframe for throttling
|
||||
signal_groups = {}
|
||||
for signal in signals:
|
||||
if not signal.chart_update_required:
|
||||
continue
|
||||
|
||||
key = f"{signal.context.symbol}_{signal.context.timeframe}"
|
||||
if key not in signal_groups:
|
||||
signal_groups[key] = []
|
||||
signal_groups[key].append(signal)
|
||||
|
||||
# Process chart updates with throttling
|
||||
current_time = time.time() * 1000 # milliseconds
|
||||
|
||||
with self._chart_update_lock:
|
||||
for key, group_signals in signal_groups.items():
|
||||
last_update = self._last_chart_update.get(key, 0)
|
||||
|
||||
if current_time - last_update >= self.config.chart_update_throttle_ms:
|
||||
# Update chart with latest signal from group
|
||||
latest_signal = max(group_signals, key=lambda s: s.generation_time)
|
||||
|
||||
try:
|
||||
self._chart_update_callback(latest_signal)
|
||||
self._last_chart_update[key] = current_time
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error in chart update callback: {e}")
|
||||
|
||||
def get_signal_stats(self) -> Dict[str, Any]:
|
||||
"""Get signal broadcasting statistics."""
|
||||
return {
|
||||
'queue_size': self._signal_queue.qsize(),
|
||||
'chart_queue_size': self._chart_update_queue.qsize(),
|
||||
'is_running': self._is_running,
|
||||
'last_chart_updates': dict(self._last_chart_update)
|
||||
}
|
||||
|
||||
|
||||
class RealTimeStrategyProcessor:
|
||||
"""
|
||||
Real-time strategy execution processor.
|
||||
|
||||
Integrates with existing chart data refresh cycle to provide
|
||||
real-time strategy signal generation and broadcasting.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RealTimeConfig = None):
|
||||
"""Initialize real-time strategy processor."""
|
||||
self.config = config or RealTimeConfig()
|
||||
self.logger = logger
|
||||
|
||||
# Core components
|
||||
self.data_integrator = StrategyDataIntegrator(
|
||||
StrategyDataIntegrationConfig(
|
||||
cache_timeout_minutes=1, # Shorter cache for real-time
|
||||
enable_indicator_caching=True
|
||||
)
|
||||
)
|
||||
self.market_integrator = MarketDataIntegrator()
|
||||
self.strategy_factory = StrategyFactory(self.logger)
|
||||
self.signal_broadcaster = StrategySignalBroadcaster(self.config)
|
||||
|
||||
# Strategy execution contexts
|
||||
self._execution_contexts: Dict[str, StrategyExecutionContext] = {}
|
||||
self._context_lock = Lock()
|
||||
|
||||
# Performance tracking
|
||||
self._performance_stats = {
|
||||
'total_calculations': 0,
|
||||
'successful_calculations': 0,
|
||||
'failed_calculations': 0,
|
||||
'average_calculation_time_ms': 0.0,
|
||||
'signals_generated': 0,
|
||||
'last_update_time': None
|
||||
}
|
||||
|
||||
# Thread pool for concurrent strategy execution
|
||||
self._executor = ThreadPoolExecutor(max_workers=self.config.max_strategies_concurrent)
|
||||
|
||||
if self.logger:
|
||||
self.logger.info("RealTimeStrategyProcessor: Initialized")
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the real-time strategy processor."""
|
||||
self.signal_broadcaster.start()
|
||||
if self.logger:
|
||||
self.logger.info("RealTimeStrategyProcessor: Started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the real-time strategy processor."""
|
||||
self.signal_broadcaster.stop()
|
||||
self._executor.shutdown(wait=True)
|
||||
if self.logger:
|
||||
self.logger.info("RealTimeStrategyProcessor: Stopped")
|
||||
|
||||
def register_strategy(
|
||||
self,
|
||||
strategy_name: str,
|
||||
strategy_config: Dict[str, Any],
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
exchange: str = "okx"
|
||||
) -> str:
|
||||
"""
|
||||
Register a strategy for real-time execution.
|
||||
|
||||
Args:
|
||||
strategy_name: Name of the strategy
|
||||
strategy_config: Strategy configuration
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Context ID for the registered strategy
|
||||
"""
|
||||
context_id = f"{strategy_name}_{symbol}_{timeframe}_{exchange}"
|
||||
|
||||
context = StrategyExecutionContext(
|
||||
strategy_name=strategy_name,
|
||||
strategy_config=strategy_config,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
exchange=exchange
|
||||
)
|
||||
|
||||
with self._context_lock:
|
||||
self._execution_contexts[context_id] = context
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"Registered strategy for real-time execution: {context_id}")
|
||||
|
||||
return context_id
|
||||
|
||||
def unregister_strategy(self, context_id: str) -> bool:
|
||||
"""
|
||||
Unregister a strategy from real-time execution.
|
||||
|
||||
Args:
|
||||
context_id: Context ID to unregister
|
||||
|
||||
Returns:
|
||||
True if strategy was unregistered, False if not found
|
||||
"""
|
||||
with self._context_lock:
|
||||
if context_id in self._execution_contexts:
|
||||
del self._execution_contexts[context_id]
|
||||
if self.logger:
|
||||
self.logger.info(f"Unregistered strategy: {context_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def execute_realtime_update(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
exchange: str = "okx"
|
||||
) -> List[RealTimeSignal]:
|
||||
"""
|
||||
Execute real-time strategy update for new market data.
|
||||
|
||||
This method should be called when new candle data is available,
|
||||
typically triggered by the chart refresh cycle.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol that was updated
|
||||
timeframe: Timeframe that was updated
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
List of generated real-time signals
|
||||
"""
|
||||
start_time = time.time()
|
||||
generated_signals = []
|
||||
|
||||
try:
|
||||
# Find all strategies for this symbol/timeframe
|
||||
matching_contexts = []
|
||||
with self._context_lock:
|
||||
for context_id, context in self._execution_contexts.items():
|
||||
if (context.symbol == symbol and
|
||||
context.timeframe == timeframe and
|
||||
context.exchange == exchange and
|
||||
context.is_active):
|
||||
matching_contexts.append((context_id, context))
|
||||
|
||||
if not matching_contexts:
|
||||
return generated_signals
|
||||
|
||||
# Execute strategies concurrently
|
||||
futures = []
|
||||
for context_id, context in matching_contexts:
|
||||
future = self._executor.submit(
|
||||
self._execute_strategy_context,
|
||||
context_id,
|
||||
context
|
||||
)
|
||||
futures.append((context_id, future))
|
||||
|
||||
# Collect results
|
||||
for context_id, future in futures:
|
||||
try:
|
||||
signals = future.result(timeout=10.0) # 10 second timeout
|
||||
generated_signals.extend(signals)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error executing strategy {context_id}: {e}")
|
||||
self._handle_strategy_error(context_id, e)
|
||||
|
||||
# Update performance stats
|
||||
calculation_time = (time.time() - start_time) * 1000
|
||||
self._update_performance_stats(len(generated_signals), calculation_time, True)
|
||||
|
||||
if self.logger and generated_signals:
|
||||
self.logger.debug(f"Generated {len(generated_signals)} real-time signals for {symbol} {timeframe}")
|
||||
|
||||
return generated_signals
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error in real-time strategy execution: {e}")
|
||||
calculation_time = (time.time() - start_time) * 1000
|
||||
self._update_performance_stats(0, calculation_time, False)
|
||||
return generated_signals
|
||||
|
||||
def _execute_strategy_context(
|
||||
self,
|
||||
context_id: str,
|
||||
context: StrategyExecutionContext
|
||||
) -> List[RealTimeSignal]:
|
||||
"""Execute a single strategy context."""
|
||||
try:
|
||||
# Calculate strategy signals
|
||||
if self.config.incremental_calculation and context.last_calculation_time:
|
||||
# Use incremental calculation for better performance
|
||||
results = self._calculate_incremental_signals(context)
|
||||
else:
|
||||
# Full recalculation
|
||||
results = self._calculate_full_signals(context)
|
||||
|
||||
# Convert to real-time signals
|
||||
real_time_signals = []
|
||||
for result in results:
|
||||
signal = RealTimeSignal(
|
||||
strategy_result=result,
|
||||
context=context
|
||||
)
|
||||
real_time_signals.append(signal)
|
||||
|
||||
# Broadcast signal
|
||||
self.signal_broadcaster.broadcast_signal(signal)
|
||||
|
||||
# Update context
|
||||
with self._context_lock:
|
||||
context.last_calculation_time = datetime.now(timezone.utc)
|
||||
context.consecutive_errors = 0
|
||||
|
||||
return real_time_signals
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Error executing strategy context {context_id}: {e}")
|
||||
self._handle_strategy_error(context_id, e)
|
||||
return []
|
||||
|
||||
def _calculate_incremental_signals(
|
||||
self,
|
||||
context: StrategyExecutionContext
|
||||
) -> List[StrategyResult]:
|
||||
"""Calculate signals incrementally (only for new data)."""
|
||||
# For this initial implementation, fall back to full calculation
|
||||
# Incremental calculation optimization can be added later
|
||||
return self._calculate_full_signals(context)
|
||||
|
||||
def _calculate_full_signals(
|
||||
self,
|
||||
context: StrategyExecutionContext
|
||||
) -> List[StrategyResult]:
|
||||
"""Calculate signals with full recalculation."""
|
||||
return self.data_integrator.calculate_strategy_signals(
|
||||
strategy_name=context.strategy_name,
|
||||
strategy_config=context.strategy_config,
|
||||
symbol=context.symbol,
|
||||
timeframe=context.timeframe,
|
||||
days_back=7, # Use shorter history for real-time
|
||||
exchange=context.exchange,
|
||||
enable_caching=True
|
||||
)
|
||||
|
||||
def _handle_strategy_error(self, context_id: str, error: Exception) -> None:
|
||||
"""Handle strategy execution error."""
|
||||
with self._context_lock:
|
||||
if context_id in self._execution_contexts:
|
||||
context = self._execution_contexts[context_id]
|
||||
context.consecutive_errors += 1
|
||||
|
||||
# Disable strategy if too many consecutive errors
|
||||
if context.consecutive_errors >= self.config.error_retry_attempts:
|
||||
context.is_active = False
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"Disabling strategy {context_id} due to consecutive errors: {context.consecutive_errors}"
|
||||
)
|
||||
|
||||
def _update_performance_stats(
|
||||
self,
|
||||
signals_generated: int,
|
||||
calculation_time_ms: float,
|
||||
success: bool
|
||||
) -> None:
|
||||
"""Update performance statistics."""
|
||||
self._performance_stats['total_calculations'] += 1
|
||||
if success:
|
||||
self._performance_stats['successful_calculations'] += 1
|
||||
else:
|
||||
self._performance_stats['failed_calculations'] += 1
|
||||
|
||||
self._performance_stats['signals_generated'] += signals_generated
|
||||
|
||||
# Update average calculation time
|
||||
total_calcs = self._performance_stats['total_calculations']
|
||||
current_avg = self._performance_stats['average_calculation_time_ms']
|
||||
self._performance_stats['average_calculation_time_ms'] = (
|
||||
(current_avg * (total_calcs - 1) + calculation_time_ms) / total_calcs
|
||||
)
|
||||
|
||||
self._performance_stats['last_update_time'] = datetime.now(timezone.utc)
|
||||
|
||||
def set_chart_update_callback(self, callback: Callable[[RealTimeSignal], None]) -> None:
|
||||
"""Set callback for chart updates."""
|
||||
self.signal_broadcaster.set_chart_update_callback(callback)
|
||||
|
||||
def get_active_strategies(self) -> Dict[str, StrategyExecutionContext]:
|
||||
"""Get all active strategy contexts."""
|
||||
with self._context_lock:
|
||||
return {
|
||||
context_id: context
|
||||
for context_id, context in self._execution_contexts.items()
|
||||
if context.is_active
|
||||
}
|
||||
|
||||
def get_performance_stats(self) -> Dict[str, Any]:
|
||||
"""Get real-time execution performance statistics."""
|
||||
stats = dict(self._performance_stats)
|
||||
stats.update(self.signal_broadcaster.get_signal_stats())
|
||||
return stats
|
||||
|
||||
def pause_strategy(self, context_id: str) -> bool:
|
||||
"""Pause a strategy (set as inactive)."""
|
||||
with self._context_lock:
|
||||
if context_id in self._execution_contexts:
|
||||
self._execution_contexts[context_id].is_active = False
|
||||
return True
|
||||
return False
|
||||
|
||||
def resume_strategy(self, context_id: str) -> bool:
|
||||
"""Resume a strategy (set as active)."""
|
||||
with self._context_lock:
|
||||
if context_id in self._execution_contexts:
|
||||
context = self._execution_contexts[context_id]
|
||||
context.is_active = True
|
||||
context.consecutive_errors = 0 # Reset error count
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance for global access
|
||||
_realtime_processor: Optional[RealTimeStrategyProcessor] = None
|
||||
|
||||
|
||||
def get_realtime_strategy_processor(config: RealTimeConfig = None) -> RealTimeStrategyProcessor:
|
||||
"""
|
||||
Get the singleton real-time strategy processor instance.
|
||||
|
||||
Args:
|
||||
config: Configuration for the processor (only used on first call)
|
||||
|
||||
Returns:
|
||||
RealTimeStrategyProcessor instance
|
||||
"""
|
||||
global _realtime_processor
|
||||
|
||||
if _realtime_processor is None:
|
||||
_realtime_processor = RealTimeStrategyProcessor(config)
|
||||
|
||||
return _realtime_processor
|
||||
|
||||
|
||||
def initialize_realtime_strategy_system(config: RealTimeConfig = None) -> RealTimeStrategyProcessor:
|
||||
"""
|
||||
Initialize the real-time strategy system.
|
||||
|
||||
Args:
|
||||
config: Configuration for the system
|
||||
|
||||
Returns:
|
||||
Initialized RealTimeStrategyProcessor
|
||||
"""
|
||||
processor = get_realtime_strategy_processor(config)
|
||||
processor.start()
|
||||
return processor
|
||||
|
||||
|
||||
def shutdown_realtime_strategy_system() -> None:
|
||||
"""Shutdown the real-time strategy system."""
|
||||
global _realtime_processor
|
||||
|
||||
if _realtime_processor is not None:
|
||||
_realtime_processor.stop()
|
||||
_realtime_processor = None
|
||||
375
strategies/validation.py
Normal file
375
strategies/validation.py
Normal file
@ -0,0 +1,375 @@
|
||||
"""
|
||||
Strategy Signal Validation Pipeline
|
||||
|
||||
This module provides validation, filtering, and quality assessment
|
||||
for strategy-generated signals to ensure reliability and consistency.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .data_types import StrategySignal, SignalType, StrategyResult
|
||||
from utils.logger import get_logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationConfig:
|
||||
"""Configuration for signal validation."""
|
||||
min_confidence: float = 0.0
|
||||
max_confidence: float = 1.0
|
||||
required_metadata_fields: List[str] = None
|
||||
allowed_signal_types: List[SignalType] = None
|
||||
price_tolerance_percent: float = 5.0 # Max price deviation from market
|
||||
|
||||
def __post_init__(self):
|
||||
if self.required_metadata_fields is None:
|
||||
self.required_metadata_fields = []
|
||||
if self.allowed_signal_types is None:
|
||||
self.allowed_signal_types = list(SignalType)
|
||||
|
||||
|
||||
class StrategySignalValidator:
|
||||
"""
|
||||
Validates strategy signals for quality, consistency, and compliance.
|
||||
|
||||
Provides comprehensive validation including confidence checks,
|
||||
signal type validation, price reasonableness, and metadata validation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ValidationConfig = None):
|
||||
"""
|
||||
Initialize signal validator.
|
||||
|
||||
Args:
|
||||
config: Validation configuration
|
||||
"""
|
||||
self.config = config or ValidationConfig()
|
||||
self.logger = get_logger()
|
||||
|
||||
# Validation statistics
|
||||
self._validation_stats = {
|
||||
'total_signals_validated': 0,
|
||||
'valid_signals': 0,
|
||||
'invalid_signals': 0,
|
||||
'validation_errors': {}
|
||||
}
|
||||
|
||||
def validate_signal(self, signal: StrategySignal) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
Validate a single strategy signal.
|
||||
|
||||
Args:
|
||||
signal: Signal to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, list_of_errors)
|
||||
"""
|
||||
errors = []
|
||||
self._validation_stats['total_signals_validated'] += 1
|
||||
|
||||
# Validate confidence
|
||||
if not (self.config.min_confidence <= signal.confidence <= self.config.max_confidence):
|
||||
errors.append(f"Invalid confidence {signal.confidence}, must be between {self.config.min_confidence} and {self.config.max_confidence}")
|
||||
|
||||
# Validate signal type
|
||||
if signal.signal_type not in self.config.allowed_signal_types:
|
||||
errors.append(f"Signal type {signal.signal_type} not in allowed types")
|
||||
|
||||
# Validate price
|
||||
if signal.price <= 0:
|
||||
errors.append(f"Invalid price {signal.price}, must be positive")
|
||||
|
||||
# Validate required metadata
|
||||
if self.config.required_metadata_fields:
|
||||
if not signal.metadata:
|
||||
errors.append(f"Missing required metadata fields: {self.config.required_metadata_fields}")
|
||||
else:
|
||||
missing_fields = [field for field in self.config.required_metadata_fields
|
||||
if field not in signal.metadata]
|
||||
if missing_fields:
|
||||
errors.append(f"Missing required metadata fields: {missing_fields}")
|
||||
|
||||
# Update statistics
|
||||
is_valid = len(errors) == 0
|
||||
if is_valid:
|
||||
self._validation_stats['valid_signals'] += 1
|
||||
else:
|
||||
self._validation_stats['invalid_signals'] += 1
|
||||
for error in errors:
|
||||
error_type = error.split(':')[0] if ':' in error else error
|
||||
self._validation_stats['validation_errors'][error_type] = \
|
||||
self._validation_stats['validation_errors'].get(error_type, 0) + 1
|
||||
|
||||
return is_valid, errors
|
||||
|
||||
def validate_signals_batch(self, signals: List[StrategySignal]) -> Tuple[List[StrategySignal], List[StrategySignal]]:
|
||||
"""
|
||||
Validate multiple signals and return valid and invalid lists.
|
||||
|
||||
Args:
|
||||
signals: List of signals to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (valid_signals, invalid_signals)
|
||||
"""
|
||||
valid_signals = []
|
||||
invalid_signals = []
|
||||
|
||||
for signal in signals:
|
||||
is_valid, errors = self.validate_signal(signal)
|
||||
if is_valid:
|
||||
valid_signals.append(signal)
|
||||
else:
|
||||
invalid_signals.append(signal)
|
||||
self.logger.debug(f"Invalid signal filtered out: {errors}")
|
||||
|
||||
return valid_signals, invalid_signals
|
||||
|
||||
def filter_signals_by_confidence(
|
||||
self,
|
||||
signals: List[StrategySignal],
|
||||
min_confidence: float = None
|
||||
) -> List[StrategySignal]:
|
||||
"""
|
||||
Filter signals by minimum confidence threshold.
|
||||
|
||||
Args:
|
||||
signals: List of signals to filter
|
||||
min_confidence: Minimum confidence threshold (uses config if None)
|
||||
|
||||
Returns:
|
||||
Filtered list of signals
|
||||
"""
|
||||
threshold = min_confidence if min_confidence is not None else self.config.min_confidence
|
||||
|
||||
filtered_signals = [signal for signal in signals if signal.confidence >= threshold]
|
||||
|
||||
self.logger.debug(f"Filtered {len(signals) - len(filtered_signals)} signals below confidence {threshold}")
|
||||
|
||||
return filtered_signals
|
||||
|
||||
def filter_signals_by_type(
|
||||
self,
|
||||
signals: List[StrategySignal],
|
||||
allowed_types: List[SignalType] = None
|
||||
) -> List[StrategySignal]:
|
||||
"""
|
||||
Filter signals by allowed signal types.
|
||||
|
||||
Args:
|
||||
signals: List of signals to filter
|
||||
allowed_types: Allowed signal types (uses config if None)
|
||||
|
||||
Returns:
|
||||
Filtered list of signals
|
||||
"""
|
||||
types = allowed_types if allowed_types is not None else self.config.allowed_signal_types
|
||||
|
||||
filtered_signals = [signal for signal in signals if signal.signal_type in types]
|
||||
|
||||
self.logger.debug(f"Filtered {len(signals) - len(filtered_signals)} signals by type")
|
||||
|
||||
return filtered_signals
|
||||
|
||||
def get_validation_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive validation statistics."""
|
||||
stats = self._validation_stats.copy()
|
||||
|
||||
if stats['total_signals_validated'] > 0:
|
||||
stats['validation_success_rate'] = stats['valid_signals'] / stats['total_signals_validated']
|
||||
stats['validation_failure_rate'] = stats['invalid_signals'] / stats['total_signals_validated']
|
||||
else:
|
||||
stats['validation_success_rate'] = 0.0
|
||||
stats['validation_failure_rate'] = 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def transform_signal_confidence(
|
||||
self,
|
||||
signal: StrategySignal,
|
||||
confidence_multiplier: float = 1.0,
|
||||
max_confidence: float = None
|
||||
) -> StrategySignal:
|
||||
"""
|
||||
Transform signal confidence with multiplier and cap.
|
||||
|
||||
Args:
|
||||
signal: Signal to transform
|
||||
confidence_multiplier: Multiplier for confidence
|
||||
max_confidence: Maximum confidence cap (uses config if None)
|
||||
|
||||
Returns:
|
||||
Transformed signal with updated confidence
|
||||
"""
|
||||
max_conf = max_confidence if max_confidence is not None else self.config.max_confidence
|
||||
|
||||
# Create new signal with transformed confidence
|
||||
new_confidence = min(signal.confidence * confidence_multiplier, max_conf)
|
||||
|
||||
transformed_signal = StrategySignal(
|
||||
timestamp=signal.timestamp,
|
||||
symbol=signal.symbol,
|
||||
timeframe=signal.timeframe,
|
||||
signal_type=signal.signal_type,
|
||||
price=signal.price,
|
||||
confidence=new_confidence,
|
||||
metadata=signal.metadata.copy() if signal.metadata else None
|
||||
)
|
||||
|
||||
return transformed_signal
|
||||
|
||||
def enrich_signal_metadata(
|
||||
self,
|
||||
signal: StrategySignal,
|
||||
additional_metadata: Dict[str, Any]
|
||||
) -> StrategySignal:
|
||||
"""
|
||||
Enrich signal with additional metadata.
|
||||
|
||||
Args:
|
||||
signal: Signal to enrich
|
||||
additional_metadata: Additional metadata to add
|
||||
|
||||
Returns:
|
||||
Signal with enriched metadata
|
||||
"""
|
||||
# Merge metadata
|
||||
enriched_metadata = signal.metadata.copy() if signal.metadata else {}
|
||||
enriched_metadata.update(additional_metadata)
|
||||
|
||||
enriched_signal = StrategySignal(
|
||||
timestamp=signal.timestamp,
|
||||
symbol=signal.symbol,
|
||||
timeframe=signal.timeframe,
|
||||
signal_type=signal.signal_type,
|
||||
price=signal.price,
|
||||
confidence=signal.confidence,
|
||||
metadata=enriched_metadata
|
||||
)
|
||||
|
||||
return enriched_signal
|
||||
|
||||
def transform_signals_batch(
|
||||
self,
|
||||
signals: List[StrategySignal],
|
||||
confidence_multiplier: float = 1.0,
|
||||
additional_metadata: Dict[str, Any] = None
|
||||
) -> List[StrategySignal]:
|
||||
"""
|
||||
Apply transformations to multiple signals.
|
||||
|
||||
Args:
|
||||
signals: List of signals to transform
|
||||
confidence_multiplier: Confidence multiplier
|
||||
additional_metadata: Additional metadata to add
|
||||
|
||||
Returns:
|
||||
List of transformed signals
|
||||
"""
|
||||
transformed_signals = []
|
||||
|
||||
for signal in signals:
|
||||
# Apply confidence transformation
|
||||
transformed_signal = self.transform_signal_confidence(signal, confidence_multiplier)
|
||||
|
||||
# Apply metadata enrichment if provided
|
||||
if additional_metadata:
|
||||
transformed_signal = self.enrich_signal_metadata(transformed_signal, additional_metadata)
|
||||
|
||||
transformed_signals.append(transformed_signal)
|
||||
|
||||
self.logger.debug(f"Transformed {len(signals)} signals")
|
||||
|
||||
return transformed_signals
|
||||
|
||||
def calculate_signal_quality_metrics(self, signals: List[StrategySignal]) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate comprehensive quality metrics for signals.
|
||||
|
||||
Args:
|
||||
signals: List of signals to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary containing quality metrics
|
||||
"""
|
||||
if not signals:
|
||||
return {'error': 'No signals provided for quality analysis'}
|
||||
|
||||
# Basic metrics
|
||||
total_signals = len(signals)
|
||||
confidence_values = [signal.confidence for signal in signals]
|
||||
|
||||
# Signal type distribution
|
||||
signal_type_counts = {}
|
||||
for signal in signals:
|
||||
signal_type_counts[signal.signal_type.value] = signal_type_counts.get(signal.signal_type.value, 0) + 1
|
||||
|
||||
# Confidence metrics
|
||||
avg_confidence = sum(confidence_values) / total_signals
|
||||
min_confidence = min(confidence_values)
|
||||
max_confidence = max(confidence_values)
|
||||
|
||||
# Quality scoring (0-100)
|
||||
high_confidence_signals = sum(1 for conf in confidence_values if conf >= 0.7)
|
||||
quality_score = (high_confidence_signals / total_signals) * 100
|
||||
|
||||
# Metadata completeness
|
||||
signals_with_metadata = sum(1 for signal in signals if signal.metadata)
|
||||
metadata_completeness = (signals_with_metadata / total_signals) * 100
|
||||
|
||||
return {
|
||||
'total_signals': total_signals,
|
||||
'signal_type_distribution': signal_type_counts,
|
||||
'confidence_metrics': {
|
||||
'average': round(avg_confidence, 3),
|
||||
'minimum': round(min_confidence, 3),
|
||||
'maximum': round(max_confidence, 3),
|
||||
'high_confidence_count': high_confidence_signals,
|
||||
'high_confidence_percentage': round((high_confidence_signals / total_signals) * 100, 1)
|
||||
},
|
||||
'quality_score': round(quality_score, 1),
|
||||
'metadata_completeness_percentage': round(metadata_completeness, 1),
|
||||
'recommendations': self._generate_quality_recommendations(signals)
|
||||
}
|
||||
|
||||
def _generate_quality_recommendations(self, signals: List[StrategySignal]) -> List[str]:
|
||||
"""Generate quality improvement recommendations."""
|
||||
recommendations = []
|
||||
|
||||
confidence_values = [signal.confidence for signal in signals]
|
||||
avg_confidence = sum(confidence_values) / len(confidence_values)
|
||||
|
||||
if avg_confidence < 0.5:
|
||||
recommendations.append("Consider increasing confidence thresholds or improving signal generation logic")
|
||||
|
||||
signals_with_metadata = sum(1 for signal in signals if signal.metadata)
|
||||
if signals_with_metadata / len(signals) < 0.8:
|
||||
recommendations.append("Enhance metadata collection to improve signal traceability")
|
||||
|
||||
signal_types = set(signal.signal_type for signal in signals)
|
||||
if len(signal_types) == 1:
|
||||
recommendations.append("Consider diversifying signal types for better strategy coverage")
|
||||
|
||||
return recommendations if recommendations else ["Signal quality appears good - no specific recommendations"]
|
||||
|
||||
def generate_validation_report(self) -> Dict[str, Any]:
|
||||
"""Generate comprehensive validation report."""
|
||||
stats = self.get_validation_statistics()
|
||||
|
||||
return {
|
||||
'report_timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'validation_summary': {
|
||||
'total_validated': stats['total_signals_validated'],
|
||||
'success_rate': f"{stats.get('validation_success_rate', 0) * 100:.1f}%",
|
||||
'failure_rate': f"{stats.get('validation_failure_rate', 0) * 100:.1f}%"
|
||||
},
|
||||
'error_analysis': stats.get('validation_errors', {}),
|
||||
'configuration': {
|
||||
'min_confidence': self.config.min_confidence,
|
||||
'max_confidence': self.config.max_confidence,
|
||||
'allowed_signal_types': [st.value for st in self.config.allowed_signal_types],
|
||||
'required_metadata_fields': self.config.required_metadata_fields
|
||||
},
|
||||
'health_status': 'good' if stats.get('validation_success_rate', 0) >= 0.8 else 'needs_attention'
|
||||
}
|
||||
@ -22,12 +22,21 @@
|
||||
- `database/migrations/versions/add_strategy_signals_table.py` - Alembic migration for strategy signals table
|
||||
- `components/charts/layers/strategy_signals.py` - Strategy signal chart layer for visualization
|
||||
- `components/charts/data_integration.py` - Updated to include strategy data integration
|
||||
- `strategies/data_integration.py` - Strategy data integration with indicator orchestration and caching
|
||||
- `strategies/validation.py` - Strategy signal validation and quality assurance
|
||||
- `strategies/batch_processing.py` - Batch processing engine for backtesting multiple strategies across large datasets
|
||||
- `strategies/realtime_execution.py` - Real-time strategy execution pipeline for live signal generation
|
||||
- `dashboard/callbacks/realtime_strategies.py` - Dashboard callbacks for real-time strategy integration
|
||||
- `tests/strategies/test_base_strategy.py` - Unit tests for BaseStrategy abstract class
|
||||
- `tests/strategies/test_strategy_factory.py` - Unit tests for strategy factory system
|
||||
- `tests/strategies/test_strategy_manager.py` - Unit tests for StrategyManager class
|
||||
- `tests/strategies/implementations/test_ema_crossover.py` - Unit tests for EMA Crossover strategy
|
||||
- `tests/strategies/implementations/test_rsi.py` - Unit tests for RSI strategy
|
||||
- `tests/strategies/implementations/test_macd.py` - Unit tests for MACD strategy
|
||||
- `tests/strategies/test_data_integration.py` - Unit tests for strategy data integration
|
||||
- `tests/strategies/test_validation.py` - Unit tests for strategy signal validation
|
||||
- `tests/strategies/test_batch_processing.py` - Unit tests for batch processing capabilities
|
||||
- `tests/strategies/test_realtime_execution.py` - Unit tests for real-time execution pipeline
|
||||
- `tests/database/test_strategy_repository.py` - Unit tests for strategy repository
|
||||
|
||||
### Notes
|
||||
@ -73,6 +82,26 @@
|
||||
- **Reasoning**: Maintains consistency with existing database access patterns, ensures proper session management, and provides a clean API for strategy data operations.
|
||||
- **Impact**: All strategy database operations follow the same patterns as other modules, with proper error handling, logging, and transaction management.
|
||||
|
||||
### 7. Vectorized Data Integration
|
||||
- **Decision**: Implement vectorized approaches in `StrategyDataIntegrator` for DataFrame construction, indicator batching, and multi-strategy processing while maintaining iterative interfaces for backward compatibility.
|
||||
- **Reasoning**: Significant performance improvements for backtesting and bulk analysis scenarios, better memory efficiency with pandas operations, and preparation for multi-strategy batch processing capabilities.
|
||||
- **Impact**: Enhanced performance for large datasets while maintaining existing single-strategy interfaces. Sets foundation for efficient multi-strategy and multi-timeframe processing in future phases.
|
||||
|
||||
### 8. Single-Strategy Orchestration Focus
|
||||
- **Decision**: Implement strategy calculation orchestration focused on single-strategy optimization with indicator dependency resolution, avoiding premature multi-strategy complexity.
|
||||
- **Reasoning**: Multi-strategy coordination is better handled at the backtesting layer or through parallelization. Single-strategy optimization provides immediate benefits while keeping code maintainable and focused.
|
||||
- **Impact**: Cleaner, more maintainable code with optimized single-strategy performance. Provides foundation for future backtester-level parallelization without architectural complexity.
|
||||
|
||||
### 9. Indicator Warm-up Handling for Streaming Batch Processing
|
||||
- **Decision**: Implemented dynamic warm-up period calculation and overlapping windows with result trimming for streaming batch processing.
|
||||
- **Reasoning**: To ensure accurate indicator calculations and prevent false signals when processing large datasets in chunks, as indicators require a certain amount of historical data to 'warm up'.
|
||||
- **Impact**: Guarantees correct backtest results for strategies relying on indicators with warm-up periods, even when using memory-efficient streaming. Automatically adjusts chunk processing to include necessary historical context and removes duplicate/invalid initial signals.
|
||||
|
||||
### 10. Real-time Strategy Execution Architecture
|
||||
- **Decision**: Implemented event-driven real-time strategy execution pipeline with signal broadcasting, chart integration, and concurrent processing capabilities.
|
||||
- **Reasoning**: Real-time strategy execution requires different architecture than batch processing - event-driven triggers, background signal processing, throttled chart updates, and integration with existing dashboard refresh cycles.
|
||||
- **Impact**: Enables live strategy signal generation that integrates seamlessly with the existing chart system. Provides concurrent strategy execution, real-time signal storage, error handling with automatic strategy disabling, and performance monitoring for production use.
|
||||
|
||||
## Tasks
|
||||
|
||||
- [x] 1.0 Core Strategy Foundation Setup
|
||||
@ -109,16 +138,16 @@
|
||||
- [x] 3.8 Add data retention policies for strategy signals (configurable cleanup of old analysis data)
|
||||
- [x] 3.9 Implement strategy signal aggregation queries for performance analysis
|
||||
|
||||
- [ ] 4.0 Strategy Data Integration
|
||||
- [ ] 4.1 Create `StrategyDataIntegrator` class in new `strategies/data_integration.py` module
|
||||
- [ ] 4.2 Implement data loading interface that leverages existing `TechnicalIndicators` class for indicator dependencies
|
||||
- [x] 4.0 Strategy Data Integration
|
||||
- [x] 4.1 Create `StrategyDataIntegrator` class in new `strategies/data_integration.py` module
|
||||
- [x] 4.2 Implement data loading interface that leverages existing `TechnicalIndicators` class for indicator dependencies
|
||||
- [x] 4.3 Add multi-timeframe data handling for strategies that require indicators from different timeframes
|
||||
- [ ] 4.4 Implement strategy calculation orchestration with proper indicator dependency resolution
|
||||
- [ ] 4.5 Create caching layer for computed indicator results to avoid recalculation across strategies
|
||||
- [ ] 4.6 Add strategy signal generation and validation pipeline
|
||||
- [ ] 4.7 Implement batch processing capabilities for backtesting large datasets
|
||||
- [ ] 4.8 Create real-time strategy execution pipeline that integrates with existing chart data refresh
|
||||
- [ ] 4.9 Add error handling and recovery mechanisms for strategy calculation failures
|
||||
- [x] 4.4 Implement strategy calculation orchestration with proper indicator dependency resolution
|
||||
- [x] 4.5 Create caching layer for computed indicator results to avoid recalculation across strategies
|
||||
- [x] 4.6 Add strategy signal generation and validation pipeline
|
||||
- [x] 4.7 Implement batch processing capabilities for backtesting large datasets
|
||||
- [x] 4.8 Create real-time strategy execution pipeline that integrates with existing chart data refresh
|
||||
- [x] 4.9 Add error handling and recovery mechanisms for strategy calculation failures
|
||||
|
||||
- [ ] 5.0 Chart Integration and Visualization
|
||||
- [ ] 5.1 Create `StrategySignalLayer` class in `components/charts/layers/strategy_signals.py`
|
||||
|
||||
798
tests/strategies/test_batch_processing.py
Normal file
798
tests/strategies/test_batch_processing.py
Normal file
@ -0,0 +1,798 @@
|
||||
"""
|
||||
Tests for Strategy Batch Processing
|
||||
|
||||
This module tests batch processing capabilities for strategy backtesting
|
||||
including memory management, parallel processing, and performance monitoring.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from datetime import datetime, timezone
|
||||
import pandas as pd
|
||||
|
||||
from strategies.batch_processing import BacktestingBatchProcessor, BatchProcessingConfig
|
||||
from strategies.data_types import StrategyResult, StrategySignal, SignalType
|
||||
|
||||
|
||||
class TestBatchProcessingConfig:
|
||||
"""Tests for BatchProcessingConfig dataclass."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default batch processing configuration."""
|
||||
config = BatchProcessingConfig()
|
||||
|
||||
assert config.max_concurrent_strategies == 4
|
||||
assert config.max_memory_usage_percent == 80.0
|
||||
assert config.chunk_size_days == 30
|
||||
assert config.enable_memory_monitoring is True
|
||||
assert config.enable_result_validation is True
|
||||
assert config.result_cache_size == 1000
|
||||
assert config.progress_reporting_interval == 10
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom batch processing configuration."""
|
||||
config = BatchProcessingConfig(
|
||||
max_concurrent_strategies=8,
|
||||
max_memory_usage_percent=90.0,
|
||||
chunk_size_days=60,
|
||||
enable_memory_monitoring=False,
|
||||
enable_result_validation=False,
|
||||
result_cache_size=500,
|
||||
progress_reporting_interval=5
|
||||
)
|
||||
|
||||
assert config.max_concurrent_strategies == 8
|
||||
assert config.max_memory_usage_percent == 90.0
|
||||
assert config.chunk_size_days == 60
|
||||
assert config.enable_memory_monitoring is False
|
||||
assert config.enable_result_validation is False
|
||||
assert config.result_cache_size == 500
|
||||
assert config.progress_reporting_interval == 5
|
||||
|
||||
|
||||
class TestBacktestingBatchProcessor:
|
||||
"""Tests for BacktestingBatchProcessor class."""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create batch processor with default configuration."""
|
||||
config = BatchProcessingConfig(
|
||||
enable_memory_monitoring=False, # Disable for testing
|
||||
progress_reporting_interval=1, # Report every strategy for testing
|
||||
enable_result_validation=False # Disable validation for basic tests
|
||||
)
|
||||
with patch('strategies.batch_processing.StrategyDataIntegrator'):
|
||||
return BacktestingBatchProcessor(config)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_strategy_configs(self):
|
||||
"""Create sample strategy configurations for testing."""
|
||||
return [
|
||||
{
|
||||
'name': 'ema_crossover',
|
||||
'type': 'trend_following',
|
||||
'parameters': {'fast_ema': 12, 'slow_ema': 26}
|
||||
},
|
||||
{
|
||||
'name': 'rsi_momentum',
|
||||
'type': 'momentum',
|
||||
'parameters': {'rsi_period': 14, 'oversold': 30, 'overbought': 70}
|
||||
},
|
||||
{
|
||||
'name': 'macd_trend',
|
||||
'type': 'trend_following',
|
||||
'parameters': {'fast_ema': 12, 'slow_ema': 26, 'signal': 9}
|
||||
}
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_strategy_results(self):
|
||||
"""Create sample strategy results for testing."""
|
||||
return [
|
||||
StrategyResult(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='BTC-USDT',
|
||||
timeframe='1h',
|
||||
strategy_name='test_strategy',
|
||||
signals=[
|
||||
StrategySignal(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='BTC-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.BUY,
|
||||
price=50000.0,
|
||||
confidence=0.8,
|
||||
metadata={'rsi': 30}
|
||||
)
|
||||
],
|
||||
indicators_used={'rsi': 30, 'ema': 49000},
|
||||
metadata={'execution_time': 0.5}
|
||||
)
|
||||
]
|
||||
|
||||
def test_initialization(self, processor):
|
||||
"""Test batch processor initialization."""
|
||||
assert processor.config is not None
|
||||
assert processor.logger is not None
|
||||
assert processor.data_integrator is not None
|
||||
assert processor._processing_stats['strategies_processed'] == 0
|
||||
assert processor._processing_stats['total_signals_generated'] == 0
|
||||
assert processor._processing_stats['errors_count'] == 0
|
||||
|
||||
def test_initialization_with_validation_disabled(self):
|
||||
"""Test initialization with validation disabled."""
|
||||
config = BatchProcessingConfig(enable_result_validation=False)
|
||||
with patch('strategies.batch_processing.StrategyDataIntegrator'):
|
||||
processor = BacktestingBatchProcessor(config)
|
||||
assert processor.signal_validator is None
|
||||
|
||||
@patch('strategies.batch_processing.StrategyDataIntegrator')
|
||||
def test_process_strategies_batch(self, mock_integrator_class, processor, sample_strategy_configs, sample_strategy_results):
|
||||
"""Test batch processing of multiple strategies."""
|
||||
# Setup mock data integrator
|
||||
mock_integrator = MagicMock()
|
||||
mock_integrator.calculate_strategy_signals_orchestrated.return_value = sample_strategy_results
|
||||
processor.data_integrator = mock_integrator
|
||||
|
||||
symbols = ['BTC-USDT', 'ETH-USDT']
|
||||
timeframe = '1h'
|
||||
days_back = 30
|
||||
|
||||
results = processor.process_strategies_batch(
|
||||
strategy_configs=sample_strategy_configs,
|
||||
symbols=symbols,
|
||||
timeframe=timeframe,
|
||||
days_back=days_back
|
||||
)
|
||||
|
||||
# Verify results structure
|
||||
assert len(results) == len(sample_strategy_configs)
|
||||
assert 'ema_crossover' in results
|
||||
assert 'rsi_momentum' in results
|
||||
assert 'macd_trend' in results
|
||||
|
||||
# Verify statistics
|
||||
stats = processor.get_processing_statistics()
|
||||
assert stats['strategies_processed'] == 3
|
||||
assert stats['total_signals_generated'] == 6 # 3 strategies × 2 symbols × 1 signal each
|
||||
assert stats['errors_count'] == 0
|
||||
|
||||
def test_process_single_strategy_batch(self, processor, sample_strategy_results):
|
||||
"""Test processing a single strategy across multiple symbols."""
|
||||
# Setup mock data integrator
|
||||
mock_integrator = MagicMock()
|
||||
mock_integrator.calculate_strategy_signals_orchestrated.return_value = sample_strategy_results
|
||||
processor.data_integrator = mock_integrator
|
||||
|
||||
strategy_config = {'name': 'test_strategy', 'type': 'test'}
|
||||
symbols = ['BTC-USDT', 'ETH-USDT']
|
||||
|
||||
results = processor._process_single_strategy_batch(
|
||||
strategy_config, symbols, '1h', 30, 'okx'
|
||||
)
|
||||
|
||||
assert len(results) == 2 # Results for 2 symbols
|
||||
assert processor._processing_stats['total_signals_generated'] == 2
|
||||
|
||||
def test_validate_strategy_results(self, processor, sample_strategy_results):
|
||||
"""Test strategy result validation."""
|
||||
# Setup mock signal validator
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate_signals_batch.return_value = (
|
||||
sample_strategy_results[0].signals, # valid signals
|
||||
[] # no invalid signals
|
||||
)
|
||||
processor.signal_validator = mock_validator
|
||||
|
||||
validated_results = processor._validate_strategy_results(sample_strategy_results)
|
||||
|
||||
assert len(validated_results) == 1
|
||||
assert len(validated_results[0].signals) == 1
|
||||
mock_validator.validate_signals_batch.assert_called_once()
|
||||
|
||||
@patch('strategies.batch_processing.psutil')
|
||||
def test_check_memory_usage_normal(self, mock_psutil, processor):
|
||||
"""Test memory usage monitoring under normal conditions."""
|
||||
# Mock memory usage below threshold
|
||||
mock_process = MagicMock()
|
||||
mock_process.memory_percent.return_value = 60.0 # Below 80% threshold
|
||||
mock_process.memory_info.return_value.rss = 500 * 1024 * 1024 # 500 MB
|
||||
mock_psutil.Process.return_value = mock_process
|
||||
|
||||
processor._check_memory_usage()
|
||||
|
||||
assert processor._processing_stats['memory_peak_mb'] == 500.0
|
||||
|
||||
@patch('strategies.batch_processing.psutil')
|
||||
def test_check_memory_usage_high(self, mock_psutil, processor):
|
||||
"""Test memory usage monitoring with high usage."""
|
||||
# Mock memory usage above threshold
|
||||
mock_process = MagicMock()
|
||||
mock_process.memory_percent.return_value = 85.0 # Above 80% threshold
|
||||
mock_process.memory_info.return_value.rss = 1000 * 1024 * 1024 # 1000 MB
|
||||
mock_psutil.Process.return_value = mock_process
|
||||
|
||||
with patch.object(processor, '_cleanup_memory') as mock_cleanup:
|
||||
processor._check_memory_usage()
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
def test_cleanup_memory(self, processor):
|
||||
"""Test memory cleanup operations."""
|
||||
# Fill result cache beyond limit
|
||||
for i in range(1500): # Above 1000 limit
|
||||
processor._result_cache[f'key_{i}'] = f'result_{i}'
|
||||
|
||||
initial_cache_size = len(processor._result_cache)
|
||||
|
||||
with patch.object(processor.data_integrator, 'clear_cache') as mock_clear, \
|
||||
patch('strategies.batch_processing.gc.collect') as mock_gc:
|
||||
|
||||
processor._cleanup_memory()
|
||||
|
||||
# Verify cache was reduced
|
||||
assert len(processor._result_cache) < initial_cache_size
|
||||
assert len(processor._result_cache) == 500 # Half of cache size limit
|
||||
|
||||
# Verify other cleanup operations
|
||||
mock_clear.assert_called_once()
|
||||
mock_gc.assert_called_once()
|
||||
|
||||
def test_get_processing_statistics(self, processor):
|
||||
"""Test processing statistics calculation."""
|
||||
# Set some test statistics
|
||||
processor._processing_stats.update({
|
||||
'strategies_processed': 5,
|
||||
'total_signals_generated': 25,
|
||||
'processing_time_seconds': 10.0,
|
||||
'errors_count': 1,
|
||||
'validation_failures': 2
|
||||
})
|
||||
|
||||
stats = processor.get_processing_statistics()
|
||||
|
||||
assert stats['strategies_processed'] == 5
|
||||
assert stats['total_signals_generated'] == 25
|
||||
assert stats['average_signals_per_strategy'] == 5.0
|
||||
assert stats['average_processing_time_per_strategy'] == 2.0
|
||||
assert stats['error_rate'] == 20.0 # 1/5 * 100
|
||||
assert stats['validation_failure_rate'] == 8.0 # 2/25 * 100
|
||||
|
||||
def test_get_processing_statistics_zero_division(self, processor):
|
||||
"""Test statistics calculation with zero values."""
|
||||
stats = processor.get_processing_statistics()
|
||||
|
||||
assert stats['average_signals_per_strategy'] == 0
|
||||
assert stats['average_processing_time_per_strategy'] == 0
|
||||
assert stats['error_rate'] == 0.0
|
||||
assert stats['validation_failure_rate'] == 0.0
|
||||
|
||||
def test_process_strategies_batch_with_error(self, processor, sample_strategy_configs):
|
||||
"""Test batch processing with errors."""
|
||||
# Setup mock to raise an exception
|
||||
mock_integrator = MagicMock()
|
||||
mock_integrator.calculate_strategy_signals_orchestrated.side_effect = Exception("Test error")
|
||||
processor.data_integrator = mock_integrator
|
||||
|
||||
results = processor.process_strategies_batch(
|
||||
strategy_configs=sample_strategy_configs,
|
||||
symbols=['BTC-USDT'],
|
||||
timeframe='1h',
|
||||
days_back=30
|
||||
)
|
||||
|
||||
# Should handle errors gracefully
|
||||
assert isinstance(results, dict)
|
||||
assert processor._processing_stats['errors_count'] > 0
|
||||
|
||||
@patch('strategies.batch_processing.StrategyDataIntegrator')
|
||||
def test_process_strategies_parallel(self, mock_integrator_class, processor, sample_strategy_configs, sample_strategy_results):
|
||||
"""Test parallel processing of multiple strategies."""
|
||||
# Setup mock data integrator
|
||||
mock_integrator = MagicMock()
|
||||
mock_integrator.calculate_strategy_signals_orchestrated.return_value = sample_strategy_results
|
||||
processor.data_integrator = mock_integrator
|
||||
|
||||
symbols = ['BTC-USDT', 'ETH-USDT']
|
||||
timeframe = '1h'
|
||||
days_back = 30
|
||||
|
||||
results = processor.process_strategies_parallel(
|
||||
strategy_configs=sample_strategy_configs,
|
||||
symbols=symbols,
|
||||
timeframe=timeframe,
|
||||
days_back=days_back
|
||||
)
|
||||
|
||||
# Verify results structure (same as sequential processing)
|
||||
assert len(results) == len(sample_strategy_configs)
|
||||
assert 'ema_crossover' in results
|
||||
assert 'rsi_momentum' in results
|
||||
assert 'macd_trend' in results
|
||||
|
||||
# Verify statistics
|
||||
stats = processor.get_processing_statistics()
|
||||
assert stats['strategies_processed'] == 3
|
||||
assert stats['total_signals_generated'] == 6 # 3 strategies × 2 symbols × 1 signal each
|
||||
assert stats['errors_count'] == 0
|
||||
|
||||
def test_process_symbols_parallel(self, processor, sample_strategy_results):
|
||||
"""Test parallel processing of single strategy across multiple symbols."""
|
||||
# Setup mock data integrator
|
||||
mock_integrator = MagicMock()
|
||||
mock_integrator.calculate_strategy_signals_orchestrated.return_value = sample_strategy_results
|
||||
processor.data_integrator = mock_integrator
|
||||
|
||||
strategy_config = {'name': 'test_strategy', 'type': 'test'}
|
||||
symbols = ['BTC-USDT', 'ETH-USDT', 'BNB-USDT']
|
||||
|
||||
results = processor.process_symbols_parallel(
|
||||
strategy_config=strategy_config,
|
||||
symbols=symbols,
|
||||
timeframe='1h',
|
||||
days_back=30
|
||||
)
|
||||
|
||||
# Should have results for all symbols
|
||||
assert len(results) == 3 # Results for 3 symbols
|
||||
assert processor._processing_stats['total_signals_generated'] == 3
|
||||
|
||||
def test_process_strategy_for_symbol(self, processor, sample_strategy_results):
|
||||
"""Test processing a single strategy for a single symbol."""
|
||||
# Setup mock data integrator
|
||||
mock_integrator = MagicMock()
|
||||
mock_integrator.calculate_strategy_signals_orchestrated.return_value = sample_strategy_results
|
||||
processor.data_integrator = mock_integrator
|
||||
|
||||
strategy_config = {'name': 'test_strategy', 'type': 'test'}
|
||||
|
||||
results = processor._process_strategy_for_symbol(
|
||||
strategy_config=strategy_config,
|
||||
symbol='BTC-USDT',
|
||||
timeframe='1h',
|
||||
days_back=30,
|
||||
exchange='okx'
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].strategy_name == 'test_strategy'
|
||||
assert results[0].symbol == 'BTC-USDT'
|
||||
|
||||
def test_process_strategy_for_symbol_with_error(self, processor):
|
||||
"""Test symbol processing with error handling."""
|
||||
# Setup mock to raise an exception
|
||||
mock_integrator = MagicMock()
|
||||
mock_integrator.calculate_strategy_signals_orchestrated.side_effect = Exception("Test error")
|
||||
processor.data_integrator = mock_integrator
|
||||
|
||||
strategy_config = {'name': 'test_strategy', 'type': 'test'}
|
||||
|
||||
results = processor._process_strategy_for_symbol(
|
||||
strategy_config=strategy_config,
|
||||
symbol='BTC-USDT',
|
||||
timeframe='1h',
|
||||
days_back=30,
|
||||
exchange='okx'
|
||||
)
|
||||
|
||||
# Should return empty list on error
|
||||
assert results == []
|
||||
|
||||
def test_process_large_dataset_streaming(self, processor, sample_strategy_configs, sample_strategy_results):
|
||||
"""Test streaming processing for large datasets."""
|
||||
# Setup mock data integrator
|
||||
mock_integrator = MagicMock()
|
||||
mock_integrator.calculate_strategy_signals_orchestrated.return_value = sample_strategy_results
|
||||
processor.data_integrator = mock_integrator
|
||||
|
||||
# Mock the parallel processing method to avoid actual parallel execution
|
||||
with patch.object(processor, 'process_strategies_parallel') as mock_parallel:
|
||||
mock_parallel.return_value = {
|
||||
'test_strategy': sample_strategy_results
|
||||
}
|
||||
|
||||
# Test streaming with 90 days split into 30-day chunks
|
||||
stream = processor.process_large_dataset_streaming(
|
||||
strategy_configs=sample_strategy_configs,
|
||||
symbols=['BTC-USDT'],
|
||||
timeframe='1h',
|
||||
total_days_back=90 # Should create 3 chunks
|
||||
)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = list(stream)
|
||||
|
||||
assert len(chunks) == 3 # 90 days / 30 days per chunk
|
||||
|
||||
# Each chunk should have results for all strategies
|
||||
for chunk in chunks:
|
||||
assert 'test_strategy' in chunk
|
||||
|
||||
def test_aggregate_streaming_results(self, processor, sample_strategy_results):
|
||||
"""Test aggregation of streaming results."""
|
||||
# Create mock streaming results
|
||||
chunk1 = {'strategy1': sample_strategy_results[:1], 'strategy2': []}
|
||||
chunk2 = {'strategy1': [], 'strategy2': sample_strategy_results[:1]}
|
||||
chunk3 = {'strategy1': sample_strategy_results[:1], 'strategy2': sample_strategy_results[:1]}
|
||||
|
||||
stream = iter([chunk1, chunk2, chunk3])
|
||||
|
||||
aggregated = processor.aggregate_streaming_results(stream)
|
||||
|
||||
assert len(aggregated) == 2
|
||||
assert 'strategy1' in aggregated
|
||||
assert 'strategy2' in aggregated
|
||||
assert len(aggregated['strategy1']) == 2 # From chunk1 and chunk3
|
||||
assert len(aggregated['strategy2']) == 2 # From chunk2 and chunk3
|
||||
|
||||
@patch('strategies.batch_processing.psutil')
|
||||
def test_process_with_memory_constraints_sufficient_memory(self, mock_psutil, processor, sample_strategy_configs):
|
||||
"""Test memory-constrained processing with sufficient memory."""
|
||||
# Mock low memory usage
|
||||
mock_process = MagicMock()
|
||||
mock_process.memory_info.return_value.rss = 100 * 1024 * 1024 # 100 MB
|
||||
mock_psutil.Process.return_value = mock_process
|
||||
|
||||
with patch.object(processor, 'process_strategies_parallel') as mock_parallel:
|
||||
mock_parallel.return_value = {}
|
||||
|
||||
processor.process_with_memory_constraints(
|
||||
strategy_configs=sample_strategy_configs,
|
||||
symbols=['BTC-USDT'],
|
||||
timeframe='1h',
|
||||
days_back=30,
|
||||
max_memory_mb=1000.0 # High limit
|
||||
)
|
||||
|
||||
# Should use parallel processing for sufficient memory
|
||||
mock_parallel.assert_called_once()
|
||||
|
||||
@patch('strategies.batch_processing.psutil')
|
||||
def test_process_with_memory_constraints_moderate_constraint(self, mock_psutil, processor, sample_strategy_configs):
|
||||
"""Test memory-constrained processing with moderate constraint."""
|
||||
# Mock moderate memory usage
|
||||
mock_process = MagicMock()
|
||||
mock_process.memory_info.return_value.rss = 400 * 1024 * 1024 # 400 MB
|
||||
mock_psutil.Process.return_value = mock_process
|
||||
|
||||
with patch.object(processor, 'process_strategies_batch') as mock_batch:
|
||||
mock_batch.return_value = {}
|
||||
|
||||
processor.process_with_memory_constraints(
|
||||
strategy_configs=sample_strategy_configs,
|
||||
symbols=['BTC-USDT'],
|
||||
timeframe='1h',
|
||||
days_back=30,
|
||||
max_memory_mb=500.0 # Moderate limit
|
||||
)
|
||||
|
||||
# Should use sequential batch processing
|
||||
mock_batch.assert_called_once()
|
||||
|
||||
@patch('strategies.batch_processing.psutil')
|
||||
def test_process_with_memory_constraints_severe_constraint(self, mock_psutil, processor, sample_strategy_configs):
|
||||
"""Test memory-constrained processing with severe constraint."""
|
||||
# Mock high memory usage
|
||||
mock_process = MagicMock()
|
||||
mock_process.memory_info.return_value.rss = 450 * 1024 * 1024 # 450 MB
|
||||
mock_psutil.Process.return_value = mock_process
|
||||
|
||||
with patch.object(processor, 'process_large_dataset_streaming_with_warmup') as mock_streaming, \
|
||||
patch.object(processor, 'aggregate_streaming_results') as mock_aggregate:
|
||||
|
||||
mock_streaming.return_value = iter([{}])
|
||||
mock_aggregate.return_value = {}
|
||||
|
||||
processor.process_with_memory_constraints(
|
||||
strategy_configs=sample_strategy_configs,
|
||||
symbols=['BTC-USDT'],
|
||||
timeframe='1h',
|
||||
days_back=30,
|
||||
max_memory_mb=500.0 # Low limit with high current usage
|
||||
)
|
||||
|
||||
# Should use streaming processing with warm-up
|
||||
mock_streaming.assert_called_once()
|
||||
mock_aggregate.assert_called_once()
|
||||
|
||||
def test_get_performance_metrics(self, processor):
|
||||
"""Test comprehensive performance metrics calculation."""
|
||||
# Set some test statistics
|
||||
processor._processing_stats.update({
|
||||
'strategies_processed': 5,
|
||||
'total_signals_generated': 25,
|
||||
'processing_time_seconds': 10.0,
|
||||
'memory_peak_mb': 500.0,
|
||||
'errors_count': 1,
|
||||
'validation_failures': 2
|
||||
})
|
||||
|
||||
with patch.object(processor.data_integrator, 'get_cache_stats') as mock_cache_stats:
|
||||
mock_cache_stats.return_value = {'cache_hits': 80, 'cache_misses': 20}
|
||||
|
||||
metrics = processor.get_performance_metrics()
|
||||
|
||||
assert 'cache_hit_rate' in metrics
|
||||
assert 'memory_efficiency' in metrics
|
||||
assert 'throughput_signals_per_second' in metrics
|
||||
assert 'parallel_efficiency' in metrics
|
||||
assert 'optimization_recommendations' in metrics
|
||||
|
||||
assert metrics['cache_hit_rate'] == 80.0 # 80/(80+20) * 100
|
||||
assert metrics['throughput_signals_per_second'] == 2.5 # 25/10
|
||||
|
||||
def test_calculate_cache_hit_rate(self, processor):
|
||||
"""Test cache hit rate calculation."""
|
||||
with patch.object(processor.data_integrator, 'get_cache_stats') as mock_cache_stats:
|
||||
mock_cache_stats.return_value = {'cache_hits': 70, 'cache_misses': 30}
|
||||
|
||||
hit_rate = processor._calculate_cache_hit_rate()
|
||||
assert hit_rate == 70.0 # 70/(70+30) * 100
|
||||
|
||||
def test_calculate_memory_efficiency(self, processor):
|
||||
"""Test memory efficiency calculation."""
|
||||
processor._processing_stats.update({
|
||||
'memory_peak_mb': 200.0,
|
||||
'strategies_processed': 2
|
||||
})
|
||||
|
||||
efficiency = processor._calculate_memory_efficiency()
|
||||
# 200MB / 2 strategies = 100MB per strategy
|
||||
# Baseline is 100MB, so efficiency should be 50%
|
||||
assert efficiency == 50.0
|
||||
|
||||
def test_generate_optimization_recommendations(self, processor):
|
||||
"""Test optimization recommendations generation."""
|
||||
# Set up poor performance metrics
|
||||
processor._processing_stats.update({
|
||||
'strategies_processed': 1,
|
||||
'total_signals_generated': 1,
|
||||
'processing_time_seconds': 10.0,
|
||||
'memory_peak_mb': 1000.0, # High memory usage
|
||||
'errors_count': 2, # High error rate
|
||||
'validation_failures': 0
|
||||
})
|
||||
|
||||
with patch.object(processor.data_integrator, 'get_cache_stats') as mock_cache_stats:
|
||||
mock_cache_stats.return_value = {'cache_hits': 1, 'cache_misses': 9} # Low cache hit rate
|
||||
|
||||
recommendations = processor._generate_optimization_recommendations()
|
||||
|
||||
assert isinstance(recommendations, list)
|
||||
assert len(recommendations) > 0
|
||||
# Should recommend memory efficiency improvement
|
||||
assert any('memory efficiency' in rec.lower() for rec in recommendations)
|
||||
|
||||
def test_optimize_configuration(self, processor):
|
||||
"""Test automatic configuration optimization."""
|
||||
# Set up metrics that indicate poor memory efficiency
|
||||
processor._processing_stats.update({
|
||||
'strategies_processed': 4,
|
||||
'total_signals_generated': 20,
|
||||
'processing_time_seconds': 8.0,
|
||||
'memory_peak_mb': 2000.0, # Very high memory usage
|
||||
'errors_count': 0,
|
||||
'validation_failures': 0
|
||||
})
|
||||
|
||||
with patch.object(processor.data_integrator, 'get_cache_stats') as mock_cache_stats:
|
||||
mock_cache_stats.return_value = {'cache_hits': 10, 'cache_misses': 90}
|
||||
|
||||
original_workers = processor.config.max_concurrent_strategies
|
||||
original_chunk_size = processor.config.chunk_size_days
|
||||
|
||||
optimized_config = processor.optimize_configuration()
|
||||
|
||||
# Should reduce workers and chunk size due to poor memory efficiency
|
||||
assert optimized_config.max_concurrent_strategies <= original_workers
|
||||
assert optimized_config.chunk_size_days <= original_chunk_size
|
||||
|
||||
def test_benchmark_processing_methods(self, processor, sample_strategy_configs):
|
||||
"""Test processing method benchmarking."""
|
||||
with patch.object(processor, 'process_strategies_batch') as mock_batch, \
|
||||
patch.object(processor, 'process_strategies_parallel') as mock_parallel:
|
||||
|
||||
# Mock batch processing results
|
||||
mock_batch.return_value = {'strategy1': []}
|
||||
|
||||
# Mock parallel processing results
|
||||
mock_parallel.return_value = {'strategy1': []}
|
||||
|
||||
benchmark_results = processor.benchmark_processing_methods(
|
||||
strategy_configs=sample_strategy_configs,
|
||||
symbols=['BTC-USDT'],
|
||||
timeframe='1h',
|
||||
days_back=7
|
||||
)
|
||||
|
||||
assert 'sequential' in benchmark_results
|
||||
assert 'parallel' in benchmark_results
|
||||
assert 'recommendation' in benchmark_results
|
||||
|
||||
# Verify both methods were called
|
||||
mock_batch.assert_called_once()
|
||||
mock_parallel.assert_called_once()
|
||||
|
||||
def test_reset_stats(self, processor):
|
||||
"""Test statistics reset functionality."""
|
||||
# Set some statistics
|
||||
processor._processing_stats.update({
|
||||
'strategies_processed': 5,
|
||||
'total_signals_generated': 25,
|
||||
'processing_time_seconds': 10.0
|
||||
})
|
||||
processor._result_cache['test'] = 'data'
|
||||
|
||||
processor._reset_stats()
|
||||
|
||||
# Verify all stats are reset
|
||||
assert processor._processing_stats['strategies_processed'] == 0
|
||||
assert processor._processing_stats['total_signals_generated'] == 0
|
||||
assert processor._processing_stats['processing_time_seconds'] == 0.0
|
||||
assert len(processor._result_cache) == 0
|
||||
|
||||
def test_calculate_warmup_period_ema_strategy(self, processor):
|
||||
"""Test warm-up period calculation for EMA strategy."""
|
||||
strategy_configs = [
|
||||
{
|
||||
'name': 'ema_crossover',
|
||||
'fast_period': 12,
|
||||
'slow_period': 26
|
||||
}
|
||||
]
|
||||
|
||||
warmup = processor._calculate_warmup_period(strategy_configs)
|
||||
|
||||
# Should be max(12, 26) + 10 safety buffer = 36
|
||||
assert warmup == 36
|
||||
|
||||
def test_calculate_warmup_period_macd_strategy(self, processor):
|
||||
"""Test warm-up period calculation for MACD strategy."""
|
||||
strategy_configs = [
|
||||
{
|
||||
'name': 'macd_trend',
|
||||
'slow_period': 26,
|
||||
'signal_period': 9
|
||||
}
|
||||
]
|
||||
|
||||
warmup = processor._calculate_warmup_period(strategy_configs)
|
||||
|
||||
# Should be max(26, 9) + 10 MACD buffer + 10 safety buffer = 46
|
||||
assert warmup == 46
|
||||
|
||||
def test_calculate_warmup_period_rsi_strategy(self, processor):
|
||||
"""Test warm-up period calculation for RSI strategy."""
|
||||
strategy_configs = [
|
||||
{
|
||||
'name': 'rsi_momentum',
|
||||
'period': 14
|
||||
}
|
||||
]
|
||||
|
||||
warmup = processor._calculate_warmup_period(strategy_configs)
|
||||
|
||||
# Should be 14 + 5 RSI buffer + 10 safety buffer = 29
|
||||
assert warmup == 29
|
||||
|
||||
def test_calculate_warmup_period_multiple_strategies(self, processor):
|
||||
"""Test warm-up period calculation with multiple strategies."""
|
||||
strategy_configs = [
|
||||
{'name': 'ema_crossover', 'slow_period': 26},
|
||||
{'name': 'rsi_momentum', 'period': 14},
|
||||
{'name': 'macd_trend', 'slow_period': 26, 'signal_period': 9}
|
||||
]
|
||||
|
||||
warmup = processor._calculate_warmup_period(strategy_configs)
|
||||
|
||||
# Should be max of all strategies: 46 (from MACD)
|
||||
assert warmup == 46
|
||||
|
||||
def test_calculate_warmup_period_unknown_strategy(self, processor):
|
||||
"""Test warm-up period calculation for unknown strategy type."""
|
||||
strategy_configs = [
|
||||
{
|
||||
'name': 'custom_strategy',
|
||||
'some_param': 100
|
||||
}
|
||||
]
|
||||
|
||||
warmup = processor._calculate_warmup_period(strategy_configs)
|
||||
|
||||
# Should be 30 default + 10 safety buffer = 40
|
||||
assert warmup == 40
|
||||
|
||||
def test_process_large_dataset_streaming_with_warmup(self, processor, sample_strategy_configs, sample_strategy_results):
|
||||
"""Test streaming processing with warm-up period handling."""
|
||||
# Mock the warm-up calculation
|
||||
with patch.object(processor, '_calculate_warmup_period') as mock_warmup:
|
||||
mock_warmup.return_value = 10 # 10 days warm-up
|
||||
|
||||
# Mock the parallel processing method
|
||||
with patch.object(processor, 'process_strategies_parallel') as mock_parallel:
|
||||
mock_parallel.return_value = {
|
||||
'test_strategy': sample_strategy_results
|
||||
}
|
||||
|
||||
# Mock the trimming method
|
||||
with patch.object(processor, '_trim_warmup_from_results') as mock_trim:
|
||||
mock_trim.return_value = {'test_strategy': sample_strategy_results}
|
||||
|
||||
# Test streaming with 60 days split into 30-day chunks
|
||||
stream = processor.process_large_dataset_streaming_with_warmup(
|
||||
strategy_configs=sample_strategy_configs,
|
||||
symbols=['BTC-USDT'],
|
||||
timeframe='1h',
|
||||
total_days_back=60 # Should create 2 chunks
|
||||
)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = list(stream)
|
||||
|
||||
assert len(chunks) == 2 # 60 days / 30 days per chunk
|
||||
|
||||
# Verify parallel processing was called with correct parameters
|
||||
assert mock_parallel.call_count == 2
|
||||
|
||||
# First chunk should not have warm-up, second should
|
||||
first_call_args = mock_parallel.call_args_list[0]
|
||||
second_call_args = mock_parallel.call_args_list[1]
|
||||
|
||||
# First chunk: 30 days (no warm-up)
|
||||
assert first_call_args[1]['days_back'] == 30
|
||||
|
||||
# Second chunk: 30 + 10 warm-up = 40 days
|
||||
assert second_call_args[1]['days_back'] == 40
|
||||
|
||||
# Trimming should only be called for second chunk
|
||||
assert mock_trim.call_count == 1
|
||||
|
||||
def test_trim_warmup_from_results(self, processor, sample_strategy_results):
|
||||
"""Test trimming warm-up period from results."""
|
||||
# Create test results with multiple signals
|
||||
extended_results = sample_strategy_results * 10 # 10 results total
|
||||
chunk_results = {
|
||||
'strategy1': extended_results,
|
||||
'strategy2': sample_strategy_results * 5 # 5 results
|
||||
}
|
||||
|
||||
trimmed = processor._trim_warmup_from_results(
|
||||
chunk_results=chunk_results,
|
||||
warmup_days=10,
|
||||
target_start_days=30,
|
||||
target_end_days=60
|
||||
)
|
||||
|
||||
# Verify trimming occurred
|
||||
assert len(trimmed['strategy1']) <= len(extended_results)
|
||||
assert len(trimmed['strategy2']) <= len(sample_strategy_results * 5)
|
||||
|
||||
# Results should be sorted by timestamp
|
||||
for strategy_name, results in trimmed.items():
|
||||
if len(results) > 1:
|
||||
timestamps = [r.timestamp for r in results]
|
||||
assert timestamps == sorted(timestamps)
|
||||
|
||||
def test_streaming_with_warmup_chunk_size_adjustment(self, processor, sample_strategy_configs):
|
||||
"""Test automatic chunk size adjustment when too small for warm-up."""
|
||||
# Set up small chunk size relative to warm-up
|
||||
processor.config.chunk_size_days = 15 # Small chunk size
|
||||
|
||||
with patch.object(processor, '_calculate_warmup_period') as mock_warmup:
|
||||
mock_warmup.return_value = 30 # Large warm-up period
|
||||
|
||||
with patch.object(processor, 'process_strategies_parallel') as mock_parallel:
|
||||
mock_parallel.return_value = {}
|
||||
|
||||
# This should trigger chunk size adjustment
|
||||
stream = processor.process_large_dataset_streaming_with_warmup(
|
||||
strategy_configs=sample_strategy_configs,
|
||||
symbols=['BTC-USDT'],
|
||||
timeframe='1h',
|
||||
total_days_back=90
|
||||
)
|
||||
|
||||
# Consume the stream to trigger processing
|
||||
list(stream)
|
||||
|
||||
# Verify warning was logged about chunk size adjustment
|
||||
# (In a real implementation, you might want to capture log messages)
|
||||
1068
tests/strategies/test_data_integration.py
Normal file
1068
tests/strategies/test_data_integration.py
Normal file
File diff suppressed because it is too large
Load Diff
558
tests/strategies/test_realtime_execution.py
Normal file
558
tests/strategies/test_realtime_execution.py
Normal file
@ -0,0 +1,558 @@
|
||||
"""
|
||||
Tests for real-time strategy execution pipeline.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import time
|
||||
from queue import Queue, Empty
|
||||
import threading
|
||||
|
||||
from strategies.realtime_execution import (
|
||||
RealTimeStrategyProcessor,
|
||||
StrategySignalBroadcaster,
|
||||
RealTimeConfig,
|
||||
StrategyExecutionContext,
|
||||
RealTimeSignal,
|
||||
get_realtime_strategy_processor,
|
||||
initialize_realtime_strategy_system,
|
||||
shutdown_realtime_strategy_system
|
||||
)
|
||||
from strategies.data_types import StrategyResult, StrategySignal, SignalType
|
||||
from data.common.data_types import OHLCVCandle
|
||||
|
||||
|
||||
class TestRealTimeConfig:
|
||||
"""Test RealTimeConfig dataclass."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = RealTimeConfig()
|
||||
|
||||
assert config.refresh_interval_seconds == 30
|
||||
assert config.max_strategies_concurrent == 5
|
||||
assert config.incremental_calculation == True
|
||||
assert config.signal_batch_size == 100
|
||||
assert config.enable_signal_broadcasting == True
|
||||
assert config.max_signal_queue_size == 1000
|
||||
assert config.chart_update_throttle_ms == 1000
|
||||
assert config.error_retry_attempts == 3
|
||||
assert config.error_retry_delay_seconds == 5
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration values."""
|
||||
config = RealTimeConfig(
|
||||
refresh_interval_seconds=15,
|
||||
max_strategies_concurrent=3,
|
||||
incremental_calculation=False,
|
||||
signal_batch_size=50
|
||||
)
|
||||
|
||||
assert config.refresh_interval_seconds == 15
|
||||
assert config.max_strategies_concurrent == 3
|
||||
assert config.incremental_calculation == False
|
||||
assert config.signal_batch_size == 50
|
||||
|
||||
|
||||
class TestStrategyExecutionContext:
|
||||
"""Test StrategyExecutionContext dataclass."""
|
||||
|
||||
def test_context_creation(self):
|
||||
"""Test strategy execution context creation."""
|
||||
context = StrategyExecutionContext(
|
||||
strategy_name="ema_crossover",
|
||||
strategy_config={"short_period": 12, "long_period": 26},
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1h"
|
||||
)
|
||||
|
||||
assert context.strategy_name == "ema_crossover"
|
||||
assert context.strategy_config == {"short_period": 12, "long_period": 26}
|
||||
assert context.symbol == "BTC-USDT"
|
||||
assert context.timeframe == "1h"
|
||||
assert context.exchange == "okx"
|
||||
assert context.last_calculation_time is None
|
||||
assert context.consecutive_errors == 0
|
||||
assert context.is_active == True
|
||||
|
||||
def test_context_with_custom_exchange(self):
|
||||
"""Test context with custom exchange."""
|
||||
context = StrategyExecutionContext(
|
||||
strategy_name="rsi",
|
||||
strategy_config={"period": 14},
|
||||
symbol="ETH-USDT",
|
||||
timeframe="4h",
|
||||
exchange="binance"
|
||||
)
|
||||
|
||||
assert context.exchange == "binance"
|
||||
|
||||
|
||||
class TestRealTimeSignal:
|
||||
"""Test RealTimeSignal dataclass."""
|
||||
|
||||
def test_signal_creation(self):
|
||||
"""Test real-time signal creation."""
|
||||
# Create mock strategy result
|
||||
strategy_result = Mock(spec=StrategyResult)
|
||||
strategy_result.timestamp = datetime.now(timezone.utc)
|
||||
strategy_result.confidence = 0.8
|
||||
|
||||
# Create context
|
||||
context = StrategyExecutionContext(
|
||||
strategy_name="macd",
|
||||
strategy_config={"fast_period": 12},
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1d"
|
||||
)
|
||||
|
||||
# Create signal
|
||||
signal = RealTimeSignal(
|
||||
strategy_result=strategy_result,
|
||||
context=context
|
||||
)
|
||||
|
||||
assert signal.strategy_result == strategy_result
|
||||
assert signal.context == context
|
||||
assert signal.chart_update_required == True
|
||||
assert isinstance(signal.generation_time, datetime)
|
||||
|
||||
|
||||
class TestStrategySignalBroadcaster:
|
||||
"""Test StrategySignalBroadcaster class."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
"""Test configuration."""
|
||||
return RealTimeConfig(
|
||||
signal_batch_size=5,
|
||||
max_signal_queue_size=10,
|
||||
chart_update_throttle_ms=100
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_ops(self):
|
||||
"""Mock database operations."""
|
||||
with patch('strategies.realtime_execution.get_database_operations') as mock:
|
||||
db_ops = Mock()
|
||||
db_ops.strategy = Mock()
|
||||
db_ops.strategy.store_signals_batch = Mock(return_value=5)
|
||||
mock.return_value = db_ops
|
||||
yield db_ops
|
||||
|
||||
@pytest.fixture
|
||||
def broadcaster(self, config, mock_db_ops):
|
||||
"""Create broadcaster instance."""
|
||||
return StrategySignalBroadcaster(config)
|
||||
|
||||
def test_broadcaster_initialization(self, broadcaster, config):
|
||||
"""Test broadcaster initialization."""
|
||||
assert broadcaster.config == config
|
||||
assert broadcaster._is_running == False
|
||||
assert broadcaster._chart_update_callback is None
|
||||
|
||||
def test_start_stop_broadcaster(self, broadcaster):
|
||||
"""Test starting and stopping broadcaster."""
|
||||
assert not broadcaster._is_running
|
||||
|
||||
broadcaster.start()
|
||||
assert broadcaster._is_running
|
||||
assert broadcaster._processing_thread is not None
|
||||
|
||||
broadcaster.stop()
|
||||
assert not broadcaster._is_running
|
||||
|
||||
def test_broadcast_signal(self, broadcaster):
|
||||
"""Test broadcasting signals."""
|
||||
# Create test signal
|
||||
strategy_result = Mock(spec=StrategyResult)
|
||||
context = StrategyExecutionContext(
|
||||
strategy_name="test",
|
||||
strategy_config={},
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1h"
|
||||
)
|
||||
signal = RealTimeSignal(strategy_result=strategy_result, context=context)
|
||||
|
||||
# Broadcast signal
|
||||
success = broadcaster.broadcast_signal(signal)
|
||||
assert success == True
|
||||
|
||||
# Check queue has signal
|
||||
assert broadcaster._signal_queue.qsize() == 1
|
||||
|
||||
def test_broadcast_signal_queue_full(self, config, mock_db_ops):
|
||||
"""Test broadcasting when queue is full."""
|
||||
# Create broadcaster with very small queue
|
||||
small_config = RealTimeConfig(max_signal_queue_size=1)
|
||||
broadcaster = StrategySignalBroadcaster(small_config)
|
||||
|
||||
# Create test signals
|
||||
strategy_result = Mock(spec=StrategyResult)
|
||||
context = StrategyExecutionContext(
|
||||
strategy_name="test",
|
||||
strategy_config={},
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1h"
|
||||
)
|
||||
signal1 = RealTimeSignal(strategy_result=strategy_result, context=context)
|
||||
signal2 = RealTimeSignal(strategy_result=strategy_result, context=context)
|
||||
|
||||
# Fill queue
|
||||
success1 = broadcaster.broadcast_signal(signal1)
|
||||
assert success1 == True
|
||||
|
||||
# Try to overfill queue
|
||||
success2 = broadcaster.broadcast_signal(signal2)
|
||||
assert success2 == False # Should fail due to full queue
|
||||
|
||||
def test_set_chart_update_callback(self, broadcaster):
|
||||
"""Test setting chart update callback."""
|
||||
callback = Mock()
|
||||
broadcaster.set_chart_update_callback(callback)
|
||||
assert broadcaster._chart_update_callback == callback
|
||||
|
||||
def test_get_signal_stats(self, broadcaster):
|
||||
"""Test getting signal statistics."""
|
||||
stats = broadcaster.get_signal_stats()
|
||||
|
||||
assert 'queue_size' in stats
|
||||
assert 'chart_queue_size' in stats
|
||||
assert 'is_running' in stats
|
||||
assert 'last_chart_updates' in stats
|
||||
assert stats['is_running'] == False
|
||||
|
||||
|
||||
class TestRealTimeStrategyProcessor:
|
||||
"""Test RealTimeStrategyProcessor class."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
"""Test configuration."""
|
||||
return RealTimeConfig(
|
||||
max_strategies_concurrent=2,
|
||||
error_retry_attempts=2
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Mock all external dependencies."""
|
||||
mocks = {}
|
||||
|
||||
with patch('strategies.realtime_execution.StrategyDataIntegrator') as mock_integrator:
|
||||
mocks['data_integrator'] = Mock()
|
||||
mock_integrator.return_value = mocks['data_integrator']
|
||||
|
||||
with patch('strategies.realtime_execution.MarketDataIntegrator') as mock_market:
|
||||
mocks['market_integrator'] = Mock()
|
||||
mock_market.return_value = mocks['market_integrator']
|
||||
|
||||
with patch('strategies.realtime_execution.StrategyFactory') as mock_factory:
|
||||
mocks['strategy_factory'] = Mock()
|
||||
mock_factory.return_value = mocks['strategy_factory']
|
||||
|
||||
yield mocks
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self, config, mock_dependencies):
|
||||
"""Create processor instance."""
|
||||
return RealTimeStrategyProcessor(config)
|
||||
|
||||
def test_processor_initialization(self, processor, config):
|
||||
"""Test processor initialization."""
|
||||
assert processor.config == config
|
||||
assert processor._execution_contexts == {}
|
||||
assert processor._performance_stats['total_calculations'] == 0
|
||||
|
||||
def test_start_stop_processor(self, processor):
|
||||
"""Test starting and stopping processor."""
|
||||
processor.start()
|
||||
assert processor.signal_broadcaster._is_running == True
|
||||
|
||||
processor.stop()
|
||||
assert processor.signal_broadcaster._is_running == False
|
||||
|
||||
def test_register_strategy(self, processor):
|
||||
"""Test registering strategy for real-time execution."""
|
||||
context_id = processor.register_strategy(
|
||||
strategy_name="ema_crossover",
|
||||
strategy_config={"short_period": 12, "long_period": 26},
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1h"
|
||||
)
|
||||
|
||||
expected_id = "ema_crossover_BTC-USDT_1h_okx"
|
||||
assert context_id == expected_id
|
||||
assert context_id in processor._execution_contexts
|
||||
|
||||
context = processor._execution_contexts[context_id]
|
||||
assert context.strategy_name == "ema_crossover"
|
||||
assert context.symbol == "BTC-USDT"
|
||||
assert context.timeframe == "1h"
|
||||
assert context.is_active == True
|
||||
|
||||
def test_unregister_strategy(self, processor):
|
||||
"""Test unregistering strategy."""
|
||||
# Register first
|
||||
context_id = processor.register_strategy(
|
||||
strategy_name="rsi",
|
||||
strategy_config={"period": 14},
|
||||
symbol="ETH-USDT",
|
||||
timeframe="4h"
|
||||
)
|
||||
|
||||
assert context_id in processor._execution_contexts
|
||||
|
||||
# Unregister
|
||||
success = processor.unregister_strategy(context_id)
|
||||
assert success == True
|
||||
assert context_id not in processor._execution_contexts
|
||||
|
||||
# Try to unregister again
|
||||
success2 = processor.unregister_strategy(context_id)
|
||||
assert success2 == False
|
||||
|
||||
def test_execute_realtime_update_no_strategies(self, processor):
|
||||
"""Test real-time update with no registered strategies."""
|
||||
signals = processor.execute_realtime_update("BTC-USDT", "1h")
|
||||
assert signals == []
|
||||
|
||||
def test_execute_realtime_update_with_strategies(self, processor, mock_dependencies):
|
||||
"""Test real-time update with registered strategies."""
|
||||
# Mock strategy calculation results
|
||||
mock_result = Mock(spec=StrategyResult)
|
||||
mock_result.timestamp = datetime.now(timezone.utc)
|
||||
mock_result.confidence = 0.8
|
||||
|
||||
mock_dependencies['data_integrator'].calculate_strategy_signals.return_value = [mock_result]
|
||||
|
||||
# Register strategy
|
||||
processor.register_strategy(
|
||||
strategy_name="ema_crossover",
|
||||
strategy_config={"short_period": 12, "long_period": 26},
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1h"
|
||||
)
|
||||
|
||||
# Execute update
|
||||
signals = processor.execute_realtime_update("BTC-USDT", "1h")
|
||||
|
||||
assert len(signals) == 1
|
||||
assert isinstance(signals[0], RealTimeSignal)
|
||||
assert signals[0].strategy_result == mock_result
|
||||
|
||||
def test_get_active_strategies(self, processor):
|
||||
"""Test getting active strategies."""
|
||||
# Register some strategies
|
||||
processor.register_strategy("ema", {}, "BTC-USDT", "1h")
|
||||
processor.register_strategy("rsi", {}, "ETH-USDT", "4h")
|
||||
|
||||
active = processor.get_active_strategies()
|
||||
assert len(active) == 2
|
||||
|
||||
# Pause one strategy
|
||||
context_id = list(active.keys())[0]
|
||||
processor.pause_strategy(context_id)
|
||||
|
||||
active_after_pause = processor.get_active_strategies()
|
||||
assert len(active_after_pause) == 1
|
||||
|
||||
def test_pause_resume_strategy(self, processor):
|
||||
"""Test pausing and resuming strategies."""
|
||||
context_id = processor.register_strategy("macd", {}, "BTC-USDT", "1d")
|
||||
|
||||
# Pause strategy
|
||||
success = processor.pause_strategy(context_id)
|
||||
assert success == True
|
||||
assert not processor._execution_contexts[context_id].is_active
|
||||
|
||||
# Resume strategy
|
||||
success = processor.resume_strategy(context_id)
|
||||
assert success == True
|
||||
assert processor._execution_contexts[context_id].is_active
|
||||
|
||||
# Test with invalid context_id
|
||||
invalid_success = processor.pause_strategy("invalid_id")
|
||||
assert invalid_success == False
|
||||
|
||||
def test_get_performance_stats(self, processor):
|
||||
"""Test getting performance statistics."""
|
||||
stats = processor.get_performance_stats()
|
||||
|
||||
assert 'total_calculations' in stats
|
||||
assert 'successful_calculations' in stats
|
||||
assert 'failed_calculations' in stats
|
||||
assert 'average_calculation_time_ms' in stats
|
||||
assert 'signals_generated' in stats
|
||||
assert 'queue_size' in stats # From signal broadcaster
|
||||
|
||||
|
||||
class TestSingletonAndInitialization:
|
||||
"""Test singleton pattern and system initialization."""
|
||||
|
||||
def test_get_realtime_strategy_processor_singleton(self):
|
||||
"""Test that processor is singleton."""
|
||||
# Clean up any existing processor
|
||||
shutdown_realtime_strategy_system()
|
||||
|
||||
processor1 = get_realtime_strategy_processor()
|
||||
processor2 = get_realtime_strategy_processor()
|
||||
|
||||
assert processor1 is processor2
|
||||
|
||||
# Clean up
|
||||
shutdown_realtime_strategy_system()
|
||||
|
||||
def test_initialize_realtime_strategy_system(self):
|
||||
"""Test system initialization."""
|
||||
# Clean up any existing processor
|
||||
shutdown_realtime_strategy_system()
|
||||
|
||||
config = RealTimeConfig(max_strategies_concurrent=2)
|
||||
processor = initialize_realtime_strategy_system(config)
|
||||
|
||||
assert processor is not None
|
||||
assert processor.signal_broadcaster._is_running == True
|
||||
|
||||
# Clean up
|
||||
shutdown_realtime_strategy_system()
|
||||
|
||||
def test_shutdown_realtime_strategy_system(self):
|
||||
"""Test system shutdown."""
|
||||
# Initialize system
|
||||
processor = initialize_realtime_strategy_system()
|
||||
assert processor.signal_broadcaster._is_running == True
|
||||
|
||||
# Shutdown
|
||||
shutdown_realtime_strategy_system()
|
||||
|
||||
# Verify shutdown
|
||||
# Note: After shutdown, the global processor is set to None
|
||||
# So we can't check the processor state, but we can verify
|
||||
# a new processor is created on next call
|
||||
new_processor = get_realtime_strategy_processor()
|
||||
assert new_processor is not None
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for real-time execution pipeline."""
|
||||
|
||||
@pytest.fixture
|
||||
def integration_config(self):
|
||||
"""Configuration for integration tests."""
|
||||
return RealTimeConfig(
|
||||
signal_batch_size=2,
|
||||
max_signal_queue_size=5,
|
||||
chart_update_throttle_ms=50
|
||||
)
|
||||
|
||||
def test_end_to_end_signal_flow(self, integration_config):
|
||||
"""Test complete signal flow from strategy to storage."""
|
||||
with patch('strategies.realtime_execution.get_database_operations') as mock_db:
|
||||
# Setup mocks
|
||||
db_ops = Mock()
|
||||
db_ops.strategy = Mock()
|
||||
db_ops.strategy.store_signals_batch = Mock(return_value=2)
|
||||
mock_db.return_value = db_ops
|
||||
|
||||
# Create processor
|
||||
processor = RealTimeStrategyProcessor(integration_config)
|
||||
processor.start()
|
||||
|
||||
try:
|
||||
# Mock strategy calculation
|
||||
mock_result = Mock(spec=StrategyResult)
|
||||
mock_result.timestamp = datetime.now(timezone.utc)
|
||||
mock_result.confidence = 0.8
|
||||
mock_result.signal = Mock()
|
||||
mock_result.signal.signal_type = SignalType.BUY
|
||||
mock_result.price = 50000.0
|
||||
mock_result.metadata = {"test": True}
|
||||
|
||||
with patch.object(processor.data_integrator, 'calculate_strategy_signals') as mock_calc:
|
||||
mock_calc.return_value = [mock_result]
|
||||
|
||||
# Register strategy
|
||||
processor.register_strategy(
|
||||
strategy_name="test_strategy",
|
||||
strategy_config={"param": "value"},
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1h"
|
||||
)
|
||||
|
||||
# Execute real-time update
|
||||
signals = processor.execute_realtime_update("BTC-USDT", "1h")
|
||||
|
||||
assert len(signals) == 1
|
||||
|
||||
# Wait for signal processing
|
||||
time.sleep(0.2) # Allow background processing
|
||||
|
||||
# Verify calculation was called
|
||||
mock_calc.assert_called_once()
|
||||
|
||||
finally:
|
||||
processor.stop()
|
||||
|
||||
def test_error_handling_and_retry(self, integration_config):
|
||||
"""Test error handling and retry mechanisms."""
|
||||
processor = RealTimeStrategyProcessor(integration_config)
|
||||
processor.start()
|
||||
|
||||
try:
|
||||
# Mock strategy calculation to raise error
|
||||
with patch.object(processor.data_integrator, 'calculate_strategy_signals') as mock_calc:
|
||||
mock_calc.side_effect = Exception("Test error")
|
||||
|
||||
# Register strategy
|
||||
context_id = processor.register_strategy(
|
||||
strategy_name="error_strategy",
|
||||
strategy_config={},
|
||||
symbol="BTC-USDT",
|
||||
timeframe="1h"
|
||||
)
|
||||
|
||||
# Execute multiple times to trigger error handling
|
||||
for _ in range(integration_config.error_retry_attempts + 1):
|
||||
processor.execute_realtime_update("BTC-USDT", "1h")
|
||||
|
||||
# Strategy should be disabled after max errors
|
||||
context = processor._execution_contexts[context_id]
|
||||
assert not context.is_active
|
||||
assert context.consecutive_errors >= integration_config.error_retry_attempts
|
||||
|
||||
finally:
|
||||
processor.stop()
|
||||
|
||||
def test_concurrent_strategy_execution(self, integration_config):
|
||||
"""Test concurrent execution of multiple strategies."""
|
||||
processor = RealTimeStrategyProcessor(integration_config)
|
||||
processor.start()
|
||||
|
||||
try:
|
||||
# Mock strategy calculations
|
||||
mock_result1 = Mock(spec=StrategyResult)
|
||||
mock_result1.timestamp = datetime.now(timezone.utc)
|
||||
mock_result1.confidence = 0.7
|
||||
|
||||
mock_result2 = Mock(spec=StrategyResult)
|
||||
mock_result2.timestamp = datetime.now(timezone.utc)
|
||||
mock_result2.confidence = 0.9
|
||||
|
||||
with patch.object(processor.data_integrator, 'calculate_strategy_signals') as mock_calc:
|
||||
mock_calc.side_effect = [[mock_result1], [mock_result2]]
|
||||
|
||||
# Register multiple strategies for same symbol/timeframe
|
||||
processor.register_strategy("strategy1", {}, "BTC-USDT", "1h")
|
||||
processor.register_strategy("strategy2", {}, "BTC-USDT", "1h")
|
||||
|
||||
# Execute update
|
||||
signals = processor.execute_realtime_update("BTC-USDT", "1h")
|
||||
|
||||
# Should get signals from both strategies
|
||||
assert len(signals) == 2
|
||||
|
||||
finally:
|
||||
processor.stop()
|
||||
478
tests/strategies/test_validation.py
Normal file
478
tests/strategies/test_validation.py
Normal file
@ -0,0 +1,478 @@
|
||||
"""
|
||||
Tests for Strategy Signal Validation Pipeline
|
||||
|
||||
This module tests signal validation, filtering, and quality assessment
|
||||
functionality for strategy-generated signals.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
from strategies.validation import StrategySignalValidator, ValidationConfig
|
||||
from strategies.data_types import StrategySignal, SignalType
|
||||
|
||||
|
||||
class TestValidationConfig:
|
||||
"""Tests for ValidationConfig dataclass."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default validation configuration."""
|
||||
config = ValidationConfig()
|
||||
|
||||
assert config.min_confidence == 0.0
|
||||
assert config.max_confidence == 1.0
|
||||
assert config.required_metadata_fields == []
|
||||
assert config.allowed_signal_types == list(SignalType)
|
||||
assert config.price_tolerance_percent == 5.0
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom validation configuration."""
|
||||
config = ValidationConfig(
|
||||
min_confidence=0.3,
|
||||
max_confidence=0.9,
|
||||
required_metadata_fields=['indicator1', 'indicator2'],
|
||||
allowed_signal_types=[SignalType.BUY, SignalType.SELL],
|
||||
price_tolerance_percent=2.0
|
||||
)
|
||||
|
||||
assert config.min_confidence == 0.3
|
||||
assert config.max_confidence == 0.9
|
||||
assert config.required_metadata_fields == ['indicator1', 'indicator2']
|
||||
assert config.allowed_signal_types == [SignalType.BUY, SignalType.SELL]
|
||||
assert config.price_tolerance_percent == 2.0
|
||||
|
||||
|
||||
class TestStrategySignalValidator:
|
||||
"""Tests for StrategySignalValidator class."""
|
||||
|
||||
@pytest.fixture
|
||||
def validator(self):
|
||||
"""Create validator with default configuration."""
|
||||
return StrategySignalValidator()
|
||||
|
||||
@pytest.fixture
|
||||
def strict_validator(self):
|
||||
"""Create validator with strict configuration."""
|
||||
config = ValidationConfig(
|
||||
min_confidence=0.5,
|
||||
max_confidence=1.0,
|
||||
required_metadata_fields=['rsi', 'macd'],
|
||||
allowed_signal_types=[SignalType.BUY, SignalType.SELL]
|
||||
)
|
||||
return StrategySignalValidator(config)
|
||||
|
||||
@pytest.fixture
|
||||
def valid_signal(self):
|
||||
"""Create a valid strategy signal for testing."""
|
||||
return StrategySignal(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='BTC-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.BUY,
|
||||
price=50000.0,
|
||||
confidence=0.8,
|
||||
metadata={'rsi': 30, 'macd': 0.05}
|
||||
)
|
||||
|
||||
def test_initialization(self, validator):
|
||||
"""Test validator initialization."""
|
||||
assert validator.config is not None
|
||||
assert validator.logger is not None
|
||||
assert validator._validation_stats['total_signals_validated'] == 0
|
||||
assert validator._validation_stats['valid_signals'] == 0
|
||||
assert validator._validation_stats['invalid_signals'] == 0
|
||||
|
||||
def test_validate_valid_signal(self, validator, valid_signal):
|
||||
"""Test validation of a completely valid signal."""
|
||||
is_valid, errors = validator.validate_signal(valid_signal)
|
||||
|
||||
assert is_valid is True
|
||||
assert errors == []
|
||||
assert validator._validation_stats['total_signals_validated'] == 1
|
||||
assert validator._validation_stats['valid_signals'] == 1
|
||||
assert validator._validation_stats['invalid_signals'] == 0
|
||||
|
||||
def test_validate_invalid_confidence_low(self, validator, valid_signal):
|
||||
"""Test validation with confidence too low."""
|
||||
valid_signal.confidence = -0.1
|
||||
|
||||
is_valid, errors = validator.validate_signal(valid_signal)
|
||||
|
||||
assert is_valid is False
|
||||
assert len(errors) == 1
|
||||
assert "Invalid confidence" in errors[0]
|
||||
assert validator._validation_stats['invalid_signals'] == 1
|
||||
|
||||
def test_validate_invalid_confidence_high(self, validator, valid_signal):
|
||||
"""Test validation with confidence too high."""
|
||||
valid_signal.confidence = 1.5
|
||||
|
||||
is_valid, errors = validator.validate_signal(valid_signal)
|
||||
|
||||
assert is_valid is False
|
||||
assert len(errors) == 1
|
||||
assert "Invalid confidence" in errors[0]
|
||||
|
||||
def test_validate_invalid_signal_type(self, strict_validator, valid_signal):
|
||||
"""Test validation with disallowed signal type."""
|
||||
valid_signal.signal_type = SignalType.HOLD
|
||||
|
||||
is_valid, errors = strict_validator.validate_signal(valid_signal)
|
||||
|
||||
assert is_valid is False
|
||||
assert len(errors) == 1
|
||||
assert "Signal type" in errors[0] and "not in allowed types" in errors[0]
|
||||
|
||||
def test_validate_invalid_price(self, validator, valid_signal):
|
||||
"""Test validation with invalid price."""
|
||||
valid_signal.price = -100.0
|
||||
|
||||
is_valid, errors = validator.validate_signal(valid_signal)
|
||||
|
||||
assert is_valid is False
|
||||
assert len(errors) == 1
|
||||
assert "Invalid price" in errors[0]
|
||||
|
||||
def test_validate_missing_required_metadata(self, strict_validator, valid_signal):
|
||||
"""Test validation with missing required metadata."""
|
||||
valid_signal.metadata = {'rsi': 30} # Missing 'macd'
|
||||
|
||||
is_valid, errors = strict_validator.validate_signal(valid_signal)
|
||||
|
||||
assert is_valid is False
|
||||
assert len(errors) == 1
|
||||
assert "Missing required metadata fields" in errors[0]
|
||||
assert "macd" in errors[0]
|
||||
|
||||
def test_validate_multiple_errors(self, strict_validator, valid_signal):
|
||||
"""Test validation with multiple errors."""
|
||||
valid_signal.confidence = 1.5 # Too high
|
||||
valid_signal.price = -100.0 # Invalid
|
||||
valid_signal.signal_type = SignalType.HOLD # Not allowed
|
||||
valid_signal.metadata = {} # Missing required fields
|
||||
|
||||
is_valid, errors = strict_validator.validate_signal(valid_signal)
|
||||
|
||||
assert is_valid is False
|
||||
assert len(errors) == 4
|
||||
assert any("confidence" in error for error in errors)
|
||||
assert any("price" in error for error in errors)
|
||||
assert any("Signal type" in error for error in errors)
|
||||
assert any("Missing required metadata" in error for error in errors)
|
||||
|
||||
def test_validation_statistics_tracking(self, validator, valid_signal):
|
||||
"""Test that validation statistics are properly tracked."""
|
||||
# Validate multiple signals
|
||||
validator.validate_signal(valid_signal) # Valid
|
||||
|
||||
invalid_signal = valid_signal
|
||||
invalid_signal.confidence = 1.5 # Invalid
|
||||
validator.validate_signal(invalid_signal) # Invalid
|
||||
|
||||
stats = validator._validation_stats
|
||||
assert stats['total_signals_validated'] == 2
|
||||
assert stats['valid_signals'] == 1
|
||||
assert stats['invalid_signals'] == 1
|
||||
assert len(stats['validation_errors']) > 0
|
||||
|
||||
def test_validate_signals_batch(self, validator, valid_signal):
|
||||
"""Test batch validation of multiple signals."""
|
||||
# Create a mix of valid and invalid signals
|
||||
signals = [
|
||||
valid_signal, # Valid
|
||||
StrategySignal( # Invalid confidence
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='ETH-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.SELL,
|
||||
price=3000.0,
|
||||
confidence=1.5, # Invalid
|
||||
metadata={}
|
||||
),
|
||||
StrategySignal( # Valid
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='BNB-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.BUY,
|
||||
price=300.0,
|
||||
confidence=0.7,
|
||||
metadata={}
|
||||
)
|
||||
]
|
||||
|
||||
valid_signals, invalid_signals = validator.validate_signals_batch(signals)
|
||||
|
||||
assert len(valid_signals) == 2
|
||||
assert len(invalid_signals) == 1
|
||||
assert invalid_signals[0].confidence == 1.5
|
||||
|
||||
def test_filter_signals_by_confidence(self, validator, valid_signal):
|
||||
"""Test filtering signals by confidence threshold."""
|
||||
signals = [
|
||||
valid_signal, # confidence 0.8
|
||||
StrategySignal(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='ETH-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.SELL,
|
||||
price=3000.0,
|
||||
confidence=0.3, # Low confidence
|
||||
metadata={}
|
||||
),
|
||||
StrategySignal(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='BNB-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.BUY,
|
||||
price=300.0,
|
||||
confidence=0.9, # High confidence
|
||||
metadata={}
|
||||
)
|
||||
]
|
||||
|
||||
# Filter with threshold 0.5
|
||||
filtered_signals = validator.filter_signals_by_confidence(signals, min_confidence=0.5)
|
||||
|
||||
assert len(filtered_signals) == 2
|
||||
assert all(signal.confidence >= 0.5 for signal in filtered_signals)
|
||||
assert filtered_signals[0].confidence == 0.8
|
||||
assert filtered_signals[1].confidence == 0.9
|
||||
|
||||
def test_filter_signals_by_type(self, validator, valid_signal):
|
||||
"""Test filtering signals by allowed types."""
|
||||
signals = [
|
||||
valid_signal, # BUY
|
||||
StrategySignal(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='ETH-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.SELL,
|
||||
price=3000.0,
|
||||
confidence=0.8,
|
||||
metadata={}
|
||||
),
|
||||
StrategySignal(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='BNB-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.HOLD,
|
||||
price=300.0,
|
||||
confidence=0.7,
|
||||
metadata={}
|
||||
)
|
||||
]
|
||||
|
||||
# Filter to only allow BUY and SELL
|
||||
filtered_signals = validator.filter_signals_by_type(
|
||||
signals,
|
||||
allowed_types=[SignalType.BUY, SignalType.SELL]
|
||||
)
|
||||
|
||||
assert len(filtered_signals) == 2
|
||||
assert filtered_signals[0].signal_type == SignalType.BUY
|
||||
assert filtered_signals[1].signal_type == SignalType.SELL
|
||||
|
||||
def test_get_validation_statistics(self, validator, valid_signal):
|
||||
"""Test comprehensive validation statistics."""
|
||||
# Validate some signals to generate statistics
|
||||
validator.validate_signal(valid_signal) # Valid
|
||||
|
||||
invalid_signal = valid_signal
|
||||
invalid_signal.confidence = -0.1 # Invalid
|
||||
validator.validate_signal(invalid_signal) # Invalid
|
||||
|
||||
stats = validator.get_validation_statistics()
|
||||
|
||||
assert stats['total_signals_validated'] == 2
|
||||
assert stats['valid_signals'] == 1
|
||||
assert stats['invalid_signals'] == 1
|
||||
assert stats['validation_success_rate'] == 0.5
|
||||
assert stats['validation_failure_rate'] == 0.5
|
||||
assert 'validation_errors' in stats
|
||||
|
||||
def test_transform_signal_confidence(self, validator, valid_signal):
|
||||
"""Test signal confidence transformation."""
|
||||
original_confidence = valid_signal.confidence # 0.8
|
||||
|
||||
# Test confidence multiplier
|
||||
transformed_signal = validator.transform_signal_confidence(
|
||||
valid_signal,
|
||||
confidence_multiplier=1.2
|
||||
)
|
||||
|
||||
assert transformed_signal.confidence == original_confidence * 1.2
|
||||
assert transformed_signal.symbol == valid_signal.symbol
|
||||
assert transformed_signal.signal_type == valid_signal.signal_type
|
||||
assert transformed_signal.price == valid_signal.price
|
||||
|
||||
# Test confidence cap
|
||||
capped_signal = validator.transform_signal_confidence(
|
||||
valid_signal,
|
||||
confidence_multiplier=2.0, # Would exceed 1.0
|
||||
max_confidence=1.0
|
||||
)
|
||||
|
||||
assert capped_signal.confidence == 1.0 # Capped at max
|
||||
|
||||
def test_enrich_signal_metadata(self, validator, valid_signal):
|
||||
"""Test signal metadata enrichment."""
|
||||
additional_metadata = {
|
||||
'validation_timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'validation_status': 'approved',
|
||||
'risk_score': 0.2
|
||||
}
|
||||
|
||||
enriched_signal = validator.enrich_signal_metadata(valid_signal, additional_metadata)
|
||||
|
||||
# Original metadata should be preserved
|
||||
assert enriched_signal.metadata['rsi'] == 30
|
||||
assert enriched_signal.metadata['macd'] == 0.05
|
||||
|
||||
# New metadata should be added
|
||||
assert enriched_signal.metadata['validation_status'] == 'approved'
|
||||
assert enriched_signal.metadata['risk_score'] == 0.2
|
||||
assert 'validation_timestamp' in enriched_signal.metadata
|
||||
|
||||
# Other properties should remain unchanged
|
||||
assert enriched_signal.confidence == valid_signal.confidence
|
||||
assert enriched_signal.signal_type == valid_signal.signal_type
|
||||
|
||||
def test_transform_signals_batch(self, validator, valid_signal):
|
||||
"""Test batch signal transformation."""
|
||||
signals = [
|
||||
valid_signal,
|
||||
StrategySignal(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='ETH-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.SELL,
|
||||
price=3000.0,
|
||||
confidence=0.6,
|
||||
metadata={'ema': 2950}
|
||||
)
|
||||
]
|
||||
|
||||
additional_metadata = {'batch_id': 'test_batch_001'}
|
||||
|
||||
transformed_signals = validator.transform_signals_batch(
|
||||
signals,
|
||||
confidence_multiplier=1.1,
|
||||
additional_metadata=additional_metadata
|
||||
)
|
||||
|
||||
assert len(transformed_signals) == 2
|
||||
|
||||
# Check confidence transformation
|
||||
assert transformed_signals[0].confidence == 0.8 * 1.1
|
||||
assert transformed_signals[1].confidence == 0.6 * 1.1
|
||||
|
||||
# Check metadata enrichment
|
||||
assert transformed_signals[0].metadata['batch_id'] == 'test_batch_001'
|
||||
assert transformed_signals[1].metadata['batch_id'] == 'test_batch_001'
|
||||
|
||||
# Verify original metadata preserved
|
||||
assert transformed_signals[0].metadata['rsi'] == 30
|
||||
assert transformed_signals[1].metadata['ema'] == 2950
|
||||
|
||||
def test_calculate_signal_quality_metrics(self, validator, valid_signal):
|
||||
"""Test signal quality metrics calculation."""
|
||||
signals = [
|
||||
valid_signal, # confidence 0.8, has metadata
|
||||
StrategySignal(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='ETH-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.SELL,
|
||||
price=3000.0,
|
||||
confidence=0.9, # High confidence
|
||||
metadata={'volume_spike': True}
|
||||
),
|
||||
StrategySignal(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='BNB-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.HOLD,
|
||||
price=300.0,
|
||||
confidence=0.4, # Low confidence
|
||||
metadata=None # No metadata
|
||||
)
|
||||
]
|
||||
|
||||
metrics = validator.calculate_signal_quality_metrics(signals)
|
||||
|
||||
assert metrics['total_signals'] == 3
|
||||
assert metrics['confidence_metrics']['average'] == round((0.8 + 0.9 + 0.4) / 3, 3)
|
||||
assert metrics['confidence_metrics']['minimum'] == 0.4
|
||||
assert metrics['confidence_metrics']['maximum'] == 0.9
|
||||
assert metrics['confidence_metrics']['high_confidence_count'] == 2 # >= 0.7
|
||||
assert metrics['quality_score'] == round((2/3) * 100, 1) # 66.7%
|
||||
assert metrics['metadata_completeness_percentage'] == round((2/3) * 100, 1)
|
||||
|
||||
# Check signal type distribution
|
||||
assert metrics['signal_type_distribution']['buy'] == 1
|
||||
assert metrics['signal_type_distribution']['sell'] == 1
|
||||
assert metrics['signal_type_distribution']['hold'] == 1
|
||||
|
||||
# Check recommendations
|
||||
assert isinstance(metrics['recommendations'], list)
|
||||
assert len(metrics['recommendations']) > 0
|
||||
|
||||
def test_calculate_signal_quality_metrics_empty(self, validator):
|
||||
"""Test quality metrics with empty signal list."""
|
||||
metrics = validator.calculate_signal_quality_metrics([])
|
||||
|
||||
assert 'error' in metrics
|
||||
assert metrics['error'] == 'No signals provided for quality analysis'
|
||||
|
||||
def test_generate_quality_recommendations(self, validator):
|
||||
"""Test quality recommendation generation."""
|
||||
# Test low confidence signals
|
||||
low_confidence_signals = [
|
||||
StrategySignal(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
symbol='BTC-USDT',
|
||||
timeframe='1h',
|
||||
signal_type=SignalType.BUY,
|
||||
price=50000.0,
|
||||
confidence=0.3, # Low confidence
|
||||
metadata=None # No metadata
|
||||
)
|
||||
]
|
||||
|
||||
recommendations = validator._generate_quality_recommendations(low_confidence_signals)
|
||||
|
||||
assert any("confidence" in rec.lower() for rec in recommendations)
|
||||
assert any("metadata" in rec.lower() for rec in recommendations)
|
||||
|
||||
def test_generate_validation_report(self, validator, valid_signal):
|
||||
"""Test comprehensive validation report generation."""
|
||||
# Generate some validation activity
|
||||
validator.validate_signal(valid_signal) # Valid
|
||||
|
||||
invalid_signal = valid_signal
|
||||
invalid_signal.confidence = -0.1 # Invalid
|
||||
validator.validate_signal(invalid_signal) # Invalid
|
||||
|
||||
report = validator.generate_validation_report()
|
||||
|
||||
assert 'report_timestamp' in report
|
||||
assert 'validation_summary' in report
|
||||
assert 'error_analysis' in report
|
||||
assert 'configuration' in report
|
||||
assert 'health_status' in report
|
||||
|
||||
# Check validation summary
|
||||
summary = report['validation_summary']
|
||||
assert summary['total_validated'] == 2
|
||||
assert '50.0%' in summary['success_rate']
|
||||
assert '50.0%' in summary['failure_rate']
|
||||
|
||||
# Check configuration
|
||||
config = report['configuration']
|
||||
assert config['min_confidence'] == 0.0
|
||||
assert config['max_confidence'] == 1.0
|
||||
assert isinstance(config['allowed_signal_types'], list)
|
||||
|
||||
# Check health status
|
||||
assert report['health_status'] in ['good', 'needs_attention']
|
||||
Loading…
x
Reference in New Issue
Block a user