Standardize database operations to use SQLAlchemy ORM

- Updated the `MarketDataRepository` and `RawTradeRepository` classes to exclusively utilize SQLAlchemy ORM for all database interactions, enhancing maintainability and type safety.
- Removed raw SQL queries in favor of ORM methods, ensuring a consistent and database-agnostic approach across the repository layer.
- Revised documentation to reflect these changes, emphasizing the importance of using the ORM for database operations.

These modifications improve the overall architecture of the database layer, making it more scalable and easier to manage.
This commit is contained in:
Vasily.onl 2025-06-06 22:07:19 +08:00
parent 028371a0e1
commit b30c16bc33
4 changed files with 122 additions and 154 deletions

View File

@ -65,7 +65,7 @@ The platform is a **monolithic application** built with Python, designed for rap
- **Logging**: A unified logging system is available in `utils/logger.py` and should be used across all components for consistent output.
- **Type Hinting**: Mandatory for all function signatures (parameters and return values) for clarity and static analysis.
- **Error Handling**: Custom, specific exceptions should be defined (e.g., `DataCollectorError`). Use `try...except` blocks to handle potential failures gracefully and provide informative error messages.
- **Database Access**: A `DatabaseManager` in `database/connection.py` provides a centralized way to handle database sessions and connections. All database operations should ideally go through an operations/repository layer.
- **Database Access**: All database operations must go through the repository layer, accessible via `database.operations.get_database_operations()`. The repositories exclusively use the **SQLAlchemy ORM** for all queries to ensure type safety, maintainability, and consistency. Raw SQL is strictly forbidden in the repository layer to maintain database-agnostic flexibility.
## 4. Current Implementation Status

View File

@ -2,7 +2,9 @@
from datetime import datetime
from typing import List, Optional, Dict, Any
from sqlalchemy import text
from sqlalchemy import desc
from sqlalchemy.dialects.postgresql import insert
from ..models import MarketData
from data.common.data_types import OHLCVCandle
@ -14,65 +16,51 @@ class MarketDataRepository(BaseRepository):
def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool:
"""
Insert or update a candle in the market_data table.
Insert or update a candle in the market_data table using the ORM.
"""
try:
candle_timestamp = candle.end_time
with self.get_session() as session:
if force_update:
query = text("""
INSERT INTO market_data (
exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
) VALUES (
:exchange, :symbol, :timeframe, :timestamp,
:open, :high, :low, :close, :volume, :trades_count,
NOW()
)
ON CONFLICT (exchange, symbol, timeframe, timestamp)
DO UPDATE SET
open = EXCLUDED.open,
high = EXCLUDED.high,
low = EXCLUDED.low,
close = EXCLUDED.close,
volume = EXCLUDED.volume,
trades_count = EXCLUDED.trades_count
""")
action = "Updated"
else:
query = text("""
INSERT INTO market_data (
exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
) VALUES (
:exchange, :symbol, :timeframe, :timestamp,
:open, :high, :low, :close, :volume, :trades_count,
NOW()
)
ON CONFLICT (exchange, symbol, timeframe, timestamp)
DO NOTHING
""")
action = "Stored"
session.execute(query, {
values = {
'exchange': candle.exchange,
'symbol': candle.symbol,
'timeframe': candle.timeframe,
'timestamp': candle_timestamp,
'open': float(candle.open),
'high': float(candle.high),
'low': float(candle.low),
'close': float(candle.close),
'volume': float(candle.volume),
'timestamp': candle.end_time,
'open': candle.open,
'high': candle.high,
'low': candle.low,
'close': candle.close,
'volume': candle.volume,
'trades_count': candle.trade_count
})
}
stmt = insert(MarketData).values(values)
if force_update:
update_stmt = stmt.on_conflict_do_update(
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'],
set_={
'open': stmt.excluded.open,
'high': stmt.excluded.high,
'low': stmt.excluded.low,
'close': stmt.excluded.close,
'volume': stmt.excluded.volume,
'trades_count': stmt.excluded.trades_count
}
)
action = "Updated"
final_stmt = update_stmt
else:
ignore_stmt = stmt.on_conflict_do_nothing(
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp']
)
action = "Stored"
final_stmt = ignore_stmt
session.execute(final_stmt)
session.commit()
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})")
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle.end_time} (force_update={force_update})")
return True
except Exception as e:
@ -86,32 +74,32 @@ class MarketDataRepository(BaseRepository):
end_time: datetime,
exchange: str = "okx") -> List[Dict[str, Any]]:
"""
Retrieve candles from the database.
Retrieve candles from the database using the ORM.
"""
try:
with self.get_session() as session:
query = text("""
SELECT exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
FROM market_data
WHERE exchange = :exchange
AND symbol = :symbol
AND timeframe = :timeframe
AND timestamp >= :start_time
AND timestamp <= :end_time
ORDER BY timestamp ASC
""")
query = (
session.query(MarketData)
.filter(
MarketData.exchange == exchange,
MarketData.symbol == symbol,
MarketData.timeframe == timeframe,
MarketData.timestamp >= start_time,
MarketData.timestamp <= end_time
)
.order_by(MarketData.timestamp.asc())
)
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timeframe': timeframe,
'start_time': start_time,
'end_time': end_time
})
results = query.all()
candles = [dict(row._mapping) for row in result]
candles = [
{
"exchange": r.exchange, "symbol": r.symbol, "timeframe": r.timeframe,
"timestamp": r.timestamp, "open": r.open, "high": r.high,
"low": r.low, "close": r.close, "volume": r.volume,
"trades_count": r.trades_count, "created_at": r.created_at
} for r in results
]
self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}")
return candles
@ -122,31 +110,28 @@ class MarketDataRepository(BaseRepository):
def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]:
"""
Get the latest candle for a symbol and timeframe.
Get the latest candle for a symbol and timeframe using the ORM.
"""
try:
with self.get_session() as session:
query = text("""
SELECT exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
FROM market_data
WHERE exchange = :exchange
AND symbol = :symbol
AND timeframe = :timeframe
ORDER BY timestamp DESC
LIMIT 1
""")
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timeframe': timeframe
})
row = result.fetchone()
if row:
return dict(row._mapping)
latest = (
session.query(MarketData)
.filter(
MarketData.exchange == exchange,
MarketData.symbol == symbol,
MarketData.timeframe == timeframe
)
.order_by(MarketData.timestamp.desc())
.first()
)
if latest:
return {
"exchange": latest.exchange, "symbol": latest.symbol, "timeframe": latest.timeframe,
"timestamp": latest.timestamp, "open": latest.open, "high": latest.high,
"low": latest.low, "close": latest.close, "volume": latest.volume,
"trades_count": latest.trades_count, "created_at": latest.created_at
}
return None
except Exception as e:

