""" Database Operations Module This module provides centralized database operations for all tables, following the Repository pattern to abstract SQL complexity from business logic. Benefits: - Centralized database operations - Clean API for different tables - Easy to test and maintain - Database implementation can change without affecting business logic - Consistent error handling and logging """ from datetime import datetime from decimal import Decimal from typing import List, Optional, Dict, Any, Union from contextlib import contextmanager import logging import json from sqlalchemy import text from .connection import get_db_manager from .models import MarketData, RawTrade from data.common.data_types import OHLCVCandle, StandardizedTrade from data.base_collector import MarketDataPoint, DataType class DatabaseOperationError(Exception): """Custom exception for database operation errors.""" pass class BaseRepository: """Base class for all repository classes.""" def __init__(self, logger: Optional[logging.Logger] = None): """Initialize repository with optional logger.""" self.logger = logger self._db_manager = get_db_manager() self._db_manager.initialize() def log_info(self, message: str) -> None: """Log info message if logger is available.""" if self.logger: self.logger.info(message) def log_debug(self, message: str) -> None: """Log debug message if logger is available.""" if self.logger: self.logger.debug(message) def log_error(self, message: str) -> None: """Log error message if logger is available.""" if self.logger: self.logger.error(message) @contextmanager def get_session(self): """Get database session with automatic cleanup.""" if not self._db_manager: raise DatabaseOperationError("Database manager not initialized") with self._db_manager.get_session() as session: yield session class MarketDataRepository(BaseRepository): """Repository for market_data table operations.""" def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool: """ Insert or update a candle in the market_data table. Args: candle: OHLCV candle to store force_update: If True, update existing records; if False, ignore duplicates Returns: True if operation successful, False otherwise """ try: # Use right-aligned timestamp (end_time) following industry standard candle_timestamp = candle.end_time with self.get_session() as session: if force_update: # Update existing records with new data query = text(""" INSERT INTO market_data ( exchange, symbol, timeframe, timestamp, open, high, low, close, volume, trades_count, created_at ) VALUES ( :exchange, :symbol, :timeframe, :timestamp, :open, :high, :low, :close, :volume, :trades_count, NOW() ) ON CONFLICT (exchange, symbol, timeframe, timestamp) DO UPDATE SET open = EXCLUDED.open, high = EXCLUDED.high, low = EXCLUDED.low, close = EXCLUDED.close, volume = EXCLUDED.volume, trades_count = EXCLUDED.trades_count """) action = "Updated" else: # Ignore duplicates, keep existing records query = text(""" INSERT INTO market_data ( exchange, symbol, timeframe, timestamp, open, high, low, close, volume, trades_count, created_at ) VALUES ( :exchange, :symbol, :timeframe, :timestamp, :open, :high, :low, :close, :volume, :trades_count, NOW() ) ON CONFLICT (exchange, symbol, timeframe, timestamp) DO NOTHING """) action = "Stored" session.execute(query, { 'exchange': candle.exchange, 'symbol': candle.symbol, 'timeframe': candle.timeframe, 'timestamp': candle_timestamp, 'open': float(candle.open), 'high': float(candle.high), 'low': float(candle.low), 'close': float(candle.close), 'volume': float(candle.volume), 'trades_count': candle.trade_count }) session.commit() self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})") return True except Exception as e: self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}") raise DatabaseOperationError(f"Failed to store candle: {e}") def get_candles(self, symbol: str, timeframe: str, start_time: datetime, end_time: datetime, exchange: str = "okx") -> List[Dict[str, Any]]: """ Retrieve candles from the database. Args: symbol: Trading symbol timeframe: Candle timeframe start_time: Start timestamp end_time: End timestamp exchange: Exchange name Returns: List of candle dictionaries """ try: with self.get_session() as session: query = text(""" SELECT exchange, symbol, timeframe, timestamp, open, high, low, close, volume, trades_count, created_at, updated_at FROM market_data WHERE exchange = :exchange AND symbol = :symbol AND timeframe = :timeframe AND timestamp >= :start_time AND timestamp <= :end_time ORDER BY timestamp ASC """) result = session.execute(query, { 'exchange': exchange, 'symbol': symbol, 'timeframe': timeframe, 'start_time': start_time, 'end_time': end_time }) candles = [] for row in result: candles.append({ 'exchange': row.exchange, 'symbol': row.symbol, 'timeframe': row.timeframe, 'timestamp': row.timestamp, 'open': row.open, 'high': row.high, 'low': row.low, 'close': row.close, 'volume': row.volume, 'trades_count': row.trades_count, 'created_at': row.created_at, 'updated_at': row.updated_at }) self.log_info(f"Retrieved {len(candles)} candles for {symbol} {timeframe}") return candles except Exception as e: self.log_error(f"Error retrieving candles for {symbol} {timeframe}: {e}") raise DatabaseOperationError(f"Failed to retrieve candles: {e}") def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]: """ Get the latest candle for a symbol and timeframe. Args: symbol: Trading symbol timeframe: Candle timeframe exchange: Exchange name Returns: Latest candle dictionary or None """ try: with self.get_session() as session: query = text(""" SELECT exchange, symbol, timeframe, timestamp, open, high, low, close, volume, trades_count, created_at, updated_at FROM market_data WHERE exchange = :exchange AND symbol = :symbol AND timeframe = :timeframe ORDER BY timestamp DESC LIMIT 1 """) result = session.execute(query, { 'exchange': exchange, 'symbol': symbol, 'timeframe': timeframe }) row = result.fetchone() if row: return { 'exchange': row.exchange, 'symbol': row.symbol, 'timeframe': row.timeframe, 'timestamp': row.timestamp, 'open': row.open, 'high': row.high, 'low': row.low, 'close': row.close, 'volume': row.volume, 'trades_count': row.trades_count, 'created_at': row.created_at, 'updated_at': row.updated_at } return None except Exception as e: self.log_error(f"Error retrieving latest candle for {symbol} {timeframe}: {e}") raise DatabaseOperationError(f"Failed to retrieve latest candle: {e}") class RawTradeRepository(BaseRepository): """Repository for raw_trades table operations.""" def insert_market_data_point(self, data_point: MarketDataPoint) -> bool: """ Insert a market data point into raw_trades table. Args: data_point: Market data point to store Returns: True if operation successful, False otherwise """ try: with self.get_session() as session: query = text(""" INSERT INTO raw_trades ( exchange, symbol, timestamp, data_type, raw_data, created_at ) VALUES ( :exchange, :symbol, :timestamp, :data_type, :raw_data, NOW() ) """) session.execute(query, { 'exchange': data_point.exchange, 'symbol': data_point.symbol, 'timestamp': data_point.timestamp, 'data_type': data_point.data_type.value, 'raw_data': json.dumps(data_point.data) }) session.commit() self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}") return True except Exception as e: self.log_error(f"Error storing raw data for {data_point.symbol}: {e}") raise DatabaseOperationError(f"Failed to store raw data: {e}") def insert_raw_websocket_data(self, exchange: str, symbol: str, data_type: str, raw_data: Dict[str, Any], timestamp: Optional[datetime] = None) -> bool: """ Insert raw WebSocket data for debugging purposes. Args: exchange: Exchange name symbol: Trading symbol data_type: Type of data (e.g., 'raw_trades', 'raw_orderbook') raw_data: Raw data dictionary timestamp: Optional timestamp (defaults to now) Returns: True if operation successful, False otherwise """ try: if timestamp is None: timestamp = datetime.now() with self.get_session() as session: query = text(""" INSERT INTO raw_trades ( exchange, symbol, timestamp, data_type, raw_data, created_at ) VALUES ( :exchange, :symbol, :timestamp, :data_type, :raw_data, NOW() ) """) session.execute(query, { 'exchange': exchange, 'symbol': symbol, 'timestamp': timestamp, 'data_type': data_type, 'raw_data': json.dumps(raw_data) }) session.commit() self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}") return True except Exception as e: self.log_error(f"Error storing raw WebSocket data for {symbol}: {e}") raise DatabaseOperationError(f"Failed to store raw WebSocket data: {e}") def get_raw_trades(self, symbol: str, data_type: str, start_time: datetime, end_time: datetime, exchange: str = "okx", limit: Optional[int] = None) -> List[Dict[str, Any]]: """ Retrieve raw trades from the database. Args: symbol: Trading symbol data_type: Data type filter start_time: Start timestamp end_time: End timestamp exchange: Exchange name limit: Maximum number of records to return Returns: List of raw trade dictionaries """ try: with self.get_session() as session: query = text(""" SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at FROM raw_trades WHERE exchange = :exchange AND symbol = :symbol AND data_type = :data_type AND timestamp >= :start_time AND timestamp <= :end_time ORDER BY timestamp ASC """) if limit: query += f" LIMIT {limit}" result = session.execute(query, { 'exchange': exchange, 'symbol': symbol, 'data_type': data_type, 'start_time': start_time, 'end_time': end_time }) trades = [] for row in result: trades.append({ 'id': row.id, 'exchange': row.exchange, 'symbol': row.symbol, 'timestamp': row.timestamp, 'data_type': row.data_type, 'raw_data': row.raw_data, 'created_at': row.created_at }) self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}") return trades except Exception as e: self.log_error(f"Error retrieving raw trades for {symbol}: {e}") raise DatabaseOperationError(f"Failed to retrieve raw trades: {e}") class DatabaseOperations: """ Main database operations manager that provides access to all repositories. This is the main entry point for database operations, providing a centralized interface to all table-specific repositories. """ def __init__(self, logger: Optional[logging.Logger] = None): """Initialize database operations with optional logger.""" self.logger = logger # Initialize repositories self.market_data = MarketDataRepository(logger) self.raw_trades = RawTradeRepository(logger) def health_check(self) -> bool: """ Perform a health check on database connections. Returns: True if database is healthy, False otherwise """ try: with self.market_data.get_session() as session: # Simple query to test connection result = session.execute(text("SELECT 1")) return result.fetchone() is not None except Exception as e: if self.logger: self.logger.error(f"Database health check failed: {e}") return False def get_stats(self) -> Dict[str, Any]: """ Get database statistics. Returns: Dictionary containing database statistics """ try: stats = { 'healthy': self.health_check(), 'repositories': { 'market_data': 'MarketDataRepository', 'raw_trades': 'RawTradeRepository' } } # Get table counts with self.market_data.get_session() as session: # Market data count result = session.execute(text("SELECT COUNT(*) FROM market_data")) stats['candle_count'] = result.fetchone()[0] # Raw trades count result = session.execute(text("SELECT COUNT(*) FROM raw_trades")) stats['raw_trade_count'] = result.fetchone()[0] return stats except Exception as e: if self.logger: self.logger.error(f"Error getting database stats: {e}") return {'healthy': False, 'error': str(e)} # Singleton instance for global access _db_operations_instance: Optional[DatabaseOperations] = None def get_database_operations(logger: Optional[logging.Logger] = None) -> DatabaseOperations: """ Get the global database operations instance. Args: logger: Optional logger for database operations Returns: DatabaseOperations instance """ global _db_operations_instance if _db_operations_instance is None: _db_operations_instance = DatabaseOperations(logger) return _db_operations_instance __all__ = [ 'DatabaseOperationError', 'MarketDataRepository', 'RawTradeRepository', 'DatabaseOperations', 'get_database_operations' ]