Refactor Redis management and enhance system health callbacks
- Replaced the `RedisManager` class with a more modular `SyncRedisManager` and `AsyncRedisManager`, improving the separation of synchronous and asynchronous operations. - Updated the `system_health.py` callbacks to utilize the new `get_sync_redis_manager` function for Redis interactions, simplifying the connection process. - Enhanced error handling and logging in Redis status checks, providing clearer feedback on connection issues. - Revised the setup documentation to reflect changes in Redis connection testing, ensuring clarity for users. These updates improve the maintainability and reliability of Redis interactions within the system, aligning with best practices for modular design.
This commit is contained in:
parent
1466223b85
commit
fe9d8e75ed
@ -12,7 +12,7 @@ from dash import Output, Input, State, html, callback_context, no_update
|
||||
import dash_bootstrap_components as dbc
|
||||
from utils.logger import get_logger
|
||||
from database.connection import DatabaseManager
|
||||
from database.redis_manager import RedisManager
|
||||
from database.redis_manager import get_sync_redis_manager
|
||||
|
||||
logger = get_logger("system_health_callbacks")
|
||||
|
||||
@ -235,13 +235,16 @@ def _get_database_quick_status() -> dbc.Badge:
|
||||
def _get_redis_quick_status() -> dbc.Badge:
|
||||
"""Get quick Redis status."""
|
||||
try:
|
||||
redis_manager = RedisManager()
|
||||
redis_manager = get_sync_redis_manager()
|
||||
redis_manager.initialize()
|
||||
if redis_manager.test_connection():
|
||||
# This check is simplified as initialize() would raise an error on failure.
|
||||
# For a more explicit check, a dedicated test_connection could be added to SyncRedisManager.
|
||||
if redis_manager.client.ping():
|
||||
return dbc.Badge("Connected", color="success", className="me-1")
|
||||
else:
|
||||
return dbc.Badge("Error", color="danger", className="me-1")
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(f"Redis quick status check failed: {e}")
|
||||
return dbc.Badge("Error", color="danger", className="me-1")
|
||||
|
||||
|
||||
@ -418,38 +421,52 @@ def _get_database_statistics() -> html.Div:
|
||||
|
||||
|
||||
def _get_redis_status() -> html.Div:
|
||||
"""Get Redis status."""
|
||||
"""Get detailed Redis server status."""
|
||||
try:
|
||||
redis_manager = RedisManager()
|
||||
redis_manager = get_sync_redis_manager()
|
||||
redis_manager.initialize()
|
||||
info = redis_manager.get_info()
|
||||
|
||||
if not redis_manager.client.ping():
|
||||
raise ConnectionError("Redis server is not responding.")
|
||||
|
||||
info = redis_manager.client.info()
|
||||
status_badge = dbc.Badge("Connected", color="success", className="me-1")
|
||||
|
||||
return html.Div([
|
||||
dbc.Row([
|
||||
dbc.Col(dbc.Badge("Redis Connected", color="success"), width="auto"),
|
||||
dbc.Col(f"Checked: {datetime.now().strftime('%H:%M:%S')}", className="text-muted")
|
||||
], align="center", className="mb-2"),
|
||||
html.P(f"Host: {redis_manager.config.host}:{redis_manager.config.port}", className="mb-0")
|
||||
html.H5("Redis Status"),
|
||||
status_badge,
|
||||
html.P(f"Version: {info.get('redis_version', 'N/A')}"),
|
||||
html.P(f"Mode: {info.get('redis_mode', 'N/A')}")
|
||||
])
|
||||
|
||||
except Exception as e:
|
||||
return dbc.Alert(f"Error connecting to Redis: {e}", color="danger")
|
||||
logger.error(f"Failed to get Redis status: {e}")
|
||||
return html.Div([
|
||||
html.H5("Redis Status"),
|
||||
dbc.Badge("Error", color="danger", className="me-1"),
|
||||
dbc.Alert(f"Error: {e}", color="danger", dismissable=True)
|
||||
])
|
||||
|
||||
|
||||
def _get_redis_statistics() -> html.Div:
|
||||
"""Get Redis statistics."""
|
||||
"""Get detailed Redis statistics."""
|
||||
try:
|
||||
redis_manager = RedisManager()
|
||||
redis_manager = get_sync_redis_manager()
|
||||
redis_manager.initialize()
|
||||
info = redis_manager.get_info()
|
||||
|
||||
if not redis_manager.client.ping():
|
||||
raise ConnectionError("Redis server is not responding.")
|
||||
|
||||
info = redis_manager.client.info()
|
||||
|
||||
return html.Div([
|
||||
dbc.Row([dbc.Col("Memory Used:"), dbc.Col(info.get('used_memory_human', 'N/A'), className="text-end")]),
|
||||
dbc.Row([dbc.Col("Connected Clients:"), dbc.Col(info.get('connected_clients', 'N/A'), className="text-end")]),
|
||||
dbc.Row([dbc.Col("Uptime (hours):"), dbc.Col(f"{info.get('uptime_in_seconds', 0) // 3600}", className="text-end")])
|
||||
html.H5("Redis Statistics"),
|
||||
html.P(f"Connected Clients: {info.get('connected_clients', 'N/A')}"),
|
||||
html.P(f"Memory Used: {info.get('used_memory_human', 'N/A')}"),
|
||||
html.P(f"Total Commands Processed: {info.get('total_commands_processed', 'N/A')}")
|
||||
])
|
||||
except Exception as e:
|
||||
return dbc.Alert(f"Error loading Redis stats: {e}", color="danger")
|
||||
logger.error(f"Failed to get Redis statistics: {e}")
|
||||
return dbc.Alert(f"Error: {e}", color="danger", dismissable=True)
|
||||
|
||||
|
||||
def _get_system_performance_metrics() -> html.Div:
|
||||
|
||||
@ -1,476 +1,291 @@
|
||||
"""
|
||||
Redis Manager for Crypto Trading Bot Platform
|
||||
Provides Redis connection, pub/sub messaging, and caching utilities
|
||||
Redis Manager for Crypto Trading Bot Platform.
|
||||
Provides Redis connection, pub/sub messaging, and caching utilities.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, List, Callable, Union
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Type
|
||||
|
||||
# Load environment variables from .env file if it exists
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
env_file = Path(__file__).parent.parent / '.env'
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file)
|
||||
except ImportError:
|
||||
# dotenv not available, proceed without it
|
||||
pass
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
import redis
|
||||
import redis.asyncio as redis_async
|
||||
from redis.exceptions import ConnectionError, TimeoutError, RedisError
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisConfig:
|
||||
"""Redis configuration class"""
|
||||
class RedisConfig(BaseSettings):
|
||||
"""Redis configuration class using Pydantic for validation."""
|
||||
|
||||
def __init__(self):
|
||||
self.host = os.getenv('REDIS_HOST', 'localhost')
|
||||
self.port = int(os.getenv('REDIS_PORT', '6379'))
|
||||
self.password = os.getenv('REDIS_PASSWORD', '')
|
||||
self.db = int(os.getenv('REDIS_DB', '0'))
|
||||
|
||||
# Connection settings
|
||||
self.socket_timeout = int(os.getenv('REDIS_SOCKET_TIMEOUT', '5'))
|
||||
self.socket_connect_timeout = int(os.getenv('REDIS_CONNECT_TIMEOUT', '5'))
|
||||
self.socket_keepalive = os.getenv('REDIS_KEEPALIVE', 'true').lower() == 'true'
|
||||
self.socket_keepalive_options = {}
|
||||
|
||||
# Pool settings
|
||||
self.max_connections = int(os.getenv('REDIS_MAX_CONNECTIONS', '20'))
|
||||
self.retry_on_timeout = os.getenv('REDIS_RETRY_ON_TIMEOUT', 'true').lower() == 'true'
|
||||
|
||||
# Channel prefixes for organization
|
||||
self.channel_prefix = os.getenv('REDIS_CHANNEL_PREFIX', 'crypto_bot')
|
||||
|
||||
logger.info(f"Redis configuration initialized for: {self.host}:{self.port}")
|
||||
REDIS_HOST: str = 'localhost'
|
||||
REDIS_PORT: int = 6379
|
||||
REDIS_PASSWORD: str = ''
|
||||
REDIS_DB: int = 0
|
||||
|
||||
# Connection settings
|
||||
REDIS_SOCKET_TIMEOUT: int = 5
|
||||
REDIS_CONNECT_TIMEOUT: int = 5
|
||||
REDIS_KEEPALIVE: bool = True
|
||||
|
||||
# Pool settings
|
||||
REDIS_MAX_CONNECTIONS: int = 20
|
||||
REDIS_RETRY_ON_TIMEOUT: bool = True
|
||||
|
||||
# Channel prefixes for organization
|
||||
REDIS_CHANNEL_PREFIX: str = 'crypto_bot'
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": True,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
def get_connection_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get Redis connection configuration"""
|
||||
"""Get Redis connection configuration."""
|
||||
kwargs = {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'db': self.db,
|
||||
'socket_timeout': self.socket_timeout,
|
||||
'socket_connect_timeout': self.socket_connect_timeout,
|
||||
'socket_keepalive': self.socket_keepalive,
|
||||
'socket_keepalive_options': self.socket_keepalive_options,
|
||||
'retry_on_timeout': self.retry_on_timeout,
|
||||
'decode_responses': True, # Automatically decode responses to strings
|
||||
'host': self.REDIS_HOST,
|
||||
'port': self.REDIS_PORT,
|
||||
'db': self.REDIS_DB,
|
||||
'socket_timeout': self.REDIS_SOCKET_TIMEOUT,
|
||||
'socket_connect_timeout': self.REDIS_CONNECT_TIMEOUT,
|
||||
'socket_keepalive': self.REDIS_KEEPALIVE,
|
||||
'socket_keepalive_options': {},
|
||||
'retry_on_timeout': self.REDIS_RETRY_ON_TIMEOUT,
|
||||
'decode_responses': True,
|
||||
}
|
||||
|
||||
if self.password:
|
||||
kwargs['password'] = self.password
|
||||
|
||||
if self.REDIS_PASSWORD:
|
||||
kwargs['password'] = self.REDIS_PASSWORD
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_pool_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get Redis connection pool configuration"""
|
||||
"""Get Redis connection pool configuration."""
|
||||
kwargs = self.get_connection_kwargs()
|
||||
kwargs['max_connections'] = self.max_connections
|
||||
kwargs['max_connections'] = self.REDIS_MAX_CONNECTIONS
|
||||
return kwargs
|
||||
|
||||
|
||||
class RedisChannels:
|
||||
"""Redis channel definitions for organized messaging"""
|
||||
"""Redis channel definitions for organized messaging."""
|
||||
|
||||
def __init__(self, prefix: str = 'crypto_bot'):
|
||||
self.prefix = prefix
|
||||
|
||||
# Market data channels
|
||||
self.market_data = f"{prefix}:market_data"
|
||||
self.market_data_raw = f"{prefix}:market_data:raw"
|
||||
self.market_data_ohlcv = f"{prefix}:market_data:ohlcv"
|
||||
|
||||
# Bot channels
|
||||
self.bot_signals = f"{prefix}:bot:signals"
|
||||
self.bot_trades = f"{prefix}:bot:trades"
|
||||
self.bot_status = f"{prefix}:bot:status"
|
||||
self.bot_performance = f"{prefix}:bot:performance"
|
||||
|
||||
# System channels
|
||||
self.system_health = f"{prefix}:system:health"
|
||||
self.system_alerts = f"{prefix}:system:alerts"
|
||||
|
||||
# Dashboard channels
|
||||
self.dashboard_updates = f"{prefix}:dashboard:updates"
|
||||
self.dashboard_commands = f"{prefix}:dashboard:commands"
|
||||
|
||||
def get_symbol_channel(self, base_channel: str, symbol: str) -> str:
|
||||
"""Get symbol-specific channel"""
|
||||
"""Get symbol-specific channel."""
|
||||
return f"{base_channel}:{symbol}"
|
||||
|
||||
def get_bot_channel(self, base_channel: str, bot_id: int) -> str:
|
||||
"""Get bot-specific channel"""
|
||||
"""Get bot-specific channel."""
|
||||
return f"{base_channel}:{bot_id}"
|
||||
|
||||
|
||||
class RedisManager:
|
||||
"""
|
||||
Redis manager with connection pooling and pub/sub messaging
|
||||
"""
|
||||
class BaseRedisManager:
|
||||
"""Base class for Redis managers, handling config and channels."""
|
||||
|
||||
def __init__(self, config: Optional[RedisConfig] = None):
|
||||
self.config = config or RedisConfig()
|
||||
self.channels = RedisChannels(self.config.channel_prefix)
|
||||
|
||||
# Synchronous Redis client
|
||||
self._redis_client: Optional[redis.Redis] = None
|
||||
self.channels = RedisChannels(self.config.REDIS_CHANNEL_PREFIX)
|
||||
|
||||
|
||||
class SyncRedisManager(BaseRedisManager):
|
||||
"""Synchronous Redis manager for standard operations."""
|
||||
|
||||
def __init__(self, config: Optional[RedisConfig] = None):
|
||||
super().__init__(config)
|
||||
self._connection_pool: Optional[redis.ConnectionPool] = None
|
||||
|
||||
# Asynchronous Redis client
|
||||
self._async_redis_client: Optional[redis_async.Redis] = None
|
||||
self._async_connection_pool: Optional[redis_async.ConnectionPool] = None
|
||||
|
||||
# Pub/sub clients
|
||||
self._redis_client: Optional[redis.Redis] = None
|
||||
self._pubsub_client: Optional[redis.client.PubSub] = None
|
||||
self._async_pubsub_client: Optional[redis_async.client.PubSub] = None
|
||||
|
||||
# Subscription handlers
|
||||
self._message_handlers: Dict[str, List[Callable]] = {}
|
||||
self._async_message_handlers: Dict[str, List[Callable]] = {}
|
||||
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize Redis connections"""
|
||||
"""Initialize synchronous Redis connection."""
|
||||
try:
|
||||
logger.info("Initializing Redis connection...")
|
||||
|
||||
# Create connection pool
|
||||
logger.info("Initializing sync Redis connection...")
|
||||
self._connection_pool = redis.ConnectionPool(**self.config.get_pool_kwargs())
|
||||
self._redis_client = redis.Redis(connection_pool=self._connection_pool)
|
||||
|
||||
# Test connection
|
||||
self._redis_client.ping()
|
||||
logger.info("Redis connection initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Redis: {e}")
|
||||
logger.info("Sync Redis connection initialized successfully.")
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize sync Redis: {e}")
|
||||
raise
|
||||
|
||||
async def initialize_async(self) -> None:
|
||||
"""Initialize async Redis connections"""
|
||||
try:
|
||||
logger.info("Initializing async Redis connection...")
|
||||
|
||||
# Create async connection pool
|
||||
self._async_connection_pool = redis_async.ConnectionPool(**self.config.get_pool_kwargs())
|
||||
self._async_redis_client = redis_async.Redis(connection_pool=self._async_connection_pool)
|
||||
|
||||
# Test connection
|
||||
await self._async_redis_client.ping()
|
||||
logger.info("Async Redis connection initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize async Redis: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@property
|
||||
def client(self) -> redis.Redis:
|
||||
"""Get synchronous Redis client"""
|
||||
"""Get synchronous Redis client."""
|
||||
if not self._redis_client:
|
||||
raise RuntimeError("Redis not initialized. Call initialize() first.")
|
||||
raise RuntimeError("Sync Redis not initialized. Call initialize() first.")
|
||||
return self._redis_client
|
||||
|
||||
|
||||
def publish(self, channel: str, message: Union[str, Dict[str, Any]]) -> int:
|
||||
"""Publish message to a channel."""
|
||||
if isinstance(message, dict):
|
||||
message = json.dumps(message, default=str)
|
||||
return self.client.publish(channel, message)
|
||||
|
||||
def set(self, key: str, value: Any, ex: Optional[int] = None) -> None:
|
||||
"""Set a key-value pair with an optional expiration."""
|
||||
self.client.set(key, json.dumps(value, default=str), ex=ex)
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""Get a value by key."""
|
||||
value = self.client.get(key)
|
||||
return json.loads(value) if value else None
|
||||
|
||||
def delete(self, *keys: str) -> int:
|
||||
"""Delete one or more keys."""
|
||||
return self.client.delete(*keys)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close Redis connections."""
|
||||
if self._connection_pool:
|
||||
self._connection_pool.disconnect()
|
||||
logger.info("Sync Redis connections closed.")
|
||||
|
||||
|
||||
class AsyncRedisManager(BaseRedisManager):
|
||||
"""Asynchronous Redis manager for asyncio operations."""
|
||||
|
||||
def __init__(self, config: Optional[RedisConfig] = None):
|
||||
super().__init__(config)
|
||||
self._async_connection_pool: Optional[redis_async.ConnectionPool] = None
|
||||
self._async_redis_client: Optional[redis_async.Redis] = None
|
||||
self._async_pubsub_client: Optional[redis_async.client.PubSub] = None
|
||||
self._async_message_handlers: Dict[str, List[Callable]] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize asynchronous Redis connection."""
|
||||
try:
|
||||
logger.info("Initializing async Redis connection...")
|
||||
self._async_connection_pool = redis_async.ConnectionPool(**self.config.get_pool_kwargs())
|
||||
self._async_redis_client = redis_async.Redis(connection_pool=self._async_connection_pool)
|
||||
await self._async_redis_client.ping()
|
||||
logger.info("Async Redis connection initialized successfully.")
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize async Redis: {e}")
|
||||
raise
|
||||
|
||||
@property
|
||||
def async_client(self) -> redis_async.Redis:
|
||||
"""Get asynchronous Redis client"""
|
||||
"""Get asynchronous Redis client."""
|
||||
if not self._async_redis_client:
|
||||
raise RuntimeError("Async Redis not initialized. Call initialize_async() first.")
|
||||
raise RuntimeError("Async Redis not initialized. Call initialize() first.")
|
||||
return self._async_redis_client
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""Test Redis connection"""
|
||||
try:
|
||||
self.client.ping()
|
||||
logger.info("Redis connection test successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Redis connection test failed: {e}")
|
||||
return False
|
||||
|
||||
async def test_connection_async(self) -> bool:
|
||||
"""Test async Redis connection"""
|
||||
try:
|
||||
await self.async_client.ping()
|
||||
logger.info("Async Redis connection test successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Async Redis connection test failed: {e}")
|
||||
return False
|
||||
|
||||
def publish(self, channel: str, message: Union[str, Dict[str, Any]]) -> int:
|
||||
"""
|
||||
Publish message to channel
|
||||
|
||||
Args:
|
||||
channel: Redis channel name
|
||||
message: Message to publish (string or dict that will be JSON serialized)
|
||||
|
||||
Returns:
|
||||
Number of clients that received the message
|
||||
"""
|
||||
try:
|
||||
if isinstance(message, dict):
|
||||
message = json.dumps(message, default=str)
|
||||
|
||||
result = self.client.publish(channel, message)
|
||||
logger.debug(f"Published message to {channel}: {result} clients received")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to publish message to {channel}: {e}")
|
||||
raise
|
||||
|
||||
async def publish_async(self, channel: str, message: Union[str, Dict[str, Any]]) -> int:
|
||||
"""
|
||||
Publish message to channel (async)
|
||||
|
||||
Args:
|
||||
channel: Redis channel name
|
||||
message: Message to publish (string or dict that will be JSON serialized)
|
||||
|
||||
Returns:
|
||||
Number of clients that received the message
|
||||
"""
|
||||
try:
|
||||
if isinstance(message, dict):
|
||||
message = json.dumps(message, default=str)
|
||||
|
||||
result = await self.async_client.publish(channel, message)
|
||||
logger.debug(f"Published message to {channel}: {result} clients received")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to publish message to {channel}: {e}")
|
||||
raise
|
||||
|
||||
def subscribe(self, channels: Union[str, List[str]], handler: Callable[[str, str], None]) -> None:
|
||||
"""
|
||||
Subscribe to Redis channels with message handler
|
||||
|
||||
Args:
|
||||
channels: Channel name or list of channel names
|
||||
handler: Function to handle received messages (channel, message)
|
||||
"""
|
||||
if isinstance(channels, str):
|
||||
channels = [channels]
|
||||
|
||||
for channel in channels:
|
||||
if channel not in self._message_handlers:
|
||||
self._message_handlers[channel] = []
|
||||
self._message_handlers[channel].append(handler)
|
||||
|
||||
logger.info(f"Registered handler for channels: {channels}")
|
||||
|
||||
async def subscribe_async(self, channels: Union[str, List[str]], handler: Callable[[str, str], None]) -> None:
|
||||
"""
|
||||
Subscribe to Redis channels with message handler (async)
|
||||
|
||||
Args:
|
||||
channels: Channel name or list of channel names
|
||||
handler: Function to handle received messages (channel, message)
|
||||
"""
|
||||
if isinstance(channels, str):
|
||||
channels = [channels]
|
||||
|
||||
for channel in channels:
|
||||
if channel not in self._async_message_handlers:
|
||||
self._async_message_handlers[channel] = []
|
||||
self._async_message_handlers[channel].append(handler)
|
||||
|
||||
logger.info(f"Registered async handler for channels: {channels}")
|
||||
|
||||
def start_subscriber(self) -> None:
|
||||
"""Start synchronous message subscriber"""
|
||||
if not self._message_handlers:
|
||||
logger.warning("No message handlers registered")
|
||||
return
|
||||
|
||||
try:
|
||||
self._pubsub_client = self.client.pubsub()
|
||||
|
||||
# Subscribe to all channels with handlers
|
||||
for channel in self._message_handlers.keys():
|
||||
self._pubsub_client.subscribe(channel)
|
||||
|
||||
logger.info(f"Started subscriber for channels: {list(self._message_handlers.keys())}")
|
||||
|
||||
# Message processing loop
|
||||
for message in self._pubsub_client.listen():
|
||||
if message['type'] == 'message':
|
||||
channel = message['channel']
|
||||
data = message['data']
|
||||
|
||||
# Call all handlers for this channel
|
||||
if channel in self._message_handlers:
|
||||
for handler in self._message_handlers[channel]:
|
||||
try:
|
||||
handler(channel, data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in message handler for {channel}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in message subscriber: {e}")
|
||||
raise
|
||||
|
||||
async def start_subscriber_async(self) -> None:
|
||||
"""Start asynchronous message subscriber"""
|
||||
if not self._async_message_handlers:
|
||||
logger.warning("No async message handlers registered")
|
||||
return
|
||||
|
||||
try:
|
||||
self._async_pubsub_client = self.async_client.pubsub()
|
||||
|
||||
# Subscribe to all channels with handlers
|
||||
for channel in self._async_message_handlers.keys():
|
||||
await self._async_pubsub_client.subscribe(channel)
|
||||
|
||||
logger.info(f"Started async subscriber for channels: {list(self._async_message_handlers.keys())}")
|
||||
|
||||
# Message processing loop
|
||||
async for message in self._async_pubsub_client.listen():
|
||||
if message['type'] == 'message':
|
||||
channel = message['channel']
|
||||
data = message['data']
|
||||
|
||||
# Call all handlers for this channel
|
||||
if channel in self._async_message_handlers:
|
||||
for handler in self._async_message_handlers[channel]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(channel, data)
|
||||
else:
|
||||
handler(channel, data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in async message handler for {channel}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in async message subscriber: {e}")
|
||||
raise
|
||||
|
||||
def stop_subscriber(self) -> None:
|
||||
"""Stop synchronous message subscriber"""
|
||||
if self._pubsub_client:
|
||||
self._pubsub_client.close()
|
||||
self._pubsub_client = None
|
||||
logger.info("Stopped message subscriber")
|
||||
|
||||
async def stop_subscriber_async(self) -> None:
|
||||
"""Stop asynchronous message subscriber"""
|
||||
if self._async_pubsub_client:
|
||||
await self._async_pubsub_client.close()
|
||||
self._async_pubsub_client = None
|
||||
logger.info("Stopped async message subscriber")
|
||||
|
||||
def get_info(self) -> Dict[str, Any]:
|
||||
"""Get Redis server information"""
|
||||
try:
|
||||
return self.client.info()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get Redis info: {e}")
|
||||
return {}
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close Redis connections"""
|
||||
try:
|
||||
self.stop_subscriber()
|
||||
|
||||
if self._connection_pool:
|
||||
self._connection_pool.disconnect()
|
||||
|
||||
logger.info("Redis connections closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing Redis connections: {e}")
|
||||
|
||||
async def close_async(self) -> None:
|
||||
"""Close async Redis connections"""
|
||||
try:
|
||||
await self.stop_subscriber_async()
|
||||
|
||||
if self._async_connection_pool:
|
||||
await self._async_connection_pool.disconnect()
|
||||
|
||||
logger.info("Async Redis connections closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing async Redis connections: {e}")
|
||||
|
||||
async def publish(self, channel: str, message: Union[str, Dict[str, Any]]) -> int:
|
||||
"""Publish message to a channel asynchronously."""
|
||||
if isinstance(message, dict):
|
||||
message = json.dumps(message, default=str)
|
||||
return await self.async_client.publish(channel, message)
|
||||
|
||||
async def set(self, key: str, value: Any, ex: Optional[int] = None) -> None:
|
||||
"""Set a key-value pair asynchronously."""
|
||||
await self.async_client.set(key, json.dumps(value, default=str), ex=ex)
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Get a value by key asynchronously."""
|
||||
value = await self.async_client.get(key)
|
||||
return json.loads(value) if value else None
|
||||
|
||||
async def delete(self, *keys: str) -> int:
|
||||
"""Delete one or more keys asynchronously."""
|
||||
return await self.async_client.delete(*keys)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close async Redis connections."""
|
||||
if self._async_connection_pool:
|
||||
await self._async_connection_pool.disconnect()
|
||||
logger.info("Async Redis connections closed.")
|
||||
|
||||
|
||||
# Global Redis manager instance
|
||||
redis_manager = RedisManager()
|
||||
# Global instances (to be managed carefully, e.g., via a factory or DI)
|
||||
sync_redis_manager = SyncRedisManager()
|
||||
async_redis_manager = AsyncRedisManager()
|
||||
|
||||
|
||||
def get_redis_manager() -> RedisManager:
|
||||
"""Get global Redis manager instance"""
|
||||
return redis_manager
|
||||
def get_sync_redis_manager() -> SyncRedisManager:
|
||||
"""Get the global synchronous Redis manager instance."""
|
||||
return sync_redis_manager
|
||||
|
||||
|
||||
def init_redis(config: Optional[RedisConfig] = None) -> RedisManager:
|
||||
def get_async_redis_manager() -> AsyncRedisManager:
|
||||
"""Get the global asynchronous Redis manager instance."""
|
||||
return async_redis_manager
|
||||
|
||||
|
||||
def init_redis(config: Optional[RedisConfig] = None) -> SyncRedisManager:
|
||||
"""
|
||||
Initialize global Redis manager
|
||||
Initialize global sync Redis manager.
|
||||
|
||||
Args:
|
||||
config: Optional Redis configuration
|
||||
config: Optional Redis configuration.
|
||||
|
||||
Returns:
|
||||
RedisManager instance
|
||||
SyncRedisManager instance.
|
||||
"""
|
||||
global redis_manager
|
||||
global sync_redis_manager
|
||||
if config:
|
||||
redis_manager = RedisManager(config)
|
||||
redis_manager.initialize()
|
||||
return redis_manager
|
||||
sync_redis_manager = SyncRedisManager(config)
|
||||
sync_redis_manager.initialize()
|
||||
return sync_redis_manager
|
||||
|
||||
|
||||
async def init_redis_async(config: Optional[RedisConfig] = None) -> RedisManager:
|
||||
async def init_redis_async(config: Optional[RedisConfig] = None) -> AsyncRedisManager:
|
||||
"""
|
||||
Initialize global Redis manager (async)
|
||||
Initialize global async Redis manager.
|
||||
|
||||
Args:
|
||||
config: Optional Redis configuration
|
||||
config: Optional Redis configuration.
|
||||
|
||||
Returns:
|
||||
RedisManager instance
|
||||
AsyncRedisManager instance.
|
||||
"""
|
||||
global redis_manager
|
||||
global async_redis_manager
|
||||
if config:
|
||||
redis_manager = RedisManager(config)
|
||||
await redis_manager.initialize_async()
|
||||
return redis_manager
|
||||
async_redis_manager = AsyncRedisManager(config)
|
||||
await async_redis_manager.initialize()
|
||||
return async_redis_manager
|
||||
|
||||
|
||||
# Convenience functions for common operations
|
||||
def publish_market_data(symbol: str, data: Dict[str, Any]) -> int:
|
||||
"""Publish market data to symbol-specific channel"""
|
||||
channel = redis_manager.channels.get_symbol_channel(redis_manager.channels.market_data_ohlcv, symbol)
|
||||
return redis_manager.publish(channel, data)
|
||||
"""Publish market data to symbol-specific channel."""
|
||||
channel = sync_redis_manager.channels.get_symbol_channel(sync_redis_manager.channels.market_data_ohlcv, symbol)
|
||||
return sync_redis_manager.publish(channel, data)
|
||||
|
||||
|
||||
def publish_bot_signal(bot_id: int, signal_data: Dict[str, Any]) -> int:
|
||||
"""Publish bot signal to bot-specific channel"""
|
||||
channel = redis_manager.channels.get_bot_channel(redis_manager.channels.bot_signals, bot_id)
|
||||
return redis_manager.publish(channel, signal_data)
|
||||
"""Publish bot signal to bot-specific channel."""
|
||||
channel = sync_redis_manager.channels.get_bot_channel(sync_redis_manager.channels.bot_signals, bot_id)
|
||||
return sync_redis_manager.publish(channel, signal_data)
|
||||
|
||||
|
||||
def publish_bot_trade(bot_id: int, trade_data: Dict[str, Any]) -> int:
|
||||
"""Publish bot trade to bot-specific channel"""
|
||||
channel = redis_manager.channels.get_bot_channel(redis_manager.channels.bot_trades, bot_id)
|
||||
return redis_manager.publish(channel, trade_data)
|
||||
"""Publish bot trade to bot-specific channel."""
|
||||
channel = sync_redis_manager.channels.get_bot_channel(sync_redis_manager.channels.bot_trades, bot_id)
|
||||
return sync_redis_manager.publish(channel, trade_data)
|
||||
|
||||
|
||||
def publish_system_health(health_data: Dict[str, Any]) -> int:
|
||||
"""Publish system health status"""
|
||||
return redis_manager.publish(redis_manager.channels.system_health, health_data)
|
||||
"""Publish system health status."""
|
||||
return sync_redis_manager.publish(sync_redis_manager.channels.system_health, health_data)
|
||||
|
||||
|
||||
def publish_dashboard_update(update_data: Dict[str, Any]) -> int:
|
||||
"""Publish dashboard update"""
|
||||
return redis_manager.publish(redis_manager.channels.dashboard_updates, update_data)
|
||||
"""Publish dashboard update."""
|
||||
return sync_redis_manager.publish(sync_redis_manager.channels.dashboard_updates, update_data)
|
||||
@ -337,34 +337,25 @@ Create a quick test script:
|
||||
```python
|
||||
# test_connection.py
|
||||
import os
|
||||
import psycopg2
|
||||
import redis
|
||||
from dotenv import load_dotenv
|
||||
from database.connection import DatabaseManager
|
||||
|
||||
# Load environment variables
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Test PostgreSQL
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=os.getenv('POSTGRES_HOST'),
|
||||
port=os.getenv('POSTGRES_PORT'),
|
||||
database=os.getenv('POSTGRES_DB'),
|
||||
user=os.getenv('POSTGRES_USER'),
|
||||
password=os.getenv('POSTGRES_PASSWORD')
|
||||
)
|
||||
print("✅ PostgreSQL connection successful!")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"❌ PostgreSQL connection failed: {e}")
|
||||
# Test Database
|
||||
db = DatabaseManager()
|
||||
db.initialize()
|
||||
if db.test_connection():
|
||||
print("✅ Database connection successful!")
|
||||
db.close()
|
||||
|
||||
# Test Redis
|
||||
from database.redis_manager import get_sync_redis_manager
|
||||
|
||||
try:
|
||||
r = redis.Redis(
|
||||
host=os.getenv('REDIS_HOST'),
|
||||
port=int(os.getenv('REDIS_PORT')),
|
||||
password=os.getenv('REDIS_PASSWORD')
|
||||
)
|
||||
r.ping()
|
||||
redis_manager = get_sync_redis_manager()
|
||||
redis_manager.initialize()
|
||||
print("✅ Redis connection successful!")
|
||||
except Exception as e:
|
||||
print(f"❌ Redis connection failed: {e}")
|
||||
|
||||
108
tests/database/test_redis_manager.py
Normal file
108
tests/database/test_redis_manager.py
Normal file
@ -0,0 +1,108 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from database.redis_manager import (
|
||||
RedisConfig,
|
||||
SyncRedisManager,
|
||||
AsyncRedisManager,
|
||||
publish_market_data,
|
||||
get_sync_redis_manager
|
||||
)
|
||||
|
||||
|
||||
class TestRedisManagers(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Set up mock configs and managers for each test."""
|
||||
self.config = RedisConfig()
|
||||
|
||||
@patch('redis.Redis')
|
||||
@patch('redis.ConnectionPool')
|
||||
def test_sync_manager_initialization(self, mock_pool, mock_redis):
|
||||
"""Test that SyncRedisManager initializes correctly."""
|
||||
mock_redis_instance = mock_redis.return_value
|
||||
manager = SyncRedisManager(self.config)
|
||||
manager.initialize()
|
||||
|
||||
mock_pool.assert_called_once_with(**self.config.get_pool_kwargs())
|
||||
mock_redis.assert_called_once_with(connection_pool=mock_pool.return_value)
|
||||
mock_redis_instance.ping.assert_called_once()
|
||||
self.assertIsNotNone(manager.client)
|
||||
|
||||
@patch('redis.asyncio.Redis')
|
||||
@patch('redis.asyncio.ConnectionPool')
|
||||
def test_async_manager_initialization(self, mock_pool, mock_redis_class):
|
||||
"""Test that AsyncRedisManager initializes correctly."""
|
||||
async def run_test():
|
||||
mock_redis_instance = AsyncMock()
|
||||
mock_redis_class.return_value = mock_redis_instance
|
||||
|
||||
manager = AsyncRedisManager(self.config)
|
||||
await manager.initialize()
|
||||
|
||||
mock_pool.assert_called_once_with(**self.config.get_pool_kwargs())
|
||||
mock_redis_class.assert_called_once_with(connection_pool=mock_pool.return_value)
|
||||
mock_redis_instance.ping.assert_awaited_once()
|
||||
self.assertIsNotNone(manager.async_client)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_sync_caching(self):
|
||||
"""Test set, get, and delete operations for SyncRedisManager."""
|
||||
manager = SyncRedisManager(self.config)
|
||||
manager._redis_client = MagicMock()
|
||||
|
||||
# Test set
|
||||
manager.set("key1", {"data": "value1"}, ex=60)
|
||||
manager.client.set.assert_called_once_with("key1", '{"data": "value1"}', ex=60)
|
||||
|
||||
# Test get
|
||||
manager.client.get.return_value = '{"data": "value1"}'
|
||||
result = manager.get("key1")
|
||||
self.assertEqual(result, {"data": "value1"})
|
||||
|
||||
# Test delete
|
||||
manager.delete("key1")
|
||||
manager.client.delete.assert_called_once_with("key1")
|
||||
|
||||
def test_async_caching(self):
|
||||
"""Test async set, get, and delete for AsyncRedisManager."""
|
||||
async def run_test():
|
||||
manager = AsyncRedisManager(self.config)
|
||||
manager._async_redis_client = AsyncMock()
|
||||
|
||||
# Test set
|
||||
await manager.set("key2", "value2", ex=30)
|
||||
manager.async_client.set.assert_awaited_once_with("key2", '"value2"', ex=30)
|
||||
|
||||
# Test get
|
||||
manager.async_client.get.return_value = '"value2"'
|
||||
result = await manager.get("key2")
|
||||
self.assertEqual(result, "value2")
|
||||
|
||||
# Test delete
|
||||
await manager.delete("key2")
|
||||
manager.async_client.delete.assert_awaited_once_with("key2")
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
@patch('database.redis_manager.sync_redis_manager', new_callable=MagicMock)
|
||||
def test_publish_market_data_convenience_func(self, mock_global_manager):
|
||||
"""Test the publish_market_data convenience function."""
|
||||
symbol = "BTC/USDT"
|
||||
data = {"price": 100}
|
||||
|
||||
# This setup is needed because the global manager is patched
|
||||
mock_global_manager.channels = get_sync_redis_manager().channels
|
||||
|
||||
publish_market_data(symbol, data)
|
||||
|
||||
expected_channel = mock_global_manager.channels.get_symbol_channel(
|
||||
mock_global_manager.channels.market_data_ohlcv, symbol
|
||||
)
|
||||
mock_global_manager.publish.assert_called_once_with(expected_channel, data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user