- Introduced a modular repository structure by creating separate repository classes for `Bot`, `MarketData`, and `RawTrade`, improving code organization and maintainability. - Updated the `DatabaseOperations` class to utilize the new repository classes, enhancing the abstraction of database interactions. - Refactored the `.env` file to update database connection parameters and add new logging and health monitoring configurations. - Modified the `okx_config.json` to change default timeframes for trading pairs, aligning with updated requirements. - Added comprehensive unit tests for the new repository classes, ensuring robust functionality and reliability. These changes improve the overall architecture of the database layer, making it more scalable and easier to manage.
247 lines
8.4 KiB
Python
247 lines
8.4 KiB
Python
import pytest
|
|
import pytest_asyncio
|
|
import asyncio
|
|
from datetime import datetime, timezone, timedelta
|
|
from decimal import Decimal
|
|
|
|
from database.operations import get_database_operations
|
|
from database.models import Bot
|
|
from data.common.data_types import OHLCVCandle
|
|
from data.base_collector import MarketDataPoint, DataType
|
|
|
|
@pytest.fixture(scope="module")
|
|
def event_loop():
|
|
"""Create an instance of the default event loop for each test module."""
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
@pytest_asyncio.fixture(scope="module")
|
|
async def db_ops():
|
|
"""Fixture to provide database operations."""
|
|
# We will need to make sure the test database is configured and running
|
|
operations = get_database_operations()
|
|
yield operations
|
|
# Teardown logic can be added here if needed, e.g., operations.close()
|
|
|
|
@pytest.mark.asyncio
|
|
class TestBotRepository:
|
|
"""Tests for the BotRepository."""
|
|
|
|
async def test_add_and_get_bot(self, db_ops):
|
|
"""
|
|
Test adding a new bot and retrieving it to verify basic repository functionality.
|
|
"""
|
|
# Define a new bot
|
|
bot_name = "test_bot_01"
|
|
new_bot = {
|
|
"name": bot_name,
|
|
"strategy_name": "test_strategy",
|
|
"symbol": "BTC-USDT",
|
|
"timeframe": "1h",
|
|
"status": "inactive",
|
|
"virtual_balance": Decimal("10000"),
|
|
}
|
|
|
|
# Clean up any existing bot with the same name
|
|
existing_bot = db_ops.bots.get_by_name(bot_name)
|
|
if existing_bot:
|
|
db_ops.bots.delete(existing_bot.id)
|
|
|
|
# Add the bot
|
|
added_bot = db_ops.bots.add(new_bot)
|
|
|
|
# Assertions to check if the bot was added correctly
|
|
assert added_bot is not None
|
|
assert added_bot.id is not None
|
|
assert added_bot.name == bot_name
|
|
assert added_bot.status == "inactive"
|
|
|
|
# Retrieve the bot by ID
|
|
retrieved_bot = db_ops.bots.get_by_id(added_bot.id)
|
|
|
|
# Assertions to check if the bot was retrieved correctly
|
|
assert retrieved_bot is not None
|
|
assert retrieved_bot.id == added_bot.id
|
|
assert retrieved_bot.name == bot_name
|
|
|
|
# Clean up the created bot
|
|
db_ops.bots.delete(added_bot.id)
|
|
|
|
# Verify it's deleted
|
|
deleted_bot = db_ops.bots.get_by_id(added_bot.id)
|
|
assert deleted_bot is None
|
|
|
|
async def test_update_bot(self, db_ops):
|
|
"""Test updating an existing bot's status."""
|
|
bot_name = "test_bot_for_update"
|
|
bot_data = {
|
|
"name": bot_name,
|
|
"strategy_name": "test_strategy",
|
|
"symbol": "ETH-USDT",
|
|
"timeframe": "5m",
|
|
"status": "active",
|
|
}
|
|
# Ensure clean state
|
|
existing_bot = db_ops.bots.get_by_name(bot_name)
|
|
if existing_bot:
|
|
db_ops.bots.delete(existing_bot.id)
|
|
|
|
# Add a bot to update
|
|
bot_to_update = db_ops.bots.add(bot_data)
|
|
|
|
# Update the bot's status
|
|
update_data = {"status": "paused"}
|
|
updated_bot = db_ops.bots.update(bot_to_update.id, update_data)
|
|
|
|
# Assertions
|
|
assert updated_bot is not None
|
|
assert updated_bot.status == "paused"
|
|
|
|
# Clean up
|
|
db_ops.bots.delete(bot_to_update.id)
|
|
|
|
async def test_get_nonexistent_bot(self, db_ops):
|
|
"""Test that fetching a non-existent bot returns None."""
|
|
non_existent_bot = db_ops.bots.get_by_id(999999)
|
|
assert non_existent_bot is None
|
|
|
|
async def test_delete_bot(self, db_ops):
|
|
"""Test deleting a bot."""
|
|
bot_name = "test_bot_for_delete"
|
|
bot_data = {
|
|
"name": bot_name,
|
|
"strategy_name": "delete_strategy",
|
|
"symbol": "LTC-USDT",
|
|
"timeframe": "15m",
|
|
}
|
|
# Ensure clean state
|
|
existing_bot = db_ops.bots.get_by_name(bot_name)
|
|
if existing_bot:
|
|
db_ops.bots.delete(existing_bot.id)
|
|
|
|
# Add a bot to delete
|
|
bot_to_delete = db_ops.bots.add(bot_data)
|
|
|
|
# Delete the bot
|
|
delete_result = db_ops.bots.delete(bot_to_delete.id)
|
|
assert delete_result is True
|
|
|
|
# Verify it's gone
|
|
retrieved_bot = db_ops.bots.get_by_id(bot_to_delete.id)
|
|
assert retrieved_bot is None
|
|
|
|
@pytest.mark.asyncio
|
|
class TestMarketDataRepository:
|
|
"""Tests for the MarketDataRepository."""
|
|
|
|
async def test_upsert_and_get_candle(self, db_ops):
|
|
"""Test upserting and retrieving a candle."""
|
|
# Use a fixed timestamp for deterministic tests
|
|
base_time = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
|
|
|
candle = OHLCVCandle(
|
|
start_time=base_time,
|
|
end_time=base_time + timedelta(hours=1),
|
|
open=Decimal("50000"),
|
|
high=Decimal("51000"),
|
|
low=Decimal("49000"),
|
|
close=Decimal("50500"),
|
|
volume=Decimal("100"),
|
|
trade_count=10,
|
|
timeframe="1h",
|
|
symbol="BTC-USDT-TUG", # Unique symbol for test
|
|
exchange="okx"
|
|
)
|
|
|
|
# Upsert the candle
|
|
success = db_ops.market_data.upsert_candle(candle)
|
|
assert success is True
|
|
|
|
# Retrieve the candle using a time range
|
|
start_time = base_time + timedelta(hours=1)
|
|
end_time = base_time + timedelta(hours=1)
|
|
|
|
retrieved_candles = db_ops.market_data.get_candles(
|
|
symbol="BTC-USDT-TUG",
|
|
timeframe="1h",
|
|
start_time=start_time,
|
|
end_time=end_time
|
|
)
|
|
|
|
assert len(retrieved_candles) >= 1
|
|
retrieved_candle = retrieved_candles[0]
|
|
assert retrieved_candle["symbol"] == "BTC-USDT-TUG"
|
|
assert retrieved_candle["close"] == candle.close
|
|
assert retrieved_candle["timestamp"] == candle.end_time
|
|
|
|
async def test_get_latest_candle(self, db_ops):
|
|
"""Test fetching the latest candle."""
|
|
base_time = datetime(2023, 1, 1, 13, 0, 0, tzinfo=timezone.utc)
|
|
symbol = "ETH-USDT-TGLC" # Unique symbol for test
|
|
|
|
# Insert a few candles with increasing timestamps
|
|
for i in range(3):
|
|
candle = OHLCVCandle(
|
|
start_time=base_time + timedelta(minutes=i*5),
|
|
end_time=base_time + timedelta(minutes=(i+1)*5),
|
|
open=Decimal("1200") + i,
|
|
high=Decimal("1210") + i,
|
|
low=Decimal("1190") + i,
|
|
close=Decimal("1205") + i,
|
|
volume=Decimal("1000"),
|
|
trade_count=20+i,
|
|
timeframe="5m",
|
|
symbol=symbol,
|
|
exchange="okx"
|
|
)
|
|
db_ops.market_data.upsert_candle(candle)
|
|
|
|
latest_candle = db_ops.market_data.get_latest_candle(
|
|
symbol=symbol,
|
|
timeframe="5m"
|
|
)
|
|
|
|
assert latest_candle is not None
|
|
assert latest_candle["symbol"] == symbol
|
|
assert latest_candle["timeframe"] == "5m"
|
|
assert latest_candle["close"] == Decimal("1207")
|
|
assert latest_candle["timestamp"] == base_time + timedelta(minutes=15)
|
|
|
|
@pytest.mark.asyncio
|
|
class TestRawTradeRepository:
|
|
"""Tests for the RawTradeRepository."""
|
|
|
|
async def test_insert_and_get_raw_trade(self, db_ops):
|
|
"""Test inserting and retrieving a raw trade data point."""
|
|
base_time = datetime(2023, 1, 1, 14, 0, 0, tzinfo=timezone.utc)
|
|
symbol = "XRP-USDT-TIRT" # Unique symbol for test
|
|
|
|
data_point = MarketDataPoint(
|
|
symbol=symbol,
|
|
data_type=DataType.TRADE,
|
|
data={"price": "2.5", "qty": "100"},
|
|
timestamp=base_time,
|
|
exchange="okx"
|
|
)
|
|
|
|
# Insert raw data
|
|
success = db_ops.raw_trades.insert_market_data_point(data_point)
|
|
assert success is True
|
|
|
|
# Retrieve raw data
|
|
start_time = base_time - timedelta(seconds=1)
|
|
end_time = base_time + timedelta(seconds=1)
|
|
|
|
raw_trades = db_ops.raw_trades.get_raw_trades(
|
|
symbol=symbol,
|
|
data_type=DataType.TRADE.value,
|
|
start_time=start_time,
|
|
end_time=end_time
|
|
)
|
|
|
|
assert len(raw_trades) >= 1
|
|
retrieved_trade = raw_trades[0]
|
|
assert retrieved_trade["symbol"] == symbol
|
|
assert retrieved_trade["data_type"] == DataType.TRADE.value
|
|
assert retrieved_trade["raw_data"] == data_point.data |