Add daily model training scripts and terminal UI for live trading
- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps. - Added `install_cron.sh` for setting up a cron job to run the daily training script. - Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks. - Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling. - Updated `pyproject.toml` to include the `rich` dependency for UI functionality. - Enhanced `.gitignore` to exclude model and log files. - Added database support for trade persistence and metrics calculation. - Updated README with installation and usage instructions for the new features.
This commit is contained in:
191
live_trading/db/migrations.py
Normal file
191
live_trading/db/migrations.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Database migrations and CSV import."""
|
||||
import csv
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from .database import TradingDatabase
|
||||
from .models import Trade, DailySummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_csv_to_db(db: TradingDatabase, csv_path: Path) -> int:
|
||||
"""
|
||||
Migrate trades from CSV file to SQLite database.
|
||||
|
||||
Args:
|
||||
db: TradingDatabase instance
|
||||
csv_path: Path to trade_log.csv
|
||||
|
||||
Returns:
|
||||
Number of trades migrated
|
||||
"""
|
||||
if not csv_path.exists():
|
||||
logger.info("No CSV file to migrate")
|
||||
return 0
|
||||
|
||||
# Check if database already has trades
|
||||
existing_count = db.count_trades()
|
||||
if existing_count > 0:
|
||||
logger.info(
|
||||
f"Database already has {existing_count} trades, skipping migration"
|
||||
)
|
||||
return 0
|
||||
|
||||
migrated = 0
|
||||
try:
|
||||
with open(csv_path, "r", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
for row in reader:
|
||||
trade = _csv_row_to_trade(row)
|
||||
if trade:
|
||||
try:
|
||||
db.insert_trade(trade)
|
||||
migrated += 1
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to migrate trade {row.get('trade_id')}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"Migrated {migrated} trades from CSV to database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CSV migration failed: {e}")
|
||||
|
||||
return migrated
|
||||
|
||||
|
||||
def _csv_row_to_trade(row: dict) -> Trade | None:
|
||||
"""Convert a CSV row to a Trade object."""
|
||||
try:
|
||||
return Trade(
|
||||
trade_id=row["trade_id"],
|
||||
symbol=row["symbol"],
|
||||
side=row["side"],
|
||||
entry_price=float(row["entry_price"]),
|
||||
exit_price=_safe_float(row.get("exit_price")),
|
||||
size=float(row["size"]),
|
||||
size_usdt=float(row["size_usdt"]),
|
||||
pnl_usd=_safe_float(row.get("pnl_usd")),
|
||||
pnl_pct=_safe_float(row.get("pnl_pct")),
|
||||
entry_time=row["entry_time"],
|
||||
exit_time=row.get("exit_time") or None,
|
||||
hold_duration_hours=_safe_float(row.get("hold_duration_hours")),
|
||||
reason=row.get("reason") or None,
|
||||
order_id_entry=row.get("order_id_entry") or None,
|
||||
order_id_exit=row.get("order_id_exit") or None,
|
||||
)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(f"Invalid CSV row: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _safe_float(value: str | None) -> float | None:
|
||||
"""Safely convert string to float."""
|
||||
if value is None or value == "":
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def rebuild_daily_summaries(db: TradingDatabase) -> int:
|
||||
"""
|
||||
Rebuild daily summary table from trades.
|
||||
|
||||
Args:
|
||||
db: TradingDatabase instance
|
||||
|
||||
Returns:
|
||||
Number of daily summaries created
|
||||
"""
|
||||
sql = """
|
||||
SELECT
|
||||
DATE(exit_time) as date,
|
||||
COUNT(*) as total_trades,
|
||||
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
|
||||
SUM(pnl_usd) as total_pnl_usd
|
||||
FROM trades
|
||||
WHERE exit_time IS NOT NULL
|
||||
GROUP BY DATE(exit_time)
|
||||
ORDER BY date
|
||||
"""
|
||||
|
||||
rows = db.connection.execute(sql).fetchall()
|
||||
count = 0
|
||||
|
||||
for row in rows:
|
||||
summary = DailySummary(
|
||||
date=row["date"],
|
||||
total_trades=row["total_trades"],
|
||||
winning_trades=row["winning_trades"],
|
||||
total_pnl_usd=row["total_pnl_usd"] or 0.0,
|
||||
max_drawdown_usd=0.0, # Calculated separately
|
||||
)
|
||||
db.upsert_daily_summary(summary)
|
||||
count += 1
|
||||
|
||||
# Calculate max drawdowns
|
||||
_calculate_daily_drawdowns(db)
|
||||
|
||||
logger.info(f"Rebuilt {count} daily summaries")
|
||||
return count
|
||||
|
||||
|
||||
def _calculate_daily_drawdowns(db: TradingDatabase) -> None:
|
||||
"""Calculate and update max drawdown for each day."""
|
||||
sql = """
|
||||
SELECT trade_id, DATE(exit_time) as date, pnl_usd
|
||||
FROM trades
|
||||
WHERE exit_time IS NOT NULL
|
||||
ORDER BY exit_time
|
||||
"""
|
||||
|
||||
rows = db.connection.execute(sql).fetchall()
|
||||
|
||||
# Track cumulative PnL and drawdown per day
|
||||
daily_drawdowns: dict[str, float] = {}
|
||||
cumulative_pnl = 0.0
|
||||
peak_pnl = 0.0
|
||||
|
||||
for row in rows:
|
||||
date = row["date"]
|
||||
pnl = row["pnl_usd"] or 0.0
|
||||
|
||||
cumulative_pnl += pnl
|
||||
peak_pnl = max(peak_pnl, cumulative_pnl)
|
||||
drawdown = peak_pnl - cumulative_pnl
|
||||
|
||||
if date not in daily_drawdowns:
|
||||
daily_drawdowns[date] = 0.0
|
||||
daily_drawdowns[date] = max(daily_drawdowns[date], drawdown)
|
||||
|
||||
# Update daily summaries with drawdown
|
||||
for date, drawdown in daily_drawdowns.items():
|
||||
db.connection.execute(
|
||||
"UPDATE daily_summary SET max_drawdown_usd = ? WHERE date = ?",
|
||||
(drawdown, date),
|
||||
)
|
||||
db.connection.commit()
|
||||
|
||||
|
||||
def run_migrations(db: TradingDatabase, csv_path: Path) -> None:
|
||||
"""
|
||||
Run all migrations.
|
||||
|
||||
Args:
|
||||
db: TradingDatabase instance
|
||||
csv_path: Path to trade_log.csv for migration
|
||||
"""
|
||||
logger.info("Running database migrations...")
|
||||
|
||||
# Migrate CSV data if exists
|
||||
migrate_csv_to_db(db, csv_path)
|
||||
|
||||
# Rebuild daily summaries
|
||||
rebuild_daily_summaries(db)
|
||||
|
||||
logger.info("Migrations complete")
|
||||
Reference in New Issue
Block a user