- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps. - Added `install_cron.sh` for setting up a cron job to run the daily training script. - Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks. - Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling. - Updated `pyproject.toml` to include the `rich` dependency for UI functionality. - Enhanced `.gitignore` to exclude model and log files. - Added database support for trade persistence and metrics calculation. - Updated README with installation and usage instructions for the new features.
236 lines
7.3 KiB
Python
236 lines
7.3 KiB
Python
"""Metrics calculation from trade database."""
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Optional
|
|
|
|
from .database import TradingDatabase
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class PeriodMetrics:
|
|
"""Trading metrics for a time period."""
|
|
|
|
period_name: str
|
|
start_time: Optional[str]
|
|
end_time: Optional[str]
|
|
total_pnl: float = 0.0
|
|
total_trades: int = 0
|
|
winning_trades: int = 0
|
|
losing_trades: int = 0
|
|
win_rate: float = 0.0
|
|
avg_trade_duration_hours: float = 0.0
|
|
max_drawdown: float = 0.0
|
|
max_drawdown_pct: float = 0.0
|
|
best_trade: float = 0.0
|
|
worst_trade: float = 0.0
|
|
avg_win: float = 0.0
|
|
avg_loss: float = 0.0
|
|
|
|
|
|
class MetricsCalculator:
|
|
"""Calculate trading metrics from database."""
|
|
|
|
def __init__(self, db: TradingDatabase):
|
|
self.db = db
|
|
|
|
def get_all_time_metrics(self) -> PeriodMetrics:
|
|
"""Get metrics for all trades ever."""
|
|
return self._calculate_metrics("All Time", None, None)
|
|
|
|
def get_monthly_metrics(self) -> PeriodMetrics:
|
|
"""Get metrics for current month."""
|
|
now = datetime.now(timezone.utc)
|
|
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
return self._calculate_metrics(
|
|
"Monthly",
|
|
start.isoformat(),
|
|
now.isoformat(),
|
|
)
|
|
|
|
def get_weekly_metrics(self) -> PeriodMetrics:
|
|
"""Get metrics for current week (Monday to now)."""
|
|
now = datetime.now(timezone.utc)
|
|
days_since_monday = now.weekday()
|
|
start = now - timedelta(days=days_since_monday)
|
|
start = start.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
return self._calculate_metrics(
|
|
"Weekly",
|
|
start.isoformat(),
|
|
now.isoformat(),
|
|
)
|
|
|
|
def get_daily_metrics(self) -> PeriodMetrics:
|
|
"""Get metrics for today (UTC)."""
|
|
now = datetime.now(timezone.utc)
|
|
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
return self._calculate_metrics(
|
|
"Daily",
|
|
start.isoformat(),
|
|
now.isoformat(),
|
|
)
|
|
|
|
def _calculate_metrics(
|
|
self,
|
|
period_name: str,
|
|
start_time: Optional[str],
|
|
end_time: Optional[str],
|
|
) -> PeriodMetrics:
|
|
"""
|
|
Calculate metrics for a time period.
|
|
|
|
Args:
|
|
period_name: Name of the period
|
|
start_time: ISO format start time (None for all time)
|
|
end_time: ISO format end time (None for all time)
|
|
|
|
Returns:
|
|
PeriodMetrics object
|
|
"""
|
|
metrics = PeriodMetrics(
|
|
period_name=period_name,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
)
|
|
|
|
# Build query conditions
|
|
conditions = ["exit_time IS NOT NULL"]
|
|
params = []
|
|
|
|
if start_time:
|
|
conditions.append("exit_time >= ?")
|
|
params.append(start_time)
|
|
if end_time:
|
|
conditions.append("exit_time <= ?")
|
|
params.append(end_time)
|
|
|
|
where_clause = " AND ".join(conditions)
|
|
|
|
# Get aggregate metrics
|
|
sql = f"""
|
|
SELECT
|
|
COUNT(*) as total_trades,
|
|
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
|
|
SUM(CASE WHEN pnl_usd < 0 THEN 1 ELSE 0 END) as losing_trades,
|
|
COALESCE(SUM(pnl_usd), 0) as total_pnl,
|
|
COALESCE(AVG(hold_duration_hours), 0) as avg_duration,
|
|
COALESCE(MAX(pnl_usd), 0) as best_trade,
|
|
COALESCE(MIN(pnl_usd), 0) as worst_trade,
|
|
COALESCE(AVG(CASE WHEN pnl_usd > 0 THEN pnl_usd END), 0) as avg_win,
|
|
COALESCE(AVG(CASE WHEN pnl_usd < 0 THEN pnl_usd END), 0) as avg_loss
|
|
FROM trades
|
|
WHERE {where_clause}
|
|
"""
|
|
|
|
row = self.db.connection.execute(sql, params).fetchone()
|
|
|
|
if row and row["total_trades"] > 0:
|
|
metrics.total_trades = row["total_trades"]
|
|
metrics.winning_trades = row["winning_trades"] or 0
|
|
metrics.losing_trades = row["losing_trades"] or 0
|
|
metrics.total_pnl = row["total_pnl"]
|
|
metrics.avg_trade_duration_hours = row["avg_duration"]
|
|
metrics.best_trade = row["best_trade"]
|
|
metrics.worst_trade = row["worst_trade"]
|
|
metrics.avg_win = row["avg_win"]
|
|
metrics.avg_loss = row["avg_loss"]
|
|
|
|
if metrics.total_trades > 0:
|
|
metrics.win_rate = (
|
|
metrics.winning_trades / metrics.total_trades * 100
|
|
)
|
|
|
|
# Calculate max drawdown
|
|
metrics.max_drawdown = self._calculate_max_drawdown(
|
|
start_time, end_time
|
|
)
|
|
|
|
return metrics
|
|
|
|
def _calculate_max_drawdown(
|
|
self,
|
|
start_time: Optional[str],
|
|
end_time: Optional[str],
|
|
) -> float:
|
|
"""Calculate maximum drawdown for a period."""
|
|
conditions = ["exit_time IS NOT NULL"]
|
|
params = []
|
|
|
|
if start_time:
|
|
conditions.append("exit_time >= ?")
|
|
params.append(start_time)
|
|
if end_time:
|
|
conditions.append("exit_time <= ?")
|
|
params.append(end_time)
|
|
|
|
where_clause = " AND ".join(conditions)
|
|
|
|
sql = f"""
|
|
SELECT pnl_usd
|
|
FROM trades
|
|
WHERE {where_clause}
|
|
ORDER BY exit_time
|
|
"""
|
|
|
|
rows = self.db.connection.execute(sql, params).fetchall()
|
|
|
|
if not rows:
|
|
return 0.0
|
|
|
|
cumulative = 0.0
|
|
peak = 0.0
|
|
max_drawdown = 0.0
|
|
|
|
for row in rows:
|
|
pnl = row["pnl_usd"] or 0.0
|
|
cumulative += pnl
|
|
peak = max(peak, cumulative)
|
|
drawdown = peak - cumulative
|
|
max_drawdown = max(max_drawdown, drawdown)
|
|
|
|
return max_drawdown
|
|
|
|
def has_monthly_data(self) -> bool:
|
|
"""Check if we have data spanning more than current month."""
|
|
sql = """
|
|
SELECT MIN(exit_time) as first_trade
|
|
FROM trades
|
|
WHERE exit_time IS NOT NULL
|
|
"""
|
|
row = self.db.connection.execute(sql).fetchone()
|
|
if not row or not row["first_trade"]:
|
|
return False
|
|
|
|
first_trade = datetime.fromisoformat(row["first_trade"])
|
|
now = datetime.now(timezone.utc)
|
|
month_start = now.replace(day=1, hour=0, minute=0, second=0)
|
|
|
|
return first_trade < month_start
|
|
|
|
def has_weekly_data(self) -> bool:
|
|
"""Check if we have data spanning more than current week."""
|
|
sql = """
|
|
SELECT MIN(exit_time) as first_trade
|
|
FROM trades
|
|
WHERE exit_time IS NOT NULL
|
|
"""
|
|
row = self.db.connection.execute(sql).fetchone()
|
|
if not row or not row["first_trade"]:
|
|
return False
|
|
|
|
first_trade = datetime.fromisoformat(row["first_trade"])
|
|
now = datetime.now(timezone.utc)
|
|
days_since_monday = now.weekday()
|
|
week_start = now - timedelta(days=days_since_monday)
|
|
week_start = week_start.replace(hour=0, minute=0, second=0)
|
|
|
|
return first_trade < week_start
|
|
|
|
def get_session_start_balance(self) -> Optional[float]:
|
|
"""Get starting balance from latest session."""
|
|
sql = "SELECT starting_balance FROM sessions ORDER BY id DESC LIMIT 1"
|
|
row = self.db.connection.execute(sql).fetchone()
|
|
return row["starting_balance"] if row else None
|