TCPDashboard/database/repositories/market_data_repository.py

273 lines
12 KiB
Python
Raw Normal View History

"""Repository for market_data table operations."""
from datetime import datetime
from typing import List, Optional, Dict, Any
import pandas as pd
from sqlalchemy import desc
from sqlalchemy.dialects.postgresql import insert
from ..models import MarketData
from data.common.data_types import OHLCVCandle
from .base_repository import BaseRepository, DatabaseOperationError
2025-06-13 16:49:29 +08:00
from tqdm import tqdm
class MarketDataRepository(BaseRepository):
"""Repository for market_data table operations."""
def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool:
"""
Insert or update a candle in the market_data table using the ORM.
"""
try:
with self.get_session() as session:
values = {
'exchange': candle.exchange,
'symbol': candle.symbol,
'timeframe': candle.timeframe,
'timestamp': candle.end_time,
'open': candle.open,
'high': candle.high,
'low': candle.low,
'close': candle.close,
'volume': candle.volume,
'trades_count': candle.trade_count
}
stmt = insert(MarketData).values(values)
if force_update:
update_stmt = stmt.on_conflict_do_update(
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'],
set_={
'open': stmt.excluded.open,
'high': stmt.excluded.high,
'low': stmt.excluded.low,
'close': stmt.excluded.close,
'volume': stmt.excluded.volume,
'trades_count': stmt.excluded.trades_count
}
)
action = "Updated"
final_stmt = update_stmt
else:
ignore_stmt = stmt.on_conflict_do_nothing(
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp']
)
action = "Stored"
final_stmt = ignore_stmt
session.execute(final_stmt)
session.commit()
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle.end_time} (force_update={force_update})")
return True
except Exception as e:
self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}")
raise DatabaseOperationError(f"Failed to store candle: {e}")
2025-06-13 16:49:29 +08:00
def upsert_candles_batch(self, candles: List[OHLCVCandle], force_update: bool = False, batch_size: int = 1000) -> int:
"""
Insert or update multiple candles in the market_data table in batches.
"""
total_processed = 0
try:
for i in tqdm(range(0, len(candles), batch_size), desc="Inserting candles in batches"):
batch = candles[i:i + batch_size]
values = [
{
'exchange': candle.exchange,
'symbol': candle.symbol,
'timeframe': candle.timeframe,
'timestamp': candle.end_time,
'open': candle.open,
'high': candle.high,
'low': candle.low,
'close': candle.close,
'volume': candle.volume,
'trades_count': candle.trade_count
}
for candle in batch
]
with self.get_session() as session:
stmt = insert(MarketData).values(values)
if force_update:
final_stmt = stmt.on_conflict_do_update(
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'],
set_={
'open': stmt.excluded.open,
'high': stmt.excluded.high,
'low': stmt.excluded.low,
'close': stmt.excluded.close,
'volume': stmt.excluded.volume,
'trades_count': stmt.excluded.trades_count
}
)
action = "Updated"
else:
final_stmt = stmt.on_conflict_do_nothing(
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp']
)
action = "Stored"
session.execute(final_stmt)
session.commit()
total_processed += len(batch)
self.log_debug(f"{action} {len(batch)} candles in batch. Total processed: {total_processed}")
return total_processed
except Exception as e:
self.log_error(f"Error storing candles in batch: {e}")
raise DatabaseOperationError(f"Failed to store candles in batch: {e}")
def get_candles(self,
symbol: str,
timeframe: str,
start_time: datetime,
end_time: datetime,
exchange: str = "okx") -> List[Dict[str, Any]]:
"""
Retrieve candles from the database using the ORM.
"""
2025-06-13 16:49:29 +08:00
self.log_debug(f"DB: get_candles called with: symbol={symbol}, timeframe={timeframe}, start_time={start_time}, end_time={end_time}, exchange={exchange}")
try:
with self.get_session() as session:
query = (
session.query(MarketData)
.filter(
MarketData.exchange == exchange,
MarketData.symbol == symbol,
MarketData.timeframe == timeframe,
MarketData.timestamp >= start_time,
MarketData.timestamp <= end_time
)
.order_by(MarketData.timestamp.asc())
)
results = query.all()
candles = [
{
"exchange": r.exchange, "symbol": r.symbol, "timeframe": r.timeframe,
"timestamp": r.timestamp, "open": r.open, "high": r.high,
"low": r.low, "close": r.close, "volume": r.volume,
"trades_count": r.trades_count, "created_at": r.created_at
} for r in results
]
2025-06-13 16:49:29 +08:00
self.log_debug(f"DB: Retrieved {len(candles)} candles for {symbol} {timeframe} from {start_time} to {end_time}")
return candles
except Exception as e:
self.log_error(f"Error retrieving candles: {e}")
raise DatabaseOperationError(f"Failed to retrieve candles: {e}")
def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]:
"""
Get the latest candle for a symbol and timeframe using the ORM.
"""
try:
with self.get_session() as session:
latest = (
session.query(MarketData)
.filter(
MarketData.exchange == exchange,
MarketData.symbol == symbol,
MarketData.timeframe == timeframe
)
.order_by(MarketData.timestamp.desc())
.first()
)
if latest:
return {
"exchange": latest.exchange, "symbol": latest.symbol, "timeframe": latest.timeframe,
"timestamp": latest.timestamp, "open": latest.open, "high": latest.high,
"low": latest.low, "close": latest.close, "volume": latest.volume,
"trades_count": latest.trades_count, "created_at": latest.created_at
}
return None
except Exception as e:
self.log_error(f"Error retrieving latest candle for {symbol} {timeframe}: {e}")
raise DatabaseOperationError(f"Failed to retrieve latest candle: {e}")
def get_candles_df(self,
symbol: str,
timeframe: str,
start_time: datetime,
end_time: datetime,
exchange: str = "okx") -> pd.DataFrame:
"""
Retrieve candles from the database as a Pandas DataFrame using the ORM.
Args:
symbol: The trading symbol (e.g., 'BTC-USDT').
timeframe: The timeframe of the candles (e.g., '1h').
start_time: The start datetime for the data.
end_time: The end datetime for the data.
exchange: The exchange name (default: 'okx').
Returns:
A Pandas DataFrame containing the fetched candles, or an empty DataFrame if no data is found.
"""
try:
with self.get_session() as session:
query = (
session.query(MarketData)
.filter(
MarketData.exchange == exchange,
MarketData.symbol == symbol,
MarketData.timeframe == timeframe,
MarketData.timestamp >= start_time,
MarketData.timestamp <= end_time
)
.order_by(MarketData.timestamp.asc())
)
# Convert query results to a list of dictionaries, then to DataFrame
results = [
{
"timestamp": r.timestamp,
"open": float(r.open),
"high": float(r.high),
"low": float(r.low),
"close": float(r.close),
"volume": float(r.volume),
"trades_count": int(r.trades_count) if r.trades_count else 0
} for r in query.all()
]
df = pd.DataFrame(results)
if not df.empty:
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.set_index('timestamp')
self.log_debug(f"Retrieved {len(df)} candles as DataFrame for {symbol} {timeframe}")
return df
except Exception as e:
self.log_error(f"Error retrieving candles as DataFrame: {e}")
2025-06-13 16:49:29 +08:00
raise DatabaseOperationError(f"Failed to retrieve candles as DataFrame: {e}")
def delete_candles_before_timestamp(self, timestamp: datetime) -> int:
"""
Delete candles from the market_data table that are older than the specified timestamp.
"""
try:
with self.get_session() as session:
deleted_count = session.query(MarketData).filter(
MarketData.timestamp < timestamp
).delete(synchronize_session=False)
session.commit()
self.logger.warning(f"Deleted {deleted_count} candles older than {timestamp}")
return deleted_count
except Exception as e:
self.log_error(f"Error deleting candles older than {timestamp}: {e}")
raise DatabaseOperationError(f"Failed to delete candles: {e}")