Files
lowkey_backtest/live_trading/db/database.py
Simon Moisy b5550f4ff4 Add daily model training scripts and terminal UI for live trading
- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps.
- Added `install_cron.sh` for setting up a cron job to run the daily training script.
- Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks.
- Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling.
- Updated `pyproject.toml` to include the `rich` dependency for UI functionality.
- Enhanced `.gitignore` to exclude model and log files.
- Added database support for trade persistence and metrics calculation.
- Updated README with installation and usage instructions for the new features.
2026-01-18 11:08:57 +08:00

326 lines
9.7 KiB
Python

"""SQLite database connection and operations."""
import sqlite3
import logging
from pathlib import Path
from typing import Optional
from contextlib import contextmanager
from .models import Trade, DailySummary, Session
logger = logging.getLogger(__name__)
# Database schema
SCHEMA = """
-- Trade history table
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trade_id TEXT UNIQUE NOT NULL,
symbol TEXT NOT NULL,
side TEXT NOT NULL,
entry_price REAL NOT NULL,
exit_price REAL,
size REAL NOT NULL,
size_usdt REAL NOT NULL,
pnl_usd REAL,
pnl_pct REAL,
entry_time TEXT NOT NULL,
exit_time TEXT,
hold_duration_hours REAL,
reason TEXT,
order_id_entry TEXT,
order_id_exit TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Daily summary table
CREATE TABLE IF NOT EXISTS daily_summary (
id INTEGER PRIMARY KEY AUTOINCREMENT,
date TEXT UNIQUE NOT NULL,
total_trades INTEGER DEFAULT 0,
winning_trades INTEGER DEFAULT 0,
total_pnl_usd REAL DEFAULT 0,
max_drawdown_usd REAL DEFAULT 0,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Session metadata
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
start_time TEXT NOT NULL,
end_time TEXT,
starting_balance REAL,
ending_balance REAL,
total_pnl REAL,
total_trades INTEGER DEFAULT 0
);
-- Indexes for common queries
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades(entry_time);
CREATE INDEX IF NOT EXISTS idx_trades_exit_time ON trades(exit_time);
CREATE INDEX IF NOT EXISTS idx_daily_summary_date ON daily_summary(date);
"""
_db_instance: Optional["TradingDatabase"] = None
class TradingDatabase:
"""SQLite database for trade persistence."""
def __init__(self, db_path: Path):
self.db_path = db_path
self._connection: Optional[sqlite3.Connection] = None
@property
def connection(self) -> sqlite3.Connection:
"""Get or create database connection."""
if self._connection is None:
self._connection = sqlite3.connect(
str(self.db_path),
check_same_thread=False,
)
self._connection.row_factory = sqlite3.Row
return self._connection
def init_schema(self) -> None:
"""Initialize database schema."""
with self.connection:
self.connection.executescript(SCHEMA)
logger.info(f"Database initialized at {self.db_path}")
def close(self) -> None:
"""Close database connection."""
if self._connection:
self._connection.close()
self._connection = None
@contextmanager
def transaction(self):
"""Context manager for database transactions."""
try:
yield self.connection
self.connection.commit()
except Exception:
self.connection.rollback()
raise
def insert_trade(self, trade: Trade) -> int:
"""
Insert a new trade record.
Args:
trade: Trade object to insert
Returns:
Row ID of inserted trade
"""
sql = """
INSERT INTO trades (
trade_id, symbol, side, entry_price, exit_price,
size, size_usdt, pnl_usd, pnl_pct, entry_time,
exit_time, hold_duration_hours, reason,
order_id_entry, order_id_exit
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
with self.transaction():
cursor = self.connection.execute(
sql,
(
trade.trade_id,
trade.symbol,
trade.side,
trade.entry_price,
trade.exit_price,
trade.size,
trade.size_usdt,
trade.pnl_usd,
trade.pnl_pct,
trade.entry_time,
trade.exit_time,
trade.hold_duration_hours,
trade.reason,
trade.order_id_entry,
trade.order_id_exit,
),
)
return cursor.lastrowid
def update_trade(self, trade_id: str, **kwargs) -> bool:
"""
Update an existing trade record.
Args:
trade_id: Trade ID to update
**kwargs: Fields to update
Returns:
True if trade was updated
"""
if not kwargs:
return False
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
sql = f"UPDATE trades SET {set_clause} WHERE trade_id = ?"
with self.transaction():
cursor = self.connection.execute(
sql, (*kwargs.values(), trade_id)
)
return cursor.rowcount > 0
def get_trade(self, trade_id: str) -> Optional[Trade]:
"""Get a trade by ID."""
sql = "SELECT * FROM trades WHERE trade_id = ?"
row = self.connection.execute(sql, (trade_id,)).fetchone()
if row:
return Trade(**dict(row))
return None
def get_trades(
self,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
limit: Optional[int] = None,
) -> list[Trade]:
"""
Get trades within a time range.
Args:
start_time: ISO format start time filter
end_time: ISO format end time filter
limit: Maximum number of trades to return
Returns:
List of Trade objects
"""
conditions = []
params = []
if start_time:
conditions.append("entry_time >= ?")
params.append(start_time)
if end_time:
conditions.append("entry_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions) if conditions else "1=1"
limit_clause = f"LIMIT {limit}" if limit else ""
sql = f"""
SELECT * FROM trades
WHERE {where_clause}
ORDER BY entry_time DESC
{limit_clause}
"""
rows = self.connection.execute(sql, params).fetchall()
return [Trade(**dict(row)) for row in rows]
def get_all_trades(self) -> list[Trade]:
"""Get all trades."""
sql = "SELECT * FROM trades ORDER BY entry_time DESC"
rows = self.connection.execute(sql).fetchall()
return [Trade(**dict(row)) for row in rows]
def count_trades(self) -> int:
"""Get total number of trades."""
sql = "SELECT COUNT(*) FROM trades WHERE exit_time IS NOT NULL"
return self.connection.execute(sql).fetchone()[0]
def upsert_daily_summary(self, summary: DailySummary) -> None:
"""Insert or update daily summary."""
sql = """
INSERT INTO daily_summary (
date, total_trades, winning_trades, total_pnl_usd, max_drawdown_usd
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(date) DO UPDATE SET
total_trades = excluded.total_trades,
winning_trades = excluded.winning_trades,
total_pnl_usd = excluded.total_pnl_usd,
max_drawdown_usd = excluded.max_drawdown_usd,
updated_at = CURRENT_TIMESTAMP
"""
with self.transaction():
self.connection.execute(
sql,
(
summary.date,
summary.total_trades,
summary.winning_trades,
summary.total_pnl_usd,
summary.max_drawdown_usd,
),
)
def get_daily_summary(self, date: str) -> Optional[DailySummary]:
"""Get daily summary for a specific date."""
sql = "SELECT * FROM daily_summary WHERE date = ?"
row = self.connection.execute(sql, (date,)).fetchone()
if row:
return DailySummary(**dict(row))
return None
def insert_session(self, session: Session) -> int:
"""Insert a new session record."""
sql = """
INSERT INTO sessions (
start_time, end_time, starting_balance,
ending_balance, total_pnl, total_trades
) VALUES (?, ?, ?, ?, ?, ?)
"""
with self.transaction():
cursor = self.connection.execute(
sql,
(
session.start_time,
session.end_time,
session.starting_balance,
session.ending_balance,
session.total_pnl,
session.total_trades,
),
)
return cursor.lastrowid
def update_session(self, session_id: int, **kwargs) -> bool:
"""Update an existing session."""
if not kwargs:
return False
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
sql = f"UPDATE sessions SET {set_clause} WHERE id = ?"
with self.transaction():
cursor = self.connection.execute(
sql, (*kwargs.values(), session_id)
)
return cursor.rowcount > 0
def get_latest_session(self) -> Optional[Session]:
"""Get the most recent session."""
sql = "SELECT * FROM sessions ORDER BY id DESC LIMIT 1"
row = self.connection.execute(sql).fetchone()
if row:
return Session(**dict(row))
return None
def init_db(db_path: Path) -> TradingDatabase:
"""
Initialize the database.
Args:
db_path: Path to the SQLite database file
Returns:
TradingDatabase instance
"""
global _db_instance
_db_instance = TradingDatabase(db_path)
_db_instance.init_schema()
return _db_instance
def get_db() -> Optional[TradingDatabase]:
"""Get the global database instance."""
return _db_instance