- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps. - Added `install_cron.sh` for setting up a cron job to run the daily training script. - Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks. - Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling. - Updated `pyproject.toml` to include the `rich` dependency for UI functionality. - Enhanced `.gitignore` to exclude model and log files. - Added database support for trade persistence and metrics calculation. - Updated README with installation and usage instructions for the new features.
192 lines
5.2 KiB
Python
192 lines
5.2 KiB
Python
"""Database migrations and CSV import."""
|
|
import csv
|
|
import logging
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
from .database import TradingDatabase
|
|
from .models import Trade, DailySummary
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def migrate_csv_to_db(db: TradingDatabase, csv_path: Path) -> int:
|
|
"""
|
|
Migrate trades from CSV file to SQLite database.
|
|
|
|
Args:
|
|
db: TradingDatabase instance
|
|
csv_path: Path to trade_log.csv
|
|
|
|
Returns:
|
|
Number of trades migrated
|
|
"""
|
|
if not csv_path.exists():
|
|
logger.info("No CSV file to migrate")
|
|
return 0
|
|
|
|
# Check if database already has trades
|
|
existing_count = db.count_trades()
|
|
if existing_count > 0:
|
|
logger.info(
|
|
f"Database already has {existing_count} trades, skipping migration"
|
|
)
|
|
return 0
|
|
|
|
migrated = 0
|
|
try:
|
|
with open(csv_path, "r", newline="") as f:
|
|
reader = csv.DictReader(f)
|
|
|
|
for row in reader:
|
|
trade = _csv_row_to_trade(row)
|
|
if trade:
|
|
try:
|
|
db.insert_trade(trade)
|
|
migrated += 1
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to migrate trade {row.get('trade_id')}: {e}"
|
|
)
|
|
|
|
logger.info(f"Migrated {migrated} trades from CSV to database")
|
|
|
|
except Exception as e:
|
|
logger.error(f"CSV migration failed: {e}")
|
|
|
|
return migrated
|
|
|
|
|
|
def _csv_row_to_trade(row: dict) -> Trade | None:
|
|
"""Convert a CSV row to a Trade object."""
|
|
try:
|
|
return Trade(
|
|
trade_id=row["trade_id"],
|
|
symbol=row["symbol"],
|
|
side=row["side"],
|
|
entry_price=float(row["entry_price"]),
|
|
exit_price=_safe_float(row.get("exit_price")),
|
|
size=float(row["size"]),
|
|
size_usdt=float(row["size_usdt"]),
|
|
pnl_usd=_safe_float(row.get("pnl_usd")),
|
|
pnl_pct=_safe_float(row.get("pnl_pct")),
|
|
entry_time=row["entry_time"],
|
|
exit_time=row.get("exit_time") or None,
|
|
hold_duration_hours=_safe_float(row.get("hold_duration_hours")),
|
|
reason=row.get("reason") or None,
|
|
order_id_entry=row.get("order_id_entry") or None,
|
|
order_id_exit=row.get("order_id_exit") or None,
|
|
)
|
|
except (KeyError, ValueError) as e:
|
|
logger.warning(f"Invalid CSV row: {e}")
|
|
return None
|
|
|
|
|
|
def _safe_float(value: str | None) -> float | None:
|
|
"""Safely convert string to float."""
|
|
if value is None or value == "":
|
|
return None
|
|
try:
|
|
return float(value)
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
def rebuild_daily_summaries(db: TradingDatabase) -> int:
|
|
"""
|
|
Rebuild daily summary table from trades.
|
|
|
|
Args:
|
|
db: TradingDatabase instance
|
|
|
|
Returns:
|
|
Number of daily summaries created
|
|
"""
|
|
sql = """
|
|
SELECT
|
|
DATE(exit_time) as date,
|
|
COUNT(*) as total_trades,
|
|
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
|
|
SUM(pnl_usd) as total_pnl_usd
|
|
FROM trades
|
|
WHERE exit_time IS NOT NULL
|
|
GROUP BY DATE(exit_time)
|
|
ORDER BY date
|
|
"""
|
|
|
|
rows = db.connection.execute(sql).fetchall()
|
|
count = 0
|
|
|
|
for row in rows:
|
|
summary = DailySummary(
|
|
date=row["date"],
|
|
total_trades=row["total_trades"],
|
|
winning_trades=row["winning_trades"],
|
|
total_pnl_usd=row["total_pnl_usd"] or 0.0,
|
|
max_drawdown_usd=0.0, # Calculated separately
|
|
)
|
|
db.upsert_daily_summary(summary)
|
|
count += 1
|
|
|
|
# Calculate max drawdowns
|
|
_calculate_daily_drawdowns(db)
|
|
|
|
logger.info(f"Rebuilt {count} daily summaries")
|
|
return count
|
|
|
|
|
|
def _calculate_daily_drawdowns(db: TradingDatabase) -> None:
|
|
"""Calculate and update max drawdown for each day."""
|
|
sql = """
|
|
SELECT trade_id, DATE(exit_time) as date, pnl_usd
|
|
FROM trades
|
|
WHERE exit_time IS NOT NULL
|
|
ORDER BY exit_time
|
|
"""
|
|
|
|
rows = db.connection.execute(sql).fetchall()
|
|
|
|
# Track cumulative PnL and drawdown per day
|
|
daily_drawdowns: dict[str, float] = {}
|
|
cumulative_pnl = 0.0
|
|
peak_pnl = 0.0
|
|
|
|
for row in rows:
|
|
date = row["date"]
|
|
pnl = row["pnl_usd"] or 0.0
|
|
|
|
cumulative_pnl += pnl
|
|
peak_pnl = max(peak_pnl, cumulative_pnl)
|
|
drawdown = peak_pnl - cumulative_pnl
|
|
|
|
if date not in daily_drawdowns:
|
|
daily_drawdowns[date] = 0.0
|
|
daily_drawdowns[date] = max(daily_drawdowns[date], drawdown)
|
|
|
|
# Update daily summaries with drawdown
|
|
for date, drawdown in daily_drawdowns.items():
|
|
db.connection.execute(
|
|
"UPDATE daily_summary SET max_drawdown_usd = ? WHERE date = ?",
|
|
(drawdown, date),
|
|
)
|
|
db.connection.commit()
|
|
|
|
|
|
def run_migrations(db: TradingDatabase, csv_path: Path) -> None:
|
|
"""
|
|
Run all migrations.
|
|
|
|
Args:
|
|
db: TradingDatabase instance
|
|
csv_path: Path to trade_log.csv for migration
|
|
"""
|
|
logger.info("Running database migrations...")
|
|
|
|
# Migrate CSV data if exists
|
|
migrate_csv_to_db(db, csv_path)
|
|
|
|
# Rebuild daily summaries
|
|
rebuild_daily_summaries(db)
|
|
|
|
logger.info("Migrations complete")
|