"""Metrics calculation from trade database.""" import logging from dataclasses import dataclass from datetime import datetime, timezone, timedelta from typing import Optional from .database import TradingDatabase logger = logging.getLogger(__name__) @dataclass class PeriodMetrics: """Trading metrics for a time period.""" period_name: str start_time: Optional[str] end_time: Optional[str] total_pnl: float = 0.0 total_trades: int = 0 winning_trades: int = 0 losing_trades: int = 0 win_rate: float = 0.0 avg_trade_duration_hours: float = 0.0 max_drawdown: float = 0.0 max_drawdown_pct: float = 0.0 best_trade: float = 0.0 worst_trade: float = 0.0 avg_win: float = 0.0 avg_loss: float = 0.0 class MetricsCalculator: """Calculate trading metrics from database.""" def __init__(self, db: TradingDatabase): self.db = db def get_all_time_metrics(self) -> PeriodMetrics: """Get metrics for all trades ever.""" return self._calculate_metrics("All Time", None, None) def get_monthly_metrics(self) -> PeriodMetrics: """Get metrics for current month.""" now = datetime.now(timezone.utc) start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) return self._calculate_metrics( "Monthly", start.isoformat(), now.isoformat(), ) def get_weekly_metrics(self) -> PeriodMetrics: """Get metrics for current week (Monday to now).""" now = datetime.now(timezone.utc) days_since_monday = now.weekday() start = now - timedelta(days=days_since_monday) start = start.replace(hour=0, minute=0, second=0, microsecond=0) return self._calculate_metrics( "Weekly", start.isoformat(), now.isoformat(), ) def get_daily_metrics(self) -> PeriodMetrics: """Get metrics for today (UTC).""" now = datetime.now(timezone.utc) start = now.replace(hour=0, minute=0, second=0, microsecond=0) return self._calculate_metrics( "Daily", start.isoformat(), now.isoformat(), ) def _calculate_metrics( self, period_name: str, start_time: Optional[str], end_time: Optional[str], ) -> PeriodMetrics: """ Calculate metrics for a time period. Args: period_name: Name of the period start_time: ISO format start time (None for all time) end_time: ISO format end time (None for all time) Returns: PeriodMetrics object """ metrics = PeriodMetrics( period_name=period_name, start_time=start_time, end_time=end_time, ) # Build query conditions conditions = ["exit_time IS NOT NULL"] params = [] if start_time: conditions.append("exit_time >= ?") params.append(start_time) if end_time: conditions.append("exit_time <= ?") params.append(end_time) where_clause = " AND ".join(conditions) # Get aggregate metrics sql = f""" SELECT COUNT(*) as total_trades, SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades, SUM(CASE WHEN pnl_usd < 0 THEN 1 ELSE 0 END) as losing_trades, COALESCE(SUM(pnl_usd), 0) as total_pnl, COALESCE(AVG(hold_duration_hours), 0) as avg_duration, COALESCE(MAX(pnl_usd), 0) as best_trade, COALESCE(MIN(pnl_usd), 0) as worst_trade, COALESCE(AVG(CASE WHEN pnl_usd > 0 THEN pnl_usd END), 0) as avg_win, COALESCE(AVG(CASE WHEN pnl_usd < 0 THEN pnl_usd END), 0) as avg_loss FROM trades WHERE {where_clause} """ row = self.db.connection.execute(sql, params).fetchone() if row and row["total_trades"] > 0: metrics.total_trades = row["total_trades"] metrics.winning_trades = row["winning_trades"] or 0 metrics.losing_trades = row["losing_trades"] or 0 metrics.total_pnl = row["total_pnl"] metrics.avg_trade_duration_hours = row["avg_duration"] metrics.best_trade = row["best_trade"] metrics.worst_trade = row["worst_trade"] metrics.avg_win = row["avg_win"] metrics.avg_loss = row["avg_loss"] if metrics.total_trades > 0: metrics.win_rate = ( metrics.winning_trades / metrics.total_trades * 100 ) # Calculate max drawdown metrics.max_drawdown = self._calculate_max_drawdown( start_time, end_time ) return metrics def _calculate_max_drawdown( self, start_time: Optional[str], end_time: Optional[str], ) -> float: """Calculate maximum drawdown for a period.""" conditions = ["exit_time IS NOT NULL"] params = [] if start_time: conditions.append("exit_time >= ?") params.append(start_time) if end_time: conditions.append("exit_time <= ?") params.append(end_time) where_clause = " AND ".join(conditions) sql = f""" SELECT pnl_usd FROM trades WHERE {where_clause} ORDER BY exit_time """ rows = self.db.connection.execute(sql, params).fetchall() if not rows: return 0.0 cumulative = 0.0 peak = 0.0 max_drawdown = 0.0 for row in rows: pnl = row["pnl_usd"] or 0.0 cumulative += pnl peak = max(peak, cumulative) drawdown = peak - cumulative max_drawdown = max(max_drawdown, drawdown) return max_drawdown def has_monthly_data(self) -> bool: """Check if we have data spanning more than current month.""" sql = """ SELECT MIN(exit_time) as first_trade FROM trades WHERE exit_time IS NOT NULL """ row = self.db.connection.execute(sql).fetchone() if not row or not row["first_trade"]: return False first_trade = datetime.fromisoformat(row["first_trade"]) now = datetime.now(timezone.utc) month_start = now.replace(day=1, hour=0, minute=0, second=0) return first_trade < month_start def has_weekly_data(self) -> bool: """Check if we have data spanning more than current week.""" sql = """ SELECT MIN(exit_time) as first_trade FROM trades WHERE exit_time IS NOT NULL """ row = self.db.connection.execute(sql).fetchone() if not row or not row["first_trade"]: return False first_trade = datetime.fromisoformat(row["first_trade"]) now = datetime.now(timezone.utc) days_since_monday = now.weekday() week_start = now - timedelta(days=days_since_monday) week_start = week_start.replace(hour=0, minute=0, second=0) return first_trade < week_start def get_session_start_balance(self) -> Optional[float]: """Get starting balance from latest session.""" sql = "SELECT starting_balance FROM sessions ORDER BY id DESC LIMIT 1" row = self.db.connection.execute(sql).fetchone() return row["starting_balance"] if row else None