Refactor database operations and enhance repository structure

- Introduced a modular repository structure by creating separate repository classes for `Bot`, `MarketData`, and `RawTrade`, improving code organization and maintainability.
- Updated the `DatabaseOperations` class to utilize the new repository classes, enhancing the abstraction of database interactions.
- Refactored the `.env` file to update database connection parameters and add new logging and health monitoring configurations.
- Modified the `okx_config.json` to change default timeframes for trading pairs, aligning with updated requirements.
- Added comprehensive unit tests for the new repository classes, ensuring robust functionality and reliability.

These changes improve the overall architecture of the database layer, making it more scalable and easier to manage.
This commit is contained in:
Vasily.onl 2025-06-06 21:54:45 +08:00
parent 848119e2cb
commit 028371a0e1
11 changed files with 787 additions and 452 deletions

23
.env
View File

@ -1,15 +1,15 @@
# Database Configuration
POSTGRES_DB=dashboard
POSTGRES_USER=dashboard
POSTGRES_PASSWORD=dashboard123
POSTGRES_PASSWORD=sdkjfh534^jh
POSTGRES_HOST=localhost
POSTGRES_PORT=5432
DATABASE_URL=postgresql://dashboard:dashboard123@localhost:5432/dashboard
POSTGRES_PORT=5434
DATABASE_URL=postgresql://dashboard:sdkjfh534^jh@localhost:5434/dashboard
# Redis Configuration
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=
REDIS_PASSWORD=redis987secure
# OKX API Configuration
OKX_API_KEY=your_okx_api_key_here
@ -29,10 +29,21 @@ DASH_DEBUG=true
# Bot Configuration
MAX_CONCURRENT_BOTS=5
BOT_UPDATE_INTERVAL=2 # seconds
BOT_UPDATE_INTERVAL=2
DEFAULT_VIRTUAL_BALANCE=10000
# Data Configuration
MARKET_DATA_SYMBOLS=BTC-USDT,ETH-USDT,LTC-USDT
HISTORICAL_DATA_DAYS=30
CHART_UPDATE_INTERVAL=2000 # milliseconds
CHART_UPDATE_INTERVAL=2000
# Logging
VERBOSE_LOGGING = true
LOG_CLEANUP=true
LOG_MAX_FILES=30
# Health monitoring
DEFAULT_HEALTH_CHECK_INTERVAL=30
MAX_SILENCE_DURATION=300
MAX_RECONNECT_ATTEMPTS=5
RECONNECT_DELAY=5

View File

