205 lines
6.8 KiB
Python
205 lines
6.8 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Test script to verify WebSocket race condition fixes.
|
||
|
|
|
||
|
|
This script tests the enhanced task management and synchronization
|
||
|
|
in the OKX WebSocket client to ensure no more recv() concurrency errors.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
# Add project root to path
|
||
|
|
project_root = Path(__file__).parent.parent
|
||
|
|
sys.path.insert(0, str(project_root))
|
||
|
|
|
||
|
|
from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType
|
||
|
|
from utils.logger import get_logger
|
||
|
|
|
||
|
|
|
||
|
|
async def test_websocket_reconnection_stability():
|
||
|
|
"""Test WebSocket reconnection without race conditions."""
|
||
|
|
logger = get_logger("websocket_test", verbose=True)
|
||
|
|
|
||
|
|
print("🧪 Testing WebSocket Race Condition Fixes")
|
||
|
|
print("=" * 50)
|
||
|
|
|
||
|
|
# Create WebSocket client
|
||
|
|
ws_client = OKXWebSocketClient(
|
||
|
|
component_name="test_ws_client",
|
||
|
|
ping_interval=25.0,
|
||
|
|
max_reconnect_attempts=3,
|
||
|
|
logger=logger
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Test 1: Basic connection
|
||
|
|
print("\n📡 Test 1: Basic Connection")
|
||
|
|
success = await ws_client.connect()
|
||
|
|
if success:
|
||
|
|
print("✅ Initial connection successful")
|
||
|
|
else:
|
||
|
|
print("❌ Initial connection failed")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Test 2: Subscribe to channels
|
||
|
|
print("\n📡 Test 2: Channel Subscription")
|
||
|
|
subscriptions = [
|
||
|
|
OKXSubscription(OKXChannelType.TRADES.value, "BTC-USDT"),
|
||
|
|
OKXSubscription(OKXChannelType.BOOKS5.value, "BTC-USDT")
|
||
|
|
]
|
||
|
|
|
||
|
|
success = await ws_client.subscribe(subscriptions)
|
||
|
|
if success:
|
||
|
|
print("✅ Subscription successful")
|
||
|
|
else:
|
||
|
|
print("❌ Subscription failed")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Test 3: Force reconnection to test race condition fixes
|
||
|
|
print("\n📡 Test 3: Force Reconnection (Race Condition Test)")
|
||
|
|
for i in range(3):
|
||
|
|
print(f" Reconnection attempt {i+1}/3...")
|
||
|
|
success = await ws_client.reconnect()
|
||
|
|
if success:
|
||
|
|
print(f" ✅ Reconnection {i+1} successful")
|
||
|
|
await asyncio.sleep(2) # Wait between reconnections
|
||
|
|
else:
|
||
|
|
print(f" ❌ Reconnection {i+1} failed")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Test 4: Verify subscriptions are maintained
|
||
|
|
print("\n📡 Test 4: Subscription Persistence")
|
||
|
|
current_subs = ws_client.get_subscriptions()
|
||
|
|
if len(current_subs) == 2:
|
||
|
|
print("✅ Subscriptions persisted after reconnections")
|
||
|
|
else:
|
||
|
|
print(f"❌ Subscription count mismatch: expected 2, got {len(current_subs)}")
|
||
|
|
|
||
|
|
# Test 5: Monitor for a few seconds to catch any errors
|
||
|
|
print("\n📡 Test 5: Stability Monitor (10 seconds)")
|
||
|
|
message_count = 0
|
||
|
|
|
||
|
|
def message_callback(message):
|
||
|
|
nonlocal message_count
|
||
|
|
message_count += 1
|
||
|
|
if message_count % 10 == 0:
|
||
|
|
print(f" 📊 Processed {message_count} messages")
|
||
|
|
|
||
|
|
ws_client.add_message_callback(message_callback)
|
||
|
|
|
||
|
|
await asyncio.sleep(10)
|
||
|
|
|
||
|
|
stats = ws_client.get_stats()
|
||
|
|
print(f"\n📊 Final Statistics:")
|
||
|
|
print(f" Messages received: {stats['messages_received']}")
|
||
|
|
print(f" Reconnections: {stats['reconnections']}")
|
||
|
|
print(f" Connection state: {stats['connection_state']}")
|
||
|
|
|
||
|
|
if stats['messages_received'] > 0:
|
||
|
|
print("✅ Receiving data successfully")
|
||
|
|
else:
|
||
|
|
print("⚠️ No messages received (may be normal for low-activity symbols)")
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Test failed with exception: {e}")
|
||
|
|
logger.error(f"Test exception: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
finally:
|
||
|
|
# Cleanup
|
||
|
|
await ws_client.disconnect()
|
||
|
|
print("\n🧹 Cleanup completed")
|
||
|
|
|
||
|
|
|
||
|
|
async def test_concurrent_operations():
|
||
|
|
"""Test concurrent WebSocket operations to ensure no race conditions."""
|
||
|
|
print("\n🔄 Testing Concurrent Operations")
|
||
|
|
print("=" * 50)
|
||
|
|
|
||
|
|
logger = get_logger("concurrent_test", verbose=False)
|
||
|
|
|
||
|
|
# Create multiple clients
|
||
|
|
clients = []
|
||
|
|
for i in range(3):
|
||
|
|
client = OKXWebSocketClient(
|
||
|
|
component_name=f"test_client_{i}",
|
||
|
|
logger=logger
|
||
|
|
)
|
||
|
|
clients.append(client)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Connect all clients concurrently
|
||
|
|
print("📡 Connecting 3 clients concurrently...")
|
||
|
|
tasks = [client.connect() for client in clients]
|
||
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
|
|
||
|
|
successful_connections = sum(1 for r in results if r is True)
|
||
|
|
print(f"✅ {successful_connections}/3 clients connected successfully")
|
||
|
|
|
||
|
|
# Test concurrent reconnections
|
||
|
|
print("\n🔄 Testing concurrent reconnections...")
|
||
|
|
reconnect_tasks = []
|
||
|
|
for client in clients:
|
||
|
|
if client.is_connected:
|
||
|
|
reconnect_tasks.append(client.reconnect())
|
||
|
|
|
||
|
|
if reconnect_tasks:
|
||
|
|
reconnect_results = await asyncio.gather(*reconnect_tasks, return_exceptions=True)
|
||
|
|
successful_reconnects = sum(1 for r in reconnect_results if r is True)
|
||
|
|
print(f"✅ {successful_reconnects}/{len(reconnect_tasks)} reconnections successful")
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Concurrent test failed: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
finally:
|
||
|
|
# Cleanup all clients
|
||
|
|
for client in clients:
|
||
|
|
try:
|
||
|
|
await client.disconnect()
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
async def main():
|
||
|
|
"""Run all WebSocket tests."""
|
||
|
|
print("🚀 WebSocket Race Condition Fix Test Suite")
|
||
|
|
print("=" * 60)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Test 1: Basic reconnection stability
|
||
|
|
test1_success = await test_websocket_reconnection_stability()
|
||
|
|
|
||
|
|
# Test 2: Concurrent operations
|
||
|
|
test2_success = await test_concurrent_operations()
|
||
|
|
|
||
|
|
# Summary
|
||
|
|
print("\n" + "=" * 60)
|
||
|
|
print("📋 Test Summary:")
|
||
|
|
print(f" Reconnection Stability: {'✅ PASS' if test1_success else '❌ FAIL'}")
|
||
|
|
print(f" Concurrent Operations: {'✅ PASS' if test2_success else '❌ FAIL'}")
|
||
|
|
|
||
|
|
if test1_success and test2_success:
|
||
|
|
print("\n🎉 All tests passed! WebSocket race condition fixes working correctly.")
|
||
|
|
return 0
|
||
|
|
else:
|
||
|
|
print("\n❌ Some tests failed. Check logs for details.")
|
||
|
|
return 1
|
||
|
|
|
||
|
|
except KeyboardInterrupt:
|
||
|
|
print("\n⏹️ Tests interrupted by user")
|
||
|
|
return 1
|
||
|
|
except Exception as e:
|
||
|
|
print(f"\n💥 Test suite failed with exception: {e}")
|
||
|
|
return 1
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
exit_code = asyncio.run(main())
|
||
|
|
sys.exit(exit_code)
|