import pytest import pytest_asyncio import asyncio from datetime import datetime, timezone, timedelta from decimal import Decimal from database.operations import get_database_operations from database.models import Bot from data.common.data_types import OHLCVCandle from data.base_collector import MarketDataPoint, DataType @pytest.fixture(scope="module") def event_loop(): """Create an instance of the default event loop for each test module.""" loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() @pytest_asyncio.fixture(scope="module") async def db_ops(): """Fixture to provide database operations.""" # We will need to make sure the test database is configured and running operations = get_database_operations() yield operations # Teardown logic can be added here if needed, e.g., operations.close() @pytest.mark.asyncio class TestBotRepository: """Tests for the BotRepository.""" async def test_add_and_get_bot(self, db_ops): """ Test adding a new bot and retrieving it to verify basic repository functionality. """ # Define a new bot bot_name = "test_bot_01" new_bot = { "name": bot_name, "strategy_name": "test_strategy", "symbol": "BTC-USDT", "timeframe": "1h", "status": "inactive", "virtual_balance": Decimal("10000"), } # Clean up any existing bot with the same name existing_bot = db_ops.bots.get_by_name(bot_name) if existing_bot: db_ops.bots.delete(existing_bot.id) # Add the bot added_bot = db_ops.bots.add(new_bot) # Assertions to check if the bot was added correctly assert added_bot is not None assert added_bot.id is not None assert added_bot.name == bot_name assert added_bot.status == "inactive" # Retrieve the bot by ID retrieved_bot = db_ops.bots.get_by_id(added_bot.id) # Assertions to check if the bot was retrieved correctly assert retrieved_bot is not None assert retrieved_bot.id == added_bot.id assert retrieved_bot.name == bot_name # Clean up the created bot db_ops.bots.delete(added_bot.id) # Verify it's deleted deleted_bot = db_ops.bots.get_by_id(added_bot.id) assert deleted_bot is None async def test_update_bot(self, db_ops): """Test updating an existing bot's status.""" bot_name = "test_bot_for_update" bot_data = { "name": bot_name, "strategy_name": "test_strategy", "symbol": "ETH-USDT", "timeframe": "5m", "status": "active", } # Ensure clean state existing_bot = db_ops.bots.get_by_name(bot_name) if existing_bot: db_ops.bots.delete(existing_bot.id) # Add a bot to update bot_to_update = db_ops.bots.add(bot_data) # Update the bot's status update_data = {"status": "paused"} updated_bot = db_ops.bots.update(bot_to_update.id, update_data) # Assertions assert updated_bot is not None assert updated_bot.status == "paused" # Clean up db_ops.bots.delete(bot_to_update.id) async def test_get_nonexistent_bot(self, db_ops): """Test that fetching a non-existent bot returns None.""" non_existent_bot = db_ops.bots.get_by_id(999999) assert non_existent_bot is None async def test_delete_bot(self, db_ops): """Test deleting a bot.""" bot_name = "test_bot_for_delete" bot_data = { "name": bot_name, "strategy_name": "delete_strategy", "symbol": "LTC-USDT", "timeframe": "15m", } # Ensure clean state existing_bot = db_ops.bots.get_by_name(bot_name) if existing_bot: db_ops.bots.delete(existing_bot.id) # Add a bot to delete bot_to_delete = db_ops.bots.add(bot_data) # Delete the bot delete_result = db_ops.bots.delete(bot_to_delete.id) assert delete_result is True # Verify it's gone retrieved_bot = db_ops.bots.get_by_id(bot_to_delete.id) assert retrieved_bot is None @pytest.mark.asyncio class TestMarketDataRepository: """Tests for the MarketDataRepository.""" async def test_upsert_and_get_candle(self, db_ops): """Test upserting and retrieving a candle.""" # Use a fixed timestamp for deterministic tests base_time = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) candle = OHLCVCandle( start_time=base_time, end_time=base_time + timedelta(hours=1), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), close=Decimal("50500"), volume=Decimal("100"), trade_count=10, timeframe="1h", symbol="BTC-USDT-TUG", # Unique symbol for test exchange="okx" ) # Upsert the candle success = db_ops.market_data.upsert_candle(candle) assert success is True # Retrieve the candle using a time range start_time = base_time + timedelta(hours=1) end_time = base_time + timedelta(hours=1) retrieved_candles = db_ops.market_data.get_candles( symbol="BTC-USDT-TUG", timeframe="1h", start_time=start_time, end_time=end_time ) assert len(retrieved_candles) >= 1 retrieved_candle = retrieved_candles[0] assert retrieved_candle["symbol"] == "BTC-USDT-TUG" assert retrieved_candle["close"] == candle.close assert retrieved_candle["timestamp"] == candle.end_time async def test_get_latest_candle(self, db_ops): """Test fetching the latest candle.""" base_time = datetime(2023, 1, 1, 13, 0, 0, tzinfo=timezone.utc) symbol = "ETH-USDT-TGLC" # Unique symbol for test # Insert a few candles with increasing timestamps for i in range(3): candle = OHLCVCandle( start_time=base_time + timedelta(minutes=i*5), end_time=base_time + timedelta(minutes=(i+1)*5), open=Decimal("1200") + i, high=Decimal("1210") + i, low=Decimal("1190") + i, close=Decimal("1205") + i, volume=Decimal("1000"), trade_count=20+i, timeframe="5m", symbol=symbol, exchange="okx" ) db_ops.market_data.upsert_candle(candle) latest_candle = db_ops.market_data.get_latest_candle( symbol=symbol, timeframe="5m" ) assert latest_candle is not None assert latest_candle["symbol"] == symbol assert latest_candle["timeframe"] == "5m" assert latest_candle["close"] == Decimal("1207") assert latest_candle["timestamp"] == base_time + timedelta(minutes=15) @pytest.mark.asyncio class TestRawTradeRepository: """Tests for the RawTradeRepository.""" async def test_insert_and_get_raw_trade(self, db_ops): """Test inserting and retrieving a raw trade data point.""" base_time = datetime(2023, 1, 1, 14, 0, 0, tzinfo=timezone.utc) symbol = "XRP-USDT-TIRT" # Unique symbol for test data_point = MarketDataPoint( symbol=symbol, data_type=DataType.TRADE, data={"price": "2.5", "qty": "100"}, timestamp=base_time, exchange="okx" ) # Insert raw data success = db_ops.raw_trades.insert_market_data_point(data_point) assert success is True # Retrieve raw data start_time = base_time - timedelta(seconds=1) end_time = base_time + timedelta(seconds=1) raw_trades = db_ops.raw_trades.get_raw_trades( symbol=symbol, data_type=DataType.TRADE.value, start_time=start_time, end_time=end_time ) assert len(raw_trades) >= 1 retrieved_trade = raw_trades[0] assert retrieved_trade["symbol"] == symbol assert retrieved_trade["data_type"] == DataType.TRADE.value assert retrieved_trade["raw_data"] == data_point.data