TCPDashboard/tests/database/test_database_operations.py
2025-06-12 13:43:05 +08:00

247 lines
8.4 KiB
Python

import pytest
import pytest_asyncio
import asyncio
from datetime import datetime, timezone, timedelta
from decimal import Decimal
from database.operations import get_database_operations
from database.models import Bot
from data.common.data_types import OHLCVCandle
from data.collector.base_collector import MarketDataPoint, DataType
@pytest.fixture(scope="module")
def event_loop():
"""Create an instance of the default event loop for each test module."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="module")
async def db_ops():
"""Fixture to provide database operations."""
# We will need to make sure the test database is configured and running
operations = get_database_operations()
yield operations
# Teardown logic can be added here if needed, e.g., operations.close()
@pytest.mark.asyncio
class TestBotRepository:
"""Tests for the BotRepository."""
async def test_add_and_get_bot(self, db_ops):
"""
Test adding a new bot and retrieving it to verify basic repository functionality.
"""
# Define a new bot
bot_name = "test_bot_01"
new_bot = {
"name": bot_name,
"strategy_name": "test_strategy",
"symbol": "BTC-USDT",
"timeframe": "1h",
"status": "inactive",
"virtual_balance": Decimal("10000"),
}
# Clean up any existing bot with the same name
existing_bot = db_ops.bots.get_by_name(bot_name)
if existing_bot:
db_ops.bots.delete(existing_bot.id)
# Add the bot
added_bot = db_ops.bots.add(new_bot)
# Assertions to check if the bot was added correctly
assert added_bot is not None
assert added_bot.id is not None
assert added_bot.name == bot_name
assert added_bot.status == "inactive"
# Retrieve the bot by ID
retrieved_bot = db_ops.bots.get_by_id(added_bot.id)
# Assertions to check if the bot was retrieved correctly
assert retrieved_bot is not None
assert retrieved_bot.id == added_bot.id
assert retrieved_bot.name == bot_name
# Clean up the created bot
db_ops.bots.delete(added_bot.id)
# Verify it's deleted
deleted_bot = db_ops.bots.get_by_id(added_bot.id)
assert deleted_bot is None
async def test_update_bot(self, db_ops):
"""Test updating an existing bot's status."""
bot_name = "test_bot_for_update"
bot_data = {
"name": bot_name,
"strategy_name": "test_strategy",
"symbol": "ETH-USDT",
"timeframe": "5m",
"status": "active",
}
# Ensure clean state
existing_bot = db_ops.bots.get_by_name(bot_name)
if existing_bot:
db_ops.bots.delete(existing_bot.id)
# Add a bot to update
bot_to_update = db_ops.bots.add(bot_data)
# Update the bot's status
update_data = {"status": "paused"}
updated_bot = db_ops.bots.update(bot_to_update.id, update_data)
# Assertions
assert updated_bot is not None
assert updated_bot.status == "paused"
# Clean up
db_ops.bots.delete(bot_to_update.id)
async def test_get_nonexistent_bot(self, db_ops):
"""Test that fetching a non-existent bot returns None."""
non_existent_bot = db_ops.bots.get_by_id(999999)
assert non_existent_bot is None
async def test_delete_bot(self, db_ops):
"""Test deleting a bot."""
bot_name = "test_bot_for_delete"
bot_data = {
"name": bot_name,
"strategy_name": "delete_strategy",
"symbol": "LTC-USDT",
"timeframe": "15m",
}
# Ensure clean state
existing_bot = db_ops.bots.get_by_name(bot_name)
if existing_bot:
db_ops.bots.delete(existing_bot.id)
# Add a bot to delete
bot_to_delete = db_ops.bots.add(bot_data)
# Delete the bot
delete_result = db_ops.bots.delete(bot_to_delete.id)
assert delete_result is True
# Verify it's gone
retrieved_bot = db_ops.bots.get_by_id(bot_to_delete.id)
assert retrieved_bot is None
@pytest.mark.asyncio
class TestMarketDataRepository:
"""Tests for the MarketDataRepository."""
async def test_upsert_and_get_candle(self, db_ops):
"""Test upserting and retrieving a candle."""
# Use a fixed timestamp for deterministic tests
base_time = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
candle = OHLCVCandle(
start_time=base_time,
end_time=base_time + timedelta(hours=1),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
trade_count=10,
timeframe="1h",
symbol="BTC-USDT-TUG", # Unique symbol for test
exchange="okx"
)
# Upsert the candle
success = db_ops.market_data.upsert_candle(candle)
assert success is True
# Retrieve the candle using a time range
start_time = base_time + timedelta(hours=1)
end_time = base_time + timedelta(hours=1)
retrieved_candles = db_ops.market_data.get_candles(
symbol="BTC-USDT-TUG",
timeframe="1h",
start_time=start_time,
end_time=end_time
)
assert len(retrieved_candles) >= 1
retrieved_candle = retrieved_candles[0]
assert retrieved_candle["symbol"] == "BTC-USDT-TUG"
assert retrieved_candle["close"] == candle.close
assert retrieved_candle["timestamp"] == candle.end_time
async def test_get_latest_candle(self, db_ops):
"""Test fetching the latest candle."""
base_time = datetime(2023, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
symbol = "ETH-USDT-TGLC" # Unique symbol for test
# Insert a few candles with increasing timestamps
for i in range(3):
candle = OHLCVCandle(
start_time=base_time + timedelta(minutes=i*5),
end_time=base_time + timedelta(minutes=(i+1)*5),
open=Decimal("1200") + i,
high=Decimal("1210") + i,
low=Decimal("1190") + i,
close=Decimal("1205") + i,
volume=Decimal("1000"),
trade_count=20+i,
timeframe="5m",
symbol=symbol,
exchange="okx"
)
db_ops.market_data.upsert_candle(candle)
latest_candle = db_ops.market_data.get_latest_candle(
symbol=symbol,
timeframe="5m"
)
assert latest_candle is not None
assert latest_candle["symbol"] == symbol
assert latest_candle["timeframe"] == "5m"
assert latest_candle["close"] == Decimal("1207")
assert latest_candle["timestamp"] == base_time + timedelta(minutes=15)
@pytest.mark.asyncio
class TestRawTradeRepository:
"""Tests for the RawTradeRepository."""
async def test_insert_and_get_raw_trade(self, db_ops):
"""Test inserting and retrieving a raw trade data point."""
base_time = datetime(2023, 1, 1, 14, 0, 0, tzinfo=timezone.utc)
symbol = "XRP-USDT-TIRT" # Unique symbol for test
data_point = MarketDataPoint(
symbol=symbol,
data_type=DataType.TRADE,
data={"price": "2.5", "qty": "100"},
timestamp=base_time,
exchange="okx"
)
# Insert raw data
success = db_ops.raw_trades.insert_market_data_point(data_point)
assert success is True
# Retrieve raw data
start_time = base_time - timedelta(seconds=1)
end_time = base_time + timedelta(seconds=1)
raw_trades = db_ops.raw_trades.get_raw_trades(
symbol=symbol,
data_type=DataType.TRADE.value,
start_time=start_time,
end_time=end_time
)
assert len(raw_trades) >= 1
retrieved_trade = raw_trades[0]
assert retrieved_trade["symbol"] == symbol
assert retrieved_trade["data_type"] == DataType.TRADE.value
assert retrieved_trade["raw_data"] == data_point.data