diff --git a/CONTEXT.md b/CONTEXT.md index c7873b9..1fc83cc 100644 --- a/CONTEXT.md +++ b/CONTEXT.md @@ -65,7 +65,7 @@ The platform is a **monolithic application** built with Python, designed for rap - **Logging**: A unified logging system is available in `utils/logger.py` and should be used across all components for consistent output. - **Type Hinting**: Mandatory for all function signatures (parameters and return values) for clarity and static analysis. - **Error Handling**: Custom, specific exceptions should be defined (e.g., `DataCollectorError`). Use `try...except` blocks to handle potential failures gracefully and provide informative error messages. -- **Database Access**: A `DatabaseManager` in `database/connection.py` provides a centralized way to handle database sessions and connections. All database operations should ideally go through an operations/repository layer. +- **Database Access**: All database operations must go through the repository layer, accessible via `database.operations.get_database_operations()`. The repositories exclusively use the **SQLAlchemy ORM** for all queries to ensure type safety, maintainability, and consistency. Raw SQL is strictly forbidden in the repository layer to maintain database-agnostic flexibility. ## 4. Current Implementation Status diff --git a/database/repositories/market_data_repository.py b/database/repositories/market_data_repository.py index af65e13..e151682 100644 --- a/database/repositories/market_data_repository.py +++ b/database/repositories/market_data_repository.py @@ -2,7 +2,9 @@ from datetime import datetime from typing import List, Optional, Dict, Any -from sqlalchemy import text + +from sqlalchemy import desc +from sqlalchemy.dialects.postgresql import insert from ..models import MarketData from data.common.data_types import OHLCVCandle @@ -14,65 +16,51 @@ class MarketDataRepository(BaseRepository): def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool: """ - Insert or update a candle in the market_data table. + Insert or update a candle in the market_data table using the ORM. """ try: - candle_timestamp = candle.end_time - with self.get_session() as session: - if force_update: - query = text(""" - INSERT INTO market_data ( - exchange, symbol, timeframe, timestamp, - open, high, low, close, volume, trades_count, - created_at - ) VALUES ( - :exchange, :symbol, :timeframe, :timestamp, - :open, :high, :low, :close, :volume, :trades_count, - NOW() - ) - ON CONFLICT (exchange, symbol, timeframe, timestamp) - DO UPDATE SET - open = EXCLUDED.open, - high = EXCLUDED.high, - low = EXCLUDED.low, - close = EXCLUDED.close, - volume = EXCLUDED.volume, - trades_count = EXCLUDED.trades_count - """) - action = "Updated" - else: - query = text(""" - INSERT INTO market_data ( - exchange, symbol, timeframe, timestamp, - open, high, low, close, volume, trades_count, - created_at - ) VALUES ( - :exchange, :symbol, :timeframe, :timestamp, - :open, :high, :low, :close, :volume, :trades_count, - NOW() - ) - ON CONFLICT (exchange, symbol, timeframe, timestamp) - DO NOTHING - """) - action = "Stored" - session.execute(query, { + values = { 'exchange': candle.exchange, 'symbol': candle.symbol, 'timeframe': candle.timeframe, - 'timestamp': candle_timestamp, - 'open': float(candle.open), - 'high': float(candle.high), - 'low': float(candle.low), - 'close': float(candle.close), - 'volume': float(candle.volume), + 'timestamp': candle.end_time, + 'open': candle.open, + 'high': candle.high, + 'low': candle.low, + 'close': candle.close, + 'volume': candle.volume, 'trades_count': candle.trade_count - }) + } + stmt = insert(MarketData).values(values) + + if force_update: + update_stmt = stmt.on_conflict_do_update( + index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'], + set_={ + 'open': stmt.excluded.open, + 'high': stmt.excluded.high, + 'low': stmt.excluded.low, + 'close': stmt.excluded.close, + 'volume': stmt.excluded.volume, + 'trades_count': stmt.excluded.trades_count + } + ) + action = "Updated" + final_stmt = update_stmt + else: + ignore_stmt = stmt.on_conflict_do_nothing( + index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'] + ) + action = "Stored" + final_stmt = ignore_stmt + + session.execute(final_stmt) session.commit() - self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})") + self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle.end_time} (force_update={force_update})") return True except Exception as e: @@ -86,32 +74,32 @@ class MarketDataRepository(BaseRepository): end_time: datetime, exchange: str = "okx") -> List[Dict[str, Any]]: """ - Retrieve candles from the database. + Retrieve candles from the database using the ORM. """ try: with self.get_session() as session: - query = text(""" - SELECT exchange, symbol, timeframe, timestamp, - open, high, low, close, volume, trades_count, - created_at - FROM market_data - WHERE exchange = :exchange - AND symbol = :symbol - AND timeframe = :timeframe - AND timestamp >= :start_time - AND timestamp <= :end_time - ORDER BY timestamp ASC - """) + query = ( + session.query(MarketData) + .filter( + MarketData.exchange == exchange, + MarketData.symbol == symbol, + MarketData.timeframe == timeframe, + MarketData.timestamp >= start_time, + MarketData.timestamp <= end_time + ) + .order_by(MarketData.timestamp.asc()) + ) - result = session.execute(query, { - 'exchange': exchange, - 'symbol': symbol, - 'timeframe': timeframe, - 'start_time': start_time, - 'end_time': end_time - }) + results = query.all() - candles = [dict(row._mapping) for row in result] + candles = [ + { + "exchange": r.exchange, "symbol": r.symbol, "timeframe": r.timeframe, + "timestamp": r.timestamp, "open": r.open, "high": r.high, + "low": r.low, "close": r.close, "volume": r.volume, + "trades_count": r.trades_count, "created_at": r.created_at + } for r in results + ] self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}") return candles @@ -122,31 +110,28 @@ class MarketDataRepository(BaseRepository): def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]: """ - Get the latest candle for a symbol and timeframe. + Get the latest candle for a symbol and timeframe using the ORM. """ try: with self.get_session() as session: - query = text(""" - SELECT exchange, symbol, timeframe, timestamp, - open, high, low, close, volume, trades_count, - created_at - FROM market_data - WHERE exchange = :exchange - AND symbol = :symbol - AND timeframe = :timeframe - ORDER BY timestamp DESC - LIMIT 1 - """) - - result = session.execute(query, { - 'exchange': exchange, - 'symbol': symbol, - 'timeframe': timeframe - }) - - row = result.fetchone() - if row: - return dict(row._mapping) + latest = ( + session.query(MarketData) + .filter( + MarketData.exchange == exchange, + MarketData.symbol == symbol, + MarketData.timeframe == timeframe + ) + .order_by(MarketData.timestamp.desc()) + .first() + ) + + if latest: + return { + "exchange": latest.exchange, "symbol": latest.symbol, "timeframe": latest.timeframe, + "timestamp": latest.timestamp, "open": latest.open, "high": latest.high, + "low": latest.low, "close": latest.close, "volume": latest.volume, + "trades_count": latest.trades_count, "created_at": latest.created_at + } return None except Exception as e: diff --git a/database/repositories/raw_trade_repository.py b/database/repositories/raw_trade_repository.py index d30547c..cdeaa85 100644 --- a/database/repositories/raw_trade_repository.py +++ b/database/repositories/raw_trade_repository.py @@ -1,9 +1,9 @@ """Repository for raw_trades table operations.""" -import json from datetime import datetime from typing import Dict, Any, Optional, List -from sqlalchemy import text + +from sqlalchemy import desc from ..models import RawTrade from data.base_collector import MarketDataPoint @@ -15,26 +15,18 @@ class RawTradeRepository(BaseRepository): def insert_market_data_point(self, data_point: MarketDataPoint) -> bool: """ - Insert a market data point into raw_trades table. + Insert a market data point into raw_trades table using the ORM. """ try: with self.get_session() as session: - query = text(""" - INSERT INTO raw_trades ( - exchange, symbol, timestamp, data_type, raw_data, created_at - ) VALUES ( - :exchange, :symbol, :timestamp, :data_type, :raw_data, NOW() - ) - """) - - session.execute(query, { - 'exchange': data_point.exchange, - 'symbol': data_point.symbol, - 'timestamp': data_point.timestamp, - 'data_type': data_point.data_type.value, - 'raw_data': json.dumps(data_point.data) - }) - + new_trade = RawTrade( + exchange=data_point.exchange, + symbol=data_point.symbol, + timestamp=data_point.timestamp, + data_type=data_point.data_type.value, + raw_data=data_point.data + ) + session.add(new_trade) session.commit() self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}") @@ -51,29 +43,18 @@ class RawTradeRepository(BaseRepository): raw_data: Dict[str, Any], timestamp: Optional[datetime] = None) -> bool: """ - Insert raw WebSocket data for debugging purposes. + Insert raw WebSocket data for debugging purposes using the ORM. """ try: - if timestamp is None: - timestamp = datetime.now() - with self.get_session() as session: - query = text(""" - INSERT INTO raw_trades ( - exchange, symbol, timestamp, data_type, raw_data, created_at - ) VALUES ( - :exchange, :symbol, :timestamp, :data_type, :raw_data, NOW() - ) - """) - - session.execute(query, { - 'exchange': exchange, - 'symbol': symbol, - 'timestamp': timestamp, - 'data_type': data_type, - 'raw_data': json.dumps(raw_data) - }) - + new_trade = RawTrade( + exchange=exchange, + symbol=symbol, + timestamp=timestamp or datetime.now(datetime.timezone.utc), + data_type=data_type, + raw_data=raw_data + ) + session.add(new_trade) session.commit() self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}") @@ -91,34 +72,34 @@ class RawTradeRepository(BaseRepository): exchange: str = "okx", limit: Optional[int] = None) -> List[Dict[str, Any]]: """ - Retrieve raw trades from the database. + Retrieve raw trades from the database using the ORM. """ try: with self.get_session() as session: - query = text(""" - SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at - FROM raw_trades - WHERE exchange = :exchange - AND symbol = :symbol - AND data_type = :data_type - AND timestamp >= :start_time - AND timestamp <= :end_time - ORDER BY timestamp ASC - """) + query = ( + session.query(RawTrade) + .filter( + RawTrade.exchange == exchange, + RawTrade.symbol == symbol, + RawTrade.data_type == data_type, + RawTrade.timestamp >= start_time, + RawTrade.timestamp <= end_time + ) + .order_by(RawTrade.timestamp.asc()) + ) if limit: - query_str = str(query.compile(compile_kwargs={"literal_binds": True})) + f" LIMIT {limit}" - query = text(query_str) + query = query.limit(limit) - result = session.execute(query, { - 'exchange': exchange, - 'symbol': symbol, - 'data_type': data_type, - 'start_time': start_time, - 'end_time': end_time - }) - - trades = [dict(row._mapping) for row in result] + results = query.all() + + trades = [ + { + "id": r.id, "exchange": r.exchange, "symbol": r.symbol, + "timestamp": r.timestamp, "data_type": r.data_type, + "raw_data": r.raw_data, "created_at": r.created_at + } for r in results + ] self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}") return trades diff --git a/docs/modules/database_operations.md b/docs/modules/database_operations.md index 92363d4..fd52f0b 100644 --- a/docs/modules/database_operations.md +++ b/docs/modules/database_operations.md @@ -435,6 +435,8 @@ candle = OHLCVCandle(...) # Create candle object success = db.market_data.upsert_candle(candle) ``` +The entire repository layer has been standardized to use the SQLAlchemy ORM internally, ensuring a consistent, maintainable, and database-agnostic approach. Raw SQL is avoided in favor of type-safe ORM queries. + ## Performance Considerations ### Connection Pooling