View File

@ -1,9 +1,9 @@
"""Repository for raw_trades table operations."""
import json
from datetime import datetime
from typing import Dict, Any, Optional, List
from sqlalchemy import text
from sqlalchemy import desc
from ..models import RawTrade
from data.base_collector import MarketDataPoint
@ -15,26 +15,18 @@ class RawTradeRepository(BaseRepository):
def insert_market_data_point(self, data_point: MarketDataPoint) -> bool:
"""
Insert a market data point into raw_trades table.
Insert a market data point into raw_trades table using the ORM.
"""
try:
with self.get_session() as session:
query = text("""
INSERT INTO raw_trades (
exchange, symbol, timestamp, data_type, raw_data, created_at
) VALUES (
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
)
""")
session.execute(query, {
'exchange': data_point.exchange,
'symbol': data_point.symbol,
'timestamp': data_point.timestamp,
'data_type': data_point.data_type.value,
'raw_data': json.dumps(data_point.data)
})
new_trade = RawTrade(
exchange=data_point.exchange,
symbol=data_point.symbol,
timestamp=data_point.timestamp,
data_type=data_point.data_type.value,
raw_data=data_point.data
)
session.add(new_trade)
session.commit()
self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}")
@ -51,29 +43,18 @@ class RawTradeRepository(BaseRepository):
raw_data: Dict[str, Any],
timestamp: Optional[datetime] = None) -> bool:
"""
Insert raw WebSocket data for debugging purposes.
Insert raw WebSocket data for debugging purposes using the ORM.
"""
try:
if timestamp is None:
timestamp = datetime.now()
with self.get_session() as session:
query = text("""
INSERT INTO raw_trades (
exchange, symbol, timestamp, data_type, raw_data, created_at
) VALUES (
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
)
""")
session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timestamp': timestamp,
'data_type': data_type,
'raw_data': json.dumps(raw_data)
})
new_trade = RawTrade(
exchange=exchange,
symbol=symbol,
timestamp=timestamp or datetime.now(datetime.timezone.utc),
data_type=data_type,
raw_data=raw_data
)
session.add(new_trade)
session.commit()
self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}")
@ -91,34 +72,34 @@ class RawTradeRepository(BaseRepository):
exchange: str = "okx",
limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Retrieve raw trades from the database.
Retrieve raw trades from the database using the ORM.
"""
try:
with self.get_session() as session:
query = text("""
SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at
FROM raw_trades
WHERE exchange = :exchange
AND symbol = :symbol
AND data_type = :data_type
AND timestamp >= :start_time
AND timestamp <= :end_time
ORDER BY timestamp ASC
""")
query = (
session.query(RawTrade)
.filter(
RawTrade.exchange == exchange,
RawTrade.symbol == symbol,
RawTrade.data_type == data_type,
RawTrade.timestamp >= start_time,
RawTrade.timestamp <= end_time
)
.order_by(RawTrade.timestamp.asc())
)
if limit:
query_str = str(query.compile(compile_kwargs={"literal_binds": True})) + f" LIMIT {limit}"
query = text(query_str)
query = query.limit(limit)
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'data_type': data_type,
'start_time': start_time,
'end_time': end_time
})
trades = [dict(row._mapping) for row in result]
results = query.all()
trades = [
{
"id": r.id, "exchange": r.exchange, "symbol": r.symbol,
"timestamp": r.timestamp, "data_type": r.data_type,
"raw_data": r.raw_data, "created_at": r.created_at
} for r in results
]
self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}")
return trades

View File

@ -435,6 +435,8 @@ candle = OHLCVCandle(...) # Create candle object
success = db.market_data.upsert_candle(candle)
```
The entire repository layer has been standardized to use the SQLAlchemy ORM internally, ensuring a consistent, maintainable, and database-agnostic approach. Raw SQL is avoided in favor of type-safe ORM queries.
## Performance Considerations
### Connection Pooling