"""Repository for strategy_signals and strategy_runs table operations.""" from datetime import datetime, timedelta from typing import Dict, Any, Optional, List from decimal import Decimal from sqlalchemy import desc, and_, func from sqlalchemy.orm import joinedload from ..models import StrategySignal, StrategyRun from strategies.data_types import StrategySignal as StrategySignalDataType, StrategyResult from .base_repository import BaseRepository, DatabaseOperationError class StrategyRepository(BaseRepository): """Repository for strategy_signals and strategy_runs table operations.""" # Strategy Run Operations def create_strategy_run(self, run_data: Dict[str, Any]) -> StrategyRun: """ Create a new strategy run session. Args: run_data: Dictionary containing run information (strategy_name, symbol, timeframe, etc.) Returns: The newly created StrategyRun object """ try: with self.get_session() as session: new_run = StrategyRun(**run_data) session.add(new_run) session.commit() session.refresh(new_run) self.log_info(f"Created strategy run: {new_run.strategy_name} for {new_run.symbol}") return new_run except Exception as e: self.log_error(f"Error creating strategy run: {e}") raise DatabaseOperationError(f"Failed to create strategy run: {e}") def get_strategy_run_by_id(self, run_id: int) -> Optional[StrategyRun]: """Get a strategy run by its ID.""" try: with self.get_session() as session: return session.query(StrategyRun).filter(StrategyRun.id == run_id).first() except Exception as e: self.log_error(f"Error getting strategy run by ID {run_id}: {e}") raise DatabaseOperationError(f"Failed to get strategy run by ID: {e}") def update_strategy_run(self, run_id: int, update_data: Dict[str, Any]) -> Optional[StrategyRun]: """Update a strategy run's information.""" try: with self.get_session() as session: strategy_run = session.query(StrategyRun).filter(StrategyRun.id == run_id).first() if strategy_run: for key, value in update_data.items(): setattr(strategy_run, key, value) session.commit() session.refresh(strategy_run) self.log_info(f"Updated strategy run {run_id}") return strategy_run return None except Exception as e: self.log_error(f"Error updating strategy run {run_id}: {e}") raise DatabaseOperationError(f"Failed to update strategy run: {e}") def complete_strategy_run(self, run_id: int, total_signals: int) -> bool: """Mark a strategy run as completed.""" try: update_data = { 'status': 'completed', 'end_time': datetime.now(datetime.timezone.utc), 'total_signals': total_signals } result = self.update_strategy_run(run_id, update_data) return result is not None except Exception as e: self.log_error(f"Error completing strategy run {run_id}: {e}") return False # Strategy Signal Operations def store_strategy_signals(self, run_id: int, strategy_results: List[StrategyResult]) -> int: """ Store multiple strategy signals from strategy results. Args: run_id: The strategy run ID these signals belong to strategy_results: List of StrategyResult objects containing signals Returns: Number of signals stored """ try: signals_stored = 0 with self.get_session() as session: for result in strategy_results: for signal in result.signals: strategy_signal = StrategySignal( run_id=run_id, strategy_name=result.strategy_name, strategy_config=None, # Could be populated from StrategyRun.config symbol=signal.symbol, timeframe=signal.timeframe, timestamp=signal.timestamp, signal_type=signal.signal_type.value, price=Decimal(str(signal.price)), confidence=Decimal(str(signal.confidence)), signal_metadata={ 'indicators_used': result.indicators_used, 'metadata': signal.metadata or {} } ) session.add(strategy_signal) signals_stored += 1 session.commit() self.log_info(f"Stored {signals_stored} strategy signals for run {run_id}") return signals_stored except Exception as e: self.log_error(f"Error storing strategy signals for run {run_id}: {e}") raise DatabaseOperationError(f"Failed to store strategy signals: {e}") def get_strategy_signals( self, run_id: Optional[int] = None, strategy_name: Optional[str] = None, symbol: Optional[str] = None, timeframe: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, signal_type: Optional[str] = None, limit: Optional[int] = None ) -> List[StrategySignal]: """ Retrieve strategy signals with flexible filtering. Args: run_id: Filter by strategy run ID strategy_name: Filter by strategy name symbol: Filter by trading symbol timeframe: Filter by timeframe start_time: Filter signals after this time end_time: Filter signals before this time signal_type: Filter by signal type limit: Maximum number of signals to return Returns: List of StrategySignal objects """ try: with self.get_session() as session: query = session.query(StrategySignal) # Apply filters if run_id is not None: query = query.filter(StrategySignal.run_id == run_id) if strategy_name: query = query.filter(StrategySignal.strategy_name == strategy_name) if symbol: query = query.filter(StrategySignal.symbol == symbol) if timeframe: query = query.filter(StrategySignal.timeframe == timeframe) if start_time: query = query.filter(StrategySignal.timestamp >= start_time) if end_time: query = query.filter(StrategySignal.timestamp <= end_time) if signal_type: query = query.filter(StrategySignal.signal_type == signal_type) # Order by timestamp descending query = query.order_by(desc(StrategySignal.timestamp)) # Apply limit if limit: query = query.limit(limit) return query.all() except Exception as e: self.log_error(f"Error retrieving strategy signals: {e}") raise DatabaseOperationError(f"Failed to retrieve strategy signals: {e}") def get_strategy_signal_stats(self, run_id: Optional[int] = None) -> Dict[str, Any]: """Get statistics about strategy signals.""" try: with self.get_session() as session: query = session.query(StrategySignal) if run_id is not None: query = query.filter(StrategySignal.run_id == run_id) # Get basic counts by signal type signal_counts = session.query( StrategySignal.signal_type, func.count(StrategySignal.id).label('count') ).group_by(StrategySignal.signal_type) if run_id is not None: signal_counts = signal_counts.filter(StrategySignal.run_id == run_id) counts_dict = {signal_type: count for signal_type, count in signal_counts.all()} # Get total signals total_signals = query.count() # Get average confidence avg_confidence = session.query(func.avg(StrategySignal.confidence)).scalar() return { 'total_signals': total_signals, 'signal_counts': counts_dict, 'average_confidence': float(avg_confidence) if avg_confidence else 0.0, 'run_id': run_id } except Exception as e: self.log_error(f"Error getting strategy signal stats: {e}") raise DatabaseOperationError(f"Failed to get strategy signal stats: {e}") # Data Retention and Cleanup def cleanup_old_strategy_data(self, days_to_keep: int = 30) -> Dict[str, int]: """ Clean up old strategy signals and runs to prevent table bloat. Args: days_to_keep: Number of days to retain data Returns: Dictionary with counts of deleted records """ try: cutoff_date = datetime.now(datetime.timezone.utc) - timedelta(days=days_to_keep) with self.get_session() as session: # Delete old strategy runs (and their signals via CASCADE) deleted_runs = session.query(StrategyRun).filter( StrategyRun.created_at < cutoff_date, StrategyRun.status == 'completed' # Only delete completed runs ).delete(synchronize_session=False) session.commit() self.log_info(f"Cleaned up {deleted_runs} old strategy runs and their signals") return { 'deleted_runs': deleted_runs, 'cutoff_date': cutoff_date.isoformat() } except Exception as e: self.log_error(f"Error cleaning up old strategy data: {e}") raise DatabaseOperationError(f"Failed to cleanup old strategy data: {e}")