Add daily model training scripts and terminal UI for live trading

- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps.
- Added `install_cron.sh` for setting up a cron job to run the daily training script.
- Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks.
- Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling.
- Updated `pyproject.toml` to include the `rich` dependency for UI functionality.
- Enhanced `.gitignore` to exclude model and log files.
- Added database support for trade persistence and metrics calculation.
- Updated README with installation and usage instructions for the new features.
This commit is contained in:
2026-01-18 11:08:57 +08:00
parent 35992ee374
commit b5550f4ff4
27 changed files with 3582 additions and 113 deletions

View File

@@ -0,0 +1,13 @@
"""Database module for live trading persistence."""
from .database import get_db, init_db
from .models import Trade, DailySummary, Session
from .metrics import MetricsCalculator
__all__ = [
"get_db",
"init_db",
"Trade",
"DailySummary",
"Session",
"MetricsCalculator",
]

325
live_trading/db/database.py Normal file
View File

@@ -0,0 +1,325 @@
"""SQLite database connection and operations."""
import sqlite3
import logging
from pathlib import Path
from typing import Optional
from contextlib import contextmanager
from .models import Trade, DailySummary, Session
logger = logging.getLogger(__name__)
# Database schema
SCHEMA = """
-- Trade history table
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trade_id TEXT UNIQUE NOT NULL,
symbol TEXT NOT NULL,
side TEXT NOT NULL,
entry_price REAL NOT NULL,
exit_price REAL,
size REAL NOT NULL,
size_usdt REAL NOT NULL,
pnl_usd REAL,
pnl_pct REAL,
entry_time TEXT NOT NULL,
exit_time TEXT,
hold_duration_hours REAL,
reason TEXT,
order_id_entry TEXT,
order_id_exit TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Daily summary table
CREATE TABLE IF NOT EXISTS daily_summary (
id INTEGER PRIMARY KEY AUTOINCREMENT,
date TEXT UNIQUE NOT NULL,
total_trades INTEGER DEFAULT 0,
winning_trades INTEGER DEFAULT 0,
total_pnl_usd REAL DEFAULT 0,
max_drawdown_usd REAL DEFAULT 0,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Session metadata
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
start_time TEXT NOT NULL,
end_time TEXT,
starting_balance REAL,
ending_balance REAL,
total_pnl REAL,
total_trades INTEGER DEFAULT 0
);
-- Indexes for common queries
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades(entry_time);
CREATE INDEX IF NOT EXISTS idx_trades_exit_time ON trades(exit_time);
CREATE INDEX IF NOT EXISTS idx_daily_summary_date ON daily_summary(date);
"""
_db_instance: Optional["TradingDatabase"] = None
class TradingDatabase:
"""SQLite database for trade persistence."""
def __init__(self, db_path: Path):
self.db_path = db_path
self._connection: Optional[sqlite3.Connection] = None
@property
def connection(self) -> sqlite3.Connection:
"""Get or create database connection."""
if self._connection is None:
self._connection = sqlite3.connect(
str(self.db_path),
check_same_thread=False,
)
self._connection.row_factory = sqlite3.Row
return self._connection
def init_schema(self) -> None:
"""Initialize database schema."""
with self.connection:
self.connection.executescript(SCHEMA)
logger.info(f"Database initialized at {self.db_path}")
def close(self) -> None:
"""Close database connection."""
if self._connection:
self._connection.close()
self._connection = None
@contextmanager
def transaction(self):
"""Context manager for database transactions."""
try:
yield self.connection
self.connection.commit()
except Exception:
self.connection.rollback()
raise
def insert_trade(self, trade: Trade) -> int:
"""
Insert a new trade record.
Args:
trade: Trade object to insert
Returns:
Row ID of inserted trade
"""
sql = """
INSERT INTO trades (
trade_id, symbol, side, entry_price, exit_price,
size, size_usdt, pnl_usd, pnl_pct, entry_time,
exit_time, hold_duration_hours, reason,
order_id_entry, order_id_exit
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
with self.transaction():
cursor = self.connection.execute(
sql,
(
trade.trade_id,
trade.symbol,
trade.side,
trade.entry_price,
trade.exit_price,
trade.size,
trade.size_usdt,
trade.pnl_usd,
trade.pnl_pct,
trade.entry_time,
trade.exit_time,
trade.hold_duration_hours,
trade.reason,
trade.order_id_entry,
trade.order_id_exit,
),
)
return cursor.lastrowid
def update_trade(self, trade_id: str, **kwargs) -> bool:
"""
Update an existing trade record.
Args:
trade_id: Trade ID to update
**kwargs: Fields to update
Returns:
True if trade was updated
"""
if not kwargs:
return False
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
sql = f"UPDATE trades SET {set_clause} WHERE trade_id = ?"
with self.transaction():
cursor = self.connection.execute(
sql, (*kwargs.values(), trade_id)
)
return cursor.rowcount > 0
def get_trade(self, trade_id: str) -> Optional[Trade]:
"""Get a trade by ID."""
sql = "SELECT * FROM trades WHERE trade_id = ?"
row = self.connection.execute(sql, (trade_id,)).fetchone()
if row:
return Trade(**dict(row))
return None
def get_trades(
self,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
limit: Optional[int] = None,
) -> list[Trade]:
"""
Get trades within a time range.
Args:
start_time: ISO format start time filter
end_time: ISO format end time filter
limit: Maximum number of trades to return
Returns:
List of Trade objects
"""
conditions = []
params = []
if start_time:
conditions.append("entry_time >= ?")
params.append(start_time)
if end_time:
conditions.append("entry_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions) if conditions else "1=1"
limit_clause = f"LIMIT {limit}" if limit else ""
sql = f"""
SELECT * FROM trades
WHERE {where_clause}
ORDER BY entry_time DESC
{limit_clause}
"""
rows = self.connection.execute(sql, params).fetchall()
return [Trade(**dict(row)) for row in rows]
def get_all_trades(self) -> list[Trade]:
"""Get all trades."""
sql = "SELECT * FROM trades ORDER BY entry_time DESC"
rows = self.connection.execute(sql).fetchall()
return [Trade(**dict(row)) for row in rows]
def count_trades(self) -> int:
"""Get total number of trades."""
sql = "SELECT COUNT(*) FROM trades WHERE exit_time IS NOT NULL"
return self.connection.execute(sql).fetchone()[0]
def upsert_daily_summary(self, summary: DailySummary) -> None:
"""Insert or update daily summary."""
sql = """
INSERT INTO daily_summary (
date, total_trades, winning_trades, total_pnl_usd, max_drawdown_usd
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(date) DO UPDATE SET
total_trades = excluded.total_trades,
winning_trades = excluded.winning_trades,
total_pnl_usd = excluded.total_pnl_usd,
max_drawdown_usd = excluded.max_drawdown_usd,
updated_at = CURRENT_TIMESTAMP
"""
with self.transaction():
self.connection.execute(
sql,
(
summary.date,
summary.total_trades,
summary.winning_trades,
summary.total_pnl_usd,
summary.max_drawdown_usd,
),
)
def get_daily_summary(self, date: str) -> Optional[DailySummary]:
"""Get daily summary for a specific date."""
sql = "SELECT * FROM daily_summary WHERE date = ?"
row = self.connection.execute(sql, (date,)).fetchone()
if row:
return DailySummary(**dict(row))
return None
def insert_session(self, session: Session) -> int:
"""Insert a new session record."""
sql = """
INSERT INTO sessions (
start_time, end_time, starting_balance,
ending_balance, total_pnl, total_trades
) VALUES (?, ?, ?, ?, ?, ?)
"""
with self.transaction():
cursor = self.connection.execute(
sql,
(
session.start_time,
session.end_time,
session.starting_balance,
session.ending_balance,
session.total_pnl,
session.total_trades,
),
)
return cursor.lastrowid
def update_session(self, session_id: int, **kwargs) -> bool:
"""Update an existing session."""
if not kwargs:
return False
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
sql = f"UPDATE sessions SET {set_clause} WHERE id = ?"
with self.transaction():
cursor = self.connection.execute(
sql, (*kwargs.values(), session_id)
)
return cursor.rowcount > 0
def get_latest_session(self) -> Optional[Session]:
"""Get the most recent session."""
sql = "SELECT * FROM sessions ORDER BY id DESC LIMIT 1"
row = self.connection.execute(sql).fetchone()
if row:
return Session(**dict(row))
return None
def init_db(db_path: Path) -> TradingDatabase:
"""
Initialize the database.
Args:
db_path: Path to the SQLite database file
Returns:
TradingDatabase instance
"""
global _db_instance
_db_instance = TradingDatabase(db_path)
_db_instance.init_schema()
return _db_instance
def get_db() -> Optional[TradingDatabase]:
"""Get the global database instance."""
return _db_instance

235
live_trading/db/metrics.py Normal file
View File

@@ -0,0 +1,235 @@
"""Metrics calculation from trade database."""
import logging
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta
from typing import Optional
from .database import TradingDatabase
logger = logging.getLogger(__name__)
@dataclass
class PeriodMetrics:
"""Trading metrics for a time period."""
period_name: str
start_time: Optional[str]
end_time: Optional[str]
total_pnl: float = 0.0
total_trades: int = 0
winning_trades: int = 0
losing_trades: int = 0
win_rate: float = 0.0
avg_trade_duration_hours: float = 0.0
max_drawdown: float = 0.0
max_drawdown_pct: float = 0.0
best_trade: float = 0.0
worst_trade: float = 0.0
avg_win: float = 0.0
avg_loss: float = 0.0
class MetricsCalculator:
"""Calculate trading metrics from database."""
def __init__(self, db: TradingDatabase):
self.db = db
def get_all_time_metrics(self) -> PeriodMetrics:
"""Get metrics for all trades ever."""
return self._calculate_metrics("All Time", None, None)
def get_monthly_metrics(self) -> PeriodMetrics:
"""Get metrics for current month."""
now = datetime.now(timezone.utc)
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Monthly",
start.isoformat(),
now.isoformat(),
)
def get_weekly_metrics(self) -> PeriodMetrics:
"""Get metrics for current week (Monday to now)."""
now = datetime.now(timezone.utc)
days_since_monday = now.weekday()
start = now - timedelta(days=days_since_monday)
start = start.replace(hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Weekly",
start.isoformat(),
now.isoformat(),
)
def get_daily_metrics(self) -> PeriodMetrics:
"""Get metrics for today (UTC)."""
now = datetime.now(timezone.utc)
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Daily",
start.isoformat(),
now.isoformat(),
)
def _calculate_metrics(
self,
period_name: str,
start_time: Optional[str],
end_time: Optional[str],
) -> PeriodMetrics:
"""
Calculate metrics for a time period.
Args:
period_name: Name of the period
start_time: ISO format start time (None for all time)
end_time: ISO format end time (None for all time)
Returns:
PeriodMetrics object
"""
metrics = PeriodMetrics(
period_name=period_name,
start_time=start_time,
end_time=end_time,
)
# Build query conditions
conditions = ["exit_time IS NOT NULL"]
params = []
if start_time:
conditions.append("exit_time >= ?")
params.append(start_time)
if end_time:
conditions.append("exit_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions)
# Get aggregate metrics
sql = f"""
SELECT
COUNT(*) as total_trades,
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
SUM(CASE WHEN pnl_usd < 0 THEN 1 ELSE 0 END) as losing_trades,
COALESCE(SUM(pnl_usd), 0) as total_pnl,
COALESCE(AVG(hold_duration_hours), 0) as avg_duration,
COALESCE(MAX(pnl_usd), 0) as best_trade,
COALESCE(MIN(pnl_usd), 0) as worst_trade,
COALESCE(AVG(CASE WHEN pnl_usd > 0 THEN pnl_usd END), 0) as avg_win,
COALESCE(AVG(CASE WHEN pnl_usd < 0 THEN pnl_usd END), 0) as avg_loss
FROM trades
WHERE {where_clause}
"""
row = self.db.connection.execute(sql, params).fetchone()
if row and row["total_trades"] > 0:
metrics.total_trades = row["total_trades"]
metrics.winning_trades = row["winning_trades"] or 0
metrics.losing_trades = row["losing_trades"] or 0
metrics.total_pnl = row["total_pnl"]
metrics.avg_trade_duration_hours = row["avg_duration"]
metrics.best_trade = row["best_trade"]
metrics.worst_trade = row["worst_trade"]
metrics.avg_win = row["avg_win"]
metrics.avg_loss = row["avg_loss"]
if metrics.total_trades > 0:
metrics.win_rate = (
metrics.winning_trades / metrics.total_trades * 100
)
# Calculate max drawdown
metrics.max_drawdown = self._calculate_max_drawdown(
start_time, end_time
)
return metrics
def _calculate_max_drawdown(
self,
start_time: Optional[str],
end_time: Optional[str],
) -> float:
"""Calculate maximum drawdown for a period."""
conditions = ["exit_time IS NOT NULL"]
params = []
if start_time:
conditions.append("exit_time >= ?")
params.append(start_time)
if end_time:
conditions.append("exit_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions)
sql = f"""
SELECT pnl_usd
FROM trades
WHERE {where_clause}
ORDER BY exit_time
"""
rows = self.db.connection.execute(sql, params).fetchall()
if not rows:
return 0.0
cumulative = 0.0
peak = 0.0
max_drawdown = 0.0
for row in rows:
pnl = row["pnl_usd"] or 0.0
cumulative += pnl
peak = max(peak, cumulative)
drawdown = peak - cumulative
max_drawdown = max(max_drawdown, drawdown)
return max_drawdown
def has_monthly_data(self) -> bool:
"""Check if we have data spanning more than current month."""
sql = """
SELECT MIN(exit_time) as first_trade
FROM trades
WHERE exit_time IS NOT NULL
"""
row = self.db.connection.execute(sql).fetchone()
if not row or not row["first_trade"]:
return False
first_trade = datetime.fromisoformat(row["first_trade"])
now = datetime.now(timezone.utc)
month_start = now.replace(day=1, hour=0, minute=0, second=0)
return first_trade < month_start
def has_weekly_data(self) -> bool:
"""Check if we have data spanning more than current week."""
sql = """
SELECT MIN(exit_time) as first_trade
FROM trades
WHERE exit_time IS NOT NULL
"""
row = self.db.connection.execute(sql).fetchone()
if not row or not row["first_trade"]:
return False
first_trade = datetime.fromisoformat(row["first_trade"])
now = datetime.now(timezone.utc)
days_since_monday = now.weekday()
week_start = now - timedelta(days=days_since_monday)
week_start = week_start.replace(hour=0, minute=0, second=0)
return first_trade < week_start
def get_session_start_balance(self) -> Optional[float]:
"""Get starting balance from latest session."""
sql = "SELECT starting_balance FROM sessions ORDER BY id DESC LIMIT 1"
row = self.db.connection.execute(sql).fetchone()
return row["starting_balance"] if row else None

View File

@@ -0,0 +1,191 @@
"""Database migrations and CSV import."""
import csv
import logging
from pathlib import Path
from datetime import datetime
from .database import TradingDatabase
from .models import Trade, DailySummary
logger = logging.getLogger(__name__)
def migrate_csv_to_db(db: TradingDatabase, csv_path: Path) -> int:
"""
Migrate trades from CSV file to SQLite database.
Args:
db: TradingDatabase instance
csv_path: Path to trade_log.csv
Returns:
Number of trades migrated
"""
if not csv_path.exists():
logger.info("No CSV file to migrate")
return 0
# Check if database already has trades
existing_count = db.count_trades()
if existing_count > 0:
logger.info(
f"Database already has {existing_count} trades, skipping migration"
)
return 0
migrated = 0
try:
with open(csv_path, "r", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
trade = _csv_row_to_trade(row)
if trade:
try:
db.insert_trade(trade)
migrated += 1
except Exception as e:
logger.warning(
f"Failed to migrate trade {row.get('trade_id')}: {e}"
)
logger.info(f"Migrated {migrated} trades from CSV to database")
except Exception as e:
logger.error(f"CSV migration failed: {e}")
return migrated
def _csv_row_to_trade(row: dict) -> Trade | None:
"""Convert a CSV row to a Trade object."""
try:
return Trade(
trade_id=row["trade_id"],
symbol=row["symbol"],
side=row["side"],
entry_price=float(row["entry_price"]),
exit_price=_safe_float(row.get("exit_price")),
size=float(row["size"]),
size_usdt=float(row["size_usdt"]),
pnl_usd=_safe_float(row.get("pnl_usd")),
pnl_pct=_safe_float(row.get("pnl_pct")),
entry_time=row["entry_time"],
exit_time=row.get("exit_time") or None,
hold_duration_hours=_safe_float(row.get("hold_duration_hours")),
reason=row.get("reason") or None,
order_id_entry=row.get("order_id_entry") or None,
order_id_exit=row.get("order_id_exit") or None,
)
except (KeyError, ValueError) as e:
logger.warning(f"Invalid CSV row: {e}")
return None
def _safe_float(value: str | None) -> float | None:
"""Safely convert string to float."""
if value is None or value == "":
return None
try:
return float(value)
except ValueError:
return None
def rebuild_daily_summaries(db: TradingDatabase) -> int:
"""
Rebuild daily summary table from trades.
Args:
db: TradingDatabase instance
Returns:
Number of daily summaries created
"""
sql = """
SELECT
DATE(exit_time) as date,
COUNT(*) as total_trades,
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
SUM(pnl_usd) as total_pnl_usd
FROM trades
WHERE exit_time IS NOT NULL
GROUP BY DATE(exit_time)
ORDER BY date
"""
rows = db.connection.execute(sql).fetchall()
count = 0
for row in rows:
summary = DailySummary(
date=row["date"],
total_trades=row["total_trades"],
winning_trades=row["winning_trades"],
total_pnl_usd=row["total_pnl_usd"] or 0.0,
max_drawdown_usd=0.0, # Calculated separately
)
db.upsert_daily_summary(summary)
count += 1
# Calculate max drawdowns
_calculate_daily_drawdowns(db)
logger.info(f"Rebuilt {count} daily summaries")
return count
def _calculate_daily_drawdowns(db: TradingDatabase) -> None:
"""Calculate and update max drawdown for each day."""
sql = """
SELECT trade_id, DATE(exit_time) as date, pnl_usd
FROM trades
WHERE exit_time IS NOT NULL
ORDER BY exit_time
"""
rows = db.connection.execute(sql).fetchall()
# Track cumulative PnL and drawdown per day
daily_drawdowns: dict[str, float] = {}
cumulative_pnl = 0.0
peak_pnl = 0.0
for row in rows:
date = row["date"]
pnl = row["pnl_usd"] or 0.0
cumulative_pnl += pnl
peak_pnl = max(peak_pnl, cumulative_pnl)
drawdown = peak_pnl - cumulative_pnl
if date not in daily_drawdowns:
daily_drawdowns[date] = 0.0
daily_drawdowns[date] = max(daily_drawdowns[date], drawdown)
# Update daily summaries with drawdown
for date, drawdown in daily_drawdowns.items():
db.connection.execute(
"UPDATE daily_summary SET max_drawdown_usd = ? WHERE date = ?",
(drawdown, date),
)
db.connection.commit()
def run_migrations(db: TradingDatabase, csv_path: Path) -> None:
"""
Run all migrations.
Args:
db: TradingDatabase instance
csv_path: Path to trade_log.csv for migration
"""
logger.info("Running database migrations...")
# Migrate CSV data if exists
migrate_csv_to_db(db, csv_path)
# Rebuild daily summaries
rebuild_daily_summaries(db)
logger.info("Migrations complete")

69
live_trading/db/models.py Normal file
View File

@@ -0,0 +1,69 @@
"""Data models for trade persistence."""
from dataclasses import dataclass, asdict
from typing import Optional
from datetime import datetime
@dataclass
class Trade:
"""Represents a completed trade."""
trade_id: str
symbol: str
side: str
entry_price: float
size: float
size_usdt: float
entry_time: str
exit_price: Optional[float] = None
pnl_usd: Optional[float] = None
pnl_pct: Optional[float] = None
exit_time: Optional[str] = None
hold_duration_hours: Optional[float] = None
reason: Optional[str] = None
order_id_entry: Optional[str] = None
order_id_exit: Optional[str] = None
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@classmethod
def from_row(cls, row: tuple, columns: list[str]) -> "Trade":
"""Create Trade from database row."""
data = dict(zip(columns, row))
return cls(**data)
@dataclass
class DailySummary:
"""Daily trading summary."""
date: str
total_trades: int = 0
winning_trades: int = 0
total_pnl_usd: float = 0.0
max_drawdown_usd: float = 0.0
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class Session:
"""Trading session metadata."""
start_time: str
end_time: Optional[str] = None
starting_balance: Optional[float] = None
ending_balance: Optional[float] = None
total_pnl: Optional[float] = None
total_trades: int = 0
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)