"""Repository for market_data table operations.""" from datetime import datetime from typing import List, Optional, Dict, Any from sqlalchemy import desc from sqlalchemy.dialects.postgresql import insert from ..models import MarketData from data.common.data_types import OHLCVCandle from .base_repository import BaseRepository, DatabaseOperationError class MarketDataRepository(BaseRepository): """Repository for market_data table operations.""" def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool: """ Insert or update a candle in the market_data table using the ORM. """ try: with self.get_session() as session: values = { 'exchange': candle.exchange, 'symbol': candle.symbol, 'timeframe': candle.timeframe, 'timestamp': candle.end_time, 'open': candle.open, 'high': candle.high, 'low': candle.low, 'close': candle.close, 'volume': candle.volume, 'trades_count': candle.trade_count } stmt = insert(MarketData).values(values) if force_update: update_stmt = stmt.on_conflict_do_update( index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'], set_={ 'open': stmt.excluded.open, 'high': stmt.excluded.high, 'low': stmt.excluded.low, 'close': stmt.excluded.close, 'volume': stmt.excluded.volume, 'trades_count': stmt.excluded.trades_count } ) action = "Updated" final_stmt = update_stmt else: ignore_stmt = stmt.on_conflict_do_nothing( index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'] ) action = "Stored" final_stmt = ignore_stmt session.execute(final_stmt) session.commit() self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle.end_time} (force_update={force_update})") return True except Exception as e: self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}") raise DatabaseOperationError(f"Failed to store candle: {e}") def get_candles(self, symbol: str, timeframe: str, start_time: datetime, end_time: datetime, exchange: str = "okx") -> List[Dict[str, Any]]: """ Retrieve candles from the database using the ORM. """ try: with self.get_session() as session: query = ( session.query(MarketData) .filter( MarketData.exchange == exchange, MarketData.symbol == symbol, MarketData.timeframe == timeframe, MarketData.timestamp >= start_time, MarketData.timestamp <= end_time ) .order_by(MarketData.timestamp.asc()) ) results = query.all() candles = [ { "exchange": r.exchange, "symbol": r.symbol, "timeframe": r.timeframe, "timestamp": r.timestamp, "open": r.open, "high": r.high, "low": r.low, "close": r.close, "volume": r.volume, "trades_count": r.trades_count, "created_at": r.created_at } for r in results ] self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}") return candles except Exception as e: self.log_error(f"Error retrieving candles: {e}") raise DatabaseOperationError(f"Failed to retrieve candles: {e}") def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]: """ Get the latest candle for a symbol and timeframe using the ORM. """ try: with self.get_session() as session: latest = ( session.query(MarketData) .filter( MarketData.exchange == exchange, MarketData.symbol == symbol, MarketData.timeframe == timeframe ) .order_by(MarketData.timestamp.desc()) .first() ) if latest: return { "exchange": latest.exchange, "symbol": latest.symbol, "timeframe": latest.timeframe, "timestamp": latest.timestamp, "open": latest.open, "high": latest.high, "low": latest.low, "close": latest.close, "volume": latest.volume, "trades_count": latest.trades_count, "created_at": latest.created_at } return None except Exception as e: self.log_error(f"Error retrieving latest candle for {symbol} {timeframe}: {e}") raise DatabaseOperationError(f"Failed to retrieve latest candle: {e}")