Enhance OKX WebSocket client with improved task management and error handling
- Implemented enhanced task synchronization to prevent race conditions during WebSocket operations. - Introduced reconnection locking to avoid concurrent reconnection attempts. - Improved error handling in message processing and reconnection logic, ensuring graceful shutdown and task management. - Added unit tests to verify the stability and reliability of the WebSocket client under concurrent operations.
This commit is contained in:
parent
aaebd9a308
commit
01cea1d5e5
@ -23,6 +23,7 @@
|
|||||||
"orderbook"
|
"orderbook"
|
||||||
],
|
],
|
||||||
"timeframes": [
|
"timeframes": [
|
||||||
|
"5s",
|
||||||
"1m",
|
"1m",
|
||||||
"5m",
|
"5m",
|
||||||
"15m",
|
"15m",
|
||||||
@ -42,6 +43,7 @@
|
|||||||
"orderbook"
|
"orderbook"
|
||||||
],
|
],
|
||||||
"timeframes": [
|
"timeframes": [
|
||||||
|
"5s",
|
||||||
"1m",
|
"1m",
|
||||||
"5m",
|
"5m",
|
||||||
"15m",
|
"15m",
|
||||||
|
|||||||
@ -122,9 +122,11 @@ class OKXWebSocketClient:
|
|||||||
self._message_callbacks: List[Callable[[Dict[str, Any]], None]] = []
|
self._message_callbacks: List[Callable[[Dict[str, Any]], None]] = []
|
||||||
self._subscriptions: Dict[str, OKXSubscription] = {}
|
self._subscriptions: Dict[str, OKXSubscription] = {}
|
||||||
|
|
||||||
# Tasks
|
# Enhanced task management
|
||||||
self._ping_task: Optional[asyncio.Task] = None
|
self._ping_task: Optional[asyncio.Task] = None
|
||||||
self._message_handler_task: Optional[asyncio.Task] = None
|
self._message_handler_task: Optional[asyncio.Task] = None
|
||||||
|
self._reconnection_lock = asyncio.Lock() # Prevent concurrent reconnections
|
||||||
|
self._tasks_stopping = False # Flag to prevent task overlap
|
||||||
|
|
||||||
# Statistics
|
# Statistics
|
||||||
self._stats = {
|
self._stats = {
|
||||||
@ -380,6 +382,15 @@ class OKXWebSocketClient:
|
|||||||
|
|
||||||
async def _start_background_tasks(self) -> None:
|
async def _start_background_tasks(self) -> None:
|
||||||
"""Start background tasks for ping and message handling."""
|
"""Start background tasks for ping and message handling."""
|
||||||
|
# Ensure no tasks are currently stopping
|
||||||
|
if self._tasks_stopping:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.component_name}: Cannot start tasks while stopping is in progress")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cancel any existing tasks first
|
||||||
|
await self._stop_background_tasks()
|
||||||
|
|
||||||
# Start ping task
|
# Start ping task
|
||||||
self._ping_task = asyncio.create_task(self._ping_loop())
|
self._ping_task = asyncio.create_task(self._ping_loop())
|
||||||
|
|
||||||
@ -390,22 +401,53 @@ class OKXWebSocketClient:
|
|||||||
self.logger.debug(f"{self.component_name}: Started background tasks")
|
self.logger.debug(f"{self.component_name}: Started background tasks")
|
||||||
|
|
||||||
async def _stop_background_tasks(self) -> None:
|
async def _stop_background_tasks(self) -> None:
|
||||||
"""Stop background tasks."""
|
"""Stop background tasks with proper synchronization."""
|
||||||
tasks = [self._ping_task, self._message_handler_task]
|
self._tasks_stopping = True
|
||||||
|
|
||||||
for task in tasks:
|
try:
|
||||||
if task and not task.done():
|
tasks = []
|
||||||
|
|
||||||
|
# Collect tasks to cancel
|
||||||
|
if self._ping_task and not self._ping_task.done():
|
||||||
|
tasks.append(self._ping_task)
|
||||||
|
if self._message_handler_task and not self._message_handler_task.done():
|
||||||
|
tasks.append(self._message_handler_task)
|
||||||
|
|
||||||
|
if not tasks:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: No background tasks to stop")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: Stopping {len(tasks)} background tasks")
|
||||||
|
|
||||||
|
# Cancel all tasks
|
||||||
|
for task in tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
# Wait for all tasks to complete with timeout
|
||||||
|
if tasks:
|
||||||
try:
|
try:
|
||||||
await task
|
await asyncio.wait_for(
|
||||||
except asyncio.CancelledError:
|
asyncio.gather(*tasks, return_exceptions=True),
|
||||||
pass
|
timeout=5.0
|
||||||
|
)
|
||||||
self._ping_task = None
|
except asyncio.TimeoutError:
|
||||||
self._message_handler_task = None
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.component_name}: Task shutdown timeout - some tasks may still be running")
|
||||||
if self.logger:
|
except Exception as e:
|
||||||
self.logger.debug(f"{self.component_name}: Stopped background tasks")
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: Expected exception during task shutdown: {e}")
|
||||||
|
|
||||||
|
# Clear task references
|
||||||
|
self._ping_task = None
|
||||||
|
self._message_handler_task = None
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: Background tasks stopped successfully")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._tasks_stopping = False
|
||||||
|
|
||||||
async def _ping_loop(self) -> None:
|
async def _ping_loop(self) -> None:
|
||||||
"""Background task for sending ping messages."""
|
"""Background task for sending ping messages."""
|
||||||
@ -435,58 +477,91 @@ class OKXWebSocketClient:
|
|||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
async def _message_handler(self) -> None:
|
async def _message_handler(self) -> None:
|
||||||
"""Background task for handling incoming messages."""
|
"""Background task for handling incoming messages with enhanced error handling."""
|
||||||
while self.is_connected:
|
if self.logger:
|
||||||
try:
|
self.logger.debug(f"{self.component_name}: Message handler started")
|
||||||
if not self._websocket:
|
|
||||||
break
|
try:
|
||||||
|
while self.is_connected and not self._tasks_stopping:
|
||||||
# Receive message with timeout
|
|
||||||
try:
|
try:
|
||||||
message = await asyncio.wait_for(
|
if not self._websocket or self._tasks_stopping:
|
||||||
self._websocket.recv(),
|
|
||||||
timeout=1.0
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue # No message received, continue loop
|
|
||||||
|
|
||||||
# Process message
|
|
||||||
await self._process_message(message)
|
|
||||||
|
|
||||||
except ConnectionClosed as e:
|
|
||||||
if self.logger:
|
|
||||||
self.logger.warning(f"{self.component_name}: WebSocket connection closed: {e}")
|
|
||||||
self._connection_state = ConnectionState.DISCONNECTED
|
|
||||||
|
|
||||||
# Attempt automatic reconnection if enabled
|
|
||||||
if self._reconnect_attempts < self.max_reconnect_attempts:
|
|
||||||
self._reconnect_attempts += 1
|
|
||||||
if self.logger:
|
|
||||||
self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
|
|
||||||
|
|
||||||
# Stop current tasks
|
|
||||||
await self._stop_background_tasks()
|
|
||||||
|
|
||||||
# Attempt reconnection
|
|
||||||
if await self.reconnect():
|
|
||||||
if self.logger:
|
|
||||||
self.logger.info(f"{self.component_name}: Automatic reconnection successful")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
if self.logger:
|
|
||||||
self.logger.error(f"{self.component_name}: Automatic reconnection failed")
|
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
if self.logger:
|
|
||||||
self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded")
|
|
||||||
break
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
# Receive message with timeout
|
||||||
break
|
try:
|
||||||
except Exception as e:
|
message = await asyncio.wait_for(
|
||||||
if self.logger:
|
self._websocket.recv(),
|
||||||
self.logger.error(f"{self.component_name}: Error in message handler: {e}")
|
timeout=1.0
|
||||||
await asyncio.sleep(1)
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue # No message received, continue loop
|
||||||
|
|
||||||
|
# Check if we're still supposed to be running
|
||||||
|
if self._tasks_stopping:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Process message
|
||||||
|
await self._process_message(message)
|
||||||
|
|
||||||
|
except ConnectionClosed as e:
|
||||||
|
if self._tasks_stopping:
|
||||||
|
break # Expected during shutdown
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.component_name}: WebSocket connection closed: {e}")
|
||||||
|
self._connection_state = ConnectionState.DISCONNECTED
|
||||||
|
|
||||||
|
# Use lock to prevent concurrent reconnection attempts
|
||||||
|
async with self._reconnection_lock:
|
||||||
|
# Double-check we still need to reconnect
|
||||||
|
if (self._connection_state == ConnectionState.DISCONNECTED and
|
||||||
|
self._reconnect_attempts < self.max_reconnect_attempts and
|
||||||
|
not self._tasks_stopping):
|
||||||
|
|
||||||
|
self._reconnect_attempts += 1
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
|
||||||
|
|
||||||
|
# Stop current tasks properly
|
||||||
|
await self._stop_background_tasks()
|
||||||
|
|
||||||
|
# Attempt reconnection with stored subscriptions
|
||||||
|
stored_subscriptions = list(self._subscriptions.values())
|
||||||
|
|
||||||
|
if await self.reconnect():
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"{self.component_name}: Automatic reconnection successful")
|
||||||
|
# The reconnect method will restart tasks, so we exit this handler
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"{self.component_name}: Automatic reconnection failed")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress")
|
||||||
|
break
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: Message handler cancelled")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
if self._tasks_stopping:
|
||||||
|
break
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"{self.component_name}: Error in message handler: {e}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: Message handler task cancelled")
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"{self.component_name}: Fatal error in message handler: {e}")
|
||||||
|
finally:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: Message handler exiting")
|
||||||
|
|
||||||
async def _send_message(self, message: Dict[str, Any]) -> None:
|
async def _send_message(self, message: Dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
@ -626,34 +701,40 @@ class OKXWebSocketClient:
|
|||||||
|
|
||||||
async def reconnect(self) -> bool:
|
async def reconnect(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Reconnect to WebSocket with retry logic.
|
Reconnect to WebSocket with enhanced synchronization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if reconnection successful, False otherwise
|
True if reconnection successful, False otherwise
|
||||||
"""
|
"""
|
||||||
if self.logger:
|
async with self._reconnection_lock:
|
||||||
self.logger.info(f"{self.component_name}: Attempting to reconnect to OKX WebSocket")
|
if self.logger:
|
||||||
self._connection_state = ConnectionState.RECONNECTING
|
self.logger.info(f"{self.component_name}: Attempting to reconnect to OKX WebSocket")
|
||||||
self._stats['reconnections'] += 1
|
self._connection_state = ConnectionState.RECONNECTING
|
||||||
|
self._stats['reconnections'] += 1
|
||||||
# Disconnect first
|
|
||||||
await self.disconnect()
|
# Store current subscriptions before disconnect
|
||||||
|
stored_subscriptions = list(self._subscriptions.values())
|
||||||
# Wait a moment before reconnecting
|
|
||||||
await asyncio.sleep(1)
|
# Disconnect first with proper cleanup
|
||||||
|
await self.disconnect()
|
||||||
# Attempt to reconnect
|
|
||||||
success = await self.connect()
|
# Wait a moment before reconnecting
|
||||||
|
await asyncio.sleep(1)
|
||||||
if success:
|
|
||||||
# Re-subscribe to previous subscriptions
|
# Attempt to reconnect
|
||||||
if self._subscriptions:
|
success = await self.connect()
|
||||||
subscriptions = list(self._subscriptions.values())
|
|
||||||
if self.logger:
|
if success:
|
||||||
self.logger.info(f"{self.component_name}: Re-subscribing to {len(subscriptions)} channels")
|
# Re-subscribe to previous subscriptions
|
||||||
await self.subscribe(subscriptions)
|
if stored_subscriptions:
|
||||||
|
if self.logger:
|
||||||
return success
|
self.logger.info(f"{self.component_name}: Re-subscribing to {len(stored_subscriptions)} channels")
|
||||||
|
await self.subscribe(stored_subscriptions)
|
||||||
|
|
||||||
|
# Reset reconnect attempts on successful reconnection
|
||||||
|
self._reconnect_attempts = 0
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<OKXWebSocketClient(state={self._connection_state.value}, subscriptions={len(self._subscriptions)})>"
|
return f"<OKXWebSocketClient(state={self._connection_state.value}, subscriptions={len(self._subscriptions)})>"
|
||||||
@ -634,93 +634,56 @@ OKX requires specific ping/pong format:
|
|||||||
# Ping interval must be < 30 seconds to avoid disconnection
|
# Ping interval must be < 30 seconds to avoid disconnection
|
||||||
```
|
```
|
||||||
|
|
||||||
## Error Handling and Troubleshooting
|
## Error Handling & Resilience
|
||||||
|
|
||||||
### Common Issues and Solutions
|
The OKX collector includes comprehensive error handling and automatic recovery mechanisms:
|
||||||
|
|
||||||
#### 1. Connection Failures
|
### Connection Management
|
||||||
|
- **Automatic Reconnection**: Handles network disconnections with exponential backoff
|
||||||
|
- **Task Synchronization**: Prevents race conditions during reconnection using asyncio locks
|
||||||
|
- **Graceful Shutdown**: Properly cancels background tasks and closes connections
|
||||||
|
- **Connection State Tracking**: Monitors connection health and validity
|
||||||
|
|
||||||
|
### Enhanced WebSocket Handling (v2.1+)
|
||||||
|
- **Race Condition Prevention**: Uses synchronization locks to prevent multiple recv() calls
|
||||||
|
- **Task Lifecycle Management**: Properly manages background task startup and shutdown
|
||||||
|
- **Reconnection Locking**: Prevents concurrent reconnection attempts
|
||||||
|
- **Subscription Persistence**: Automatically re-subscribes to channels after reconnection
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Check connection status
|
# The collector handles these scenarios automatically:
|
||||||
status = collector.get_status()
|
# - Network interruptions
|
||||||
if not status['websocket_connected']:
|
# - WebSocket connection drops
|
||||||
print("WebSocket not connected")
|
# - OKX server maintenance
|
||||||
|
# - Rate limiting responses
|
||||||
# Check WebSocket state
|
# - Malformed data packets
|
||||||
ws_state = status.get('websocket_state', 'unknown')
|
|
||||||
|
# Enhanced error logging for diagnostics
|
||||||
if ws_state == 'error':
|
collector = OKXCollector('BTC-USDT', [DataType.TRADE])
|
||||||
print("WebSocket in error state - will auto-restart")
|
stats = collector.get_status()
|
||||||
elif ws_state == 'reconnecting':
|
print(f"Connection state: {stats['connection_state']}")
|
||||||
print("WebSocket is reconnecting...")
|
print(f"Reconnection attempts: {stats['reconnect_attempts']}")
|
||||||
|
print(f"Error count: {stats['error_count']}")
|
||||||
# Manual restart if needed
|
|
||||||
await collector.restart()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 2. Ping/Pong Issues
|
### Common Error Patterns
|
||||||
|
|
||||||
```python
|
#### WebSocket Concurrency Errors (Fixed in v2.1)
|
||||||
# Monitor ping/pong status
|
|
||||||
if 'websocket_stats' in status:
|
|
||||||
ws_stats = status['websocket_stats']
|
|
||||||
pings_sent = ws_stats.get('pings_sent', 0)
|
|
||||||
pongs_received = ws_stats.get('pongs_received', 0)
|
|
||||||
|
|
||||||
if pings_sent > pongs_received + 3: # Allow some tolerance
|
|
||||||
print("Ping/pong issue detected - connection may be stale")
|
|
||||||
# Auto-restart will handle this
|
|
||||||
```
|
```
|
||||||
|
ERROR: cannot call recv while another coroutine is already running recv or recv_streaming
|
||||||
#### 3. Data Validation Errors
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Monitor for validation errors
|
|
||||||
errors = status.get('errors', 0)
|
|
||||||
if errors > 0:
|
|
||||||
print(f"Data validation errors detected: {errors}")
|
|
||||||
|
|
||||||
# Check logs for details:
|
|
||||||
# - Malformed messages
|
|
||||||
# - Missing required fields
|
|
||||||
# - Invalid data types
|
|
||||||
```
|
```
|
||||||
|
**Solution**: Updated WebSocket client with proper task synchronization and reconnection locking.
|
||||||
|
|
||||||
#### 4. Performance Issues
|
#### Connection Recovery
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Monitor message processing rate
|
# Monitor connection health
|
||||||
messages = status.get('messages_processed', 0)
|
async def monitor_connection():
|
||||||
uptime = status.get('uptime_seconds', 1)
|
while True:
|
||||||
rate = messages / uptime
|
if collector.is_connected():
|
||||||
|
print("✅ Connected and receiving data")
|
||||||
if rate < 1.0: # Less than 1 message per second
|
else:
|
||||||
print("Low message rate - check:")
|
print("❌ Connection issue - auto-recovery in progress")
|
||||||
print("- Network connectivity")
|
await asyncio.sleep(30)
|
||||||
print("- OKX API status")
|
|
||||||
print("- Symbol activity")
|
|
||||||
```
|
|
||||||
|
|
||||||
### Debug Mode
|
|
||||||
|
|
||||||
Enable debug logging for detailed information:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
os.environ['LOG_LEVEL'] = 'DEBUG'
|
|
||||||
|
|
||||||
# Create collector with verbose logging
|
|
||||||
collector = create_okx_collector(
|
|
||||||
symbol='BTC-USDT',
|
|
||||||
data_types=[DataType.TRADE, DataType.ORDERBOOK]
|
|
||||||
)
|
|
||||||
|
|
||||||
await collector.start()
|
|
||||||
|
|
||||||
# Check logs in ./logs/ directory:
|
|
||||||
# - okx_collector_btc_usdt_debug.log
|
|
||||||
# - okx_collector_btc_usdt_info.log
|
|
||||||
# - okx_collector_btc_usdt_error.log
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|||||||
205
tests/test_websocket_race_condition_fix.py
Normal file
205
tests/test_websocket_race_condition_fix.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify WebSocket race condition fixes.
|
||||||
|
|
||||||
|
This script tests the enhanced task management and synchronization
|
||||||
|
in the OKX WebSocket client to ensure no more recv() concurrency errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType
|
||||||
|
from utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
async def test_websocket_reconnection_stability():
|
||||||
|
"""Test WebSocket reconnection without race conditions."""
|
||||||
|
logger = get_logger("websocket_test", verbose=True)
|
||||||
|
|
||||||
|
print("🧪 Testing WebSocket Race Condition Fixes")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Create WebSocket client
|
||||||
|
ws_client = OKXWebSocketClient(
|
||||||
|
component_name="test_ws_client",
|
||||||
|
ping_interval=25.0,
|
||||||
|
max_reconnect_attempts=3,
|
||||||
|
logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test 1: Basic connection
|
||||||
|
print("\n📡 Test 1: Basic Connection")
|
||||||
|
success = await ws_client.connect()
|
||||||
|
if success:
|
||||||
|
print("✅ Initial connection successful")
|
||||||
|
else:
|
||||||
|
print("❌ Initial connection failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test 2: Subscribe to channels
|
||||||
|
print("\n📡 Test 2: Channel Subscription")
|
||||||
|
subscriptions = [
|
||||||
|
OKXSubscription(OKXChannelType.TRADES.value, "BTC-USDT"),
|
||||||
|
OKXSubscription(OKXChannelType.BOOKS5.value, "BTC-USDT")
|
||||||
|
]
|
||||||
|
|
||||||
|
success = await ws_client.subscribe(subscriptions)
|
||||||
|
if success:
|
||||||
|
print("✅ Subscription successful")
|
||||||
|
else:
|
||||||
|
print("❌ Subscription failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test 3: Force reconnection to test race condition fixes
|
||||||
|
print("\n📡 Test 3: Force Reconnection (Race Condition Test)")
|
||||||
|
for i in range(3):
|
||||||
|
print(f" Reconnection attempt {i+1}/3...")
|
||||||
|
success = await ws_client.reconnect()
|
||||||
|
if success:
|
||||||
|
print(f" ✅ Reconnection {i+1} successful")
|
||||||
|
await asyncio.sleep(2) # Wait between reconnections
|
||||||
|
else:
|
||||||
|
print(f" ❌ Reconnection {i+1} failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test 4: Verify subscriptions are maintained
|
||||||
|
print("\n📡 Test 4: Subscription Persistence")
|
||||||
|
current_subs = ws_client.get_subscriptions()
|
||||||
|
if len(current_subs) == 2:
|
||||||
|
print("✅ Subscriptions persisted after reconnections")
|
||||||
|
else:
|
||||||
|
print(f"❌ Subscription count mismatch: expected 2, got {len(current_subs)}")
|
||||||
|
|
||||||
|
# Test 5: Monitor for a few seconds to catch any errors
|
||||||
|
print("\n📡 Test 5: Stability Monitor (10 seconds)")
|
||||||
|
message_count = 0
|
||||||
|
|
||||||
|
def message_callback(message):
|
||||||
|
nonlocal message_count
|
||||||
|
message_count += 1
|
||||||
|
if message_count % 10 == 0:
|
||||||
|
print(f" 📊 Processed {message_count} messages")
|
||||||
|
|
||||||
|
ws_client.add_message_callback(message_callback)
|
||||||
|
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
stats = ws_client.get_stats()
|
||||||
|
print(f"\n📊 Final Statistics:")
|
||||||
|
print(f" Messages received: {stats['messages_received']}")
|
||||||
|
print(f" Reconnections: {stats['reconnections']}")
|
||||||
|
print(f" Connection state: {stats['connection_state']}")
|
||||||
|
|
||||||
|
if stats['messages_received'] > 0:
|
||||||
|
print("✅ Receiving data successfully")
|
||||||
|
else:
|
||||||
|
print("⚠️ No messages received (may be normal for low-activity symbols)")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Test failed with exception: {e}")
|
||||||
|
logger.error(f"Test exception: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
await ws_client.disconnect()
|
||||||
|
print("\n🧹 Cleanup completed")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_concurrent_operations():
|
||||||
|
"""Test concurrent WebSocket operations to ensure no race conditions."""
|
||||||
|
print("\n🔄 Testing Concurrent Operations")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
logger = get_logger("concurrent_test", verbose=False)
|
||||||
|
|
||||||
|
# Create multiple clients
|
||||||
|
clients = []
|
||||||
|
for i in range(3):
|
||||||
|
client = OKXWebSocketClient(
|
||||||
|
component_name=f"test_client_{i}",
|
||||||
|
logger=logger
|
||||||
|
)
|
||||||
|
clients.append(client)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect all clients concurrently
|
||||||
|
print("📡 Connecting 3 clients concurrently...")
|
||||||
|
tasks = [client.connect() for client in clients]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
successful_connections = sum(1 for r in results if r is True)
|
||||||
|
print(f"✅ {successful_connections}/3 clients connected successfully")
|
||||||
|
|
||||||
|
# Test concurrent reconnections
|
||||||
|
print("\n🔄 Testing concurrent reconnections...")
|
||||||
|
reconnect_tasks = []
|
||||||
|
for client in clients:
|
||||||
|
if client.is_connected:
|
||||||
|
reconnect_tasks.append(client.reconnect())
|
||||||
|
|
||||||
|
if reconnect_tasks:
|
||||||
|
reconnect_results = await asyncio.gather(*reconnect_tasks, return_exceptions=True)
|
||||||
|
successful_reconnects = sum(1 for r in reconnect_results if r is True)
|
||||||
|
print(f"✅ {successful_reconnects}/{len(reconnect_tasks)} reconnections successful")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Concurrent test failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup all clients
|
||||||
|
for client in clients:
|
||||||
|
try:
|
||||||
|
await client.disconnect()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run all WebSocket tests."""
|
||||||
|
print("🚀 WebSocket Race Condition Fix Test Suite")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test 1: Basic reconnection stability
|
||||||
|
test1_success = await test_websocket_reconnection_stability()
|
||||||
|
|
||||||
|
# Test 2: Concurrent operations
|
||||||
|
test2_success = await test_concurrent_operations()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("📋 Test Summary:")
|
||||||
|
print(f" Reconnection Stability: {'✅ PASS' if test1_success else '❌ FAIL'}")
|
||||||
|
print(f" Concurrent Operations: {'✅ PASS' if test2_success else '❌ FAIL'}")
|
||||||
|
|
||||||
|
if test1_success and test2_success:
|
||||||
|
print("\n🎉 All tests passed! WebSocket race condition fixes working correctly.")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
print("\n❌ Some tests failed. Check logs for details.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⏹️ Tests interrupted by user")
|
||||||
|
return 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n💥 Test suite failed with exception: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = asyncio.run(main())
|
||||||
|
sys.exit(exit_code)
|
||||||
Loading…
x
Reference in New Issue
Block a user