Enhance OKX WebSocket client with improved task management and error handling
- Implemented enhanced task synchronization to prevent race conditions during WebSocket operations. - Introduced reconnection locking to avoid concurrent reconnection attempts. - Improved error handling in message processing and reconnection logic, ensuring graceful shutdown and task management. - Added unit tests to verify the stability and reliability of the WebSocket client under concurrent operations.
This commit is contained in:
@@ -122,9 +122,11 @@ class OKXWebSocketClient:
|
||||
self._message_callbacks: List[Callable[[Dict[str, Any]], None]] = []
|
||||
self._subscriptions: Dict[str, OKXSubscription] = {}
|
||||
|
||||
# Tasks
|
||||
# Enhanced task management
|
||||
self._ping_task: Optional[asyncio.Task] = None
|
||||
self._message_handler_task: Optional[asyncio.Task] = None
|
||||
self._reconnection_lock = asyncio.Lock() # Prevent concurrent reconnections
|
||||
self._tasks_stopping = False # Flag to prevent task overlap
|
||||
|
||||
# Statistics
|
||||
self._stats = {
|
||||
@@ -380,6 +382,15 @@ class OKXWebSocketClient:
|
||||
|
||||
async def _start_background_tasks(self) -> None:
|
||||
"""Start background tasks for ping and message handling."""
|
||||
# Ensure no tasks are currently stopping
|
||||
if self._tasks_stopping:
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.component_name}: Cannot start tasks while stopping is in progress")
|
||||
return
|
||||
|
||||
# Cancel any existing tasks first
|
||||
await self._stop_background_tasks()
|
||||
|
||||
# Start ping task
|
||||
self._ping_task = asyncio.create_task(self._ping_loop())
|
||||
|
||||
@@ -390,22 +401,53 @@ class OKXWebSocketClient:
|
||||
self.logger.debug(f"{self.component_name}: Started background tasks")
|
||||
|
||||
async def _stop_background_tasks(self) -> None:
|
||||
"""Stop background tasks."""
|
||||
tasks = [self._ping_task, self._message_handler_task]
|
||||
"""Stop background tasks with proper synchronization."""
|
||||
self._tasks_stopping = True
|
||||
|
||||
for task in tasks:
|
||||
if task and not task.done():
|
||||
try:
|
||||
tasks = []
|
||||
|
||||
# Collect tasks to cancel
|
||||
if self._ping_task and not self._ping_task.done():
|
||||
tasks.append(self._ping_task)
|
||||
if self._message_handler_task and not self._message_handler_task.done():
|
||||
tasks.append(self._message_handler_task)
|
||||
|
||||
if not tasks:
|
||||
if self.logger:
|
||||
self.logger.debug(f"{self.component_name}: No background tasks to stop")
|
||||
return
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"{self.component_name}: Stopping {len(tasks)} background tasks")
|
||||
|
||||
# Cancel all tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to complete with timeout
|
||||
if tasks:
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._ping_task = None
|
||||
self._message_handler_task = None
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"{self.component_name}: Stopped background tasks")
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True),
|
||||
timeout=5.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.component_name}: Task shutdown timeout - some tasks may still be running")
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.debug(f"{self.component_name}: Expected exception during task shutdown: {e}")
|
||||
|
||||
# Clear task references
|
||||
self._ping_task = None
|
||||
self._message_handler_task = None
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"{self.component_name}: Background tasks stopped successfully")
|
||||
|
||||
finally:
|
||||
self._tasks_stopping = False
|
||||
|
||||
async def _ping_loop(self) -> None:
|
||||
"""Background task for sending ping messages."""
|
||||
@@ -435,58 +477,91 @@ class OKXWebSocketClient:
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _message_handler(self) -> None:
|
||||
"""Background task for handling incoming messages."""
|
||||
while self.is_connected:
|
||||
try:
|
||||
if not self._websocket:
|
||||
break
|
||||
|
||||
# Receive message with timeout
|
||||
"""Background task for handling incoming messages with enhanced error handling."""
|
||||
if self.logger:
|
||||
self.logger.debug(f"{self.component_name}: Message handler started")
|
||||
|
||||
try:
|
||||
while self.is_connected and not self._tasks_stopping:
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue # No message received, continue loop
|
||||
|
||||
# Process message
|
||||
await self._process_message(message)
|
||||
|
||||
except ConnectionClosed as e:
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.component_name}: WebSocket connection closed: {e}")
|
||||
self._connection_state = ConnectionState.DISCONNECTED
|
||||
|
||||
# Attempt automatic reconnection if enabled
|
||||
if self._reconnect_attempts < self.max_reconnect_attempts:
|
||||
self._reconnect_attempts += 1
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
|
||||
|
||||
# Stop current tasks
|
||||
await self._stop_background_tasks()
|
||||
|
||||
# Attempt reconnection
|
||||
if await self.reconnect():
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.component_name}: Automatic reconnection successful")
|
||||
continue
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.error(f"{self.component_name}: Automatic reconnection failed")
|
||||
if not self._websocket or self._tasks_stopping:
|
||||
break
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"{self.component_name}: Error in message handler: {e}")
|
||||
await asyncio.sleep(1)
|
||||
# Receive message with timeout
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue # No message received, continue loop
|
||||
|
||||
# Check if we're still supposed to be running
|
||||
if self._tasks_stopping:
|
||||
break
|
||||
|
||||
# Process message
|
||||
await self._process_message(message)
|
||||
|
||||
except ConnectionClosed as e:
|
||||
if self._tasks_stopping:
|
||||
break # Expected during shutdown
|
||||
|
||||
if self.logger:
|
||||
self.logger.warning(f"{self.component_name}: WebSocket connection closed: {e}")
|
||||
self._connection_state = ConnectionState.DISCONNECTED
|
||||
|
||||
# Use lock to prevent concurrent reconnection attempts
|
||||
async with self._reconnection_lock:
|
||||
# Double-check we still need to reconnect
|
||||
if (self._connection_state == ConnectionState.DISCONNECTED and
|
||||
self._reconnect_attempts < self.max_reconnect_attempts and
|
||||
not self._tasks_stopping):
|
||||
|
||||
self._reconnect_attempts += 1
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
|
||||
|
||||
# Stop current tasks properly
|
||||
await self._stop_background_tasks()
|
||||
|
||||
# Attempt reconnection with stored subscriptions
|
||||
stored_subscriptions = list(self._subscriptions.values())
|
||||
|
||||
if await self.reconnect():
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.component_name}: Automatic reconnection successful")
|
||||
# The reconnect method will restart tasks, so we exit this handler
|
||||
break
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.error(f"{self.component_name}: Automatic reconnection failed")
|
||||
break
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
if self.logger:
|
||||
self.logger.debug(f"{self.component_name}: Message handler cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
if self._tasks_stopping:
|
||||
break
|
||||
if self.logger:
|
||||
self.logger.error(f"{self.component_name}: Error in message handler: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
if self.logger:
|
||||
self.logger.debug(f"{self.component_name}: Message handler task cancelled")
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"{self.component_name}: Fatal error in message handler: {e}")
|
||||
finally:
|
||||
if self.logger:
|
||||
self.logger.debug(f"{self.component_name}: Message handler exiting")
|
||||
|
||||
async def _send_message(self, message: Dict[str, Any]) -> None:
|
||||
"""
|
||||
@@ -626,34 +701,40 @@ class OKXWebSocketClient:
|
||||
|
||||
async def reconnect(self) -> bool:
|
||||
"""
|
||||
Reconnect to WebSocket with retry logic.
|
||||
Reconnect to WebSocket with enhanced synchronization.
|
||||
|
||||
Returns:
|
||||
True if reconnection successful, False otherwise
|
||||
"""
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.component_name}: Attempting to reconnect to OKX WebSocket")
|
||||
self._connection_state = ConnectionState.RECONNECTING
|
||||
self._stats['reconnections'] += 1
|
||||
|
||||
# Disconnect first
|
||||
await self.disconnect()
|
||||
|
||||
# Wait a moment before reconnecting
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Attempt to reconnect
|
||||
success = await self.connect()
|
||||
|
||||
if success:
|
||||
# Re-subscribe to previous subscriptions
|
||||
if self._subscriptions:
|
||||
subscriptions = list(self._subscriptions.values())
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.component_name}: Re-subscribing to {len(subscriptions)} channels")
|
||||
await self.subscribe(subscriptions)
|
||||
|
||||
return success
|
||||
async with self._reconnection_lock:
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.component_name}: Attempting to reconnect to OKX WebSocket")
|
||||
self._connection_state = ConnectionState.RECONNECTING
|
||||
self._stats['reconnections'] += 1
|
||||
|
||||
# Store current subscriptions before disconnect
|
||||
stored_subscriptions = list(self._subscriptions.values())
|
||||
|
||||
# Disconnect first with proper cleanup
|
||||
await self.disconnect()
|
||||
|
||||
# Wait a moment before reconnecting
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Attempt to reconnect
|
||||
success = await self.connect()
|
||||
|
||||
if success:
|
||||
# Re-subscribe to previous subscriptions
|
||||
if stored_subscriptions:
|
||||
if self.logger:
|
||||
self.logger.info(f"{self.component_name}: Re-subscribing to {len(stored_subscriptions)} channels")
|
||||
await self.subscribe(stored_subscriptions)
|
||||
|
||||
# Reset reconnect attempts on successful reconnection
|
||||
self._reconnect_attempts = 0
|
||||
|
||||
return success
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<OKXWebSocketClient(state={self._connection_state.value}, subscriptions={len(self._subscriptions)})>"
|
||||
Reference in New Issue
Block a user