@ -17,7 +17,7 @@
"factory": {
"use_factory_pattern": true,
"default_data_types": ["trade", "orderbook"],
"default_timeframes": ["5s", "30s", "1m", "5m", "15m", "1h"],
"default_timeframes": ["1s", "5s", "1m", "5m", "15m", "1h"],
"batch_create": true
},
"trading_pairs": [
@ -25,7 +25,7 @@
"symbol": "BTC-USDT",
"enabled": true,
"data_types": ["trade", "orderbook"],
"timeframes": ["5s", "1m", "5m", "15m", "1h"],
"timeframes": ["1s", "5s", "1m", "5m", "15m", "1h"],
"channels": {
"trades": "trades",
"orderbook": "books5",
@ -36,7 +36,7 @@
"symbol": "ETH-USDT",
"enabled": true,
"data_types": ["trade", "orderbook"],
"timeframes": ["5s", "1m", "5m", "15m", "1h"],
"timeframes": ["1s", "5s", "1m", "5m", "15m", "1h"],
"channels": {
"trades": "trades",
"orderbook": "books5",

View File

@ -13,8 +13,7 @@ from sqlalchemy import (
UniqueConstraint, text
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.sql import func
# Create base class for all models

View File

@ -1,422 +1,21 @@
"""
Database Operations Module
This module provides centralized database operations for all tables,
following the Repository pattern to abstract SQL complexity from business logic.
Benefits:
- Centralized database operations
- Clean API for different tables
- Easy to test and maintain
- Database implementation can change without affecting business logic
- Consistent error handling and logging
This module provides a centralized `DatabaseOperations` class that serves as the
main entry point for all database interactions. It follows the Repository pattern
by composing individual repository classes, each responsible for a specific table.
"""
from datetime import datetime
from decimal import Decimal
from typing import List, Optional, Dict, Any, Union
from contextlib import contextmanager
import logging
import json
from typing import Optional, Dict, Any
from sqlalchemy import text
from .connection import get_db_manager
from .models import MarketData, RawTrade
from data.common.data_types import OHLCVCandle, StandardizedTrade
from data.base_collector import MarketDataPoint, DataType
class DatabaseOperationError(Exception):
"""Custom exception for database operation errors."""
pass
class BaseRepository:
"""Base class for all repository classes."""
def __init__(self, logger: Optional[logging.Logger] = None):
"""Initialize repository with optional logger."""
self.logger = logger
self._db_manager = get_db_manager()
self._db_manager.initialize()
def log_info(self, message: str) -> None:
"""Log info message if logger is available."""
if self.logger:
self.logger.info(message)
def log_debug(self, message: str) -> None:
"""Log debug message if logger is available."""
if self.logger:
self.logger.debug(message)
def log_error(self, message: str) -> None:
"""Log error message if logger is available."""
if self.logger:
self.logger.error(message)
@contextmanager
def get_session(self):
"""Get database session with automatic cleanup."""
if not self._db_manager:
raise DatabaseOperationError("Database manager not initialized")
with self._db_manager.get_session() as session:
yield session
class MarketDataRepository(BaseRepository):
"""Repository for market_data table operations."""
def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool:
"""
Insert or update a candle in the market_data table.
Args:
candle: OHLCV candle to store
force_update: If True, update existing records; if False, ignore duplicates
Returns:
True if operation successful, False otherwise
"""
try:
# Use right-aligned timestamp (end_time) following industry standard
candle_timestamp = candle.end_time
with self.get_session() as session:
if force_update:
# Update existing records with new data
query = text("""
INSERT INTO market_data (
exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
) VALUES (
:exchange, :symbol, :timeframe, :timestamp,
:open, :high, :low, :close, :volume, :trades_count,
NOW()
)
ON CONFLICT (exchange, symbol, timeframe, timestamp)
DO UPDATE SET
open = EXCLUDED.open,
high = EXCLUDED.high,
low = EXCLUDED.low,
close = EXCLUDED.close,
volume = EXCLUDED.volume,
trades_count = EXCLUDED.trades_count
""")
action = "Updated"
else:
# Ignore duplicates, keep existing records
query = text("""
INSERT INTO market_data (
exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
) VALUES (
:exchange, :symbol, :timeframe, :timestamp,
:open, :high, :low, :close, :volume, :trades_count,
NOW()
)
ON CONFLICT (exchange, symbol, timeframe, timestamp)
DO NOTHING
""")
action = "Stored"
session.execute(query, {
'exchange': candle.exchange,
'symbol': candle.symbol,
'timeframe': candle.timeframe,
'timestamp': candle_timestamp,
'open': float(candle.open),
'high': float(candle.high),
'low': float(candle.low),
'close': float(candle.close),
'volume': float(candle.volume),
'trades_count': candle.trade_count
})
session.commit()
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})")
return True
except Exception as e:
self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}")
raise DatabaseOperationError(f"Failed to store candle: {e}")
def get_candles(self,
symbol: str,
timeframe: str,
start_time: datetime,
end_time: datetime,
exchange: str = "okx") -> List[Dict[str, Any]]:
"""
Retrieve candles from the database.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
start_time: Start timestamp
end_time: End timestamp
exchange: Exchange name
Returns:
List of candle dictionaries
"""
try:
with self.get_session() as session:
query = text("""
SELECT exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
FROM market_data
WHERE exchange = :exchange
AND symbol = :symbol
AND timeframe = :timeframe
AND timestamp >= :start_time
AND timestamp <= :end_time
ORDER BY timestamp ASC
""")
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timeframe': timeframe,
'start_time': start_time,
'end_time': end_time
})
candles = []
for row in result:
candles.append({
'exchange': row.exchange,
'symbol': row.symbol,
'timeframe': row.timeframe,
'timestamp': row.timestamp,
'open': row.open,
'high': row.high,
'low': row.low,
'close': row.close,
'volume': row.volume,
'trades_count': row.trades_count,
'created_at': row.created_at
})
self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}")
return candles
except Exception as e:
self.log_error(f"Error retrieving candles: {e}")
raise DatabaseOperationError(f"Failed to retrieve candles: {e}")
def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]:
"""
Get the latest candle for a symbol and timeframe.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
exchange: Exchange name
Returns:
Latest candle dictionary or None
"""
try:
with self.get_session() as session:
query = text("""
SELECT exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
FROM market_data
WHERE exchange = :exchange
AND symbol = :symbol
AND timeframe = :timeframe
ORDER BY timestamp DESC
LIMIT 1
""")
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timeframe': timeframe
})
row = result.fetchone()
if row:
return {
'exchange': row.exchange,
'symbol': row.symbol,
'timeframe': row.timeframe,
'timestamp': row.timestamp,
'open': row.open,
'high': row.high,
'low': row.low,
'close': row.close,
'volume': row.volume,
'trades_count': row.trades_count,
'created_at': row.created_at
}
return None
except Exception as e:
self.log_error(f"Error retrieving latest candle for {symbol} {timeframe}: {e}")
raise DatabaseOperationError(f"Failed to retrieve latest candle: {e}")
class RawTradeRepository(BaseRepository):
"""Repository for raw_trades table operations."""
def insert_market_data_point(self, data_point: MarketDataPoint) -> bool:
"""
Insert a market data point into raw_trades table.
Args:
data_point: Market data point to store
Returns:
True if operation successful, False otherwise
"""
try:
with self.get_session() as session:
query = text("""
INSERT INTO raw_trades (
exchange, symbol, timestamp, data_type, raw_data, created_at
) VALUES (
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
)
""")
session.execute(query, {
'exchange': data_point.exchange,
'symbol': data_point.symbol,
'timestamp': data_point.timestamp,
'data_type': data_point.data_type.value,
'raw_data': json.dumps(data_point.data)
})
session.commit()
self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}")
return True
except Exception as e:
self.log_error(f"Error storing raw data for {data_point.symbol}: {e}")
raise DatabaseOperationError(f"Failed to store raw data: {e}")
def insert_raw_websocket_data(self,
exchange: str,
symbol: str,
data_type: str,
raw_data: Dict[str, Any],
timestamp: Optional[datetime] = None) -> bool:
"""
Insert raw WebSocket data for debugging purposes.
Args:
exchange: Exchange name
symbol: Trading symbol
data_type: Type of data (e.g., 'raw_trades', 'raw_orderbook')
raw_data: Raw data dictionary
timestamp: Optional timestamp (defaults to now)
Returns:
True if operation successful, False otherwise
"""
try:
if timestamp is None:
timestamp = datetime.now()
with self.get_session() as session:
query = text("""
INSERT INTO raw_trades (
exchange, symbol, timestamp, data_type, raw_data, created_at
) VALUES (
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
)
""")
session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timestamp': timestamp,
'data_type': data_type,
'raw_data': json.dumps(raw_data)
})
session.commit()
self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}")
return True
except Exception as e:
self.log_error(f"Error storing raw WebSocket data for {symbol}: {e}")
raise DatabaseOperationError(f"Failed to store raw WebSocket data: {e}")
def get_raw_trades(self,
symbol: str,
data_type: str,
start_time: datetime,
end_time: datetime,
exchange: str = "okx",
limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Retrieve raw trades from the database.
Args:
symbol: Trading symbol
data_type: Data type filter
start_time: Start timestamp
end_time: End timestamp
exchange: Exchange name
limit: Maximum number of records to return
Returns:
List of raw trade dictionaries
"""
try:
with self.get_session() as session:
query = text("""
SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at
FROM raw_trades
WHERE exchange = :exchange
AND symbol = :symbol
AND data_type = :data_type
AND timestamp >= :start_time
AND timestamp <= :end_time
ORDER BY timestamp ASC
""")
if limit:
query += f" LIMIT {limit}"
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'data_type': data_type,
'start_time': start_time,
'end_time': end_time
})
trades = []
for row in result:
trades.append({
'id': row.id,
'exchange': row.exchange,
'symbol': row.symbol,
'timestamp': row.timestamp,
'data_type': row.data_type,
'raw_data': row.raw_data,
'created_at': row.created_at
})
self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}")
return trades
except Exception as e:
self.log_error(f"Error retrieving raw trades for {symbol}: {e}")
raise DatabaseOperationError(f"Failed to retrieve raw trades: {e}")
from .repositories import (
BotRepository,
MarketDataRepository,
RawTradeRepository,
DatabaseOperationError,
)
class DatabaseOperations:
"""
@ -431,6 +30,7 @@ class DatabaseOperations:
self.logger = logger
# Initialize repositories
self.bots = BotRepository(logger)
self.market_data = MarketDataRepository(logger)
self.raw_trades = RawTradeRepository(logger)
@ -442,8 +42,8 @@ class DatabaseOperations:
True if database is healthy, False otherwise
"""
try:
# We use one of the repositories to get a session
with self.market_data.get_session() as session:
# Simple query to test connection
result = session.execute(text("SELECT 1"))
return result.fetchone() is not None
except Exception as e:
@ -462,20 +62,17 @@ class DatabaseOperations:
stats = {
'healthy': self.health_check(),
'repositories': {
'bots': 'BotRepository',
'market_data': 'MarketDataRepository',
'raw_trades': 'RawTradeRepository'
}
}
# Get table counts
# Use a single session for all counts for efficiency
with self.market_data.get_session() as session:
# Market data count
result = session.execute(text("SELECT COUNT(*) FROM market_data"))
stats['candle_count'] = result.fetchone()[0]
# Raw trades count
result = session.execute(text("SELECT COUNT(*) FROM raw_trades"))
stats['raw_trade_count'] = result.fetchone()[0]
stats['bot_count'] = session.execute(text("SELECT COUNT(*) FROM bots")).scalar_one()
stats['candle_count'] = session.execute(text("SELECT COUNT(*) FROM market_data")).scalar_one()
stats['raw_trade_count'] = session.execute(text("SELECT COUNT(*) FROM raw_trades")).scalar_one()
return stats
@ -484,11 +81,9 @@ class DatabaseOperations:
self.logger.error(f"Error getting database stats: {e}")
return {'healthy': False, 'error': str(e)}
# Singleton instance for global access
_db_operations_instance: Optional[DatabaseOperations] = None
def get_database_operations(logger: Optional[logging.Logger] = None) -> DatabaseOperations:
"""
Get the global database operations instance.
@ -500,17 +95,6 @@ def get_database_operations(logger: Optional[logging.Logger] = None) -> Database
DatabaseOperations instance
"""
global _db_operations_instance
if _db_operations_instance is None:
_db_operations_instance = DatabaseOperations(logger)
return _db_operations_instance
__all__ = [
'DatabaseOperationError',
'MarketDataRepository',
'RawTradeRepository',
'DatabaseOperations',
'get_database_operations'
]
return _db_operations_instance

View File

@ -0,0 +1,15 @@
"""
This package contains all the repository classes for database operations.
"""
from .base_repository import BaseRepository, DatabaseOperationError
from .bot_repository import BotRepository
from .market_data_repository import MarketDataRepository
from .raw_trade_repository import RawTradeRepository
__all__ = [
"BaseRepository",
"DatabaseOperationError",
"BotRepository",
"MarketDataRepository",
"RawTradeRepository",
]

View File

@ -0,0 +1,46 @@
"""Base repository for all other repository classes."""
import logging
from contextlib import contextmanager
from typing import Optional
from ..connection import get_db_manager
class DatabaseOperationError(Exception):
"""Custom exception for database operation errors."""
pass
class BaseRepository:
"""Base class for all repository classes."""
def __init__(self, logger: Optional[logging.Logger] = None):
"""Initialize repository with optional logger."""
self.logger = logger
self._db_manager = get_db_manager()
self._db_manager.initialize()
def log_info(self, message: str) -> None:
"""Log info message if logger is available."""
if self.logger:
self.logger.info(message)
def log_debug(self, message: str) -> None:
"""Log debug message if logger is available."""
if self.logger:
self.logger.debug(message)
def log_error(self, message: str) -> None:
"""Log error message if logger is available."""
if self.logger:
self.logger.error(message)
@contextmanager
def get_session(self):
"""Get database session with automatic cleanup."""
if not self._db_manager:
raise DatabaseOperationError("Database manager not initialized")
with self._db_manager.get_session() as session:
yield session

View File

@ -0,0 +1,74 @@
"""Repository for bots table operations."""
from typing import Dict, Any, Optional
from ..models import Bot
from .base_repository import BaseRepository, DatabaseOperationError
class BotRepository(BaseRepository):
"""Repository for bots table operations."""
def add(self, bot_data: Dict[str, Any]) -> Bot:
"""Add a new bot to the database."""
try:
with self.get_session() as session:
new_bot = Bot(**bot_data)
session.add(new_bot)
session.commit()
session.refresh(new_bot)
self.log_info(f"Added new bot: {new_bot.name}")
return new_bot
except Exception as e:
self.log_error(f"Error adding bot: {e}")
raise DatabaseOperationError(f"Failed to add bot: {e}")
def get_by_id(self, bot_id: int) -> Optional[Bot]:
"""Get a bot by its ID."""
try:
with self.get_session() as session:
return session.query(Bot).filter(Bot.id == bot_id).first()
except Exception as e:
self.log_error(f"Error getting bot by ID {bot_id}: {e}")
raise DatabaseOperationError(f"Failed to get bot by ID: {e}")
def get_by_name(self, name: str) -> Optional[Bot]:
"""Get a bot by its name."""
try:
with self.get_session() as session:
return session.query(Bot).filter(Bot.name == name).first()
except Exception as e:
self.log_error(f"Error getting bot by name {name}: {e}")
raise DatabaseOperationError(f"Failed to get bot by name: {e}")
def update(self, bot_id: int, update_data: Dict[str, Any]) -> Optional[Bot]:
"""Update a bot's information."""
try:
with self.get_session() as session:
bot = session.query(Bot).filter(Bot.id == bot_id).first()
if bot:
for key, value in update_data.items():
setattr(bot, key, value)
session.commit()
session.refresh(bot)
self.log_info(f"Updated bot {bot_id}")
return bot
return None
except Exception as e:
self.log_error(f"Error updating bot {bot_id}: {e}")
raise DatabaseOperationError(f"Failed to update bot: {e}")
def delete(self, bot_id: int) -> bool:
"""Delete a bot by its ID."""
try:
with self.get_session() as session:
bot = session.query(Bot).filter(Bot.id == bot_id).first()
if bot:
session.delete(bot)
session.commit()
self.log_info(f"Deleted bot {bot_id}")
return True
return False
except Exception as e:
self.log_error(f"Error deleting bot {bot_id}: {e}")
raise DatabaseOperationError(f"Failed to delete bot: {e}")

View File

@ -0,0 +1,154 @@
"""Repository for market_data table operations."""
from datetime import datetime
from typing import List, Optional, Dict, Any
from sqlalchemy import text
from ..models import MarketData
from data.common.data_types import OHLCVCandle
from .base_repository import BaseRepository, DatabaseOperationError
class MarketDataRepository(BaseRepository):
"""Repository for market_data table operations."""
def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool:
"""
Insert or update a candle in the market_data table.
"""
try:
candle_timestamp = candle.end_time
with self.get_session() as session:
if force_update:
query = text("""
INSERT INTO market_data (
exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
) VALUES (
:exchange, :symbol, :timeframe, :timestamp,
:open, :high, :low, :close, :volume, :trades_count,
NOW()
)
ON CONFLICT (exchange, symbol, timeframe, timestamp)
DO UPDATE SET
open = EXCLUDED.open,
high = EXCLUDED.high,
low = EXCLUDED.low,
close = EXCLUDED.close,
volume = EXCLUDED.volume,
trades_count = EXCLUDED.trades_count
""")
action = "Updated"
else:
query = text("""
INSERT INTO market_data (
exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
) VALUES (
:exchange, :symbol, :timeframe, :timestamp,
:open, :high, :low, :close, :volume, :trades_count,
NOW()
)
ON CONFLICT (exchange, symbol, timeframe, timestamp)
DO NOTHING
""")
action = "Stored"
session.execute(query, {
'exchange': candle.exchange,
'symbol': candle.symbol,
'timeframe': candle.timeframe,
'timestamp': candle_timestamp,
'open': float(candle.open),
'high': float(candle.high),
'low': float(candle.low),
'close': float(candle.close),
'volume': float(candle.volume),
'trades_count': candle.trade_count
})
session.commit()
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})")
return True
except Exception as e:
self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}")
raise DatabaseOperationError(f"Failed to store candle: {e}")
def get_candles(self,
symbol: str,
timeframe: str,
start_time: datetime,
end_time: datetime,
exchange: str = "okx") -> List[Dict[str, Any]]:
"""
Retrieve candles from the database.
"""
try:
with self.get_session() as session:
query = text("""
SELECT exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
FROM market_data
WHERE exchange = :exchange
AND symbol = :symbol
AND timeframe = :timeframe
AND timestamp >= :start_time
AND timestamp <= :end_time
ORDER BY timestamp ASC
""")
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timeframe': timeframe,
'start_time': start_time,
'end_time': end_time
})
candles = [dict(row._mapping) for row in result]
self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}")
return candles
except Exception as e:
self.log_error(f"Error retrieving candles: {e}")
raise DatabaseOperationError(f"Failed to retrieve candles: {e}")
def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]:
"""
Get the latest candle for a symbol and timeframe.
"""
try:
with self.get_session() as session:
query = text("""
SELECT exchange, symbol, timeframe, timestamp,
open, high, low, close, volume, trades_count,
created_at
FROM market_data
WHERE exchange = :exchange
AND symbol = :symbol
AND timeframe = :timeframe
ORDER BY timestamp DESC
LIMIT 1
""")
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timeframe': timeframe
})
row = result.fetchone()
if row:
return dict(row._mapping)
return None
except Exception as e:
self.log_error(f"Error retrieving latest candle for {symbol} {timeframe}: {e}")
raise DatabaseOperationError(f"Failed to retrieve latest candle: {e}")

