108 lines
3.9 KiB
Python
108 lines
3.9 KiB
Python
|
|
import asyncio
|
||
|
|
import unittest
|
||
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||
|
|
|
||
|
|
from database.redis_manager import (
|
||
|
|
RedisConfig,
|
||
|
|
SyncRedisManager,
|
||
|
|
AsyncRedisManager,
|
||
|
|
publish_market_data,
|
||
|
|
get_sync_redis_manager
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestRedisManagers(unittest.TestCase):
|
||
|
|
|
||
|
|
def setUp(self):
|
||
|
|
"""Set up mock configs and managers for each test."""
|
||
|
|
self.config = RedisConfig()
|
||
|
|
|
||
|
|
@patch('redis.Redis')
|
||
|
|
@patch('redis.ConnectionPool')
|
||
|
|
def test_sync_manager_initialization(self, mock_pool, mock_redis):
|
||
|
|
"""Test that SyncRedisManager initializes correctly."""
|
||
|
|
mock_redis_instance = mock_redis.return_value
|
||
|
|
manager = SyncRedisManager(self.config)
|
||
|
|
manager.initialize()
|
||
|
|
|
||
|
|
mock_pool.assert_called_once_with(**self.config.get_pool_kwargs())
|
||
|
|
mock_redis.assert_called_once_with(connection_pool=mock_pool.return_value)
|
||
|
|
mock_redis_instance.ping.assert_called_once()
|
||
|
|
self.assertIsNotNone(manager.client)
|
||
|
|
|
||
|
|
@patch('redis.asyncio.Redis')
|
||
|
|
@patch('redis.asyncio.ConnectionPool')
|
||
|
|
def test_async_manager_initialization(self, mock_pool, mock_redis_class):
|
||
|
|
"""Test that AsyncRedisManager initializes correctly."""
|
||
|
|
async def run_test():
|
||
|
|
mock_redis_instance = AsyncMock()
|
||
|
|
mock_redis_class.return_value = mock_redis_instance
|
||
|
|
|
||
|
|
manager = AsyncRedisManager(self.config)
|
||
|
|
await manager.initialize()
|
||
|
|
|
||
|
|
mock_pool.assert_called_once_with(**self.config.get_pool_kwargs())
|
||
|
|
mock_redis_class.assert_called_once_with(connection_pool=mock_pool.return_value)
|
||
|
|
mock_redis_instance.ping.assert_awaited_once()
|
||
|
|
self.assertIsNotNone(manager.async_client)
|
||
|
|
|
||
|
|
asyncio.run(run_test())
|
||
|
|
|
||
|
|
def test_sync_caching(self):
|
||
|
|
"""Test set, get, and delete operations for SyncRedisManager."""
|
||
|
|
manager = SyncRedisManager(self.config)
|
||
|
|
manager._redis_client = MagicMock()
|
||
|
|
|
||
|
|
# Test set
|
||
|
|
manager.set("key1", {"data": "value1"}, ex=60)
|
||
|
|
manager.client.set.assert_called_once_with("key1", '{"data": "value1"}', ex=60)
|
||
|
|
|
||
|
|
# Test get
|
||
|
|
manager.client.get.return_value = '{"data": "value1"}'
|
||
|
|
result = manager.get("key1")
|
||
|
|
self.assertEqual(result, {"data": "value1"})
|
||
|
|
|
||
|
|
# Test delete
|
||
|
|
manager.delete("key1")
|
||
|
|
manager.client.delete.assert_called_once_with("key1")
|
||
|
|
|
||
|
|
def test_async_caching(self):
|
||
|
|
"""Test async set, get, and delete for AsyncRedisManager."""
|
||
|
|
async def run_test():
|
||
|
|
manager = AsyncRedisManager(self.config)
|
||
|
|
manager._async_redis_client = AsyncMock()
|
||
|
|
|
||
|
|
# Test set
|
||
|
|
await manager.set("key2", "value2", ex=30)
|
||
|
|
manager.async_client.set.assert_awaited_once_with("key2", '"value2"', ex=30)
|
||
|
|
|
||
|
|
# Test get
|
||
|
|
manager.async_client.get.return_value = '"value2"'
|
||
|
|
result = await manager.get("key2")
|
||
|
|
self.assertEqual(result, "value2")
|
||
|
|
|
||
|
|
# Test delete
|
||
|
|
await manager.delete("key2")
|
||
|
|
manager.async_client.delete.assert_awaited_once_with("key2")
|
||
|
|
|
||
|
|
asyncio.run(run_test())
|
||
|
|
|
||
|
|
@patch('database.redis_manager.sync_redis_manager', new_callable=MagicMock)
|
||
|
|
def test_publish_market_data_convenience_func(self, mock_global_manager):
|
||
|
|
"""Test the publish_market_data convenience function."""
|
||
|
|
symbol = "BTC/USDT"
|
||
|
|
data = {"price": 100}
|
||
|
|
|
||
|
|
# This setup is needed because the global manager is patched
|
||
|
|
mock_global_manager.channels = get_sync_redis_manager().channels
|
||
|
|
|
||
|
|
publish_market_data(symbol, data)
|
||
|
|
|
||
|
|
expected_channel = mock_global_manager.channels.get_symbol_channel(
|
||
|
|
mock_global_manager.channels.market_data_ohlcv, symbol
|
||
|
|
)
|
||
|
|
mock_global_manager.publish.assert_called_once_with(expected_channel, data)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
unittest.main()
|