Refactor database operations and enhance repository structure
- Introduced a modular repository structure by creating separate repository classes for `Bot`, `MarketData`, and `RawTrade`, improving code organization and maintainability. - Updated the `DatabaseOperations` class to utilize the new repository classes, enhancing the abstraction of database interactions. - Refactored the `.env` file to update database connection parameters and add new logging and health monitoring configurations. - Modified the `okx_config.json` to change default timeframes for trading pairs, aligning with updated requirements. - Added comprehensive unit tests for the new repository classes, ensuring robust functionality and reliability. These changes improve the overall architecture of the database layer, making it more scalable and easier to manage.
This commit is contained in:
parent
848119e2cb
commit
028371a0e1
23
.env
23
.env
@ -1,15 +1,15 @@
|
||||
# Database Configuration
|
||||
POSTGRES_DB=dashboard
|
||||
POSTGRES_USER=dashboard
|
||||
POSTGRES_PASSWORD=dashboard123
|
||||
POSTGRES_PASSWORD=sdkjfh534^jh
|
||||
POSTGRES_HOST=localhost
|
||||
POSTGRES_PORT=5432
|
||||
DATABASE_URL=postgresql://dashboard:dashboard123@localhost:5432/dashboard
|
||||
POSTGRES_PORT=5434
|
||||
DATABASE_URL=postgresql://dashboard:sdkjfh534^jh@localhost:5434/dashboard
|
||||
|
||||
# Redis Configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=
|
||||
REDIS_PASSWORD=redis987secure
|
||||
|
||||
# OKX API Configuration
|
||||
OKX_API_KEY=your_okx_api_key_here
|
||||
@ -29,10 +29,21 @@ DASH_DEBUG=true
|
||||
|
||||
# Bot Configuration
|
||||
MAX_CONCURRENT_BOTS=5
|
||||
BOT_UPDATE_INTERVAL=2 # seconds
|
||||
BOT_UPDATE_INTERVAL=2
|
||||
DEFAULT_VIRTUAL_BALANCE=10000
|
||||
|
||||
# Data Configuration
|
||||
MARKET_DATA_SYMBOLS=BTC-USDT,ETH-USDT,LTC-USDT
|
||||
HISTORICAL_DATA_DAYS=30
|
||||
CHART_UPDATE_INTERVAL=2000 # milliseconds
|
||||
CHART_UPDATE_INTERVAL=2000
|
||||
|
||||
# Logging
|
||||
VERBOSE_LOGGING = true
|
||||
LOG_CLEANUP=true
|
||||
LOG_MAX_FILES=30
|
||||
|
||||
# Health monitoring
|
||||
DEFAULT_HEALTH_CHECK_INTERVAL=30
|
||||
MAX_SILENCE_DURATION=300
|
||||
MAX_RECONNECT_ATTEMPTS=5
|
||||
RECONNECT_DELAY=5
|
||||
@ -17,7 +17,7 @@
|
||||
"factory": {
|
||||
"use_factory_pattern": true,
|
||||
"default_data_types": ["trade", "orderbook"],
|
||||
"default_timeframes": ["5s", "30s", "1m", "5m", "15m", "1h"],
|
||||
"default_timeframes": ["1s", "5s", "1m", "5m", "15m", "1h"],
|
||||
"batch_create": true
|
||||
},
|
||||
"trading_pairs": [
|
||||
@ -25,7 +25,7 @@
|
||||
"symbol": "BTC-USDT",
|
||||
"enabled": true,
|
||||
"data_types": ["trade", "orderbook"],
|
||||
"timeframes": ["5s", "1m", "5m", "15m", "1h"],
|
||||
"timeframes": ["1s", "5s", "1m", "5m", "15m", "1h"],
|
||||
"channels": {
|
||||
"trades": "trades",
|
||||
"orderbook": "books5",
|
||||
@ -36,7 +36,7 @@
|
||||
"symbol": "ETH-USDT",
|
||||
"enabled": true,
|
||||
"data_types": ["trade", "orderbook"],
|
||||
"timeframes": ["5s", "1m", "5m", "15m", "1h"],
|
||||
"timeframes": ["1s", "5s", "1m", "5m", "15m", "1h"],
|
||||
"channels": {
|
||||
"trades": "trades",
|
||||
"orderbook": "books5",
|
||||
|
||||
@ -13,8 +13,7 @@ from sqlalchemy import (
|
||||
UniqueConstraint, text
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
# Create base class for all models
|
||||
|
||||
@ -1,422 +1,21 @@
|
||||
"""
|
||||
Database Operations Module
|
||||
|
||||
This module provides centralized database operations for all tables,
|
||||
following the Repository pattern to abstract SQL complexity from business logic.
|
||||
|
||||
Benefits:
|
||||
- Centralized database operations
|
||||
- Clean API for different tables
|
||||
- Easy to test and maintain
|
||||
- Database implementation can change without affecting business logic
|
||||
- Consistent error handling and logging
|
||||
This module provides a centralized `DatabaseOperations` class that serves as the
|
||||
main entry point for all database interactions. It follows the Repository pattern
|
||||
by composing individual repository classes, each responsible for a specific table.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy import text
|
||||
|
||||
from .connection import get_db_manager
|
||||
from .models import MarketData, RawTrade
|
||||
from data.common.data_types import OHLCVCandle, StandardizedTrade
|
||||
from data.base_collector import MarketDataPoint, DataType
|
||||
|
||||
|
||||
class DatabaseOperationError(Exception):
|
||||
"""Custom exception for database operation errors."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseRepository:
|
||||
"""Base class for all repository classes."""
|
||||
|
||||
def __init__(self, logger: Optional[logging.Logger] = None):
|
||||
"""Initialize repository with optional logger."""
|
||||
self.logger = logger
|
||||
self._db_manager = get_db_manager()
|
||||
self._db_manager.initialize()
|
||||
|
||||
def log_info(self, message: str) -> None:
|
||||
"""Log info message if logger is available."""
|
||||
if self.logger:
|
||||
self.logger.info(message)
|
||||
|
||||
def log_debug(self, message: str) -> None:
|
||||
"""Log debug message if logger is available."""
|
||||
if self.logger:
|
||||
self.logger.debug(message)
|
||||
|
||||
def log_error(self, message: str) -> None:
|
||||
"""Log error message if logger is available."""
|
||||
if self.logger:
|
||||
self.logger.error(message)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self):
|
||||
"""Get database session with automatic cleanup."""
|
||||
if not self._db_manager:
|
||||
raise DatabaseOperationError("Database manager not initialized")
|
||||
|
||||
with self._db_manager.get_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
class MarketDataRepository(BaseRepository):
|
||||
"""Repository for market_data table operations."""
|
||||
|
||||
def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool:
|
||||
"""
|
||||
Insert or update a candle in the market_data table.
|
||||
|
||||
Args:
|
||||
candle: OHLCV candle to store
|
||||
force_update: If True, update existing records; if False, ignore duplicates
|
||||
|
||||
Returns:
|
||||
True if operation successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Use right-aligned timestamp (end_time) following industry standard
|
||||
candle_timestamp = candle.end_time
|
||||
|
||||
with self.get_session() as session:
|
||||
if force_update:
|
||||
# Update existing records with new data
|
||||
query = text("""
|
||||
INSERT INTO market_data (
|
||||
exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timeframe, :timestamp,
|
||||
:open, :high, :low, :close, :volume, :trades_count,
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT (exchange, symbol, timeframe, timestamp)
|
||||
DO UPDATE SET
|
||||
open = EXCLUDED.open,
|
||||
high = EXCLUDED.high,
|
||||
low = EXCLUDED.low,
|
||||
close = EXCLUDED.close,
|
||||
volume = EXCLUDED.volume,
|
||||
trades_count = EXCLUDED.trades_count
|
||||
""")
|
||||
action = "Updated"
|
||||
else:
|
||||
# Ignore duplicates, keep existing records
|
||||
query = text("""
|
||||
INSERT INTO market_data (
|
||||
exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timeframe, :timestamp,
|
||||
:open, :high, :low, :close, :volume, :trades_count,
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT (exchange, symbol, timeframe, timestamp)
|
||||
DO NOTHING
|
||||
""")
|
||||
action = "Stored"
|
||||
|
||||
session.execute(query, {
|
||||
'exchange': candle.exchange,
|
||||
'symbol': candle.symbol,
|
||||
'timeframe': candle.timeframe,
|
||||
'timestamp': candle_timestamp,
|
||||
'open': float(candle.open),
|
||||
'high': float(candle.high),
|
||||
'low': float(candle.low),
|
||||
'close': float(candle.close),
|
||||
'volume': float(candle.volume),
|
||||
'trades_count': candle.trade_count
|
||||
})
|
||||
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to store candle: {e}")
|
||||
|
||||
def get_candles(self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
exchange: str = "okx") -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve candles from the database.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
start_time: Start timestamp
|
||||
end_time: End timestamp
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
List of candle dictionaries
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
FROM market_data
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND timeframe = :timeframe
|
||||
AND timestamp >= :start_time
|
||||
AND timestamp <= :end_time
|
||||
ORDER BY timestamp ASC
|
||||
""")
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
|
||||
candles = []
|
||||
for row in result:
|
||||
candles.append({
|
||||
'exchange': row.exchange,
|
||||
'symbol': row.symbol,
|
||||
'timeframe': row.timeframe,
|
||||
'timestamp': row.timestamp,
|
||||
'open': row.open,
|
||||
'high': row.high,
|
||||
'low': row.low,
|
||||
'close': row.close,
|
||||
'volume': row.volume,
|
||||
'trades_count': row.trades_count,
|
||||
'created_at': row.created_at
|
||||
})
|
||||
|
||||
self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}")
|
||||
return candles
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error retrieving candles: {e}")
|
||||
raise DatabaseOperationError(f"Failed to retrieve candles: {e}")
|
||||
|
||||
def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get the latest candle for a symbol and timeframe.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Latest candle dictionary or None
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
FROM market_data
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND timeframe = :timeframe
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe
|
||||
})
|
||||
|
||||
row = result.fetchone()
|
||||
if row:
|
||||
return {
|
||||
'exchange': row.exchange,
|
||||
'symbol': row.symbol,
|
||||
'timeframe': row.timeframe,
|
||||
'timestamp': row.timestamp,
|
||||
'open': row.open,
|
||||
'high': row.high,
|
||||
'low': row.low,
|
||||
'close': row.close,
|
||||
'volume': row.volume,
|
||||
'trades_count': row.trades_count,
|
||||
'created_at': row.created_at
|
||||
}
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error retrieving latest candle for {symbol} {timeframe}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to retrieve latest candle: {e}")
|
||||
|
||||
|
||||
class RawTradeRepository(BaseRepository):
|
||||
"""Repository for raw_trades table operations."""
|
||||
|
||||
def insert_market_data_point(self, data_point: MarketDataPoint) -> bool:
|
||||
"""
|
||||
Insert a market data point into raw_trades table.
|
||||
|
||||
Args:
|
||||
data_point: Market data point to store
|
||||
|
||||
Returns:
|
||||
True if operation successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
INSERT INTO raw_trades (
|
||||
exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
session.execute(query, {
|
||||
'exchange': data_point.exchange,
|
||||
'symbol': data_point.symbol,
|
||||
'timestamp': data_point.timestamp,
|
||||
'data_type': data_point.data_type.value,
|
||||
'raw_data': json.dumps(data_point.data)
|
||||
})
|
||||
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error storing raw data for {data_point.symbol}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to store raw data: {e}")
|
||||
|
||||
def insert_raw_websocket_data(self,
|
||||
exchange: str,
|
||||
symbol: str,
|
||||
data_type: str,
|
||||
raw_data: Dict[str, Any],
|
||||
timestamp: Optional[datetime] = None) -> bool:
|
||||
"""
|
||||
Insert raw WebSocket data for debugging purposes.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
symbol: Trading symbol
|
||||
data_type: Type of data (e.g., 'raw_trades', 'raw_orderbook')
|
||||
raw_data: Raw data dictionary
|
||||
timestamp: Optional timestamp (defaults to now)
|
||||
|
||||
Returns:
|
||||
True if operation successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
INSERT INTO raw_trades (
|
||||
exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timestamp': timestamp,
|
||||
'data_type': data_type,
|
||||
'raw_data': json.dumps(raw_data)
|
||||
})
|
||||
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error storing raw WebSocket data for {symbol}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to store raw WebSocket data: {e}")
|
||||
|
||||
def get_raw_trades(self,
|
||||
symbol: str,
|
||||
data_type: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
exchange: str = "okx",
|
||||
limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve raw trades from the database.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
data_type: Data type filter
|
||||
start_time: Start timestamp
|
||||
end_time: End timestamp
|
||||
exchange: Exchange name
|
||||
limit: Maximum number of records to return
|
||||
|
||||
Returns:
|
||||
List of raw trade dictionaries
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
FROM raw_trades
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND data_type = :data_type
|
||||
AND timestamp >= :start_time
|
||||
AND timestamp <= :end_time
|
||||
ORDER BY timestamp ASC
|
||||
""")
|
||||
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'data_type': data_type,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
|
||||
trades = []
|
||||
for row in result:
|
||||
trades.append({
|
||||
'id': row.id,
|
||||
'exchange': row.exchange,
|
||||
'symbol': row.symbol,
|
||||
'timestamp': row.timestamp,
|
||||
'data_type': row.data_type,
|
||||
'raw_data': row.raw_data,
|
||||
'created_at': row.created_at
|
||||
})
|
||||
|
||||
self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}")
|
||||
return trades
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error retrieving raw trades for {symbol}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to retrieve raw trades: {e}")
|
||||
|
||||
from .repositories import (
|
||||
BotRepository,
|
||||
MarketDataRepository,
|
||||
RawTradeRepository,
|
||||
DatabaseOperationError,
|
||||
)
|
||||
|
||||
class DatabaseOperations:
|
||||
"""
|
||||
@ -431,6 +30,7 @@ class DatabaseOperations:
|
||||
self.logger = logger
|
||||
|
||||
# Initialize repositories
|
||||
self.bots = BotRepository(logger)
|
||||
self.market_data = MarketDataRepository(logger)
|
||||
self.raw_trades = RawTradeRepository(logger)
|
||||
|
||||
@ -442,8 +42,8 @@ class DatabaseOperations:
|
||||
True if database is healthy, False otherwise
|
||||
"""
|
||||
try:
|
||||
# We use one of the repositories to get a session
|
||||
with self.market_data.get_session() as session:
|
||||
# Simple query to test connection
|
||||
result = session.execute(text("SELECT 1"))
|
||||
return result.fetchone() is not None
|
||||
except Exception as e:
|
||||
@ -462,20 +62,17 @@ class DatabaseOperations:
|
||||
stats = {
|
||||
'healthy': self.health_check(),
|
||||
'repositories': {
|
||||
'bots': 'BotRepository',
|
||||
'market_data': 'MarketDataRepository',
|
||||
'raw_trades': 'RawTradeRepository'
|
||||
}
|
||||
}
|
||||
|
||||
# Get table counts
|
||||
# Use a single session for all counts for efficiency
|
||||
with self.market_data.get_session() as session:
|
||||
# Market data count
|
||||
result = session.execute(text("SELECT COUNT(*) FROM market_data"))
|
||||
stats['candle_count'] = result.fetchone()[0]
|
||||
|
||||
# Raw trades count
|
||||
result = session.execute(text("SELECT COUNT(*) FROM raw_trades"))
|
||||
stats['raw_trade_count'] = result.fetchone()[0]
|
||||
stats['bot_count'] = session.execute(text("SELECT COUNT(*) FROM bots")).scalar_one()
|
||||
stats['candle_count'] = session.execute(text("SELECT COUNT(*) FROM market_data")).scalar_one()
|
||||
stats['raw_trade_count'] = session.execute(text("SELECT COUNT(*) FROM raw_trades")).scalar_one()
|
||||
|
||||
return stats
|
||||
|
||||
@ -484,11 +81,9 @@ class DatabaseOperations:
|
||||
self.logger.error(f"Error getting database stats: {e}")
|
||||
return {'healthy': False, 'error': str(e)}
|
||||
|
||||
|
||||
# Singleton instance for global access
|
||||
_db_operations_instance: Optional[DatabaseOperations] = None
|
||||
|
||||
|
||||
def get_database_operations(logger: Optional[logging.Logger] = None) -> DatabaseOperations:
|
||||
"""
|
||||
Get the global database operations instance.
|
||||
@ -500,17 +95,6 @@ def get_database_operations(logger: Optional[logging.Logger] = None) -> Database
|
||||
DatabaseOperations instance
|
||||
"""
|
||||
global _db_operations_instance
|
||||
|
||||
if _db_operations_instance is None:
|
||||
_db_operations_instance = DatabaseOperations(logger)
|
||||
|
||||
return _db_operations_instance
|
||||
|
||||
|
||||
__all__ = [
|
||||
'DatabaseOperationError',
|
||||
'MarketDataRepository',
|
||||
'RawTradeRepository',
|
||||
'DatabaseOperations',
|
||||
'get_database_operations'
|
||||
]
|
||||
return _db_operations_instance
|
||||
15
database/repositories/__init__.py
Normal file
15
database/repositories/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""
|
||||
This package contains all the repository classes for database operations.
|
||||
"""
|
||||
from .base_repository import BaseRepository, DatabaseOperationError
|
||||
from .bot_repository import BotRepository
|
||||
from .market_data_repository import MarketDataRepository
|
||||
from .raw_trade_repository import RawTradeRepository
|
||||
|
||||
__all__ = [
|
||||
"BaseRepository",
|
||||
"DatabaseOperationError",
|
||||
"BotRepository",
|
||||
"MarketDataRepository",
|
||||
"RawTradeRepository",
|
||||
]
|
||||
46
database/repositories/base_repository.py
Normal file
46
database/repositories/base_repository.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Base repository for all other repository classes."""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
from ..connection import get_db_manager
|
||||
|
||||
|
||||
class DatabaseOperationError(Exception):
|
||||
"""Custom exception for database operation errors."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseRepository:
|
||||
"""Base class for all repository classes."""
|
||||
|
||||
def __init__(self, logger: Optional[logging.Logger] = None):
|
||||
"""Initialize repository with optional logger."""
|
||||
self.logger = logger
|
||||
self._db_manager = get_db_manager()
|
||||
self._db_manager.initialize()
|
||||
|
||||
def log_info(self, message: str) -> None:
|
||||
"""Log info message if logger is available."""
|
||||
if self.logger:
|
||||
self.logger.info(message)
|
||||
|
||||
def log_debug(self, message: str) -> None:
|
||||
"""Log debug message if logger is available."""
|
||||
if self.logger:
|
||||
self.logger.debug(message)
|
||||
|
||||
def log_error(self, message: str) -> None:
|
||||
"""Log error message if logger is available."""
|
||||
if self.logger:
|
||||
self.logger.error(message)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self):
|
||||
"""Get database session with automatic cleanup."""
|
||||
if not self._db_manager:
|
||||
raise DatabaseOperationError("Database manager not initialized")
|
||||
|
||||
with self._db_manager.get_session() as session:
|
||||
yield session
|
||||
74
database/repositories/bot_repository.py
Normal file
74
database/repositories/bot_repository.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""Repository for bots table operations."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from ..models import Bot
|
||||
from .base_repository import BaseRepository, DatabaseOperationError
|
||||
|
||||
|
||||
class BotRepository(BaseRepository):
|
||||
"""Repository for bots table operations."""
|
||||
|
||||
def add(self, bot_data: Dict[str, Any]) -> Bot:
|
||||
"""Add a new bot to the database."""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
new_bot = Bot(**bot_data)
|
||||
session.add(new_bot)
|
||||
session.commit()
|
||||
session.refresh(new_bot)
|
||||
self.log_info(f"Added new bot: {new_bot.name}")
|
||||
return new_bot
|
||||
except Exception as e:
|
||||
self.log_error(f"Error adding bot: {e}")
|
||||
raise DatabaseOperationError(f"Failed to add bot: {e}")
|
||||
|
||||
def get_by_id(self, bot_id: int) -> Optional[Bot]:
|
||||
"""Get a bot by its ID."""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
return session.query(Bot).filter(Bot.id == bot_id).first()
|
||||
except Exception as e:
|
||||
self.log_error(f"Error getting bot by ID {bot_id}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to get bot by ID: {e}")
|
||||
|
||||
def get_by_name(self, name: str) -> Optional[Bot]:
|
||||
"""Get a bot by its name."""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
return session.query(Bot).filter(Bot.name == name).first()
|
||||
except Exception as e:
|
||||
self.log_error(f"Error getting bot by name {name}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to get bot by name: {e}")
|
||||
|
||||
def update(self, bot_id: int, update_data: Dict[str, Any]) -> Optional[Bot]:
|
||||
"""Update a bot's information."""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
bot = session.query(Bot).filter(Bot.id == bot_id).first()
|
||||
if bot:
|
||||
for key, value in update_data.items():
|
||||
setattr(bot, key, value)
|
||||
session.commit()
|
||||
session.refresh(bot)
|
||||
self.log_info(f"Updated bot {bot_id}")
|
||||
return bot
|
||||
return None
|
||||
except Exception as e:
|
||||
self.log_error(f"Error updating bot {bot_id}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to update bot: {e}")
|
||||
|
||||
def delete(self, bot_id: int) -> bool:
|
||||
"""Delete a bot by its ID."""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
bot = session.query(Bot).filter(Bot.id == bot_id).first()
|
||||
if bot:
|
||||
session.delete(bot)
|
||||
session.commit()
|
||||
self.log_info(f"Deleted bot {bot_id}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
self.log_error(f"Error deleting bot {bot_id}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to delete bot: {e}")
|
||||
154
database/repositories/market_data_repository.py
Normal file
154
database/repositories/market_data_repository.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""Repository for market_data table operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy import text
|
||||
|
||||
from ..models import MarketData
|
||||
from data.common.data_types import OHLCVCandle
|
||||
from .base_repository import BaseRepository, DatabaseOperationError
|
||||
|
||||
|
||||
class MarketDataRepository(BaseRepository):
|
||||
"""Repository for market_data table operations."""
|
||||
|
||||
def upsert_candle(self, candle: OHLCVCandle, force_update: bool = False) -> bool:
|
||||
"""
|
||||
Insert or update a candle in the market_data table.
|
||||
"""
|
||||
try:
|
||||
candle_timestamp = candle.end_time
|
||||
|
||||
with self.get_session() as session:
|
||||
if force_update:
|
||||
query = text("""
|
||||
INSERT INTO market_data (
|
||||
exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timeframe, :timestamp,
|
||||
:open, :high, :low, :close, :volume, :trades_count,
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT (exchange, symbol, timeframe, timestamp)
|
||||
DO UPDATE SET
|
||||
open = EXCLUDED.open,
|
||||
high = EXCLUDED.high,
|
||||
low = EXCLUDED.low,
|
||||
close = EXCLUDED.close,
|
||||
volume = EXCLUDED.volume,
|
||||
trades_count = EXCLUDED.trades_count
|
||||
""")
|
||||
action = "Updated"
|
||||
else:
|
||||
query = text("""
|
||||
INSERT INTO market_data (
|
||||
exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timeframe, :timestamp,
|
||||
:open, :high, :low, :close, :volume, :trades_count,
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT (exchange, symbol, timeframe, timestamp)
|
||||
DO NOTHING
|
||||
""")
|
||||
action = "Stored"
|
||||
|
||||
session.execute(query, {
|
||||
'exchange': candle.exchange,
|
||||
'symbol': candle.symbol,
|
||||
'timeframe': candle.timeframe,
|
||||
'timestamp': candle_timestamp,
|
||||
'open': float(candle.open),
|
||||
'high': float(candle.high),
|
||||
'low': float(candle.low),
|
||||
'close': float(candle.close),
|
||||
'volume': float(candle.volume),
|
||||
'trades_count': candle.trade_count
|
||||
})
|
||||
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"{action} candle: {candle.symbol} {candle.timeframe} at {candle_timestamp} (force_update={force_update})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to store candle: {e}")
|
||||
|
||||
def get_candles(self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
exchange: str = "okx") -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve candles from the database.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
FROM market_data
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND timeframe = :timeframe
|
||||
AND timestamp >= :start_time
|
||||
AND timestamp <= :end_time
|
||||
ORDER BY timestamp ASC
|
||||
""")
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
|
||||
candles = [dict(row._mapping) for row in result]
|
||||
|
||||
self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}")
|
||||
return candles
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error retrieving candles: {e}")
|
||||
raise DatabaseOperationError(f"Failed to retrieve candles: {e}")
|
||||
|
||||
def get_latest_candle(self, symbol: str, timeframe: str, exchange: str = "okx") -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get the latest candle for a symbol and timeframe.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT exchange, symbol, timeframe, timestamp,
|
||||
open, high, low, close, volume, trades_count,
|
||||
created_at
|
||||
FROM market_data
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND timeframe = :timeframe
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe
|
||||
})
|
||||
|
||||
row = result.fetchone()
|
||||
if row:
|
||||
return dict(row._mapping)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error retrieving latest candle for {symbol} {timeframe}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to retrieve latest candle: {e}")
|
||||
128
database/repositories/raw_trade_repository.py
Normal file
128
database/repositories/raw_trade_repository.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""Repository for raw_trades table operations."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy import text
|
||||
|
||||
from ..models import RawTrade
|
||||
from data.base_collector import MarketDataPoint
|
||||
from .base_repository import BaseRepository, DatabaseOperationError
|
||||
|
||||
|
||||
class RawTradeRepository(BaseRepository):
|
||||
"""Repository for raw_trades table operations."""
|
||||
|
||||
def insert_market_data_point(self, data_point: MarketDataPoint) -> bool:
|
||||
"""
|
||||
Insert a market data point into raw_trades table.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
INSERT INTO raw_trades (
|
||||
exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
session.execute(query, {
|
||||
'exchange': data_point.exchange,
|
||||
'symbol': data_point.symbol,
|
||||
'timestamp': data_point.timestamp,
|
||||
'data_type': data_point.data_type.value,
|
||||
'raw_data': json.dumps(data_point.data)
|
||||
})
|
||||
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"Stored raw {data_point.data_type.value} data for {data_point.symbol}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error storing raw data for {data_point.symbol}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to store raw data: {e}")
|
||||
|
||||
def insert_raw_websocket_data(self,
|
||||
exchange: str,
|
||||
symbol: str,
|
||||
data_type: str,
|
||||
raw_data: Dict[str, Any],
|
||||
timestamp: Optional[datetime] = None) -> bool:
|
||||
"""
|
||||
Insert raw WebSocket data for debugging purposes.
|
||||
"""
|
||||
try:
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
INSERT INTO raw_trades (
|
||||
exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
) VALUES (
|
||||
:exchange, :symbol, :timestamp, :data_type, :raw_data, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'timestamp': timestamp,
|
||||
'data_type': data_type,
|
||||
'raw_data': json.dumps(raw_data)
|
||||
})
|
||||
|
||||
session.commit()
|
||||
|
||||
self.log_debug(f"Stored raw WebSocket data: {data_type} for {symbol}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error storing raw WebSocket data for {symbol}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to store raw WebSocket data: {e}")
|
||||
|
||||
def get_raw_trades(self,
|
||||
symbol: str,
|
||||
data_type: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
exchange: str = "okx",
|
||||
limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve raw trades from the database.
|
||||
"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
query = text("""
|
||||
SELECT id, exchange, symbol, timestamp, data_type, raw_data, created_at
|
||||
FROM raw_trades
|
||||
WHERE exchange = :exchange
|
||||
AND symbol = :symbol
|
||||
AND data_type = :data_type
|
||||
AND timestamp >= :start_time
|
||||
AND timestamp <= :end_time
|
||||
ORDER BY timestamp ASC
|
||||
""")
|
||||
|
||||
if limit:
|
||||
query_str = str(query.compile(compile_kwargs={"literal_binds": True})) + f" LIMIT {limit}"
|
||||
query = text(query_str)
|
||||
|
||||
result = session.execute(query, {
|
||||
'exchange': exchange,
|
||||
'symbol': symbol,
|
||||
'data_type': data_type,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
|
||||
trades = [dict(row._mapping) for row in result]
|
||||
|
||||
self.log_info(f"Retrieved {len(trades)} raw trades for {symbol} {data_type}")
|
||||
return trades
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"Error retrieving raw trades for {symbol}: {e}")
|
||||
raise DatabaseOperationError(f"Failed to retrieve raw trades: {e}")
|
||||
@ -37,11 +37,11 @@ The Database Operations module (`database/operations.py`) provides a clean, cent
|
||||
│ └─────────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌─────────────────┐ ┌─────────────────┐ ┌──────────────┐ │
|
||||
│ │MarketDataRepo │ │RawTradeRepo │ │ Future │ │
|
||||
│ │ │ │ │ │ Repositories │ │
|
||||
│ │ • upsert_candle │ │ • insert_data │ │ • OrderBook │ │
|
||||
│ │ • get_candles │ │ • get_trades │ │ • UserTrades │ │
|
||||
│ │ • get_latest │ │ • raw_websocket │ │ • Positions │ │
|
||||
│ │MarketDataRepo │ │RawTradeRepo │ │ BotRepo │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ │ • upsert_candle │ │ • insert_data │ │ • add │ │
|
||||
│ │ • get_candles │ │ • get_trades │ │ • get_by_id │ │
|
||||
│ │ • get_latest │ │ • raw_websocket │ │ • update/delete│ │
|
||||
│ └─────────────────┘ └─────────────────┘ └──────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
@ -118,8 +118,9 @@ async def main():
|
||||
|
||||
# Check statistics
|
||||
stats = db.get_stats()
|
||||
print(f"Total bots: {stats['bot_count']}")
|
||||
print(f"Total candles: {stats['candle_count']}")
|
||||
print(f"Total raw trades: {stats['trade_count']}")
|
||||
print(f"Total raw trades: {stats['raw_trade_count']}")
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
@ -148,8 +149,9 @@ Get comprehensive database statistics.
|
||||
|
||||
```python
|
||||
stats = db.get_stats()
|
||||
print(f"Bots: {stats['bot_count']:,}")
|
||||
print(f"Candles: {stats['candle_count']:,}")
|
||||
print(f"Raw trades: {stats['trade_count']:,}")
|
||||
print(f"Raw trades: {stats['raw_trade_count']:,}")
|
||||
print(f"Health: {stats['healthy']}")
|
||||
```
|
||||
|
||||
@ -212,6 +214,81 @@ else:
|
||||
print("No candles found")
|
||||
```
|
||||
|
||||
### BotRepository
|
||||
|
||||
Repository for `bots` table operations.
|
||||
|
||||
#### Methods
|
||||
|
||||
##### `add(bot_data: Dict[str, Any]) -> Bot`
|
||||
|
||||
Adds a new bot to the database.
|
||||
|
||||
**Parameters:**
|
||||
- `bot_data`: Dictionary containing the bot's attributes (`name`, `strategy_name`, etc.)
|
||||
|
||||
**Returns:** The newly created `Bot` object.
|
||||
|
||||
```python
|
||||
from decimal import Decimal
|
||||
|
||||
bot_data = {
|
||||
"name": "MyTestBot",
|
||||
"strategy_name": "SimpleMACD",
|
||||
"symbol": "BTC-USDT",
|
||||
"timeframe": "1h",
|
||||
"status": "inactive",
|
||||
"virtual_balance": Decimal("10000"),
|
||||
}
|
||||
new_bot = db.bots.add(bot_data)
|
||||
print(f"Added bot with ID: {new_bot.id}")
|
||||
```
|
||||
|
||||
##### `get_by_id(bot_id: int) -> Optional[Bot]`
|
||||
|
||||
Retrieves a bot by its unique ID.
|
||||
|
||||
```python
|
||||
bot = db.bots.get_by_id(1)
|
||||
if bot:
|
||||
print(f"Found bot: {bot.name}")
|
||||
```
|
||||
|
||||
##### `get_by_name(name: str) -> Optional[Bot]`
|
||||
|
||||
Retrieves a bot by its unique name.
|
||||
|
||||
```python
|
||||
bot = db.bots.get_by_name("MyTestBot")
|
||||
if bot:
|
||||
print(f"Found bot with ID: {bot.id}")
|
||||
```
|
||||
|
||||
##### `update(bot_id: int, update_data: Dict[str, Any]) -> Optional[Bot]`
|
||||
|
||||
Updates an existing bot's attributes.
|
||||
|
||||
```python
|
||||
from datetime import datetime, timezone
|
||||
|
||||
update_payload = {"status": "active", "last_heartbeat": datetime.now(timezone.utc)}
|
||||
updated_bot = db.bots.update(1, update_payload)
|
||||
if updated_bot:
|
||||
print(f"Bot status updated to: {updated_bot.status}")
|
||||
```
|
||||
|
||||
##### `delete(bot_id: int) -> bool`
|
||||
|
||||
Deletes a bot from the database.
|
||||
|
||||
**Returns:** `True` if deletion was successful, `False` otherwise.
|
||||
|
||||
```python
|
||||
success = db.bots.delete(1)
|
||||
if success:
|
||||
print("Bot deleted successfully.")
|
||||
```
|
||||
|
||||
### RawTradeRepository
|
||||
|
||||
Repository for `raw_trades` table operations (raw WebSocket data).
|
||||
|
||||
247
tests/database/test_database_operations.py
Normal file
247
tests/database/test_database_operations.py
Normal file
@ -0,0 +1,247 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import asyncio
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
from database.operations import get_database_operations
|
||||
from database.models import Bot
|
||||
from data.common.data_types import OHLCVCandle
|
||||
from data.base_collector import MarketDataPoint, DataType
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test module."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def db_ops():
|
||||
"""Fixture to provide database operations."""
|
||||
# We will need to make sure the test database is configured and running
|
||||
operations = get_database_operations()
|
||||
yield operations
|
||||
# Teardown logic can be added here if needed, e.g., operations.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestBotRepository:
|
||||
"""Tests for the BotRepository."""
|
||||
|
||||
async def test_add_and_get_bot(self, db_ops):
|
||||
"""
|
||||
Test adding a new bot and retrieving it to verify basic repository functionality.
|
||||
"""
|
||||
# Define a new bot
|
||||
bot_name = "test_bot_01"
|
||||
new_bot = {
|
||||
"name": bot_name,
|
||||
"strategy_name": "test_strategy",
|
||||
"symbol": "BTC-USDT",
|
||||
"timeframe": "1h",
|
||||
"status": "inactive",
|
||||
"virtual_balance": Decimal("10000"),
|
||||
}
|
||||
|
||||
# Clean up any existing bot with the same name
|
||||
existing_bot = db_ops.bots.get_by_name(bot_name)
|
||||
if existing_bot:
|
||||
db_ops.bots.delete(existing_bot.id)
|
||||
|
||||
# Add the bot
|
||||
added_bot = db_ops.bots.add(new_bot)
|
||||
|
||||
# Assertions to check if the bot was added correctly
|
||||
assert added_bot is not None
|
||||
assert added_bot.id is not None
|
||||
assert added_bot.name == bot_name
|
||||
assert added_bot.status == "inactive"
|
||||
|
||||
# Retrieve the bot by ID
|
||||
retrieved_bot = db_ops.bots.get_by_id(added_bot.id)
|
||||
|
||||
# Assertions to check if the bot was retrieved correctly
|
||||
assert retrieved_bot is not None
|
||||
assert retrieved_bot.id == added_bot.id
|
||||
assert retrieved_bot.name == bot_name
|
||||
|
||||
# Clean up the created bot
|
||||
db_ops.bots.delete(added_bot.id)
|
||||
|
||||
# Verify it's deleted
|
||||
deleted_bot = db_ops.bots.get_by_id(added_bot.id)
|
||||
assert deleted_bot is None
|
||||
|
||||
async def test_update_bot(self, db_ops):
|
||||
"""Test updating an existing bot's status."""
|
||||
bot_name = "test_bot_for_update"
|
||||
bot_data = {
|
||||
"name": bot_name,
|
||||
"strategy_name": "test_strategy",
|
||||
"symbol": "ETH-USDT",
|
||||
"timeframe": "5m",
|
||||
"status": "active",
|
||||
}
|
||||
# Ensure clean state
|
||||
existing_bot = db_ops.bots.get_by_name(bot_name)
|
||||
if existing_bot:
|
||||
db_ops.bots.delete(existing_bot.id)
|
||||
|
||||
# Add a bot to update
|
||||
bot_to_update = db_ops.bots.add(bot_data)
|
||||
|
||||
# Update the bot's status
|
||||
update_data = {"status": "paused"}
|
||||
updated_bot = db_ops.bots.update(bot_to_update.id, update_data)
|
||||
|
||||
# Assertions
|
||||
assert updated_bot is not None
|
||||
assert updated_bot.status == "paused"
|
||||
|
||||
# Clean up
|
||||
db_ops.bots.delete(bot_to_update.id)
|
||||
|
||||
async def test_get_nonexistent_bot(self, db_ops):
|
||||
"""Test that fetching a non-existent bot returns None."""
|
||||
non_existent_bot = db_ops.bots.get_by_id(999999)
|
||||
assert non_existent_bot is None
|
||||
|
||||
async def test_delete_bot(self, db_ops):
|
||||
"""Test deleting a bot."""
|
||||
bot_name = "test_bot_for_delete"
|
||||
bot_data = {
|
||||
"name": bot_name,
|
||||
"strategy_name": "delete_strategy",
|
||||
"symbol": "LTC-USDT",
|
||||
"timeframe": "15m",
|
||||
}
|
||||
# Ensure clean state
|
||||
existing_bot = db_ops.bots.get_by_name(bot_name)
|
||||
if existing_bot:
|
||||
db_ops.bots.delete(existing_bot.id)
|
||||
|
||||
# Add a bot to delete
|
||||
bot_to_delete = db_ops.bots.add(bot_data)
|
||||
|
||||
# Delete the bot
|
||||
delete_result = db_ops.bots.delete(bot_to_delete.id)
|
||||
assert delete_result is True
|
||||
|
||||
# Verify it's gone
|
||||
retrieved_bot = db_ops.bots.get_by_id(bot_to_delete.id)
|
||||
assert retrieved_bot is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestMarketDataRepository:
|
||||
"""Tests for the MarketDataRepository."""
|
||||
|
||||
async def test_upsert_and_get_candle(self, db_ops):
|
||||
"""Test upserting and retrieving a candle."""
|
||||
# Use a fixed timestamp for deterministic tests
|
||||
base_time = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
candle = OHLCVCandle(
|
||||
start_time=base_time,
|
||||
end_time=base_time + timedelta(hours=1),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("51000"),
|
||||
low=Decimal("49000"),
|
||||
close=Decimal("50500"),
|
||||
volume=Decimal("100"),
|
||||
trade_count=10,
|
||||
timeframe="1h",
|
||||
symbol="BTC-USDT-TUG", # Unique symbol for test
|
||||
exchange="okx"
|
||||
)
|
||||
|
||||
# Upsert the candle
|
||||
success = db_ops.market_data.upsert_candle(candle)
|
||||
assert success is True
|
||||
|
||||
# Retrieve the candle using a time range
|
||||
start_time = base_time + timedelta(hours=1)
|
||||
end_time = base_time + timedelta(hours=1)
|
||||
|
||||
retrieved_candles = db_ops.market_data.get_candles(
|
||||
symbol="BTC-USDT-TUG",
|
||||
timeframe="1h",
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
assert len(retrieved_candles) >= 1
|
||||
retrieved_candle = retrieved_candles[0]
|
||||
assert retrieved_candle["symbol"] == "BTC-USDT-TUG"
|
||||
assert retrieved_candle["close"] == candle.close
|
||||
assert retrieved_candle["timestamp"] == candle.end_time
|
||||
|
||||
async def test_get_latest_candle(self, db_ops):
|
||||
"""Test fetching the latest candle."""
|
||||
base_time = datetime(2023, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
|
||||
symbol = "ETH-USDT-TGLC" # Unique symbol for test
|
||||
|
||||
# Insert a few candles with increasing timestamps
|
||||
for i in range(3):
|
||||
candle = OHLCVCandle(
|
||||
start_time=base_time + timedelta(minutes=i*5),
|
||||
end_time=base_time + timedelta(minutes=(i+1)*5),
|
||||
open=Decimal("1200") + i,
|
||||
high=Decimal("1210") + i,
|
||||
low=Decimal("1190") + i,
|
||||
close=Decimal("1205") + i,
|
||||
volume=Decimal("1000"),
|
||||
trade_count=20+i,
|
||||
timeframe="5m",
|
||||
symbol=symbol,
|
||||
exchange="okx"
|
||||
)
|
||||
db_ops.market_data.upsert_candle(candle)
|
||||
|
||||
latest_candle = db_ops.market_data.get_latest_candle(
|
||||
symbol=symbol,
|
||||
timeframe="5m"
|
||||
)
|
||||
|
||||
assert latest_candle is not None
|
||||
assert latest_candle["symbol"] == symbol
|
||||
assert latest_candle["timeframe"] == "5m"
|
||||
assert latest_candle["close"] == Decimal("1207")
|
||||
assert latest_candle["timestamp"] == base_time + timedelta(minutes=15)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRawTradeRepository:
|
||||
"""Tests for the RawTradeRepository."""
|
||||
|
||||
async def test_insert_and_get_raw_trade(self, db_ops):
|
||||
"""Test inserting and retrieving a raw trade data point."""
|
||||
base_time = datetime(2023, 1, 1, 14, 0, 0, tzinfo=timezone.utc)
|
||||
symbol = "XRP-USDT-TIRT" # Unique symbol for test
|
||||
|
||||
data_point = MarketDataPoint(
|
||||
symbol=symbol,
|
||||
data_type=DataType.TRADE,
|
||||
data={"price": "2.5", "qty": "100"},
|
||||
timestamp=base_time,
|
||||
exchange="okx"
|
||||
)
|
||||
|
||||
# Insert raw data
|
||||
success = db_ops.raw_trades.insert_market_data_point(data_point)
|
||||
assert success is True
|
||||
|
||||
# Retrieve raw data
|
||||
start_time = base_time - timedelta(seconds=1)
|
||||
end_time = base_time + timedelta(seconds=1)
|
||||
|
||||
raw_trades = db_ops.raw_trades.get_raw_trades(
|
||||
symbol=symbol,
|
||||
data_type=DataType.TRADE.value,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
assert len(raw_trades) >= 1
|
||||
retrieved_trade = raw_trades[0]
|
||||
assert retrieved_trade["symbol"] == symbol
|
||||
assert retrieved_trade["data_type"] == DataType.TRADE.value
|
||||
assert retrieved_trade["raw_data"] == data_point.data
|
||||
Loading…
x
Reference in New Issue
Block a user