TCPDashboard/tests/database/test_database_operations.py
Vasily.onl 028371a0e1 Refactor database operations and enhance repository structure
- Introduced a modular repository structure by creating separate repository classes for `Bot`, `MarketData`, and `RawTrade`, improving code organization and maintainability.
- Updated the `DatabaseOperations` class to utilize the new repository classes, enhancing the abstraction of database interactions.
- Refactored the `.env` file to update database connection parameters and add new logging and health monitoring configurations.
- Modified the `okx_config.json` to change default timeframes for trading pairs, aligning with updated requirements.
- Added comprehensive unit tests for the new repository classes, ensuring robust functionality and reliability.

These changes improve the overall architecture of the database layer, making it more scalable and easier to manage.
2025-06-06 21:54:45 +08:00

247 lines
8.4 KiB
Python

import pytest
import pytest_asyncio
import asyncio
from datetime import datetime, timezone, timedelta
from decimal import Decimal
from database.operations import get_database_operations
from database.models import Bot
from data.common.data_types import OHLCVCandle
from data.base_collector import MarketDataPoint, DataType
@pytest.fixture(scope="module")
def event_loop():
"""Create an instance of the default event loop for each test module."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="module")
async def db_ops():
"""Fixture to provide database operations."""
# We will need to make sure the test database is configured and running
operations = get_database_operations()
yield operations
# Teardown logic can be added here if needed, e.g., operations.close()
@pytest.mark.asyncio
class TestBotRepository:
"""Tests for the BotRepository."""
async def test_add_and_get_bot(self, db_ops):
"""
Test adding a new bot and retrieving it to verify basic repository functionality.
"""
# Define a new bot
bot_name = "test_bot_01"
new_bot = {
"name": bot_name,
"strategy_name": "test_strategy",
"symbol": "BTC-USDT",
"timeframe": "1h",
"status": "inactive",
"virtual_balance": Decimal("10000"),
}
# Clean up any existing bot with the same name
existing_bot = db_ops.bots.get_by_name(bot_name)
if existing_bot:
db_ops.bots.delete(existing_bot.id)
# Add the bot
added_bot = db_ops.bots.add(new_bot)
# Assertions to check if the bot was added correctly
assert added_bot is not None
assert added_bot.id is not None
assert added_bot.name == bot_name
assert added_bot.status == "inactive"
# Retrieve the bot by ID
retrieved_bot = db_ops.bots.get_by_id(added_bot.id)
# Assertions to check if the bot was retrieved correctly
assert retrieved_bot is not None
assert retrieved_bot.id == added_bot.id
assert retrieved_bot.name == bot_name
# Clean up the created bot
db_ops.bots.delete(added_bot.id)
# Verify it's deleted
deleted_bot = db_ops.bots.get_by_id(added_bot.id)
assert deleted_bot is None
async def test_update_bot(self, db_ops):
"""Test updating an existing bot's status."""
bot_name = "test_bot_for_update"
bot_data = {
"name": bot_name,
"strategy_name": "test_strategy",
"symbol": "ETH-USDT",
"timeframe": "5m",
"status": "active",
}
# Ensure clean state
existing_bot = db_ops.bots.get_by_name(bot_name)
if existing_bot:
db_ops.bots.delete(existing_bot.id)
# Add a bot to update
bot_to_update = db_ops.bots.add(bot_data)
# Update the bot's status
update_data = {"status": "paused"}
updated_bot = db_ops.bots.update(bot_to_update.id, update_data)
# Assertions
assert updated_bot is not None
assert updated_bot.status == "paused"
# Clean up
db_ops.bots.delete(bot_to_update.id)
async def test_get_nonexistent_bot(self, db_ops):
"""Test that fetching a non-existent bot returns None."""
non_existent_bot = db_ops.bots.get_by_id(999999)
assert non_existent_bot is None
async def test_delete_bot(self, db_ops):
"""Test deleting a bot."""
bot_name = "test_bot_for_delete"
bot_data = {
"name": bot_name,
"strategy_name": "delete_strategy",
"symbol": "LTC-USDT",
"timeframe": "15m",
}
# Ensure clean state
existing_bot = db_ops.bots.get_by_name(bot_name)
if existing_bot:
db_ops.bots.delete(existing_bot.id)
# Add a bot to delete
bot_to_delete = db_ops.bots.add(bot_data)
# Delete the bot
delete_result = db_ops.bots.delete(bot_to_delete.id)
assert delete_result is True
# Verify it's gone
retrieved_bot = db_ops.bots.get_by_id(bot_to_delete.id)
assert retrieved_bot is None
@pytest.mark.asyncio
class TestMarketDataRepository:
"""Tests for the MarketDataRepository."""
async def test_upsert_and_get_candle(self, db_ops):
"""Test upserting and retrieving a candle."""
# Use a fixed timestamp for deterministic tests
base_time = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
candle = OHLCVCandle(
start_time=base_time,
end_time=base_time + timedelta(hours=1),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
trade_count=10,
timeframe="1h",
symbol="BTC-USDT-TUG", # Unique symbol for test
exchange="okx"
)
# Upsert the candle
success = db_ops.market_data.upsert_candle(candle)
assert success is True
# Retrieve the candle using a time range
start_time = base_time + timedelta(hours=1)
end_time = base_time + timedelta(hours=1)
retrieved_candles = db_ops.market_data.get_candles(
symbol="BTC-USDT-TUG",
timeframe="1h",
start_time=start_time,
end_time=end_time
)
assert len(retrieved_candles) >= 1
retrieved_candle = retrieved_candles[0]
assert retrieved_candle["symbol"] == "BTC-USDT-TUG"
assert retrieved_candle["close"] == candle.close
assert retrieved_candle["timestamp"] == candle.end_time
async def test_get_latest_candle(self, db_ops):
"""Test fetching the latest candle."""
base_time = datetime(2023, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
symbol = "ETH-USDT-TGLC" # Unique symbol for test
# Insert a few candles with increasing timestamps
for i in range(3):
candle = OHLCVCandle(
start_time=base_time + timedelta(minutes=i*5),
end_time=base_time + timedelta(minutes=(i+1)*5),
open=Decimal("1200") + i,
high=Decimal("1210") + i,
low=Decimal("1190") + i,
close=Decimal("1205") + i,
volume=Decimal("1000"),
trade_count=20+i,
timeframe="5m",
symbol=symbol,
exchange="okx"
)
db_ops.market_data.upsert_candle(candle)
latest_candle = db_ops.market_data.get_latest_candle(
symbol=symbol,
timeframe="5m"
)
assert latest_candle is not None
assert latest_candle["symbol"] == symbol
assert latest_candle["timeframe"] == "5m"
assert latest_candle["close"] == Decimal("1207")
assert latest_candle["timestamp"] == base_time + timedelta(minutes=15)
@pytest.mark.asyncio
class TestRawTradeRepository:
"""Tests for the RawTradeRepository."""
async def test_insert_and_get_raw_trade(self, db_ops):
"""Test inserting and retrieving a raw trade data point."""
base_time = datetime(2023, 1, 1, 14, 0, 0, tzinfo=timezone.utc)
symbol = "XRP-USDT-TIRT" # Unique symbol for test
data_point = MarketDataPoint(
symbol=symbol,
data_type=DataType.TRADE,
data={"price": "2.5", "qty": "100"},
timestamp=base_time,
exchange="okx"
)
# Insert raw data
success = db_ops.raw_trades.insert_market_data_point(data_point)
assert success is True
# Retrieve raw data
start_time = base_time - timedelta(seconds=1)
end_time = base_time + timedelta(seconds=1)
raw_trades = db_ops.raw_trades.get_raw_trades(
symbol=symbol,
data_type=DataType.TRADE.value,
start_time=start_time,
end_time=end_time
)
assert len(raw_trades) >= 1
retrieved_trade = raw_trades[0]
assert retrieved_trade["symbol"] == symbol
assert retrieved_trade["data_type"] == DataType.TRADE.value
assert retrieved_trade["raw_data"] == data_point.data