- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps. - Added `install_cron.sh` for setting up a cron job to run the daily training script. - Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks. - Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling. - Updated `pyproject.toml` to include the `rich` dependency for UI functionality. - Enhanced `.gitignore` to exclude model and log files. - Added database support for trade persistence and metrics calculation. - Updated README with installation and usage instructions for the new features.
326 lines
9.7 KiB
Python
326 lines
9.7 KiB
Python
"""SQLite database connection and operations."""
|
|
import sqlite3
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from contextlib import contextmanager
|
|
|
|
from .models import Trade, DailySummary, Session
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Database schema
|
|
SCHEMA = """
|
|
-- Trade history table
|
|
CREATE TABLE IF NOT EXISTS trades (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
trade_id TEXT UNIQUE NOT NULL,
|
|
symbol TEXT NOT NULL,
|
|
side TEXT NOT NULL,
|
|
entry_price REAL NOT NULL,
|
|
exit_price REAL,
|
|
size REAL NOT NULL,
|
|
size_usdt REAL NOT NULL,
|
|
pnl_usd REAL,
|
|
pnl_pct REAL,
|
|
entry_time TEXT NOT NULL,
|
|
exit_time TEXT,
|
|
hold_duration_hours REAL,
|
|
reason TEXT,
|
|
order_id_entry TEXT,
|
|
order_id_exit TEXT,
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
|
|
-- Daily summary table
|
|
CREATE TABLE IF NOT EXISTS daily_summary (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
date TEXT UNIQUE NOT NULL,
|
|
total_trades INTEGER DEFAULT 0,
|
|
winning_trades INTEGER DEFAULT 0,
|
|
total_pnl_usd REAL DEFAULT 0,
|
|
max_drawdown_usd REAL DEFAULT 0,
|
|
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
|
|
-- Session metadata
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
start_time TEXT NOT NULL,
|
|
end_time TEXT,
|
|
starting_balance REAL,
|
|
ending_balance REAL,
|
|
total_pnl REAL,
|
|
total_trades INTEGER DEFAULT 0
|
|
);
|
|
|
|
-- Indexes for common queries
|
|
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades(entry_time);
|
|
CREATE INDEX IF NOT EXISTS idx_trades_exit_time ON trades(exit_time);
|
|
CREATE INDEX IF NOT EXISTS idx_daily_summary_date ON daily_summary(date);
|
|
"""
|
|
|
|
_db_instance: Optional["TradingDatabase"] = None
|
|
|
|
|
|
class TradingDatabase:
|
|
"""SQLite database for trade persistence."""
|
|
|
|
def __init__(self, db_path: Path):
|
|
self.db_path = db_path
|
|
self._connection: Optional[sqlite3.Connection] = None
|
|
|
|
@property
|
|
def connection(self) -> sqlite3.Connection:
|
|
"""Get or create database connection."""
|
|
if self._connection is None:
|
|
self._connection = sqlite3.connect(
|
|
str(self.db_path),
|
|
check_same_thread=False,
|
|
)
|
|
self._connection.row_factory = sqlite3.Row
|
|
return self._connection
|
|
|
|
def init_schema(self) -> None:
|
|
"""Initialize database schema."""
|
|
with self.connection:
|
|
self.connection.executescript(SCHEMA)
|
|
logger.info(f"Database initialized at {self.db_path}")
|
|
|
|
def close(self) -> None:
|
|
"""Close database connection."""
|
|
if self._connection:
|
|
self._connection.close()
|
|
self._connection = None
|
|
|
|
@contextmanager
|
|
def transaction(self):
|
|
"""Context manager for database transactions."""
|
|
try:
|
|
yield self.connection
|
|
self.connection.commit()
|
|
except Exception:
|
|
self.connection.rollback()
|
|
raise
|
|
|
|
def insert_trade(self, trade: Trade) -> int:
|
|
"""
|
|
Insert a new trade record.
|
|
|
|
Args:
|
|
trade: Trade object to insert
|
|
|
|
Returns:
|
|
Row ID of inserted trade
|
|
"""
|
|
sql = """
|
|
INSERT INTO trades (
|
|
trade_id, symbol, side, entry_price, exit_price,
|
|
size, size_usdt, pnl_usd, pnl_pct, entry_time,
|
|
exit_time, hold_duration_hours, reason,
|
|
order_id_entry, order_id_exit
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
"""
|
|
with self.transaction():
|
|
cursor = self.connection.execute(
|
|
sql,
|
|
(
|
|
trade.trade_id,
|
|
trade.symbol,
|
|
trade.side,
|
|
trade.entry_price,
|
|
trade.exit_price,
|
|
trade.size,
|
|
trade.size_usdt,
|
|
trade.pnl_usd,
|
|
trade.pnl_pct,
|
|
trade.entry_time,
|
|
trade.exit_time,
|
|
trade.hold_duration_hours,
|
|
trade.reason,
|
|
trade.order_id_entry,
|
|
trade.order_id_exit,
|
|
),
|
|
)
|
|
return cursor.lastrowid
|
|
|
|
def update_trade(self, trade_id: str, **kwargs) -> bool:
|
|
"""
|
|
Update an existing trade record.
|
|
|
|
Args:
|
|
trade_id: Trade ID to update
|
|
**kwargs: Fields to update
|
|
|
|
Returns:
|
|
True if trade was updated
|
|
"""
|
|
if not kwargs:
|
|
return False
|
|
|
|
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
|
|
sql = f"UPDATE trades SET {set_clause} WHERE trade_id = ?"
|
|
|
|
with self.transaction():
|
|
cursor = self.connection.execute(
|
|
sql, (*kwargs.values(), trade_id)
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def get_trade(self, trade_id: str) -> Optional[Trade]:
|
|
"""Get a trade by ID."""
|
|
sql = "SELECT * FROM trades WHERE trade_id = ?"
|
|
row = self.connection.execute(sql, (trade_id,)).fetchone()
|
|
if row:
|
|
return Trade(**dict(row))
|
|
return None
|
|
|
|
def get_trades(
|
|
self,
|
|
start_time: Optional[str] = None,
|
|
end_time: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
) -> list[Trade]:
|
|
"""
|
|
Get trades within a time range.
|
|
|
|
Args:
|
|
start_time: ISO format start time filter
|
|
end_time: ISO format end time filter
|
|
limit: Maximum number of trades to return
|
|
|
|
Returns:
|
|
List of Trade objects
|
|
"""
|
|
conditions = []
|
|
params = []
|
|
|
|
if start_time:
|
|
conditions.append("entry_time >= ?")
|
|
params.append(start_time)
|
|
if end_time:
|
|
conditions.append("entry_time <= ?")
|
|
params.append(end_time)
|
|
|
|
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
|
limit_clause = f"LIMIT {limit}" if limit else ""
|
|
|
|
sql = f"""
|
|
SELECT * FROM trades
|
|
WHERE {where_clause}
|
|
ORDER BY entry_time DESC
|
|
{limit_clause}
|
|
"""
|
|
|
|
rows = self.connection.execute(sql, params).fetchall()
|
|
return [Trade(**dict(row)) for row in rows]
|
|
|
|
def get_all_trades(self) -> list[Trade]:
|
|
"""Get all trades."""
|
|
sql = "SELECT * FROM trades ORDER BY entry_time DESC"
|
|
rows = self.connection.execute(sql).fetchall()
|
|
return [Trade(**dict(row)) for row in rows]
|
|
|
|
def count_trades(self) -> int:
|
|
"""Get total number of trades."""
|
|
sql = "SELECT COUNT(*) FROM trades WHERE exit_time IS NOT NULL"
|
|
return self.connection.execute(sql).fetchone()[0]
|
|
|
|
def upsert_daily_summary(self, summary: DailySummary) -> None:
|
|
"""Insert or update daily summary."""
|
|
sql = """
|
|
INSERT INTO daily_summary (
|
|
date, total_trades, winning_trades, total_pnl_usd, max_drawdown_usd
|
|
) VALUES (?, ?, ?, ?, ?)
|
|
ON CONFLICT(date) DO UPDATE SET
|
|
total_trades = excluded.total_trades,
|
|
winning_trades = excluded.winning_trades,
|
|
total_pnl_usd = excluded.total_pnl_usd,
|
|
max_drawdown_usd = excluded.max_drawdown_usd,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
"""
|
|
with self.transaction():
|
|
self.connection.execute(
|
|
sql,
|
|
(
|
|
summary.date,
|
|
summary.total_trades,
|
|
summary.winning_trades,
|
|
summary.total_pnl_usd,
|
|
summary.max_drawdown_usd,
|
|
),
|
|
)
|
|
|
|
def get_daily_summary(self, date: str) -> Optional[DailySummary]:
|
|
"""Get daily summary for a specific date."""
|
|
sql = "SELECT * FROM daily_summary WHERE date = ?"
|
|
row = self.connection.execute(sql, (date,)).fetchone()
|
|
if row:
|
|
return DailySummary(**dict(row))
|
|
return None
|
|
|
|
def insert_session(self, session: Session) -> int:
|
|
"""Insert a new session record."""
|
|
sql = """
|
|
INSERT INTO sessions (
|
|
start_time, end_time, starting_balance,
|
|
ending_balance, total_pnl, total_trades
|
|
) VALUES (?, ?, ?, ?, ?, ?)
|
|
"""
|
|
with self.transaction():
|
|
cursor = self.connection.execute(
|
|
sql,
|
|
(
|
|
session.start_time,
|
|
session.end_time,
|
|
session.starting_balance,
|
|
session.ending_balance,
|
|
session.total_pnl,
|
|
session.total_trades,
|
|
),
|
|
)
|
|
return cursor.lastrowid
|
|
|
|
def update_session(self, session_id: int, **kwargs) -> bool:
|
|
"""Update an existing session."""
|
|
if not kwargs:
|
|
return False
|
|
|
|
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
|
|
sql = f"UPDATE sessions SET {set_clause} WHERE id = ?"
|
|
|
|
with self.transaction():
|
|
cursor = self.connection.execute(
|
|
sql, (*kwargs.values(), session_id)
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def get_latest_session(self) -> Optional[Session]:
|
|
"""Get the most recent session."""
|
|
sql = "SELECT * FROM sessions ORDER BY id DESC LIMIT 1"
|
|
row = self.connection.execute(sql).fetchone()
|
|
if row:
|
|
return Session(**dict(row))
|
|
return None
|
|
|
|
|
|
def init_db(db_path: Path) -> TradingDatabase:
|
|
"""
|
|
Initialize the database.
|
|
|
|
Args:
|
|
db_path: Path to the SQLite database file
|
|
|
|
Returns:
|
|
TradingDatabase instance
|
|
"""
|
|
global _db_instance
|
|
_db_instance = TradingDatabase(db_path)
|
|
_db_instance.init_schema()
|
|
return _db_instance
|
|
|
|
|
|
def get_db() -> Optional[TradingDatabase]:
|
|
"""Get the global database instance."""
|
|
return _db_instance
|