From b30c16bc3379c3e0ddc96aae960802d23a1f78a3 Mon Sep 17 00:00:00 2001 From: "Vasily.onl" Date: Fri, 6 Jun 2025 22:07:19 +0800 Subject: [PATCH] Standardize database operations to use SQLAlchemy ORM - Updated the `MarketDataRepository` and `RawTradeRepository` classes to exclusively utilize SQLAlchemy ORM for all database interactions, enhancing maintainability and type safety. - Removed raw SQL queries in favor of ORM methods, ensuring a consistent and database-agnostic approach across the repository layer. - Revised documentation to reflect these changes, emphasizing the importance of using the ORM for database operations. These modifications improve the overall architecture of the database layer, making it more scalable and easier to manage. --- CONTEXT.md | 2 +- .../repositories/market_data_repository.py | 169 ++++++++---------- database/repositories/raw_trade_repository.py | 103 +++++------ docs/modules/database_operations.md | 2 + 4 files changed, 122 insertions(+), 154 deletions(-) diff --git a/CONTEXT.md b/CONTEXT.md index c7873b9..1fc83cc 100644 --- a/CONTEXT.md +++ b/CONTEXT.md @@ -65,7 +65,7 @@ The platform is a **monolithic application** built with Python, designed for rap - **Logging**: A unified logging system is available in `utils/logger.py` and should be used across all components for consistent output. - **Type Hinting**: Mandatory for all function signatures (parameters and return values) for clarity and static analysis. - **Error Handling**: Custom, specific exceptions should be defined (e.g., `DataCollectorError`). Use `try...except` blocks to handle potential failures gracefully and provide informative error messages. -- **Database Access**: A `DatabaseManager` in `database/connection.py` provides a centralized way to handle database sessions and connections. All database operations should ideally go through an operations/repository layer. +- **Database Access**: All database operations must go through the repository layer, accessible via `database.operations.get_database_operations()`. The repositories exclusively use the **SQLAlchemy ORM** for all queries to ensure type safety, maintainability, and consistency. Raw SQL is strictly forbidden in the repository layer to maintain database-agnostic flexibility. ## 4. Current Implementation Status diff --git a/database/repositories/market_data_repository.py b/database/repositories/market_data_repository.py index af65e13..e151682 100644 --- a/database/repositories/market_data_repository.py +++ b/database/repositories/market_data_repository.py @@ -2,7 +2,9 @@ from datetime import datetime from typing import List, Optional, Dict, Any -from sqlalchemy import text + +from sqlalchemy import desc +from sqlalchemy.dialects.postgresql import insert from ..models import MarketData from data.common.data_types import OHLCVCandle @@ -14,65 +16,51 @@ class MarketDataRepository(BaseRepository): def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool: """ - Insert or update a candle in the market_data table. + Insert or update a candle in the market_data table using the ORM. """ try: - candle_timestamp = candle.end_time - with self.get_session() as session: - if force_update: - query = text(""" - INSERT INTO market_data ( - exchange, symbol, timeframe, timestamp, - open, high, low, close, volume, trades_count, - created_at - ) VALUES ( - :exchange, :symbol, :timeframe, :timestamp, - :open, :high, :low, :close, :volume, :trades_count, - NOW() - ) - ON CONFLICT (exchange, symbol, timeframe, timestamp) - DO UPDATE SET - open = EXCLUDED.open, - high = EXCLUDED.high, - low = EXCLUDED.low, - close = EXCLUDED.close, - volume = EXCLUDED.volume, - trades_count = EXCLUDED.trades_count - """) - action = "Updated" - else: - query = text(""" - INSERT INTO market_data ( - exchange, symbol, timeframe, timestamp, - open, high, low, close, volume, trades_count, - created_at - ) VALUES ( - :exchange, :symbol, :timeframe, :timestamp, - :open, :high, :low, :close, :volume, :trades_count, - NOW() - ) - ON CONFLICT (exchange, symbol, timeframe, timestamp) - DO NOTHING - """) - action = "Stored" - session.execute(query, { + values = { 'exchange': candle.exchange, 'symbol': candle.symbol, 'timeframe': candle.timeframe, - 'timestamp': candle_timestamp, - 'open': float(candle.open), - 'high': float(candle.high), - 'low': float(candle.low), - 'close': float(candle.close), - 'volume': float(candle.volume), + 'timestamp': candle.end_time, + 'open': candle.open, + 'high': candle.high, + 'low': candle.low, + 'close': candle.close, + 'volume': candle.volume, 'trades_count': candle.trade_count - }) + } + stmt = insert(MarketData).values(values) + + if force_update: + update_stmt = stmt.on_conflict_do_update( + index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'], + set_={ + 'open': stmt.excluded.open, + 'high': stmt.excluded.high, + 'low': stmt.excluded.low, + 'close': stmt.excluded.close, + 'volume': stmt.excluded.volume, + 'trades_count': stmt.excluded.trades_count + } + ) + action = "Updated" + final_stmt = update_stmt + else: + ignore_stmt = stmt.on_conflict_do_nothing( + index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'] + ) + action = "Stored" + final_stmt = ignore_stmt + + session.execute(final_stmt) session.commit() - self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})") + self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle.end_time} (force_update={force_update})") return True except Exception as e: @@ -86,32 +74,32 @@ class MarketDataRepository(BaseRepository): end_time: datetime, exchange: str = "okx") -> List[Dict[str, Any]]: """ - Retrieve candles from the database. + Retrieve candles from the database using the ORM. """ try: with self.get_session() as session: - query = text(""" - SELECT exchange, symbol, timeframe, timestamp, - open, high, low, close, volume, trades_count, - created_at - FROM market_data - WHERE exchange = :exchange - AND symbol = :symbol - AND timeframe = :timeframe - AND timestamp >= :start_time - AND timestamp <= :end_time - ORDER BY timestamp ASC - """) + query = ( + session.query(MarketData) + .filter( + MarketData.exchange == exchange, + MarketData.symbol == symbol, + MarketData.timeframe == timeframe, + MarketData.timestamp >= start_time, + MarketData.timestamp <= end_time + ) + .order_by(MarketData.timestamp.asc()) + ) - result = session.execute(query, { - 'exchange': exchange, - 'symbol': symbol, - 'timeframe': timeframe, - 'start_time': start_time, - 'end_time': end_time - }) + results = query.all() - candles = [dict(row._mapping) for row in result] + candles = [ + { + "exchange": r.exchange, "symbol": r.symbol, "timeframe": r.timeframe, + "timestamp": r.timestamp, "open": r.open, "high": r.high, + "low": r.low, "close": r.close, "volume": r.volume, + "trades_count": r.trades_count, "created_at": r.created_at + } for r in results + ] self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}") return candles @@ -122,31 +110,28 @@ class MarketDataRepository(BaseRepository): def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]: """ - Get the latest candle for a symbol and timeframe. + Get the latest candle for a symbol and timeframe using the ORM. """ try: with self.get_session() as session: - query = text(""" - SELECT exchange, symbol, timeframe, timestamp, - open, high, low, close, volume, trades_count, - created_at - FROM market_data - WHERE exchange = :exchange - AND symbol = :symbol - AND timeframe = :timeframe - ORDER BY timestamp DESC - LIMIT 1 - """) - - result = session.execute(query, { - 'exchange': exchange, - 'symbol': symbol, - 'timeframe': timeframe - }) - - row = result.fetchone() - if row: - return dict(row._mapping) + latest = ( + session.query(MarketData) + .filter( + MarketData.exchange == exchange, + MarketData.symbol == symbol, + MarketData.timeframe == timeframe + ) + .order_by(MarketData.timestamp.desc()) + .first() + ) + + if latest: + return { + "exchange": latest.exchange, "symbol": latest.symbol, "timeframe": latest.timeframe, + "timestamp": latest.timestamp, "open": latest.open, "high": latest.high, + "low": latest.low, "close": latest.close, "volume": latest.volume, + "trades_count": latest.trades_count, "created_at": latest.created_at + } return None except Exception as e: diff --git a/database/repositories/raw_trade_repository.py b/database/repositories/raw_trade_repository.py index d30547c..cdeaa85 100644 --- a/database/repositories/raw_trade_repository.py +++ b/database/repositories/raw_trade_repository.py @@ -1,9 +1,9 @@ """Repository for raw_trades table operations.""" -import json from datetime import datetime from typing import Dict, Any, Optional, List -from sqlalchemy import text + +from sqlalchemy import desc from ..models import RawTrade from data.base_collector import MarketDataPoint @@ -15,26 +15,18 @@ class RawTradeRepository(BaseRepository): def insert_market_data_point(self, data_point: MarketDataPoint) -> bool: """ - Insert a market data point into raw_trades table. + Insert a market data point into raw_trades table using the ORM. """ try: with self.get_session() as session: - query = text(""" - INSERT INTO raw_trades ( - exchange, symbol, timestamp, data_type, raw_data, created_at - ) VALUES ( - :exchange, :symbol, :timestamp, :data_type, :raw_data, NOW() - ) - """) - - session.execute(query, { - 'exchange': data_point.exchange, - 'symbol': data_point.symbol, - 'timestamp': data_point.timestamp, - 'data_type': data_point.data_type.value, - 'raw_data': json.dumps(data_point.data) - }) - + new_trade = RawTrade( + exchange=data_point.exchange, + symbol=data_point.symbol, + timestamp=data_point.timestamp, + data_type=data_point.data_type.value, + raw_data=data_point.data + ) + session.add(new_trade) session.commit() self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}") @@ -51,29 +43,18 @@ class RawTradeRepository(BaseRepository): raw_data: Dict[str, Any], timestamp: Optional[datetime] = None) -> bool: """ - Insert raw WebSocket data for debugging purposes. + Insert raw WebSocket data for debugging purposes using the ORM. """ try: - if timestamp is None: - timestamp = datetime.now() - with self.get_session() as session: - query = text(""" - INSERT INTO raw_trades ( - exchange, symbol, timestamp, data_type, raw_data, created_at - ) VALUES ( - :exchange, :symbol, :timestamp, :data_type, :raw_data, NOW() - ) - """) - - session.execute(query, { - 'exchange': exchange, - 'symbol': symbol, - 'timestamp': timestamp, - 'data_type': data_type, - 'raw_data': json.dumps(raw_data) - }) - + new_trade = RawTrade( + exchange=exchange, + symbol=symbol, + timestamp=timestamp or datetime.now(datetime.timezone.utc), + data_type=data_type, + raw_data=raw_data + ) + session.add(new_trade) session.commit() self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}") @@ -91,34 +72,34 @@ class RawTradeRepository(BaseRepository): exchange: str = "okx", limit: Optional[int] = None) -> List[Dict[str, Any]]: """ - Retrieve raw trades from the database. + Retrieve raw trades from the database using the ORM. """ try: with self.get_session() as session: - query = text(""" - SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at - FROM raw_trades - WHERE exchange = :exchange - AND symbol = :symbol - AND data_type = :data_type - AND timestamp >= :start_time - AND timestamp <= :end_time - ORDER BY timestamp ASC - """) + query = ( + session.query(RawTrade) + .filter( + RawTrade.exchange == exchange, + RawTrade.symbol == symbol, + RawTrade.data_type == data_type, + RawTrade.timestamp >= start_time, + RawTrade.timestamp <= end_time + ) + .order_by(RawTrade.timestamp.asc()) + ) if limit: - query_str = str(query.compile(compile_kwargs={"literal_binds": True})) + f" LIMIT {limit}" - query = text(query_str) + query = query.limit(limit) - result = session.execute(query, { - 'exchange': exchange, - 'symbol': symbol, - 'data_type': data_type, - 'start_time': start_time, - 'end_time': end_time - }) - - trades = [dict(row._mapping) for row in result] + results = query.all() + + trades = [ + { + "id": r.id, "exchange": r.exchange, "symbol": r.symbol, + "timestamp": r.timestamp, "data_type": r.data_type, + "raw_data": r.raw_data, "created_at": r.created_at + } for r in results + ] self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}") return trades diff --git a/docs/modules/database_operations.md b/docs/modules/database_operations.md index 92363d4..fd52f0b 100644 --- a/docs/modules/database_operations.md +++ b/docs/modules/database_operations.md @@ -435,6 +435,8 @@ candle = OHLCVCandle(...) # Create candle object success = db.market_data.upsert_candle(candle) ``` +The entire repository layer has been standardized to use the SQLAlchemy ORM internally, ensuring a consistent, maintainable, and database-agnostic approach. Raw SQL is avoided in favor of type-safe ORM queries. + ## Performance Considerations ### Connection Pooling