"""SQLite database connection and operations.""" import sqlite3 import logging from pathlib import Path from typing import Optional from contextlib import contextmanager from .models import Trade, DailySummary, Session logger = logging.getLogger(__name__) # Database schema SCHEMA = """ -- Trade history table CREATE TABLE IF NOT EXISTS trades ( id INTEGER PRIMARY KEY AUTOINCREMENT, trade_id TEXT UNIQUE NOT NULL, symbol TEXT NOT NULL, side TEXT NOT NULL, entry_price REAL NOT NULL, exit_price REAL, size REAL NOT NULL, size_usdt REAL NOT NULL, pnl_usd REAL, pnl_pct REAL, entry_time TEXT NOT NULL, exit_time TEXT, hold_duration_hours REAL, reason TEXT, order_id_entry TEXT, order_id_exit TEXT, created_at TEXT DEFAULT CURRENT_TIMESTAMP ); -- Daily summary table CREATE TABLE IF NOT EXISTS daily_summary ( id INTEGER PRIMARY KEY AUTOINCREMENT, date TEXT UNIQUE NOT NULL, total_trades INTEGER DEFAULT 0, winning_trades INTEGER DEFAULT 0, total_pnl_usd REAL DEFAULT 0, max_drawdown_usd REAL DEFAULT 0, updated_at TEXT DEFAULT CURRENT_TIMESTAMP ); -- Session metadata CREATE TABLE IF NOT EXISTS sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, start_time TEXT NOT NULL, end_time TEXT, starting_balance REAL, ending_balance REAL, total_pnl REAL, total_trades INTEGER DEFAULT 0 ); -- Indexes for common queries CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades(entry_time); CREATE INDEX IF NOT EXISTS idx_trades_exit_time ON trades(exit_time); CREATE INDEX IF NOT EXISTS idx_daily_summary_date ON daily_summary(date); """ _db_instance: Optional["TradingDatabase"] = None class TradingDatabase: """SQLite database for trade persistence.""" def __init__(self, db_path: Path): self.db_path = db_path self._connection: Optional[sqlite3.Connection] = None @property def connection(self) -> sqlite3.Connection: """Get or create database connection.""" if self._connection is None: self._connection = sqlite3.connect( str(self.db_path), check_same_thread=False, ) self._connection.row_factory = sqlite3.Row return self._connection def init_schema(self) -> None: """Initialize database schema.""" with self.connection: self.connection.executescript(SCHEMA) logger.info(f"Database initialized at {self.db_path}") def close(self) -> None: """Close database connection.""" if self._connection: self._connection.close() self._connection = None @contextmanager def transaction(self): """Context manager for database transactions.""" try: yield self.connection self.connection.commit() except Exception: self.connection.rollback() raise def insert_trade(self, trade: Trade) -> int: """ Insert a new trade record. Args: trade: Trade object to insert Returns: Row ID of inserted trade """ sql = """ INSERT INTO trades ( trade_id, symbol, side, entry_price, exit_price, size, size_usdt, pnl_usd, pnl_pct, entry_time, exit_time, hold_duration_hours, reason, order_id_entry, order_id_exit ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ with self.transaction(): cursor = self.connection.execute( sql, ( trade.trade_id, trade.symbol, trade.side, trade.entry_price, trade.exit_price, trade.size, trade.size_usdt, trade.pnl_usd, trade.pnl_pct, trade.entry_time, trade.exit_time, trade.hold_duration_hours, trade.reason, trade.order_id_entry, trade.order_id_exit, ), ) return cursor.lastrowid def update_trade(self, trade_id: str, **kwargs) -> bool: """ Update an existing trade record. Args: trade_id: Trade ID to update **kwargs: Fields to update Returns: True if trade was updated """ if not kwargs: return False set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys()) sql = f"UPDATE trades SET {set_clause} WHERE trade_id = ?" with self.transaction(): cursor = self.connection.execute( sql, (*kwargs.values(), trade_id) ) return cursor.rowcount > 0 def get_trade(self, trade_id: str) -> Optional[Trade]: """Get a trade by ID.""" sql = "SELECT * FROM trades WHERE trade_id = ?" row = self.connection.execute(sql, (trade_id,)).fetchone() if row: return Trade(**dict(row)) return None def get_trades( self, start_time: Optional[str] = None, end_time: Optional[str] = None, limit: Optional[int] = None, ) -> list[Trade]: """ Get trades within a time range. Args: start_time: ISO format start time filter end_time: ISO format end time filter limit: Maximum number of trades to return Returns: List of Trade objects """ conditions = [] params = [] if start_time: conditions.append("entry_time >= ?") params.append(start_time) if end_time: conditions.append("entry_time <= ?") params.append(end_time) where_clause = " AND ".join(conditions) if conditions else "1=1" limit_clause = f"LIMIT {limit}" if limit else "" sql = f""" SELECT * FROM trades WHERE {where_clause} ORDER BY entry_time DESC {limit_clause} """ rows = self.connection.execute(sql, params).fetchall() return [Trade(**dict(row)) for row in rows] def get_all_trades(self) -> list[Trade]: """Get all trades.""" sql = "SELECT * FROM trades ORDER BY entry_time DESC" rows = self.connection.execute(sql).fetchall() return [Trade(**dict(row)) for row in rows] def count_trades(self) -> int: """Get total number of trades.""" sql = "SELECT COUNT(*) FROM trades WHERE exit_time IS NOT NULL" return self.connection.execute(sql).fetchone()[0] def upsert_daily_summary(self, summary: DailySummary) -> None: """Insert or update daily summary.""" sql = """ INSERT INTO daily_summary ( date, total_trades, winning_trades, total_pnl_usd, max_drawdown_usd ) VALUES (?, ?, ?, ?, ?) ON CONFLICT(date) DO UPDATE SET total_trades = excluded.total_trades, winning_trades = excluded.winning_trades, total_pnl_usd = excluded.total_pnl_usd, max_drawdown_usd = excluded.max_drawdown_usd, updated_at = CURRENT_TIMESTAMP """ with self.transaction(): self.connection.execute( sql, ( summary.date, summary.total_trades, summary.winning_trades, summary.total_pnl_usd, summary.max_drawdown_usd, ), ) def get_daily_summary(self, date: str) -> Optional[DailySummary]: """Get daily summary for a specific date.""" sql = "SELECT * FROM daily_summary WHERE date = ?" row = self.connection.execute(sql, (date,)).fetchone() if row: return DailySummary(**dict(row)) return None def insert_session(self, session: Session) -> int: """Insert a new session record.""" sql = """ INSERT INTO sessions ( start_time, end_time, starting_balance, ending_balance, total_pnl, total_trades ) VALUES (?, ?, ?, ?, ?, ?) """ with self.transaction(): cursor = self.connection.execute( sql, ( session.start_time, session.end_time, session.starting_balance, session.ending_balance, session.total_pnl, session.total_trades, ), ) return cursor.lastrowid def update_session(self, session_id: int, **kwargs) -> bool: """Update an existing session.""" if not kwargs: return False set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys()) sql = f"UPDATE sessions SET {set_clause} WHERE id = ?" with self.transaction(): cursor = self.connection.execute( sql, (*kwargs.values(), session_id) ) return cursor.rowcount > 0 def get_latest_session(self) -> Optional[Session]: """Get the most recent session.""" sql = "SELECT * FROM sessions ORDER BY id DESC LIMIT 1" row = self.connection.execute(sql).fetchone() if row: return Session(**dict(row)) return None def init_db(db_path: Path) -> TradingDatabase: """ Initialize the database. Args: db_path: Path to the SQLite database file Returns: TradingDatabase instance """ global _db_instance _db_instance = TradingDatabase(db_path) _db_instance.init_schema() return _db_instance def get_db() -> Optional[TradingDatabase]: """Get the global database instance.""" return _db_instance