247 lines
8.4 KiB
Python
247 lines
8.4 KiB
Python
|
|
import pytest
|
||
|
|
import pytest_asyncio
|
||
|
|
import asyncio
|
||
|
|
from datetime import datetime, timezone, timedelta
|
||
|
|
from decimal import Decimal
|
||
|
|
|
||
|
|
from database.operations import get_database_operations
|
||
|
|
from database.models import Bot
|
||
|
|
from data.common.data_types import OHLCVCandle
|
||
|
|
from data.base_collector import MarketDataPoint, DataType
|
||
|
|
|
||
|
|
@pytest.fixture(scope="module")
|
||
|
|
def event_loop():
|
||
|
|
"""Create an instance of the default event loop for each test module."""
|
||
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||
|
|
yield loop
|
||
|
|
loop.close()
|
||
|
|
|
||
|
|
@pytest_asyncio.fixture(scope="module")
|
||
|
|
async def db_ops():
|
||
|
|
"""Fixture to provide database operations."""
|
||
|
|
# We will need to make sure the test database is configured and running
|
||
|
|
operations = get_database_operations()
|
||
|
|
yield operations
|
||
|
|
# Teardown logic can be added here if needed, e.g., operations.close()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
class TestBotRepository:
|
||
|
|
"""Tests for the BotRepository."""
|
||
|
|
|
||
|
|
async def test_add_and_get_bot(self, db_ops):
|
||
|
|
"""
|
||
|
|
Test adding a new bot and retrieving it to verify basic repository functionality.
|
||
|
|
"""
|
||
|
|
# Define a new bot
|
||
|
|
bot_name = "test_bot_01"
|
||
|
|
new_bot = {
|
||
|
|
"name": bot_name,
|
||
|
|
"strategy_name": "test_strategy",
|
||
|
|
"symbol": "BTC-USDT",
|
||
|
|
"timeframe": "1h",
|
||
|
|
"status": "inactive",
|
||
|
|
"virtual_balance": Decimal("10000"),
|
||
|
|
}
|
||
|
|
|
||
|
|
# Clean up any existing bot with the same name
|
||
|
|
existing_bot = db_ops.bots.get_by_name(bot_name)
|
||
|
|
if existing_bot:
|
||
|
|
db_ops.bots.delete(existing_bot.id)
|
||
|
|
|
||
|
|
# Add the bot
|
||
|
|
added_bot = db_ops.bots.add(new_bot)
|
||
|
|
|
||
|
|
# Assertions to check if the bot was added correctly
|
||
|
|
assert added_bot is not None
|
||
|
|
assert added_bot.id is not None
|
||
|
|
assert added_bot.name == bot_name
|
||
|
|
assert added_bot.status == "inactive"
|
||
|
|
|
||
|
|
# Retrieve the bot by ID
|
||
|
|
retrieved_bot = db_ops.bots.get_by_id(added_bot.id)
|
||
|
|
|
||
|
|
# Assertions to check if the bot was retrieved correctly
|
||
|
|
assert retrieved_bot is not None
|
||
|
|
assert retrieved_bot.id == added_bot.id
|
||
|
|
assert retrieved_bot.name == bot_name
|
||
|
|
|
||
|
|
# Clean up the created bot
|
||
|
|
db_ops.bots.delete(added_bot.id)
|
||
|
|
|
||
|
|
# Verify it's deleted
|
||
|
|
deleted_bot = db_ops.bots.get_by_id(added_bot.id)
|
||
|
|
assert deleted_bot is None
|
||
|
|
|
||
|
|
async def test_update_bot(self, db_ops):
|
||
|
|
"""Test updating an existing bot's status."""
|
||
|
|
bot_name = "test_bot_for_update"
|
||
|
|
bot_data = {
|
||
|
|
"name": bot_name,
|
||
|
|
"strategy_name": "test_strategy",
|
||
|
|
"symbol": "ETH-USDT",
|
||
|
|
"timeframe": "5m",
|
||
|
|
"status": "active",
|
||
|
|
}
|
||
|
|
# Ensure clean state
|
||
|
|
existing_bot = db_ops.bots.get_by_name(bot_name)
|
||
|
|
if existing_bot:
|
||
|
|
db_ops.bots.delete(existing_bot.id)
|
||
|
|
|
||
|
|
# Add a bot to update
|
||
|
|
bot_to_update = db_ops.bots.add(bot_data)
|
||
|
|
|
||
|
|
# Update the bot's status
|
||
|
|
update_data = {"status": "paused"}
|
||
|
|
updated_bot = db_ops.bots.update(bot_to_update.id, update_data)
|
||
|
|
|
||
|
|
# Assertions
|
||
|
|
assert updated_bot is not None
|
||
|
|
assert updated_bot.status == "paused"
|
||
|
|
|
||
|
|
# Clean up
|
||
|
|
db_ops.bots.delete(bot_to_update.id)
|
||
|
|
|
||
|
|
async def test_get_nonexistent_bot(self, db_ops):
|
||
|
|
"""Test that fetching a non-existent bot returns None."""
|
||
|
|
non_existent_bot = db_ops.bots.get_by_id(999999)
|
||
|
|
assert non_existent_bot is None
|
||
|
|
|
||
|
|
async def test_delete_bot(self, db_ops):
|
||
|
|
"""Test deleting a bot."""
|
||
|
|
bot_name = "test_bot_for_delete"
|
||
|
|
bot_data = {
|
||
|
|
"name": bot_name,
|
||
|
|
"strategy_name": "delete_strategy",
|
||
|
|
"symbol": "LTC-USDT",
|
||
|
|
"timeframe": "15m",
|
||
|
|
}
|
||
|
|
# Ensure clean state
|
||
|
|
existing_bot = db_ops.bots.get_by_name(bot_name)
|
||
|
|
if existing_bot:
|
||
|
|
db_ops.bots.delete(existing_bot.id)
|
||
|
|
|
||
|
|
# Add a bot to delete
|
||
|
|
bot_to_delete = db_ops.bots.add(bot_data)
|
||
|
|
|
||
|
|
# Delete the bot
|
||
|
|
delete_result = db_ops.bots.delete(bot_to_delete.id)
|
||
|
|
assert delete_result is True
|
||
|
|
|
||
|
|
# Verify it's gone
|
||
|
|
retrieved_bot = db_ops.bots.get_by_id(bot_to_delete.id)
|
||
|
|
assert retrieved_bot is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
class TestMarketDataRepository:
|
||
|
|
"""Tests for the MarketDataRepository."""
|
||
|
|
|
||
|
|
async def test_upsert_and_get_candle(self, db_ops):
|
||
|
|
"""Test upserting and retrieving a candle."""
|
||
|
|
# Use a fixed timestamp for deterministic tests
|
||
|
|
base_time = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||
|
|
|
||
|
|
candle = OHLCVCandle(
|
||
|
|
start_time=base_time,
|
||
|
|
end_time=base_time + timedelta(hours=1),
|
||
|
|
open=Decimal("50000"),
|
||
|
|
high=Decimal("51000"),
|
||
|
|
low=Decimal("49000"),
|
||
|
|
close=Decimal("50500"),
|
||
|
|
volume=Decimal("100"),
|
||
|
|
trade_count=10,
|
||
|
|
timeframe="1h",
|
||
|
|
symbol="BTC-USDT-TUG", # Unique symbol for test
|
||
|
|
exchange="okx"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Upsert the candle
|
||
|
|
success = db_ops.market_data.upsert_candle(candle)
|
||
|
|
assert success is True
|
||
|
|
|
||
|
|
# Retrieve the candle using a time range
|
||
|
|
start_time = base_time + timedelta(hours=1)
|
||
|
|
end_time = base_time + timedelta(hours=1)
|
||
|
|
|
||
|
|
retrieved_candles = db_ops.market_data.get_candles(
|
||
|
|
symbol="BTC-USDT-TUG",
|
||
|
|
timeframe="1h",
|
||
|
|
start_time=start_time,
|
||
|
|
end_time=end_time
|
||
|
|
)
|
||
|
|
|
||
|
|
assert len(retrieved_candles) >= 1
|
||
|
|
retrieved_candle = retrieved_candles[0]
|
||
|
|
assert retrieved_candle["symbol"] == "BTC-USDT-TUG"
|
||
|
|
assert retrieved_candle["close"] == candle.close
|
||
|
|
assert retrieved_candle["timestamp"] == candle.end_time
|
||
|
|
|
||
|
|
async def test_get_latest_candle(self, db_ops):
|
||
|
|
"""Test fetching the latest candle."""
|
||
|
|
base_time = datetime(2023, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
|
||
|
|
symbol = "ETH-USDT-TGLC" # Unique symbol for test
|
||
|
|
|
||
|
|
# Insert a few candles with increasing timestamps
|
||
|
|
for i in range(3):
|
||
|
|
candle = OHLCVCandle(
|
||
|
|
start_time=base_time + timedelta(minutes=i*5),
|
||
|
|
end_time=base_time + timedelta(minutes=(i+1)*5),
|
||
|
|
open=Decimal("1200") + i,
|
||
|
|
high=Decimal("1210") + i,
|
||
|
|
low=Decimal("1190") + i,
|
||
|
|
close=Decimal("1205") + i,
|
||
|
|
volume=Decimal("1000"),
|
||
|
|
trade_count=20+i,
|
||
|
|
timeframe="5m",
|
||
|
|
symbol=symbol,
|
||
|
|
exchange="okx"
|
||
|
|
)
|
||
|
|
db_ops.market_data.upsert_candle(candle)
|
||
|
|
|
||
|
|
latest_candle = db_ops.market_data.get_latest_candle(
|
||
|
|
symbol=symbol,
|
||
|
|
timeframe="5m"
|
||
|
|
)
|
||
|
|
|
||
|
|
assert latest_candle is not None
|
||
|
|
assert latest_candle["symbol"] == symbol
|
||
|
|
assert latest_candle["timeframe"] == "5m"
|
||
|
|
assert latest_candle["close"] == Decimal("1207")
|
||
|
|
assert latest_candle["timestamp"] == base_time + timedelta(minutes=15)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
class TestRawTradeRepository:
|
||
|
|
"""Tests for the RawTradeRepository."""
|
||
|
|
|
||
|
|
async def test_insert_and_get_raw_trade(self, db_ops):
|
||
|
|
"""Test inserting and retrieving a raw trade data point."""
|
||
|
|
base_time = datetime(2023, 1, 1, 14, 0, 0, tzinfo=timezone.utc)
|
||
|
|
symbol = "XRP-USDT-TIRT" # Unique symbol for test
|
||
|
|
|
||
|
|
data_point = MarketDataPoint(
|
||
|
|
symbol=symbol,
|
||
|
|
data_type=DataType.TRADE,
|
||
|
|
data={"price": "2.5", "qty": "100"},
|
||
|
|
timestamp=base_time,
|
||
|
|
exchange="okx"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Insert raw data
|
||
|
|
success = db_ops.raw_trades.insert_market_data_point(data_point)
|
||
|
|
assert success is True
|
||
|
|
|
||
|
|
# Retrieve raw data
|
||
|
|
start_time = base_time - timedelta(seconds=1)
|
||
|
|
end_time = base_time + timedelta(seconds=1)
|
||
|
|
|
||
|
|
raw_trades = db_ops.raw_trades.get_raw_trades(
|
||
|
|
symbol=symbol,
|
||
|
|
data_type=DataType.TRADE.value,
|
||
|
|
start_time=start_time,
|
||
|
|
end_time=end_time
|
||
|
|
)
|
||
|
|
|
||
|
|
assert len(raw_trades) >= 1
|
||
|
|
retrieved_trade = raw_trades[0]
|
||
|
|
assert retrieved_trade["symbol"] == symbol
|
||
|
|
assert retrieved_trade["data_type"] == DataType.TRADE.value
|
||
|
|
assert retrieved_trade["raw_data"] == data_point.data
|