diff --git a/data/exchanges/okx/websocket.py b/data/exchanges/okx/websocket.py index 11d3f75..c10e8b0 100644 --- a/data/exchanges/okx/websocket.py +++ b/data/exchanges/okx/websocket.py @@ -388,56 +388,83 @@ class OKXWebSocketClient: self.logger.warning(f"{self.component_name}: Cannot start tasks while stopping is in progress") return - # Cancel any existing tasks first + # Check if tasks are already running + if (self._ping_task and not self._ping_task.done() and + self._message_handler_task and not self._message_handler_task.done()): + if self.logger: + self.logger.debug(f"{self.component_name}: Background tasks already running") + return + + # Cancel any existing tasks first (safety measure) await self._stop_background_tasks() - # Start ping task - self._ping_task = asyncio.create_task(self._ping_loop()) + # Ensure we're still supposed to start tasks after stopping + if self._tasks_stopping or not self.is_connected: + if self.logger: + self.logger.debug(f"{self.component_name}: Aborting task start - stopping or disconnected") + return - # Start message handler task - self._message_handler_task = asyncio.create_task(self._message_handler()) - - if self.logger: - self.logger.debug(f"{self.component_name}: Started background tasks") + try: + # Start ping task + self._ping_task = asyncio.create_task(self._ping_loop()) + + # Start message handler task + self._message_handler_task = asyncio.create_task(self._message_handler()) + + if self.logger: + self.logger.debug(f"{self.component_name}: Started background tasks") + + except Exception as e: + if self.logger: + self.logger.error(f"{self.component_name}: Error starting background tasks: {e}") + # Clean up on failure + await self._stop_background_tasks() async def _stop_background_tasks(self) -> None: - """Stop background tasks with proper synchronization.""" + """Stop background tasks with proper synchronization - simplified approach.""" self._tasks_stopping = True try: - tasks = [] - # Collect tasks to cancel - if self._ping_task and not self._ping_task.done(): - tasks.append(self._ping_task) - if self._message_handler_task and not self._message_handler_task.done(): - tasks.append(self._message_handler_task) + tasks_to_cancel = [] - if not tasks: + if self._ping_task and not self._ping_task.done(): + tasks_to_cancel.append(('ping_task', self._ping_task)) + if self._message_handler_task and not self._message_handler_task.done(): + tasks_to_cancel.append(('message_handler_task', self._message_handler_task)) + + if not tasks_to_cancel: if self.logger: self.logger.debug(f"{self.component_name}: No background tasks to stop") return if self.logger: - self.logger.debug(f"{self.component_name}: Stopping {len(tasks)} background tasks") + self.logger.debug(f"{self.component_name}: Stopping {len(tasks_to_cancel)} background tasks") - # Cancel all tasks - for task in tasks: - task.cancel() - - # Wait for all tasks to complete with timeout - if tasks: + # Cancel tasks individually to avoid recursion + for task_name, task in tasks_to_cancel: try: - await asyncio.wait_for( - asyncio.gather(*tasks, return_exceptions=True), - timeout=5.0 - ) - except asyncio.TimeoutError: - if self.logger: - self.logger.warning(f"{self.component_name}: Task shutdown timeout - some tasks may still be running") + if not task.done(): + task.cancel() + if self.logger: + self.logger.debug(f"{self.component_name}: Cancelled {task_name}") except Exception as e: if self.logger: - self.logger.debug(f"{self.component_name}: Expected exception during task shutdown: {e}") + self.logger.debug(f"{self.component_name}: Error cancelling {task_name}: {e}") + + # Wait for tasks to complete individually with shorter timeouts + for task_name, task in tasks_to_cancel: + try: + await asyncio.wait_for(task, timeout=2.0) + except asyncio.TimeoutError: + if self.logger: + self.logger.warning(f"{self.component_name}: {task_name} shutdown timeout") + except asyncio.CancelledError: + # Expected when task is cancelled + pass + except Exception as e: + if self.logger: + self.logger.debug(f"{self.component_name}: {task_name} shutdown exception: {e}") # Clear task references self._ping_task = None @@ -446,6 +473,9 @@ class OKXWebSocketClient: if self.logger: self.logger.debug(f"{self.component_name}: Background tasks stopped successfully") + except Exception as e: + if self.logger: + self.logger.error(f"{self.component_name}: Error in _stop_background_tasks: {e}") finally: self._tasks_stopping = False @@ -495,6 +525,9 @@ class OKXWebSocketClient: ) except asyncio.TimeoutError: continue # No message received, continue loop + except asyncio.CancelledError: + # Exit immediately on cancellation + break # Check if we're still supposed to be running if self._tasks_stopping: @@ -512,35 +545,42 @@ class OKXWebSocketClient: self._connection_state = ConnectionState.DISCONNECTED # Use lock to prevent concurrent reconnection attempts - async with self._reconnection_lock: - # Double-check we still need to reconnect - if (self._connection_state == ConnectionState.DISCONNECTED and - self._reconnect_attempts < self.max_reconnect_attempts and - not self._tasks_stopping): - - self._reconnect_attempts += 1 - if self.logger: - self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})") - - # Stop current tasks properly - await self._stop_background_tasks() - - # Attempt reconnection with stored subscriptions - stored_subscriptions = list(self._subscriptions.values()) - - if await self.reconnect(): - if self.logger: - self.logger.info(f"{self.component_name}: Automatic reconnection successful") - # The reconnect method will restart tasks, so we exit this handler - break - else: - if self.logger: - self.logger.error(f"{self.component_name}: Automatic reconnection failed") - break - else: - if self.logger: - self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress") - break + try: + # Use asyncio.wait_for to prevent hanging on lock acquisition + async with asyncio.wait_for(self._reconnection_lock.acquire(), timeout=5.0): + try: + # Double-check we still need to reconnect + if (self._connection_state == ConnectionState.DISCONNECTED and + self._reconnect_attempts < self.max_reconnect_attempts and + not self._tasks_stopping): + + self._reconnect_attempts += 1 + if self.logger: + self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})") + + # Attempt reconnection (this will handle task cleanup) + if await self.reconnect(): + if self.logger: + self.logger.info(f"{self.component_name}: Automatic reconnection successful") + # Exit this handler as reconnect will start new tasks + break + else: + if self.logger: + self.logger.error(f"{self.component_name}: Automatic reconnection failed") + break + else: + if self.logger: + self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress") + break + finally: + self._reconnection_lock.release() + except asyncio.TimeoutError: + if self.logger: + self.logger.warning(f"{self.component_name}: Timeout acquiring reconnection lock") + break + except asyncio.CancelledError: + # Exit immediately on cancellation + break except asyncio.CancelledError: if self.logger: diff --git a/tests/test_recursion_fix.py b/tests/test_recursion_fix.py new file mode 100644 index 0000000..35702a5 --- /dev/null +++ b/tests/test_recursion_fix.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +""" +Simple test to verify recursion fix in WebSocket task management. +""" + +import asyncio +import sys +from pathlib import Path + +# Add project root to path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType +from utils.logger import get_logger + + +async def test_rapid_connection_cycles(): + """Test rapid connect/disconnect cycles to verify no recursion errors.""" + logger = get_logger("recursion_test", verbose=False) + + print("๐Ÿงช Testing WebSocket Recursion Fix") + print("=" * 40) + + for cycle in range(5): + print(f"\n๐Ÿ”„ Cycle {cycle + 1}/5: Rapid connect/disconnect") + + ws_client = OKXWebSocketClient( + component_name=f"test_client_{cycle}", + max_reconnect_attempts=2, + logger=logger + ) + + try: + # Connect + success = await ws_client.connect() + if not success: + print(f" โŒ Connection failed in cycle {cycle + 1}") + continue + + # Subscribe + subscriptions = [ + OKXSubscription(OKXChannelType.TRADES.value, "BTC-USDT") + ] + await ws_client.subscribe(subscriptions) + + # Quick activity + await asyncio.sleep(0.5) + + # Disconnect (this should not cause recursion) + await ws_client.disconnect() + print(f" โœ… Cycle {cycle + 1} completed successfully") + + except RecursionError as e: + print(f" โŒ Recursion error in cycle {cycle + 1}: {e}") + return False + except Exception as e: + print(f" โš ๏ธ Other error in cycle {cycle + 1}: {e}") + # Continue with other cycles + + # Small delay between cycles + await asyncio.sleep(0.2) + + print("\nโœ… All cycles completed without recursion errors") + return True + + +async def test_concurrent_shutdowns(): + """Test concurrent client shutdowns to verify no recursion.""" + logger = get_logger("concurrent_shutdown_test", verbose=False) + + print("\n๐Ÿ”„ Testing Concurrent Shutdowns") + print("=" * 40) + + # Create multiple clients + clients = [] + for i in range(3): + client = OKXWebSocketClient( + component_name=f"concurrent_client_{i}", + logger=logger + ) + clients.append(client) + + try: + # Connect all clients + connect_tasks = [client.connect() for client in clients] + results = await asyncio.gather(*connect_tasks, return_exceptions=True) + + successful_connections = sum(1 for r in results if r is True) + print(f"๐Ÿ“ก Connected {successful_connections}/3 clients") + + # Let them run briefly + await asyncio.sleep(1) + + # Shutdown all concurrently (this is where recursion might occur) + print("๐Ÿ›‘ Shutting down all clients concurrently...") + shutdown_tasks = [client.disconnect() for client in clients] + + # Use wait_for to prevent hanging + try: + await asyncio.wait_for( + asyncio.gather(*shutdown_tasks, return_exceptions=True), + timeout=10.0 + ) + print("โœ… All clients shut down successfully") + return True + + except asyncio.TimeoutError: + print("โš ๏ธ Shutdown timeout - but no recursion errors") + return True # Timeout is better than recursion + + except RecursionError as e: + print(f"โŒ Recursion error during concurrent shutdown: {e}") + return False + except Exception as e: + print(f"โš ๏ธ Other error during test: {e}") + return True # Other errors are acceptable for this test + + +async def main(): + """Run recursion fix tests.""" + print("๐Ÿš€ WebSocket Recursion Fix Test Suite") + print("=" * 50) + + try: + # Test 1: Rapid cycles + test1_success = await test_rapid_connection_cycles() + + # Test 2: Concurrent shutdowns + test2_success = await test_concurrent_shutdowns() + + # Summary + print("\n" + "=" * 50) + print("๐Ÿ“‹ Test Summary:") + print(f" Rapid Cycles: {'โœ… PASS' if test1_success else 'โŒ FAIL'}") + print(f" Concurrent Shutdowns: {'โœ… PASS' if test2_success else 'โŒ FAIL'}") + + if test1_success and test2_success: + print("\n๐ŸŽ‰ All tests passed! Recursion issue fixed.") + return 0 + else: + print("\nโŒ Some tests failed.") + return 1 + + except KeyboardInterrupt: + print("\nโน๏ธ Tests interrupted") + return 1 + except Exception as e: + print(f"\n๐Ÿ’ฅ Test suite failed: {e}") + return 1 + + +if __name__ == "__main__": + exit_code = asyncio.run(main()) + sys.exit(exit_code) \ No newline at end of file