View File

@ -0,0 +1,128 @@
"""Repository for raw_trades table operations."""
import json
from datetime import datetime
from typing import Dict, Any, Optional, List
from sqlalchemy import text
from ..models import RawTrade
from data.base_collector import MarketDataPoint
from .base_repository import BaseRepository, DatabaseOperationError
class RawTradeRepository(BaseRepository):
"""Repository for raw_trades table operations."""
def insert_market_data_point(self, data_point: MarketDataPoint) -> bool:
"""
Insert a market data point into raw_trades table.
"""
try:
with self.get_session() as session:
query = text("""
INSERT INTO raw_trades (
exchange, symbol, timestamp, data_type, raw_data, created_at
) VALUES (
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
)
""")
session.execute(query, {
'exchange': data_point.exchange,
'symbol': data_point.symbol,
'timestamp': data_point.timestamp,
'data_type': data_point.data_type.value,
'raw_data': json.dumps(data_point.data)
})
session.commit()
self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}")
return True
except Exception as e:
self.log_error(f"Error storing raw data for {data_point.symbol}: {e}")
raise DatabaseOperationError(f"Failed to store raw data: {e}")
def insert_raw_websocket_data(self,
exchange: str,
symbol: str,
data_type: str,
raw_data: Dict[str, Any],
timestamp: Optional[datetime] = None) -> bool:
"""
Insert raw WebSocket data for debugging purposes.
"""
try:
if timestamp is None:
timestamp = datetime.now()
with self.get_session() as session:
query = text("""
INSERT INTO raw_trades (
exchange, symbol, timestamp, data_type, raw_data, created_at
) VALUES (
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
)
""")
session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'timestamp': timestamp,
'data_type': data_type,
'raw_data': json.dumps(raw_data)
})
session.commit()
self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}")
return True
except Exception as e:
self.log_error(f"Error storing raw WebSocket data for {symbol}: {e}")
raise DatabaseOperationError(f"Failed to store raw WebSocket data: {e}")
def get_raw_trades(self,
symbol: str,
data_type: str,
start_time: datetime,
end_time: datetime,
exchange: str = "okx",
limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Retrieve raw trades from the database.
"""
try:
with self.get_session() as session:
query = text("""
SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at
FROM raw_trades
WHERE exchange = :exchange
AND symbol = :symbol
AND data_type = :data_type
AND timestamp >= :start_time
AND timestamp <= :end_time
ORDER BY timestamp ASC
""")
if limit:
query_str = str(query.compile(compile_kwargs={"literal_binds": True})) + f" LIMIT {limit}"
query = text(query_str)
result = session.execute(query, {
'exchange': exchange,
'symbol': symbol,
'data_type': data_type,
'start_time': start_time,
'end_time': end_time
})
trades = [dict(row._mapping) for row in result]
self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}")
return trades
except Exception as e:
self.log_error(f"Error retrieving raw trades for {symbol}: {e}")
raise DatabaseOperationError(f"Failed to retrieve raw trades: {e}")

View File

@ -37,11 +37,11 @@ The Database Operations module (`database/operations.py`) provides a clean, cent
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────┐ ┌─────────────────┐ ┌──────────────┐ │
│ │MarketDataRepo │ │RawTradeRepo │ │ Future │ │
│ │ │ │ │ │ Repositories │ │
│ │ • upsert_candle │ │ • insert_data │ │ • OrderBook │ │
│ │ • get_candles │ │ • get_trades │ │ • UserTrades │ │
│ │ • get_latest │ │ • raw_websocket │ │ • Positions │ │
│ │MarketDataRepo │ │RawTradeRepo │ │ BotRepo │ │
│ │ │ │ │ │ │ │
│ │ • upsert_candle │ │ • insert_data │ │ • add │ │
│ │ • get_candles │ │ • get_trades │ │ • get_by_id │ │
│ │ • get_latest │ │ • raw_websocket │ │ • update/delete│ │
│ └─────────────────┘ └─────────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────┘
@ -118,8 +118,9 @@ async def main():
# Check statistics
stats = db.get_stats()
print(f"Total bots: {stats['bot_count']}")
print(f"Total candles: {stats['candle_count']}")
print(f"Total raw trades: {stats['trade_count']}")
print(f"Total raw trades: {stats['raw_trade_count']}")
asyncio.run(main())
```
@ -148,8 +149,9 @@ Get comprehensive database statistics.
```python
stats = db.get_stats()
print(f"Bots: {stats['bot_count']:,}")
print(f"Candles: {stats['candle_count']:,}")
print(f"Raw trades: {stats['trade_count']:,}")
print(f"Raw trades: {stats['raw_trade_count']:,}")
print(f"Health: {stats['healthy']}")
```
@ -212,6 +214,81 @@ else:
print("No candles found")
```
### BotRepository
Repository for `bots` table operations.
#### Methods
##### `add(bot_data: Dict[str, Any]) -> Bot`
Adds a new bot to the database.
**Parameters:**
- `bot_data`: Dictionary containing the bot's attributes (`name`, `strategy_name`, etc.)
**Returns:** The newly created `Bot` object.
```python
from decimal import Decimal
bot_data = {
"name": "MyTestBot",
"strategy_name": "SimpleMACD",
"symbol": "BTC-USDT",
"timeframe": "1h",
"status": "inactive",
"virtual_balance": Decimal("10000"),
}
new_bot = db.bots.add(bot_data)
print(f"Added bot with ID: {new_bot.id}")
```
##### `get_by_id(bot_id: int) -> Optional[Bot]`
Retrieves a bot by its unique ID.
```python
bot = db.bots.get_by_id(1)
if bot:
print(f"Found bot: {bot.name}")
```
##### `get_by_name(name: str) -> Optional[Bot]`
Retrieves a bot by its unique name.
```python
bot = db.bots.get_by_name("MyTestBot")
if bot:
print(f"Found bot with ID: {bot.id}")
```
##### `update(bot_id: int, update_data: Dict[str, Any]) -> Optional[Bot]`
Updates an existing bot's attributes.
```python
from datetime import datetime, timezone
update_payload = {"status": "active", "last_heartbeat": datetime.now(timezone.utc)}
updated_bot = db.bots.update(1, update_payload)
if updated_bot:
print(f"Bot status updated to: {updated_bot.status}")
```
##### `delete(bot_id: int) -> bool`
Deletes a bot from the database.
**Returns:** `True` if deletion was successful, `False` otherwise.
```python
success = db.bots.delete(1)
if success:
print("Bot deleted successfully.")
```
### RawTradeRepository
Repository for `raw_trades` table operations (raw WebSocket data).

View File

@ -0,0 +1,247 @@
import pytest
import pytest_asyncio
import asyncio
from datetime import datetime, timezone, timedelta
from decimal import Decimal
from database.operations import get_database_operations
from database.models import Bot
from data.common.data_types import OHLCVCandle
from data.base_collector import MarketDataPoint, DataType
@pytest.fixture(scope="module")
def event_loop():
"""Create an instance of the default event loop for each test module."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="module")
async def db_ops():
"""Fixture to provide database operations."""
# We will need to make sure the test database is configured and running
operations = get_database_operations()
yield operations
# Teardown logic can be added here if needed, e.g., operations.close()
@pytest.mark.asyncio
class TestBotRepository:
"""Tests for the BotRepository."""
async def test_add_and_get_bot(self, db_ops):
"""
Test adding a new bot and retrieving it to verify basic repository functionality.
"""
# Define a new bot
bot_name = "test_bot_01"
new_bot = {
"name": bot_name,
"strategy_name": "test_strategy",
"symbol": "BTC-USDT",
"timeframe": "1h",
"status": "inactive",
"virtual_balance": Decimal("10000"),
}
# Clean up any existing bot with the same name
existing_bot = db_ops.bots.get_by_name(bot_name)
if existing_bot:
db_ops.bots.delete(existing_bot.id)
# Add the bot
added_bot = db_ops.bots.add(new_bot)
# Assertions to check if the bot was added correctly
assert added_bot is not None
assert added_bot.id is not None
assert added_bot.name == bot_name
assert added_bot.status == "inactive"
# Retrieve the bot by ID
retrieved_bot = db_ops.bots.get_by_id(added_bot.id)
# Assertions to check if the bot was retrieved correctly
assert retrieved_bot is not None
assert retrieved_bot.id == added_bot.id
assert retrieved_bot.name == bot_name
# Clean up the created bot
db_ops.bots.delete(added_bot.id)
# Verify it's deleted
deleted_bot = db_ops.bots.get_by_id(added_bot.id)
assert deleted_bot is None
async def test_update_bot(self, db_ops):
"""Test updating an existing bot's status."""
bot_name = "test_bot_for_update"
bot_data = {
"name": bot_name,
"strategy_name": "test_strategy",
"symbol": "ETH-USDT",
"timeframe": "5m",
"status": "active",
}
# Ensure clean state
existing_bot = db_ops.bots.get_by_name(bot_name)
if existing_bot:
db_ops.bots.delete(existing_bot.id)
# Add a bot to update
bot_to_update = db_ops.bots.add(bot_data)
# Update the bot's status
update_data = {"status": "paused"}
updated_bot = db_ops.bots.update(bot_to_update.id, update_data)
# Assertions
assert updated_bot is not None
assert updated_bot.status == "paused"
# Clean up
db_ops.bots.delete(bot_to_update.id)
async def test_get_nonexistent_bot(self, db_ops):
"""Test that fetching a non-existent bot returns None."""
non_existent_bot = db_ops.bots.get_by_id(999999)
assert non_existent_bot is None
async def test_delete_bot(self, db_ops):
"""Test deleting a bot."""
bot_name = "test_bot_for_delete"
bot_data = {
"name": bot_name,
"strategy_name": "delete_strategy",
"symbol": "LTC-USDT",
"timeframe": "15m",
}
# Ensure clean state
existing_bot = db_ops.bots.get_by_name(bot_name)
if existing_bot:
db_ops.bots.delete(existing_bot.id)
# Add a bot to delete
bot_to_delete = db_ops.bots.add(bot_data)
# Delete the bot
delete_result = db_ops.bots.delete(bot_to_delete.id)
assert delete_result is True
# Verify it's gone
retrieved_bot = db_ops.bots.get_by_id(bot_to_delete.id)
assert retrieved_bot is None
@pytest.mark.asyncio
class TestMarketDataRepository:
"""Tests for the MarketDataRepository."""
async def test_upsert_and_get_candle(self, db_ops):
"""Test upserting and retrieving a candle."""
# Use a fixed timestamp for deterministic tests
base_time = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
candle = OHLCVCandle(
start_time=base_time,
end_time=base_time + timedelta(hours=1),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
trade_count=10,
timeframe="1h",
symbol="BTC-USDT-TUG", # Unique symbol for test
exchange="okx"
)
# Upsert the candle
success = db_ops.market_data.upsert_candle(candle)
assert success is True
# Retrieve the candle using a time range
start_time = base_time + timedelta(hours=1)
end_time = base_time + timedelta(hours=1)
retrieved_candles = db_ops.market_data.get_candles(
symbol="BTC-USDT-TUG",
timeframe="1h",
start_time=start_time,
end_time=end_time
)
assert len(retrieved_candles) >= 1
retrieved_candle = retrieved_candles[0]
assert retrieved_candle["symbol"] == "BTC-USDT-TUG"
assert retrieved_candle["close"] == candle.close
assert retrieved_candle["timestamp"] == candle.end_time
async def test_get_latest_candle(self, db_ops):
"""Test fetching the latest candle."""
base_time = datetime(2023, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
symbol = "ETH-USDT-TGLC" # Unique symbol for test
# Insert a few candles with increasing timestamps
for i in range(3):
candle = OHLCVCandle(
start_time=base_time + timedelta(minutes=i*5),
end_time=base_time + timedelta(minutes=(i+1)*5),
open=Decimal("1200") + i,
high=Decimal("1210") + i,
low=Decimal("1190") + i,
close=Decimal("1205") + i,
volume=Decimal("1000"),
trade_count=20+i,
timeframe="5m",
symbol=symbol,
exchange="okx"
)
db_ops.market_data.upsert_candle(candle)
latest_candle = db_ops.market_data.get_latest_candle(
symbol=symbol,
timeframe="5m"
)
assert latest_candle is not None
assert latest_candle["symbol"] == symbol
assert latest_candle["timeframe"] == "5m"
assert latest_candle["close"] == Decimal("1207")
assert latest_candle["timestamp"] == base_time + timedelta(minutes=15)
@pytest.mark.asyncio
class TestRawTradeRepository:
"""Tests for the RawTradeRepository."""
async def test_insert_and_get_raw_trade(self, db_ops):
"""Test inserting and retrieving a raw trade data point."""
base_time = datetime(2023, 1, 1, 14, 0, 0, tzinfo=timezone.utc)
symbol = "XRP-USDT-TIRT" # Unique symbol for test
data_point = MarketDataPoint(
symbol=symbol,
data_type=DataType.TRADE,
data={"price": "2.5", "qty": "100"},
timestamp=base_time,
exchange="okx"
)
# Insert raw data
success = db_ops.raw_trades.insert_market_data_point(data_point)
assert success is True
# Retrieve raw data
start_time = base_time - timedelta(seconds=1)
end_time = base_time + timedelta(seconds=1)
raw_trades = db_ops.raw_trades.get_raw_trades(
symbol=symbol,
data_type=DataType.TRADE.value,
start_time=start_time,
end_time=end_time
)
assert len(raw_trades) >= 1
retrieved_trade = raw_trades[0]
assert retrieved_trade["symbol"] == symbol
assert retrieved_trade["data_type"] == DataType.TRADE.value
assert retrieved_trade["raw_data"] == data_point.data