- Updated the `MarketDataRepository` and `RawTradeRepository` classes to exclusively utilize SQLAlchemy ORM for all database interactions, enhancing maintainability and type safety. - Removed raw SQL queries in favor of ORM methods, ensuring a consistent and database-agnostic approach across the repository layer. - Revised documentation to reflect these changes, emphasizing the importance of using the ORM for database operations. These modifications improve the overall architecture of the database layer, making it more scalable and easier to manage.
139 lines
5.7 KiB
Python
139 lines
5.7 KiB
Python
"""Repository for market_data table operations."""
|
|
|
|
from datetime import datetime
|
|
from typing import List, Optional, Dict, Any
|
|
|
|
from sqlalchemy import desc
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
|
from ..models import MarketData
|
|
from data.common.data_types import OHLCVCandle
|
|
from .base_repository import BaseRepository, DatabaseOperationError
|
|
|
|
|
|
class MarketDataRepository(BaseRepository):
|
|
"""Repository for market_data table operations."""
|
|
|
|
def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool:
|
|
"""
|
|
Insert or update a candle in the market_data table using the ORM.
|
|
"""
|
|
try:
|
|
with self.get_session() as session:
|
|
|
|
values = {
|
|
'exchange': candle.exchange,
|
|
'symbol': candle.symbol,
|
|
'timeframe': candle.timeframe,
|
|
'timestamp': candle.end_time,
|
|
'open': candle.open,
|
|
'high': candle.high,
|
|
'low': candle.low,
|
|
'close': candle.close,
|
|
'volume': candle.volume,
|
|
'trades_count': candle.trade_count
|
|
}
|
|
|
|
stmt = insert(MarketData).values(values)
|
|
|
|
if force_update:
|
|
update_stmt = stmt.on_conflict_do_update(
|
|
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'],
|
|
set_={
|
|
'open': stmt.excluded.open,
|
|
'high': stmt.excluded.high,
|
|
'low': stmt.excluded.low,
|
|
'close': stmt.excluded.close,
|
|
'volume': stmt.excluded.volume,
|
|
'trades_count': stmt.excluded.trades_count
|
|
}
|
|
)
|
|
action = "Updated"
|
|
final_stmt = update_stmt
|
|
else:
|
|
ignore_stmt = stmt.on_conflict_do_nothing(
|
|
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp']
|
|
)
|
|
action = "Stored"
|
|
final_stmt = ignore_stmt
|
|
|
|
session.execute(final_stmt)
|
|
session.commit()
|
|
|
|
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle.end_time} (force_update={force_update})")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}")
|
|
raise DatabaseOperationError(f"Failed to store candle: {e}")
|
|
|
|
def get_candles(self,
|
|
symbol: str,
|
|
timeframe: str,
|
|
start_time: datetime,
|
|
end_time: datetime,
|
|
exchange: str = "okx") -> List[Dict[str, Any]]:
|
|
"""
|
|
Retrieve candles from the database using the ORM.
|
|
"""
|
|
try:
|
|
with self.get_session() as session:
|
|
query = (
|
|
session.query(MarketData)
|
|
.filter(
|
|
MarketData.exchange == exchange,
|
|
MarketData.symbol == symbol,
|
|
MarketData.timeframe == timeframe,
|
|
MarketData.timestamp >= start_time,
|
|
MarketData.timestamp <= end_time
|
|
)
|
|
.order_by(MarketData.timestamp.asc())
|
|
)
|
|
|
|
results = query.all()
|
|
|
|
candles = [
|
|
{
|
|
"exchange": r.exchange, "symbol": r.symbol, "timeframe": r.timeframe,
|
|
"timestamp": r.timestamp, "open": r.open, "high": r.high,
|
|
"low": r.low, "close": r.close, "volume": r.volume,
|
|
"trades_count": r.trades_count, "created_at": r.created_at
|
|
} for r in results
|
|
]
|
|
|
|
self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}")
|
|
return candles
|
|
|
|
except Exception as e:
|
|
self.log_error(f"Error retrieving candles: {e}")
|
|
raise DatabaseOperationError(f"Failed to retrieve candles: {e}")
|
|
|
|
def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get the latest candle for a symbol and timeframe using the ORM.
|
|
"""
|
|
try:
|
|
with self.get_session() as session:
|
|
latest = (
|
|
session.query(MarketData)
|
|
.filter(
|
|
MarketData.exchange == exchange,
|
|
MarketData.symbol == symbol,
|
|
MarketData.timeframe == timeframe
|
|
)
|
|
.order_by(MarketData.timestamp.desc())
|
|
.first()
|
|
)
|
|
|
|
if latest:
|
|
return {
|
|
"exchange": latest.exchange, "symbol": latest.symbol, "timeframe": latest.timeframe,
|
|
"timestamp": latest.timestamp, "open": latest.open, "high": latest.high,
|
|
"low": latest.low, "close": latest.close, "volume": latest.volume,
|
|
"trades_count": latest.trades_count, "created_at": latest.created_at
|
|
}
|
|
return None
|
|
|
|
except Exception as e:
|
|
self.log_error(f"Error retrieving latest candle for {symbol} {timeframe}: {e}")
|
|
raise DatabaseOperationError(f"Failed to retrieve latest candle: {e}") |