Refactor database operations and enhance repository structure

- Introduced a modular repository structure by creating separate repository classes for `Bot`, `MarketData`, and `RawTrade`, improving code organization and maintainability.
- Updated the `DatabaseOperations` class to utilize the new repository classes, enhancing the abstraction of database interactions.
- Refactored the `.env` file to update database connection parameters and add new logging and health monitoring configurations.
- Modified the `okx_config.json` to change default timeframes for trading pairs, aligning with updated requirements.
- Added comprehensive unit tests for the new repository classes, ensuring robust functionality and reliability.

These changes improve the overall architecture of the database layer, making it more scalable and easier to manage.
This commit is contained in:
Vasily.onl
2025-06-06 21:54:45 +08:00
parent 848119e2cb
commit 028371a0e1
11 changed files with 787 additions and 452 deletions

View File

@@ -0,0 +1,15 @@
"""
This package contains all the repository classes for database operations.
"""
from .base_repository import BaseRepository, DatabaseOperationError
from .bot_repository import BotRepository
from .market_data_repository import MarketDataRepository
from .raw_trade_repository import RawTradeRepository
__all__ = [
"BaseRepository",
"DatabaseOperationError",
"BotRepository",
"MarketDataRepository",
"RawTradeRepository",
]

View File

@@ -0,0 +1,46 @@
"""Base repository for all other repository classes."""
import logging
from contextlib import contextmanager
from typing import Optional
from ..connection import get_db_manager
class DatabaseOperationError(Exception):
"""Custom exception for database operation errors."""
pass
class BaseRepository:
"""Base class for all repository classes."""
def __init__(self, logger: Optional[logging.Logger] = None):
"""Initialize repository with optional logger."""
self.logger = logger
self._db_manager = get_db_manager()
self._db_manager.initialize()
def log_info(self, message: str) -> None:
"""Log info message if logger is available."""
if self.logger:
self.logger.info(message)
def log_debug(self, message: str) -> None:
"""Log debug message if logger is available."""
if self.logger:
self.logger.debug(message)
def log_error(self, message: str) -> None:
"""Log error message if logger is available."""
if self.logger:
self.logger.error(message)
@contextmanager
def get_session(self):
"""Get database session with automatic cleanup."""
if not self._db_manager:
raise DatabaseOperationError("Database manager not initialized")
with self._db_manager.get_session() as session:
yield session

View File

@@ -0,0 +1,74 @@
"""Repository for bots table operations."""
from typing import Dict, Any, Optional
from ..models import Bot
from .base_repository import BaseRepository, DatabaseOperationError
class BotRepository(BaseRepository):
"""Repository for bots table operations."""
def add(self, bot_data: Dict[str, Any]) -> Bot:
"""Add a new bot to the database."""
try:
with self.get_session() as session:
new_bot = Bot(**bot_data)
session.add(new_bot)
session.commit()
session.refresh(new_bot)
self.log_info(f"Added new bot: {new_bot.name}")
return new_bot
except Exception as e:
self.log_error(f"Error adding bot: {e}")
raise DatabaseOperationError(f"Failed to add bot: {e}")
def get_by_id(self, bot_id: int) -> Optional[Bot]:
"""Get a bot by its ID."""
try:
with self.get_session() as session:
return session.query(Bot).filter(Bot.id == bot_id).first()
except Exception as e:
self.log_error(f"Error getting bot by ID {bot_id}: {e}")
raise DatabaseOperationError(f"Failed to get bot by ID: {e}")
def get_by_name(self, name: str) -> Optional[Bot]:
"""Get a bot by its name."""
try:
with self.get_session() as session:
return session.query(Bot).filter(Bot.name == name).first()
except Exception as e:
self.log_error(f"Error getting bot by name {name}: {e}")
raise DatabaseOperationError(f"Failed to get bot by name: {e}")
def update(self, bot_id: int, update_data: Dict[str, Any]) -> Optional[Bot]:
"""Update a bot's information."""
try:
with self.get_session() as session:
bot = session.query(Bot).filter(Bot.id == bot_id).first()
if bot:
for key, value in update_data.items():
setattr(bot, key, value)
session.commit()
session.refresh(bot)
self.log_info(f"Updated bot {bot_id}")
return bot
return None
except Exception as e:
self.log_error(f"Error updating bot {bot_id}: {e}")
raise DatabaseOperationError(f"Failed to update bot: {e}")
def delete(self, bot_id: int) -> bool:
"""Delete a bot by its ID."""
try:
with self.get_session() as session:
bot = session.query(Bot).filter(Bot.id == bot_id).first()
if bot:
session.delete(bot)
session.commit()
self.log_info(f"Deleted bot {bot_id}")
return True
return False
except Exception as e:
self.log_error(f"Error deleting bot {bot_id}: {e}")
raise DatabaseOperationError(f"Failed to delete bot: {e}")

View File

