#!/usr/bin/env python3 """ Test script to verify WebSocket race condition fixes. This script tests the enhanced task management and synchronization in the OKX WebSocket client to ensure no more recv() concurrency errors. """ import asyncio import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType from utils.logger import get_logger async def test_websocket_reconnection_stability(): """Test WebSocket reconnection without race conditions.""" logger = get_logger("websocket_test", verbose=True) print("๐Ÿงช Testing WebSocket Race Condition Fixes") print("=" * 50) # Create WebSocket client ws_client = OKXWebSocketClient( component_name="test_ws_client", ping_interval=25.0, max_reconnect_attempts=3, logger=logger ) try: # Test 1: Basic connection print("\n๐Ÿ“ก Test 1: Basic Connection") success = await ws_client.connect() if success: print("โœ… Initial connection successful") else: print("โŒ Initial connection failed") return False # Test 2: Subscribe to channels print("\n๐Ÿ“ก Test 2: Channel Subscription") subscriptions = [ OKXSubscription(OKXChannelType.TRADES.value, "BTC-USDT"), OKXSubscription(OKXChannelType.BOOKS5.value, "BTC-USDT") ] success = await ws_client.subscribe(subscriptions) if success: print("โœ… Subscription successful") else: print("โŒ Subscription failed") return False # Test 3: Force reconnection to test race condition fixes print("\n๐Ÿ“ก Test 3: Force Reconnection (Race Condition Test)") for i in range(3): print(f" Reconnection attempt {i+1}/3...") success = await ws_client.reconnect() if success: print(f" โœ… Reconnection {i+1} successful") await asyncio.sleep(2) # Wait between reconnections else: print(f" โŒ Reconnection {i+1} failed") return False # Test 4: Verify subscriptions are maintained print("\n๐Ÿ“ก Test 4: Subscription Persistence") current_subs = ws_client.get_subscriptions() if len(current_subs) == 2: print("โœ… Subscriptions persisted after reconnections") else: print(f"โŒ Subscription count mismatch: expected 2, got {len(current_subs)}") # Test 5: Monitor for a few seconds to catch any errors print("\n๐Ÿ“ก Test 5: Stability Monitor (10 seconds)") message_count = 0 def message_callback(message): nonlocal message_count message_count += 1 if message_count % 10 == 0: print(f" ๐Ÿ“Š Processed {message_count} messages") ws_client.add_message_callback(message_callback) await asyncio.sleep(10) stats = ws_client.get_stats() print(f"\n๐Ÿ“Š Final Statistics:") print(f" Messages received: {stats['messages_received']}") print(f" Reconnections: {stats['reconnections']}") print(f" Connection state: {stats['connection_state']}") if stats['messages_received'] > 0: print("โœ… Receiving data successfully") else: print("โš ๏ธ No messages received (may be normal for low-activity symbols)") return True except Exception as e: print(f"โŒ Test failed with exception: {e}") logger.error(f"Test exception: {e}") return False finally: # Cleanup await ws_client.disconnect() print("\n๐Ÿงน Cleanup completed") async def test_concurrent_operations(): """Test concurrent WebSocket operations to ensure no race conditions.""" print("\n๐Ÿ”„ Testing Concurrent Operations") print("=" * 50) logger = get_logger("concurrent_test", verbose=False) # Create multiple clients clients = [] for i in range(3): client = OKXWebSocketClient( component_name=f"test_client_{i}", logger=logger ) clients.append(client) try: # Connect all clients concurrently print("๐Ÿ“ก Connecting 3 clients concurrently...") tasks = [client.connect() for client in clients] results = await asyncio.gather(*tasks, return_exceptions=True) successful_connections = sum(1 for r in results if r is True) print(f"โœ… {successful_connections}/3 clients connected successfully") # Test concurrent reconnections print("\n๐Ÿ”„ Testing concurrent reconnections...") reconnect_tasks = [] for client in clients: if client.is_connected: reconnect_tasks.append(client.reconnect()) if reconnect_tasks: reconnect_results = await asyncio.gather(*reconnect_tasks, return_exceptions=True) successful_reconnects = sum(1 for r in reconnect_results if r is True) print(f"โœ… {successful_reconnects}/{len(reconnect_tasks)} reconnections successful") return True except Exception as e: print(f"โŒ Concurrent test failed: {e}") return False finally: # Cleanup all clients for client in clients: try: await client.disconnect() except: pass async def main(): """Run all WebSocket tests.""" print("๐Ÿš€ WebSocket Race Condition Fix Test Suite") print("=" * 60) try: # Test 1: Basic reconnection stability test1_success = await test_websocket_reconnection_stability() # Test 2: Concurrent operations test2_success = await test_concurrent_operations() # Summary print("\n" + "=" * 60) print("๐Ÿ“‹ Test Summary:") print(f" Reconnection Stability: {'โœ… PASS' if test1_success else 'โŒ FAIL'}") print(f" Concurrent Operations: {'โœ… PASS' if test2_success else 'โŒ FAIL'}") if test1_success and test2_success: print("\n๐ŸŽ‰ All tests passed! WebSocket race condition fixes working correctly.") return 0 else: print("\nโŒ Some tests failed. Check logs for details.") return 1 except KeyboardInterrupt: print("\nโน๏ธ Tests interrupted by user") return 1 except Exception as e: print(f"\n๐Ÿ’ฅ Test suite failed with exception: {e}") return 1 if __name__ == "__main__": exit_code = asyncio.run(main()) sys.exit(exit_code)