TCPDashboard/tests/test_websocket_race_condition_fix.py
Vasily.onl 01cea1d5e5 Enhance OKX WebSocket client with improved task management and error handling
- Implemented enhanced task synchronization to prevent race conditions during WebSocket operations.
- Introduced reconnection locking to avoid concurrent reconnection attempts.
- Improved error handling in message processing and reconnection logic, ensuring graceful shutdown and task management.
- Added unit tests to verify the stability and reliability of the WebSocket client under concurrent operations.
2025-06-02 23:14:04 +08:00

205 lines
6.8 KiB
Python

#!/usr/bin/env python3
"""
Test script to verify WebSocket race condition fixes.
This script tests the enhanced task management and synchronization
in the OKX WebSocket client to ensure no more recv() concurrency errors.
"""
import asyncio
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType
from utils.logger import get_logger
async def test_websocket_reconnection_stability():
"""Test WebSocket reconnection without race conditions."""
logger = get_logger("websocket_test", verbose=True)
print("🧪 Testing WebSocket Race Condition Fixes")
print("=" * 50)
# Create WebSocket client
ws_client = OKXWebSocketClient(
component_name="test_ws_client",
ping_interval=25.0,
max_reconnect_attempts=3,
logger=logger
)
try:
# Test 1: Basic connection
print("\n📡 Test 1: Basic Connection")
success = await ws_client.connect()
if success:
print("✅ Initial connection successful")
else:
print("❌ Initial connection failed")
return False
# Test 2: Subscribe to channels
print("\n📡 Test 2: Channel Subscription")
subscriptions = [
OKXSubscription(OKXChannelType.TRADES.value, "BTC-USDT"),
OKXSubscription(OKXChannelType.BOOKS5.value, "BTC-USDT")
]
success = await ws_client.subscribe(subscriptions)
if success:
print("✅ Subscription successful")
else:
print("❌ Subscription failed")
return False
# Test 3: Force reconnection to test race condition fixes
print("\n📡 Test 3: Force Reconnection (Race Condition Test)")
for i in range(3):
print(f" Reconnection attempt {i+1}/3...")
success = await ws_client.reconnect()
if success:
print(f" ✅ Reconnection {i+1} successful")
await asyncio.sleep(2) # Wait between reconnections
else:
print(f" ❌ Reconnection {i+1} failed")
return False
# Test 4: Verify subscriptions are maintained
print("\n📡 Test 4: Subscription Persistence")
current_subs = ws_client.get_subscriptions()
if len(current_subs) == 2:
print("✅ Subscriptions persisted after reconnections")
else:
print(f"❌ Subscription count mismatch: expected 2, got {len(current_subs)}")
# Test 5: Monitor for a few seconds to catch any errors
print("\n📡 Test 5: Stability Monitor (10 seconds)")
message_count = 0
def message_callback(message):
nonlocal message_count
message_count += 1
if message_count % 10 == 0:
print(f" 📊 Processed {message_count} messages")
ws_client.add_message_callback(message_callback)
await asyncio.sleep(10)
stats = ws_client.get_stats()
print(f"\n📊 Final Statistics:")
print(f" Messages received: {stats['messages_received']}")
print(f" Reconnections: {stats['reconnections']}")
print(f" Connection state: {stats['connection_state']}")
if stats['messages_received'] > 0:
print("✅ Receiving data successfully")
else:
print("⚠️ No messages received (may be normal for low-activity symbols)")
return True
except Exception as e:
print(f"❌ Test failed with exception: {e}")
logger.error(f"Test exception: {e}")
return False
finally:
# Cleanup
await ws_client.disconnect()
print("\n🧹 Cleanup completed")
async def test_concurrent_operations():
"""Test concurrent WebSocket operations to ensure no race conditions."""
print("\n🔄 Testing Concurrent Operations")
print("=" * 50)
logger = get_logger("concurrent_test", verbose=False)
# Create multiple clients
clients = []
for i in range(3):
client = OKXWebSocketClient(
component_name=f"test_client_{i}",
logger=logger
)
clients.append(client)
try:
# Connect all clients concurrently
print("📡 Connecting 3 clients concurrently...")
tasks = [client.connect() for client in clients]
results = await asyncio.gather(*tasks, return_exceptions=True)
successful_connections = sum(1 for r in results if r is True)
print(f"{successful_connections}/3 clients connected successfully")
# Test concurrent reconnections
print("\n🔄 Testing concurrent reconnections...")
reconnect_tasks = []
for client in clients:
if client.is_connected:
reconnect_tasks.append(client.reconnect())
if reconnect_tasks:
reconnect_results = await asyncio.gather(*reconnect_tasks, return_exceptions=True)
successful_reconnects = sum(1 for r in reconnect_results if r is True)
print(f"{successful_reconnects}/{len(reconnect_tasks)} reconnections successful")
return True
except Exception as e:
print(f"❌ Concurrent test failed: {e}")
return False
finally:
# Cleanup all clients
for client in clients:
try:
await client.disconnect()
except:
pass
async def main():
"""Run all WebSocket tests."""
print("🚀 WebSocket Race Condition Fix Test Suite")
print("=" * 60)
try:
# Test 1: Basic reconnection stability
test1_success = await test_websocket_reconnection_stability()
# Test 2: Concurrent operations
test2_success = await test_concurrent_operations()
# Summary
print("\n" + "=" * 60)
print("📋 Test Summary:")
print(f" Reconnection Stability: {'✅ PASS' if test1_success else '❌ FAIL'}")
print(f" Concurrent Operations: {'✅ PASS' if test2_success else '❌ FAIL'}")
if test1_success and test2_success:
print("\n🎉 All tests passed! WebSocket race condition fixes working correctly.")
return 0
else:
print("\n❌ Some tests failed. Check logs for details.")
return 1
except KeyboardInterrupt:
print("\n⏹️ Tests interrupted by user")
return 1
except Exception as e:
print(f"\n💥 Test suite failed with exception: {e}")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)