import asyncio import unittest from unittest.mock import MagicMock, AsyncMock, patch from database.redis_manager import ( RedisConfig, SyncRedisManager, AsyncRedisManager, publish_market_data, get_sync_redis_manager ) class TestRedisManagers(unittest.TestCase): def setUp(self): """Set up mock configs and managers for each test.""" self.config = RedisConfig() @patch('redis.Redis') @patch('redis.ConnectionPool') def test_sync_manager_initialization(self, mock_pool, mock_redis): """Test that SyncRedisManager initializes correctly.""" mock_redis_instance = mock_redis.return_value manager = SyncRedisManager(self.config) manager.initialize() mock_pool.assert_called_once_with(**self.config.get_pool_kwargs()) mock_redis.assert_called_once_with(connection_pool=mock_pool.return_value) mock_redis_instance.ping.assert_called_once() self.assertIsNotNone(manager.client) @patch('redis.asyncio.Redis') @patch('redis.asyncio.ConnectionPool') def test_async_manager_initialization(self, mock_pool, mock_redis_class): """Test that AsyncRedisManager initializes correctly.""" async def run_test(): mock_redis_instance = AsyncMock() mock_redis_class.return_value = mock_redis_instance manager = AsyncRedisManager(self.config) await manager.initialize() mock_pool.assert_called_once_with(**self.config.get_pool_kwargs()) mock_redis_class.assert_called_once_with(connection_pool=mock_pool.return_value) mock_redis_instance.ping.assert_awaited_once() self.assertIsNotNone(manager.async_client) asyncio.run(run_test()) def test_sync_caching(self): """Test set, get, and delete operations for SyncRedisManager.""" manager = SyncRedisManager(self.config) manager._redis_client = MagicMock() # Test set manager.set("key1", {"data": "value1"}, ex=60) manager.client.set.assert_called_once_with("key1", '{"data": "value1"}', ex=60) # Test get manager.client.get.return_value = '{"data": "value1"}' result = manager.get("key1") self.assertEqual(result, {"data": "value1"}) # Test delete manager.delete("key1") manager.client.delete.assert_called_once_with("key1") def test_async_caching(self): """Test async set, get, and delete for AsyncRedisManager.""" async def run_test(): manager = AsyncRedisManager(self.config) manager._async_redis_client = AsyncMock() # Test set await manager.set("key2", "value2", ex=30) manager.async_client.set.assert_awaited_once_with("key2", '"value2"', ex=30) # Test get manager.async_client.get.return_value = '"value2"' result = await manager.get("key2") self.assertEqual(result, "value2") # Test delete await manager.delete("key2") manager.async_client.delete.assert_awaited_once_with("key2") asyncio.run(run_test()) @patch('database.redis_manager.sync_redis_manager', new_callable=MagicMock) def test_publish_market_data_convenience_func(self, mock_global_manager): """Test the publish_market_data convenience function.""" symbol = "BTC/USDT" data = {"price": 100} # This setup is needed because the global manager is patched mock_global_manager.channels = get_sync_redis_manager().channels publish_market_data(symbol, data) expected_channel = mock_global_manager.channels.get_symbol_channel( mock_global_manager.channels.market_data_ohlcv, symbol ) mock_global_manager.publish.assert_called_once_with(expected_channel, data) if __name__ == '__main__': unittest.main()