Enhance OKX WebSocket client with improved task management and error handling

- Implemented enhanced task synchronization to prevent race conditions during WebSocket operations.
- Introduced reconnection locking to avoid concurrent reconnection attempts.
- Improved error handling in message processing and reconnection logic, ensuring graceful shutdown and task management.
- Added unit tests to verify the stability and reliability of the WebSocket client under concurrent operations.
This commit is contained in:
Vasily.onl 2025-06-02 23:14:04 +08:00
parent aaebd9a308
commit 01cea1d5e5
4 changed files with 414 additions and 163 deletions

View File

@ -23,6 +23,7 @@
"orderbook" "orderbook"
], ],
"timeframes": [ "timeframes": [
"5s",
"1m", "1m",
"5m", "5m",
"15m", "15m",
@ -42,6 +43,7 @@
"orderbook" "orderbook"
], ],
"timeframes": [ "timeframes": [
"5s",
"1m", "1m",
"5m", "5m",
"15m", "15m",

View File

@ -122,9 +122,11 @@ class OKXWebSocketClient:
self._message_callbacks: List[Callable[[Dict[str, Any]], None]] = [] self._message_callbacks: List[Callable[[Dict[str, Any]], None]] = []
self._subscriptions: Dict[str, OKXSubscription] = {} self._subscriptions: Dict[str, OKXSubscription] = {}
# Tasks # Enhanced task management
self._ping_task: Optional[asyncio.Task] = None self._ping_task: Optional[asyncio.Task] = None
self._message_handler_task: Optional[asyncio.Task] = None self._message_handler_task: Optional[asyncio.Task] = None
self._reconnection_lock = asyncio.Lock() # Prevent concurrent reconnections
self._tasks_stopping = False # Flag to prevent task overlap
# Statistics # Statistics
self._stats = { self._stats = {
@ -380,6 +382,15 @@ class OKXWebSocketClient:
async def _start_background_tasks(self) -> None: async def _start_background_tasks(self) -> None:
"""Start background tasks for ping and message handling.""" """Start background tasks for ping and message handling."""
# Ensure no tasks are currently stopping
if self._tasks_stopping:
if self.logger:
self.logger.warning(f"{self.component_name}: Cannot start tasks while stopping is in progress")
return
# Cancel any existing tasks first
await self._stop_background_tasks()
# Start ping task # Start ping task
self._ping_task = asyncio.create_task(self._ping_loop()) self._ping_task = asyncio.create_task(self._ping_loop())
@ -390,22 +401,53 @@ class OKXWebSocketClient:
self.logger.debug(f"{self.component_name}: Started background tasks") self.logger.debug(f"{self.component_name}: Started background tasks")
async def _stop_background_tasks(self) -> None: async def _stop_background_tasks(self) -> None:
"""Stop background tasks.""" """Stop background tasks with proper synchronization."""
tasks = [self._ping_task, self._message_handler_task] self._tasks_stopping = True
for task in tasks: try:
if task and not task.done(): tasks = []
# Collect tasks to cancel
if self._ping_task and not self._ping_task.done():
tasks.append(self._ping_task)
if self._message_handler_task and not self._message_handler_task.done():
tasks.append(self._message_handler_task)
if not tasks:
if self.logger:
self.logger.debug(f"{self.component_name}: No background tasks to stop")
return
if self.logger:
self.logger.debug(f"{self.component_name}: Stopping {len(tasks)} background tasks")
# Cancel all tasks
for task in tasks:
task.cancel() task.cancel()
# Wait for all tasks to complete with timeout
if tasks:
try: try:
await task await asyncio.wait_for(
except asyncio.CancelledError: asyncio.gather(*tasks, return_exceptions=True),
pass timeout=5.0
)
self._ping_task = None except asyncio.TimeoutError:
self._message_handler_task = None if self.logger:
self.logger.warning(f"{self.component_name}: Task shutdown timeout - some tasks may still be running")
if self.logger: except Exception as e:
self.logger.debug(f"{self.component_name}: Stopped background tasks") if self.logger:
self.logger.debug(f"{self.component_name}: Expected exception during task shutdown: {e}")
# Clear task references
self._ping_task = None
self._message_handler_task = None
if self.logger:
self.logger.debug(f"{self.component_name}: Background tasks stopped successfully")
finally:
self._tasks_stopping = False
async def _ping_loop(self) -> None: async def _ping_loop(self) -> None:
"""Background task for sending ping messages.""" """Background task for sending ping messages."""
@ -435,58 +477,91 @@ class OKXWebSocketClient:
await asyncio.sleep(5) await asyncio.sleep(5)
async def _message_handler(self) -> None: async def _message_handler(self) -> None:
"""Background task for handling incoming messages.""" """Background task for handling incoming messages with enhanced error handling."""
while self.is_connected: if self.logger:
try: self.logger.debug(f"{self.component_name}: Message handler started")
if not self._websocket:
break try:
while self.is_connected and not self._tasks_stopping:
# Receive message with timeout
try: try:
message = await asyncio.wait_for( if not self._websocket or self._tasks_stopping:
self._websocket.recv(),
timeout=1.0
)
except asyncio.TimeoutError:
continue # No message received, continue loop
# Process message
await self._process_message(message)
except ConnectionClosed as e:
if self.logger:
self.logger.warning(f"{self.component_name}: WebSocket connection closed: {e}")
self._connection_state = ConnectionState.DISCONNECTED
# Attempt automatic reconnection if enabled
if self._reconnect_attempts < self.max_reconnect_attempts:
self._reconnect_attempts += 1
if self.logger:
self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
# Stop current tasks
await self._stop_background_tasks()
# Attempt reconnection
if await self.reconnect():
if self.logger:
self.logger.info(f"{self.component_name}: Automatic reconnection successful")
continue
else:
if self.logger:
self.logger.error(f"{self.component_name}: Automatic reconnection failed")
break break
else:
if self.logger:
self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded")
break
except asyncio.CancelledError: # Receive message with timeout
break try:
except Exception as e: message = await asyncio.wait_for(
if self.logger: self._websocket.recv(),
self.logger.error(f"{self.component_name}: Error in message handler: {e}") timeout=1.0
await asyncio.sleep(1) )
except asyncio.TimeoutError:
continue # No message received, continue loop
# Check if we're still supposed to be running
if self._tasks_stopping:
break
# Process message
await self._process_message(message)
except ConnectionClosed as e:
if self._tasks_stopping:
break # Expected during shutdown
if self.logger:
self.logger.warning(f"{self.component_name}: WebSocket connection closed: {e}")
self._connection_state = ConnectionState.DISCONNECTED
# Use lock to prevent concurrent reconnection attempts
async with self._reconnection_lock:
# Double-check we still need to reconnect
if (self._connection_state == ConnectionState.DISCONNECTED and
self._reconnect_attempts < self.max_reconnect_attempts and
not self._tasks_stopping):
self._reconnect_attempts += 1
if self.logger:
self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
# Stop current tasks properly
await self._stop_background_tasks()
# Attempt reconnection with stored subscriptions
stored_subscriptions = list(self._subscriptions.values())
if await self.reconnect():
if self.logger:
self.logger.info(f"{self.component_name}: Automatic reconnection successful")
# The reconnect method will restart tasks, so we exit this handler
break
else:
if self.logger:
self.logger.error(f"{self.component_name}: Automatic reconnection failed")
break
else:
if self.logger:
self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress")
break
except asyncio.CancelledError:
if self.logger:
self.logger.debug(f"{self.component_name}: Message handler cancelled")
break
except Exception as e:
if self._tasks_stopping:
break
if self.logger:
self.logger.error(f"{self.component_name}: Error in message handler: {e}")
await asyncio.sleep(1)
except asyncio.CancelledError:
if self.logger:
self.logger.debug(f"{self.component_name}: Message handler task cancelled")
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Fatal error in message handler: {e}")
finally:
if self.logger:
self.logger.debug(f"{self.component_name}: Message handler exiting")
async def _send_message(self, message: Dict[str, Any]) -> None: async def _send_message(self, message: Dict[str, Any]) -> None:
""" """
@ -626,34 +701,40 @@ class OKXWebSocketClient:
async def reconnect(self) -> bool: async def reconnect(self) -> bool:
""" """
Reconnect to WebSocket with retry logic. Reconnect to WebSocket with enhanced synchronization.
Returns: Returns:
True if reconnection successful, False otherwise True if reconnection successful, False otherwise
""" """
if self.logger: async with self._reconnection_lock:
self.logger.info(f"{self.component_name}: Attempting to reconnect to OKX WebSocket") if self.logger:
self._connection_state = ConnectionState.RECONNECTING self.logger.info(f"{self.component_name}: Attempting to reconnect to OKX WebSocket")
self._stats['reconnections'] += 1 self._connection_state = ConnectionState.RECONNECTING
self._stats['reconnections'] += 1
# Disconnect first
await self.disconnect() # Store current subscriptions before disconnect
stored_subscriptions = list(self._subscriptions.values())
# Wait a moment before reconnecting
await asyncio.sleep(1) # Disconnect first with proper cleanup
await self.disconnect()
# Attempt to reconnect
success = await self.connect() # Wait a moment before reconnecting
await asyncio.sleep(1)
if success:
# Re-subscribe to previous subscriptions # Attempt to reconnect
if self._subscriptions: success = await self.connect()
subscriptions = list(self._subscriptions.values())
if self.logger: if success:
self.logger.info(f"{self.component_name}: Re-subscribing to {len(subscriptions)} channels") # Re-subscribe to previous subscriptions
await self.subscribe(subscriptions) if stored_subscriptions:
if self.logger:
return success self.logger.info(f"{self.component_name}: Re-subscribing to {len(stored_subscriptions)} channels")
await self.subscribe(stored_subscriptions)
# Reset reconnect attempts on successful reconnection
self._reconnect_attempts = 0
return success
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<OKXWebSocketClient(state={self._connection_state.value}, subscriptions={len(self._subscriptions)})>" return f"<OKXWebSocketClient(state={self._connection_state.value}, subscriptions={len(self._subscriptions)})>"

View File

@ -634,93 +634,56 @@ OKX requires specific ping/pong format:
# Ping interval must be < 30 seconds to avoid disconnection # Ping interval must be < 30 seconds to avoid disconnection
``` ```
## Error Handling and Troubleshooting ## Error Handling & Resilience
### Common Issues and Solutions The OKX collector includes comprehensive error handling and automatic recovery mechanisms:
#### 1. Connection Failures ### Connection Management
- **Automatic Reconnection**: Handles network disconnections with exponential backoff
- **Task Synchronization**: Prevents race conditions during reconnection using asyncio locks
- **Graceful Shutdown**: Properly cancels background tasks and closes connections
- **Connection State Tracking**: Monitors connection health and validity
### Enhanced WebSocket Handling (v2.1+)
- **Race Condition Prevention**: Uses synchronization locks to prevent multiple recv() calls
- **Task Lifecycle Management**: Properly manages background task startup and shutdown
- **Reconnection Locking**: Prevents concurrent reconnection attempts
- **Subscription Persistence**: Automatically re-subscribes to channels after reconnection
```python ```python
# Check connection status # The collector handles these scenarios automatically:
status = collector.get_status() # - Network interruptions
if not status['websocket_connected']: # - WebSocket connection drops
print("WebSocket not connected") # - OKX server maintenance
# - Rate limiting responses
# Check WebSocket state # - Malformed data packets
ws_state = status.get('websocket_state', 'unknown')
# Enhanced error logging for diagnostics
if ws_state == 'error': collector = OKXCollector('BTC-USDT', [DataType.TRADE])
print("WebSocket in error state - will auto-restart") stats = collector.get_status()
elif ws_state == 'reconnecting': print(f"Connection state: {stats['connection_state']}")
print("WebSocket is reconnecting...") print(f"Reconnection attempts: {stats['reconnect_attempts']}")
print(f"Error count: {stats['error_count']}")
# Manual restart if needed
await collector.restart()
``` ```
#### 2. Ping/Pong Issues ### Common Error Patterns
```python #### WebSocket Concurrency Errors (Fixed in v2.1)
# Monitor ping/pong status
if 'websocket_stats' in status:
ws_stats = status['websocket_stats']
pings_sent = ws_stats.get('pings_sent', 0)
pongs_received = ws_stats.get('pongs_received', 0)
if pings_sent > pongs_received + 3: # Allow some tolerance
print("Ping/pong issue detected - connection may be stale")
# Auto-restart will handle this
``` ```
ERROR: cannot call recv while another coroutine is already running recv or recv_streaming
#### 3. Data Validation Errors
```python
# Monitor for validation errors
errors = status.get('errors', 0)
if errors > 0:
print(f"Data validation errors detected: {errors}")
# Check logs for details:
# - Malformed messages
# - Missing required fields
# - Invalid data types
``` ```
**Solution**: Updated WebSocket client with proper task synchronization and reconnection locking.
#### 4. Performance Issues #### Connection Recovery
```python ```python
# Monitor message processing rate # Monitor connection health
messages = status.get('messages_processed', 0) async def monitor_connection():
uptime = status.get('uptime_seconds', 1) while True:
rate = messages / uptime if collector.is_connected():
print("✅ Connected and receiving data")
if rate < 1.0: # Less than 1 message per second else:
print("Low message rate - check:") print("❌ Connection issue - auto-recovery in progress")
print("- Network connectivity") await asyncio.sleep(30)
print("- OKX API status")
print("- Symbol activity")
```
### Debug Mode
Enable debug logging for detailed information:
```python
import os
os.environ['LOG_LEVEL'] = 'DEBUG'
# Create collector with verbose logging
collector = create_okx_collector(
symbol='BTC-USDT',
data_types=[DataType.TRADE, DataType.ORDERBOOK]
)
await collector.start()
# Check logs in ./logs/ directory:
# - okx_collector_btc_usdt_debug.log
# - okx_collector_btc_usdt_info.log
# - okx_collector_btc_usdt_error.log
``` ```
## Testing ## Testing

View File

@ -0,0 +1,205 @@
#!/usr/bin/env python3
"""
Test script to verify WebSocket race condition fixes.
This script tests the enhanced task management and synchronization
in the OKX WebSocket client to ensure no more recv() concurrency errors.
"""
import asyncio
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType
from utils.logger import get_logger
async def test_websocket_reconnection_stability():
"""Test WebSocket reconnection without race conditions."""
logger = get_logger("websocket_test", verbose=True)
print("🧪 Testing WebSocket Race Condition Fixes")
print("=" * 50)
# Create WebSocket client
ws_client = OKXWebSocketClient(
component_name="test_ws_client",
ping_interval=25.0,
max_reconnect_attempts=3,
logger=logger
)
try:
# Test 1: Basic connection
print("\n📡 Test 1: Basic Connection")
success = await ws_client.connect()
if success:
print("✅ Initial connection successful")
else:
print("❌ Initial connection failed")
return False
# Test 2: Subscribe to channels
print("\n📡 Test 2: Channel Subscription")
subscriptions = [
OKXSubscription(OKXChannelType.TRADES.value, "BTC-USDT"),
OKXSubscription(OKXChannelType.BOOKS5.value, "BTC-USDT")
]
success = await ws_client.subscribe(subscriptions)
if success:
print("✅ Subscription successful")
else:
print("❌ Subscription failed")
return False
# Test 3: Force reconnection to test race condition fixes
print("\n📡 Test 3: Force Reconnection (Race Condition Test)")
for i in range(3):
print(f" Reconnection attempt {i+1}/3...")
success = await ws_client.reconnect()
if success:
print(f" ✅ Reconnection {i+1} successful")
await asyncio.sleep(2) # Wait between reconnections
else:
print(f" ❌ Reconnection {i+1} failed")
return False
# Test 4: Verify subscriptions are maintained
print("\n📡 Test 4: Subscription Persistence")
current_subs = ws_client.get_subscriptions()
if len(current_subs) == 2:
print("✅ Subscriptions persisted after reconnections")
else:
print(f"❌ Subscription count mismatch: expected 2, got {len(current_subs)}")
# Test 5: Monitor for a few seconds to catch any errors
print("\n📡 Test 5: Stability Monitor (10 seconds)")
message_count = 0
def message_callback(message):
nonlocal message_count
message_count += 1
if message_count % 10 == 0:
print(f" 📊 Processed {message_count} messages")
ws_client.add_message_callback(message_callback)
await asyncio.sleep(10)
stats = ws_client.get_stats()
print(f"\n📊 Final Statistics:")
print(f" Messages received: {stats['messages_received']}")
print(f" Reconnections: {stats['reconnections']}")
print(f" Connection state: {stats['connection_state']}")
if stats['messages_received'] > 0:
print("✅ Receiving data successfully")
else:
print("⚠️ No messages received (may be normal for low-activity symbols)")
return True
except Exception as e:
print(f"❌ Test failed with exception: {e}")
logger.error(f"Test exception: {e}")
return False
finally:
# Cleanup
await ws_client.disconnect()
print("\n🧹 Cleanup completed")
async def test_concurrent_operations():
"""Test concurrent WebSocket operations to ensure no race conditions."""
print("\n🔄 Testing Concurrent Operations")
print("=" * 50)
logger = get_logger("concurrent_test", verbose=False)
# Create multiple clients
clients = []
for i in range(3):
client = OKXWebSocketClient(
component_name=f"test_client_{i}",
logger=logger
)
clients.append(client)
try:
# Connect all clients concurrently
print("📡 Connecting 3 clients concurrently...")
tasks = [client.connect() for client in clients]
results = await asyncio.gather(*tasks, return_exceptions=True)
successful_connections = sum(1 for r in results if r is True)
print(f"{successful_connections}/3 clients connected successfully")
# Test concurrent reconnections
print("\n🔄 Testing concurrent reconnections...")
reconnect_tasks = []
for client in clients:
if client.is_connected:
reconnect_tasks.append(client.reconnect())
if reconnect_tasks:
reconnect_results = await asyncio.gather(*reconnect_tasks, return_exceptions=True)
successful_reconnects = sum(1 for r in reconnect_results if r is True)
print(f"{successful_reconnects}/{len(reconnect_tasks)} reconnections successful")
return True
except Exception as e:
print(f"❌ Concurrent test failed: {e}")
return False
finally:
# Cleanup all clients
for client in clients:
try:
await client.disconnect()
except:
pass
async def main():
"""Run all WebSocket tests."""
print("🚀 WebSocket Race Condition Fix Test Suite")
print("=" * 60)
try:
# Test 1: Basic reconnection stability
test1_success = await test_websocket_reconnection_stability()
# Test 2: Concurrent operations
test2_success = await test_concurrent_operations()
# Summary
print("\n" + "=" * 60)
print("📋 Test Summary:")
print(f" Reconnection Stability: {'✅ PASS' if test1_success else '❌ FAIL'}")
print(f" Concurrent Operations: {'✅ PASS' if test2_success else '❌ FAIL'}")
if test1_success and test2_success:
print("\n🎉 All tests passed! WebSocket race condition fixes working correctly.")
return 0
else:
print("\n❌ Some tests failed. Check logs for details.")
return 1
except KeyboardInterrupt:
print("\n⏹️ Tests interrupted by user")
return 1
except Exception as e:
print(f"\n💥 Test suite failed with exception: {e}")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)