Standardize database operations to use SQLAlchemy ORM
- Updated the `MarketDataRepository` and `RawTradeRepository` classes to exclusively utilize SQLAlchemy ORM for all database interactions, enhancing maintainability and type safety. - Removed raw SQL queries in favor of ORM methods, ensuring a consistent and database-agnostic approach across the repository layer. - Revised documentation to reflect these changes, emphasizing the importance of using the ORM for database operations. These modifications improve the overall architecture of the database layer, making it more scalable and easier to manage.
This commit is contained in:
parent
028371a0e1
commit
b30c16bc33
@ -65,7 +65,7 @@ The platform is a **monolithic application** built with Python, designed for rap
|
||||
- **Logging**: A unified logging system is available in `utils/logger.py` and should be used across all components for consistent output.
|
||||
- **Type Hinting**: Mandatory for all function signatures (parameters and return values) for clarity and static analysis.
|
||||
- **Error Handling**: Custom, specific exceptions should be defined (e.g., `DataCollectorError`). Use `try...except` blocks to handle potential failures gracefully and provide informative error messages.
|
||||
- **Database Access**: A `DatabaseManager` in `database/connection.py` provides a centralized way to handle database sessions and connections. All database operations should ideally go through an operations/repository layer.
|
||||
- **Database Access**: All database operations must go through the repository layer, accessible via `database.operations.get_database_operations()`. The repositories exclusively use the **SQLAlchemy ORM** for all queries to ensure type safety, maintainability, and consistency. Raw SQL is strictly forbidden in the repository layer to maintain database-agnostic flexibility.
|
||||
|
||||
## 4. Current Implementation Status
|
||||
|
||||
|
||||
@ -2,7 +2,9 @@
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy import text
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
from ..models import MarketData
|
||||
from data.common.data_types import OHLCVCandle
|
||||
@ -14,65 +16,51 @@ class MarketDataRepository(BaseRepository):
|
||||
|
||||
def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool:
|
||||
"""
|
||||
Insert or update a candle in the market_data table.
|
||||
Insert or update a candle in the market_data table using the ORM.
|
||||
"""
|
||||
try:
|
||||
candle_timestamp = candle.end_time
|
||||
|
||||
with self.get_session() as session:
|
||||
if force_update:
|
||||
query = text("""
|
||||
INSERT INTO market_data (
|
||||
exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timeframe, :timestamp,
|
||||
:open, :high, :low, :close, :volume, :trades_count,
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT (exchange, symbol, timeframe, timestamp)
|
||||
DO UPDATE SET
|
||||
open = EXCLUDED.open,
|
||||
high = EXCLUDED.high,
|
||||
low = EXCLUDED.low,
|
||||
close = EXCLUDED.close,
|
||||
volume = EXCLUDED.volume,
|
||||
trades_count = EXCLUDED.trades_count
|
||||
""")
|
||||
action = "Updated"
|
||||
else:
|
||||
query = text("""
|
||||
INSERT INTO market_data (
|
||||
exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timeframe, :timestamp,
|
||||
:open, :high, :low, :close, :volume, :trades_count,
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT (exchange, symbol, timeframe, timestamp)
|
||||
DO NOTHING
|
||||
""")
|
||||
action = "Stored"
|
||||
|
||||
session.execute(query, {
|
||||
values = {
|
||||
'exchange': candle.exchange,
|
||||
'symbol': candle.symbol,
|
||||
'timeframe': candle.timeframe,
|
||||
'timestamp': candle_timestamp,
|
||||
'open': float(candle.open),
|
||||
'high': float(candle.high),
|
||||
'low': float(candle.low),
|
||||
'close': float(candle.close),
|
||||
'volume': float(candle.volume),
|
||||
'timestamp': candle.end_time,
|
||||
'open': candle.open,
|
||||
'high': candle.high,
|
||||
'low': candle.low,
|
||||
'close': candle.close,
|
||||
'volume': candle.volume,
|
||||
'trades_count': candle.trade_count
|
||||
})
|
||||
}
|
||||
|
||||
stmt = insert(MarketData).values(values)
|
||||
|
||||
if force_update:
|
||||
update_stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'],
|
||||
set_={
|
||||
'open': stmt.excluded.open,
|
||||
'high': stmt.excluded.high,
|
||||
'low': stmt.excluded.low,
|
||||
'close': stmt.excluded.close,
|
||||
'volume': stmt.excluded.volume,
|
||||
'trades_count': stmt.excluded.trades_count
|
||||
}
|
||||
)
|
||||
action = "Updated"
|
||||
final_stmt = update_stmt
|
||||
else:
|
||||
ignore_stmt = stmt.on_conflict_do_nothing(
|
||||
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp']
|
||||
)
|
||||
action = "Stored"
|
||||
final_stmt = ignore_stmt
|
||||
|
||||
session.execute(final_stmt)
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})")
|
||||
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle.end_time} (force_update={force_update})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@ -86,32 +74,32 @@ class MarketDataRepository(BaseRepository):
|
||||
end_time: datetime,
|
||||
exchange: str = "okx") -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve candles from the database.
|
||||
Retrieve candles from the database using the ORM.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
FROM market_data
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND timeframe = :timeframe
|
||||
AND timestamp >= :start_time
|
||||
AND timestamp <= :end_time
|
||||
ORDER BY timestamp ASC
|
||||
""")
|
||||
query = (
|
||||
session.query(MarketData)
|
||||
.filter(
|
||||
MarketData.exchange == exchange,
|
||||
MarketData.symbol == symbol,
|
||||
MarketData.timeframe == timeframe,
|
||||
MarketData.timestamp >= start_time,
|
||||
MarketData.timestamp <= end_time
|
||||
)
|
||||
.order_by(MarketData.timestamp.asc())
|
||||
)
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
results = query.all()
|
||||
|
||||
candles = [dict(row._mapping) for row in result]
|
||||
candles = [
|
||||
{
|
||||
"exchange": r.exchange, "symbol": r.symbol, "timeframe": r.timeframe,
|
||||
"timestamp": r.timestamp, "open": r.open, "high": r.high,
|
||||
"low": r.low, "close": r.close, "volume": r.volume,
|
||||
"trades_count": r.trades_count, "created_at": r.created_at
|
||||
} for r in results
|
||||
]
|
||||
|
||||
self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}")
|
||||
return candles
|
||||
@ -122,31 +110,28 @@ class MarketDataRepository(BaseRepository):
|
||||
|
||||
def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get the latest candle for a symbol and timeframe.
|
||||
Get the latest candle for a symbol and timeframe using the ORM.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
FROM market_data
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND timeframe = :timeframe
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe
|
||||
})
|
||||
|
||||
row = result.fetchone()
|
||||
if row:
|
||||
return dict(row._mapping)
|
||||
latest = (
|
||||
session.query(MarketData)
|
||||
.filter(
|
||||
MarketData.exchange == exchange,
|
||||
MarketData.symbol == symbol,
|
||||
MarketData.timeframe == timeframe
|
||||
)
|
||||
.order_by(MarketData.timestamp.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if latest:
|
||||
return {
|
||||
"exchange": latest.exchange, "symbol": latest.symbol, "timeframe": latest.timeframe,
|
||||
"timestamp": latest.timestamp, "open": latest.open, "high": latest.high,
|
||||
"low": latest.low, "close": latest.close, "volume": latest.volume,
|
||||
"trades_count": latest.trades_count, "created_at": latest.created_at
|
||||
}
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""Repository for raw_trades table operations."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy import text
|
||||
|
||||
from sqlalchemy import desc
|
||||
|
||||
from ..models import RawTrade
|
||||
from data.base_collector import MarketDataPoint
|
||||
@ -15,26 +15,18 @@ class RawTradeRepository(BaseRepository):
|
||||
|
||||
def insert_market_data_point(self, data_point: MarketDataPoint) -> bool:
|
||||
"""
|
||||
Insert a market data point into raw_trades table.
|
||||
Insert a market data point into raw_trades table using the ORM.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
INSERT INTO raw_trades (
|
||||
exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
session.execute(query, {
|
||||
'exchange': data_point.exchange,
|
||||
'symbol': data_point.symbol,
|
||||
'timestamp': data_point.timestamp,
|
||||
'data_type': data_point.data_type.value,
|
||||
'raw_data': json.dumps(data_point.data)
|
||||
})
|
||||
|
||||
new_trade = RawTrade(
|
||||
exchange=data_point.exchange,
|
||||
symbol=data_point.symbol,
|
||||
timestamp=data_point.timestamp,
|
||||
data_type=data_point.data_type.value,
|
||||
raw_data=data_point.data
|
||||
)
|
||||
session.add(new_trade)
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}")
|
||||
@ -51,29 +43,18 @@ class RawTradeRepository(BaseRepository):
|
||||
raw_data: Dict[str, Any],
|
||||
timestamp: Optional[datetime] = None) -> bool:
|
||||
"""
|
||||
Insert raw WebSocket data for debugging purposes.
|
||||
Insert raw WebSocket data for debugging purposes using the ORM.
|
||||
"""
|
||||
try:
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
INSERT INTO raw_trades (
|
||||
exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timestamp': timestamp,
|
||||
'data_type': data_type,
|
||||
'raw_data': json.dumps(raw_data)
|
||||
})
|
||||
|
||||
new_trade = RawTrade(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timestamp=timestamp or datetime.now(datetime.timezone.utc),
|
||||
data_type=data_type,
|
||||
raw_data=raw_data
|
||||
)
|
||||
session.add(new_trade)
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}")
|
||||
@ -91,34 +72,34 @@ class RawTradeRepository(BaseRepository):
|
||||
exchange: str = "okx",
|
||||
limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve raw trades from the database.
|
||||
Retrieve raw trades from the database using the ORM.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
FROM raw_trades
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND data_type = :data_type
|
||||
AND timestamp >= :start_time
|
||||
AND timestamp <= :end_time
|
||||
ORDER BY timestamp ASC
|
||||
""")
|
||||
query = (
|
||||
session.query(RawTrade)
|
||||
.filter(
|
||||
RawTrade.exchange == exchange,
|
||||
RawTrade.symbol == symbol,
|
||||
RawTrade.data_type == data_type,
|
||||
RawTrade.timestamp >= start_time,
|
||||
RawTrade.timestamp <= end_time
|
||||
)
|
||||
.order_by(RawTrade.timestamp.asc())
|
||||
)
|
||||
|
||||
if limit:
|
||||
query_str = str(query.compile(compile_kwargs={"literal_binds": True})) + f" LIMIT {limit}"
|
||||
query = text(query_str)
|
||||
query = query.limit(limit)
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'data_type': data_type,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
|
||||
trades = [dict(row._mapping) for row in result]
|
||||
results = query.all()
|
||||
|
||||
trades = [
|
||||
{
|
||||
"id": r.id, "exchange": r.exchange, "symbol": r.symbol,
|
||||
"timestamp": r.timestamp, "data_type": r.data_type,
|
||||
"raw_data": r.raw_data, "created_at": r.created_at
|
||||
} for r in results
|
||||
]
|
||||
|
||||
self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}")
|
||||
return trades
|
||||
|
||||
@ -435,6 +435,8 @@ candle = OHLCVCandle(...) # Create candle object
|
||||
success = db.market_data.upsert_candle(candle)
|
||||
```
|
||||
|
||||
The entire repository layer has been standardized to use the SQLAlchemy ORM internally, ensuring a consistent, maintainable, and database-agnostic approach. Raw SQL is avoided in favor of type-safe ORM queries.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Connection Pooling
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user