Standardize database operations to use SQLAlchemy ORM
- Updated the `MarketDataRepository` and `RawTradeRepository` classes to exclusively utilize SQLAlchemy ORM for all database interactions, enhancing maintainability and type safety. - Removed raw SQL queries in favor of ORM methods, ensuring a consistent and database-agnostic approach across the repository layer. - Revised documentation to reflect these changes, emphasizing the importance of using the ORM for database operations. These modifications improve the overall architecture of the database layer, making it more scalable and easier to manage.
This commit is contained in:
@@ -2,7 +2,9 @@
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy import text
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
from ..models import MarketData
|
||||
from data.common.data_types import OHLCVCandle
|
||||
@@ -14,65 +16,51 @@ class MarketDataRepository(BaseRepository):
|
||||
|
||||
def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool:
|
||||
"""
|
||||
Insert or update a candle in the market_data table.
|
||||
Insert or update a candle in the market_data table using the ORM.
|
||||
"""
|
||||
try:
|
||||
candle_timestamp = candle.end_time
|
||||
|
||||
with self.get_session() as session:
|
||||
if force_update:
|
||||
query = text("""
|
||||
INSERT INTO market_data (
|
||||
exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timeframe, :timestamp,
|
||||
:open, :high, :low, :close, :volume, :trades_count,
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT (exchange, symbol, timeframe, timestamp)
|
||||
DO UPDATE SET
|
||||
open = EXCLUDED.open,
|
||||
high = EXCLUDED.high,
|
||||
low = EXCLUDED.low,
|
||||
close = EXCLUDED.close,
|
||||
volume = EXCLUDED.volume,
|
||||
trades_count = EXCLUDED.trades_count
|
||||
""")
|
||||
action = "Updated"
|
||||
else:
|
||||
query = text("""
|
||||
INSERT INTO market_data (
|
||||
exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timeframe, :timestamp,
|
||||
:open, :high, :low, :close, :volume, :trades_count,
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT (exchange, symbol, timeframe, timestamp)
|
||||
DO NOTHING
|
||||
""")
|
||||
action = "Stored"
|
||||
|
||||
session.execute(query, {
|
||||
values = {
|
||||
'exchange': candle.exchange,
|
||||
'symbol': candle.symbol,
|
||||
'timeframe': candle.timeframe,
|
||||
'timestamp': candle_timestamp,
|
||||
'open': float(candle.open),
|
||||
'high': float(candle.high),
|
||||
'low': float(candle.low),
|
||||
'close': float(candle.close),
|
||||
'volume': float(candle.volume),
|
||||
'timestamp': candle.end_time,
|
||||
'open': candle.open,
|
||||
'high': candle.high,
|
||||
'low': candle.low,
|
||||
'close': candle.close,
|
||||
'volume': candle.volume,
|
||||
'trades_count': candle.trade_count
|
||||
})
|
||||
}
|
||||
|
||||
stmt = insert(MarketData).values(values)
|
||||
|
||||
if force_update:
|
||||
update_stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'],
|
||||
set_={
|
||||
'open': stmt.excluded.open,
|
||||
'high': stmt.excluded.high,
|
||||
'low': stmt.excluded.low,
|
||||
'close': stmt.excluded.close,
|
||||
'volume': stmt.excluded.volume,
|
||||
'trades_count': stmt.excluded.trades_count
|
||||
}
|
||||
)
|
||||
action = "Updated"
|
||||
final_stmt = update_stmt
|
||||
else:
|
||||
ignore_stmt = stmt.on_conflict_do_nothing(
|
||||
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp']
|
||||
)
|
||||
action = "Stored"
|
||||
final_stmt = ignore_stmt
|
||||
|
||||
session.execute(final_stmt)
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})")
|
||||
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle.end_time} (force_update={force_update})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -86,32 +74,32 @@ class MarketDataRepository(BaseRepository):
|
||||
end_time: datetime,
|
||||
exchange: str = "okx") -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve candles from the database.
|
||||
Retrieve candles from the database using the ORM.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
FROM market_data
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND timeframe = :timeframe
|
||||
AND timestamp >= :start_time
|
||||
AND timestamp <= :end_time
|
||||
ORDER BY timestamp ASC
|
||||
""")
|
||||
query = (
|
||||
session.query(MarketData)
|
||||
.filter(
|
||||
MarketData.exchange == exchange,
|
||||
MarketData.symbol == symbol,
|
||||
MarketData.timeframe == timeframe,
|
||||
MarketData.timestamp >= start_time,
|
||||
MarketData.timestamp <= end_time
|
||||
)
|
||||
.order_by(MarketData.timestamp.asc())
|
||||
)
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
results = query.all()
|
||||
|
||||
candles = [dict(row._mapping) for row in result]
|
||||
candles = [
|
||||
{
|
||||
"exchange": r.exchange, "symbol": r.symbol, "timeframe": r.timeframe,
|
||||
"timestamp": r.timestamp, "open": r.open, "high": r.high,
|
||||
"low": r.low, "close": r.close, "volume": r.volume,
|
||||
"trades_count": r.trades_count, "created_at": r.created_at
|
||||
} for r in results
|
||||
]
|
||||
|
||||
self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}")
|
||||
return candles
|
||||
@@ -122,31 +110,28 @@ class MarketDataRepository(BaseRepository):
|
||||
|
||||
def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get the latest candle for a symbol and timeframe.
|
||||
Get the latest candle for a symbol and timeframe using the ORM.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
FROM market_data
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND timeframe = :timeframe
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe
|
||||
})
|
||||
|
||||
row = result.fetchone()
|
||||
if row:
|
||||
return dict(row._mapping)
|
||||
latest = (
|
||||
session.query(MarketData)
|
||||
.filter(
|
||||
MarketData.exchange == exchange,
|
||||
MarketData.symbol == symbol,
|
||||
MarketData.timeframe == timeframe
|
||||
)
|
||||
.order_by(MarketData.timestamp.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if latest:
|
||||
return {
|
||||
"exchange": latest.exchange, "symbol": latest.symbol, "timeframe": latest.timeframe,
|
||||
"timestamp": latest.timestamp, "open": latest.open, "high": latest.high,
|
||||
"low": latest.low, "close": latest.close, "volume": latest.volume,
|
||||
"trades_count": latest.trades_count, "created_at": latest.created_at
|
||||
}
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Repository for raw_trades table operations."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy import text
|
||||
|
||||
from sqlalchemy import desc
|
||||
|
||||
from ..models import RawTrade
|
||||
from data.base_collector import MarketDataPoint
|
||||
@@ -15,26 +15,18 @@ class RawTradeRepository(BaseRepository):
|
||||
|
||||
def insert_market_data_point(self, data_point: MarketDataPoint) -> bool:
|
||||
"""
|
||||
Insert a market data point into raw_trades table.
|
||||
Insert a market data point into raw_trades table using the ORM.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
INSERT INTO raw_trades (
|
||||
exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
session.execute(query, {
|
||||
'exchange': data_point.exchange,
|
||||
'symbol': data_point.symbol,
|
||||
'timestamp': data_point.timestamp,
|
||||
'data_type': data_point.data_type.value,
|
||||
'raw_data': json.dumps(data_point.data)
|
||||
})
|
||||
|
||||
new_trade = RawTrade(
|
||||
exchange=data_point.exchange,
|
||||
symbol=data_point.symbol,
|
||||
timestamp=data_point.timestamp,
|
||||
data_type=data_point.data_type.value,
|
||||
raw_data=data_point.data
|
||||
)
|
||||
session.add(new_trade)
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}")
|
||||
@@ -51,29 +43,18 @@ class RawTradeRepository(BaseRepository):
|
||||
raw_data: Dict[str, Any],
|
||||
timestamp: Optional[datetime] = None) -> bool:
|
||||
"""
|
||||
Insert raw WebSocket data for debugging purposes.
|
||||
Insert raw WebSocket data for debugging purposes using the ORM.
|
||||
"""
|
||||
try:
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
INSERT INTO raw_trades (
|
||||
exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timestamp': timestamp,
|
||||
'data_type': data_type,
|
||||
'raw_data': json.dumps(raw_data)
|
||||
})
|
||||
|
||||
new_trade = RawTrade(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timestamp=timestamp or datetime.now(datetime.timezone.utc),
|
||||
data_type=data_type,
|
||||
raw_data=raw_data
|
||||
)
|
||||
session.add(new_trade)
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}")
|
||||
@@ -91,34 +72,34 @@ class RawTradeRepository(BaseRepository):
|
||||
exchange: str = "okx",
|
||||
limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve raw trades from the database.
|
||||
Retrieve raw trades from the database using the ORM.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
FROM raw_trades
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND data_type = :data_type
|
||||
AND timestamp >= :start_time
|
||||
AND timestamp <= :end_time
|
||||
ORDER BY timestamp ASC
|
||||
""")
|
||||
query = (
|
||||
session.query(RawTrade)
|
||||
.filter(
|
||||
RawTrade.exchange == exchange,
|
||||
RawTrade.symbol == symbol,
|
||||
RawTrade.data_type == data_type,
|
||||
RawTrade.timestamp >= start_time,
|
||||
RawTrade.timestamp <= end_time
|
||||
)
|
||||
.order_by(RawTrade.timestamp.asc())
|
||||
)
|
||||
|
||||
if limit:
|
||||
query_str = str(query.compile(compile_kwargs={"literal_binds": True})) + f" LIMIT {limit}"
|
||||
query = text(query_str)
|
||||
query = query.limit(limit)
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'data_type': data_type,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
|
||||
trades = [dict(row._mapping) for row in result]
|
||||
results = query.all()
|
||||
|
||||
trades = [
|
||||
{
|
||||
"id": r.id, "exchange": r.exchange, "symbol": r.symbol,
|
||||
"timestamp": r.timestamp, "data_type": r.data_type,
|
||||
"raw_data": r.raw_data, "created_at": r.created_at
|
||||
} for r in results
|
||||
]
|
||||
|
||||
self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}")
|
||||
return trades
|
||||
|
||||
Reference in New Issue
Block a user