#!/usr/bin/env python3 """ Simple test to verify recursion fix in WebSocket task management. """ import asyncio import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType from utils.logger import get_logger async def test_rapid_connection_cycles(): """Test rapid connect/disconnect cycles to verify no recursion errors.""" logger = get_logger("recursion_test", verbose=False) print("๐Ÿงช Testing WebSocket Recursion Fix") print("=" * 40) for cycle in range(5): print(f"\n๐Ÿ”„ Cycle {cycle + 1}/5: Rapid connect/disconnect") ws_client = OKXWebSocketClient( component_name=f"test_client_{cycle}", max_reconnect_attempts=2, logger=logger ) try: # Connect success = await ws_client.connect() if not success: print(f" โŒ Connection failed in cycle {cycle + 1}") continue # Subscribe subscriptions = [ OKXSubscription(OKXChannelType.TRADES.value, "BTC-USDT") ] await ws_client.subscribe(subscriptions) # Quick activity await asyncio.sleep(0.5) # Disconnect (this should not cause recursion) await ws_client.disconnect() print(f" โœ… Cycle {cycle + 1} completed successfully") except RecursionError as e: print(f" โŒ Recursion error in cycle {cycle + 1}: {e}") return False except Exception as e: print(f" โš ๏ธ Other error in cycle {cycle + 1}: {e}") # Continue with other cycles # Small delay between cycles await asyncio.sleep(0.2) print("\nโœ… All cycles completed without recursion errors") return True async def test_concurrent_shutdowns(): """Test concurrent client shutdowns to verify no recursion.""" logger = get_logger("concurrent_shutdown_test", verbose=False) print("\n๐Ÿ”„ Testing Concurrent Shutdowns") print("=" * 40) # Create multiple clients clients = [] for i in range(3): client = OKXWebSocketClient( component_name=f"concurrent_client_{i}", logger=logger ) clients.append(client) try: # Connect all clients connect_tasks = [client.connect() for client in clients] results = await asyncio.gather(*connect_tasks, return_exceptions=True) successful_connections = sum(1 for r in results if r is True) print(f"๐Ÿ“ก Connected {successful_connections}/3 clients") # Let them run briefly await asyncio.sleep(1) # Shutdown all concurrently (this is where recursion might occur) print("๐Ÿ›‘ Shutting down all clients concurrently...") shutdown_tasks = [client.disconnect() for client in clients] # Use wait_for to prevent hanging try: await asyncio.wait_for( asyncio.gather(*shutdown_tasks, return_exceptions=True), timeout=10.0 ) print("โœ… All clients shut down successfully") return True except asyncio.TimeoutError: print("โš ๏ธ Shutdown timeout - but no recursion errors") return True # Timeout is better than recursion except RecursionError as e: print(f"โŒ Recursion error during concurrent shutdown: {e}") return False except Exception as e: print(f"โš ๏ธ Other error during test: {e}") return True # Other errors are acceptable for this test async def main(): """Run recursion fix tests.""" print("๐Ÿš€ WebSocket Recursion Fix Test Suite") print("=" * 50) try: # Test 1: Rapid cycles test1_success = await test_rapid_connection_cycles() # Test 2: Concurrent shutdowns test2_success = await test_concurrent_shutdowns() # Summary print("\n" + "=" * 50) print("๐Ÿ“‹ Test Summary:") print(f" Rapid Cycles: {'โœ… PASS' if test1_success else 'โŒ FAIL'}") print(f" Concurrent Shutdowns: {'โœ… PASS' if test2_success else 'โŒ FAIL'}") if test1_success and test2_success: print("\n๐ŸŽ‰ All tests passed! Recursion issue fixed.") return 0 else: print("\nโŒ Some tests failed.") return 1 except KeyboardInterrupt: print("\nโน๏ธ Tests interrupted") return 1 except Exception as e: print(f"\n๐Ÿ’ฅ Test suite failed: {e}") return 1 if __name__ == "__main__": exit_code = asyncio.run(main()) sys.exit(exit_code)