@@ -0,0 +1,154 @@
"""Repository for market_data table operations."""
from datetime import datetime
from typing import List, Optional, Dict, Any
from sqlalchemy import text
from ..models import MarketData
from data.common.data_types import OHLCVCandle
from .base_repository import BaseRepository, DatabaseOperationError
class MarketDataRepository(BaseRepository):
"""Repository for market_data table operations."""
def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool:
"""
Insert or update a candle in the market_data table.
"""
try:
candle_timestamp = candle.end_time
with self.get_session() as session:
if force_update:
query = text("""
INSERT INTO market_data (
exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
) VALUES (
:exchange, :symbol, :timeframe, :timestamp,
:open, :high, :low, :close, :volume, :trades_count,
NOW()
)
ON CONFLICT (exchange, symbol, timeframe, timestamp)
DO UPDATE SET
open = EXCLUDED.open,
high = EXCLUDED.high,
low = EXCLUDED.low,
close = EXCLUDED.close,
volume = EXCLUDED.volume,
trades_count = EXCLUDED.trades_count
""")
action = "Updated"
else:
query = text("""
INSERT INTO market_data (
exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
) VALUES (
:exchange, :symbol, :timeframe, :timestamp,
:open, :high, :low, :close, :volume, :trades_count,
NOW()
)
ON CONFLICT (exchange, symbol, timeframe, timestamp)
DO NOTHING
""")
action = "Stored"
session.execute(query, {
'exchange': candle.exchange,
'symbol': candle.symbol,
'timeframe': candle.timeframe,
'timestamp': candle_timestamp,
'open': float(candle.open),
'high': float(candle.high),
'low': float(candle.low),
'close': float(candle.close),
'volume': float(candle.volume),
'trades_count': candle.trade_count
})
session.commit()
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})")
return True
except Exception as e:
self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}")
raise DatabaseOperationError(f"Failed to store candle: {e}")
def get_candles(self,
symbol: str,
timeframe: str,
start_time: datetime,
end_time: datetime,
exchange: str = "okx") -> List[Dict[str, Any]]:
"""
Retrieve candles from the database.
"""
try:
with self.get_session() as session:
query = text("""
SELECT exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
FROM market_data
WHERE exchange = :exchange
AND symbol = :symbol
AND timeframe = :timeframe
AND timestamp >= :start_time
AND timestamp <= :end_time
ORDER BY timestamp ASC
""")
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timeframe': timeframe,
'start_time': start_time,
'end_time': end_time
})
candles = [dict(row._mapping) for row in result]
self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}")
return candles
except Exception as e:
self.log_error(f"Error retrieving candles: {e}")
raise DatabaseOperationError(f"Failed to retrieve candles: {e}")
def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]:
"""
Get the latest candle for a symbol and timeframe.
"""
try:
with self.get_session() as session:
query = text("""
SELECT exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
FROM market_data
WHERE exchange = :exchange
AND symbol = :symbol
AND timeframe = :timeframe
ORDER BY timestamp DESC
LIMIT 1
""")
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timeframe': timeframe
})
row = result.fetchone()
if row:
return dict(row._mapping)
return None
except Exception as e:
self.log_error(f"Error retrieving latest candle for {symbol} {timeframe}: {e}")
raise DatabaseOperationError(f"Failed to retrieve latest candle: {e}")

View File

@@ -0,0 +1,128 @@
"""Repository for raw_trades table operations."""
import json
from datetime import datetime
from typing import Dict, Any, Optional, List
from sqlalchemy import text
from ..models import RawTrade
from data.base_collector import MarketDataPoint
from .base_repository import BaseRepository, DatabaseOperationError
class RawTradeRepository(BaseRepository):
"""Repository for raw_trades table operations."""
def insert_market_data_point(self, data_point: MarketDataPoint) -> bool:
"""
Insert a market data point into raw_trades table.
"""
try:
with self.get_session() as session:
query = text("""
INSERT INTO raw_trades (
exchange, symbol, timestamp, data_type, raw_data, created_at
) VALUES (
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
)
""")
session.execute(query, {
'exchange': data_point.exchange,
'symbol': data_point.symbol,
'timestamp': data_point.timestamp,
'data_type': data_point.data_type.value,
'raw_data': json.dumps(data_point.data)
})
session.commit()
self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}")
return True
except Exception as e:
self.log_error(f"Error storing raw data for {data_point.symbol}: {e}")
raise DatabaseOperationError(f"Failed to store raw data: {e}")
def insert_raw_websocket_data(self,
exchange: str,
symbol: str,
data_type: str,
raw_data: Dict[str, Any],
timestamp: Optional[datetime] = None) -> bool:
"""
Insert raw WebSocket data for debugging purposes.
"""
try:
if timestamp is None:
timestamp = datetime.now()
with self.get_session() as session:
query = text("""
INSERT INTO raw_trades (
exchange, symbol, timestamp, data_type, raw_data, created_at
) VALUES (
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
)
""")
session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timestamp': timestamp,
'data_type': data_type,
'raw_data': json.dumps(raw_data)
})
session.commit()
self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}")
return True
except Exception as e:
self.log_error(f"Error storing raw WebSocket data for {symbol}: {e}")
raise DatabaseOperationError(f"Failed to store raw WebSocket data: {e}")
def get_raw_trades(self,
symbol: str,
data_type: str,
start_time: datetime,
end_time: datetime,
exchange: str = "okx",
limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Retrieve raw trades from the database.
"""
try:
with self.get_session() as session:
query = text("""
SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at
FROM raw_trades
WHERE exchange = :exchange
AND symbol = :symbol
AND data_type = :data_type
AND timestamp >= :start_time
AND timestamp <= :end_time
ORDER BY timestamp ASC
""")
if limit:
query_str = str(query.compile(compile_kwargs={"literal_binds": True})) + f" LIMIT {limit}"
query = text(query_str)
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'data_type': data_type,
'start_time': start_time,
'end_time': end_time
})
trades = [dict(row._mapping) for row in result]
self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}")
return trades
except Exception as e:
self.log_error(f"Error retrieving raw trades for {symbol}: {e}")
raise DatabaseOperationError(f"Failed to retrieve raw trades: {e}")