fix recursion error on reconnection

This commit is contained in:
Vasily.onl 2025-06-03 11:42:10 +08:00
parent 01cea1d5e5
commit d508616677
2 changed files with 255 additions and 60 deletions

View File

@ -388,56 +388,83 @@ class OKXWebSocketClient:
self.logger.warning(f"{self.component_name}: Cannot start tasks while stopping is in progress") self.logger.warning(f"{self.component_name}: Cannot start tasks while stopping is in progress")
return return
# Cancel any existing tasks first # Check if tasks are already running
if (self._ping_task and not self._ping_task.done() and
self._message_handler_task and not self._message_handler_task.done()):
if self.logger:
self.logger.debug(f"{self.component_name}: Background tasks already running")
return
# Cancel any existing tasks first (safety measure)
await self._stop_background_tasks() await self._stop_background_tasks()
# Start ping task # Ensure we're still supposed to start tasks after stopping
self._ping_task = asyncio.create_task(self._ping_loop()) if self._tasks_stopping or not self.is_connected:
if self.logger:
self.logger.debug(f"{self.component_name}: Aborting task start - stopping or disconnected")
return
# Start message handler task try:
self._message_handler_task = asyncio.create_task(self._message_handler()) # Start ping task
self._ping_task = asyncio.create_task(self._ping_loop())
if self.logger:
self.logger.debug(f"{self.component_name}: Started background tasks") # Start message handler task
self._message_handler_task = asyncio.create_task(self._message_handler())
if self.logger:
self.logger.debug(f"{self.component_name}: Started background tasks")
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Error starting background tasks: {e}")
# Clean up on failure
await self._stop_background_tasks()
async def _stop_background_tasks(self) -> None: async def _stop_background_tasks(self) -> None:
"""Stop background tasks with proper synchronization.""" """Stop background tasks with proper synchronization - simplified approach."""
self._tasks_stopping = True self._tasks_stopping = True
try: try:
tasks = []
# Collect tasks to cancel # Collect tasks to cancel
if self._ping_task and not self._ping_task.done(): tasks_to_cancel = []
tasks.append(self._ping_task)
if self._message_handler_task and not self._message_handler_task.done():
tasks.append(self._message_handler_task)
if not tasks: if self._ping_task and not self._ping_task.done():
tasks_to_cancel.append(('ping_task', self._ping_task))
if self._message_handler_task and not self._message_handler_task.done():
tasks_to_cancel.append(('message_handler_task', self._message_handler_task))
if not tasks_to_cancel:
if self.logger: if self.logger:
self.logger.debug(f"{self.component_name}: No background tasks to stop") self.logger.debug(f"{self.component_name}: No background tasks to stop")
return return
if self.logger: if self.logger:
self.logger.debug(f"{self.component_name}: Stopping {len(tasks)} background tasks") self.logger.debug(f"{self.component_name}: Stopping {len(tasks_to_cancel)} background tasks")
# Cancel all tasks # Cancel tasks individually to avoid recursion
for task in tasks: for task_name, task in tasks_to_cancel:
task.cancel()
# Wait for all tasks to complete with timeout
if tasks:
try: try:
await asyncio.wait_for( if not task.done():
asyncio.gather(*tasks, return_exceptions=True), task.cancel()
timeout=5.0 if self.logger:
) self.logger.debug(f"{self.component_name}: Cancelled {task_name}")
except asyncio.TimeoutError:
if self.logger:
self.logger.warning(f"{self.component_name}: Task shutdown timeout - some tasks may still be running")
except Exception as e: except Exception as e:
if self.logger: if self.logger:
self.logger.debug(f"{self.component_name}: Expected exception during task shutdown: {e}") self.logger.debug(f"{self.component_name}: Error cancelling {task_name}: {e}")
# Wait for tasks to complete individually with shorter timeouts
for task_name, task in tasks_to_cancel:
try:
await asyncio.wait_for(task, timeout=2.0)
except asyncio.TimeoutError:
if self.logger:
self.logger.warning(f"{self.component_name}: {task_name} shutdown timeout")
except asyncio.CancelledError:
# Expected when task is cancelled
pass
except Exception as e:
if self.logger:
self.logger.debug(f"{self.component_name}: {task_name} shutdown exception: {e}")
# Clear task references # Clear task references
self._ping_task = None self._ping_task = None
@ -446,6 +473,9 @@ class OKXWebSocketClient:
if self.logger: if self.logger:
self.logger.debug(f"{self.component_name}: Background tasks stopped successfully") self.logger.debug(f"{self.component_name}: Background tasks stopped successfully")
except Exception as e:
if self.logger:
self.logger.error(f"{self.component_name}: Error in _stop_background_tasks: {e}")
finally: finally:
self._tasks_stopping = False self._tasks_stopping = False
@ -495,6 +525,9 @@ class OKXWebSocketClient:
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue # No message received, continue loop continue # No message received, continue loop
except asyncio.CancelledError:
# Exit immediately on cancellation
break
# Check if we're still supposed to be running # Check if we're still supposed to be running
if self._tasks_stopping: if self._tasks_stopping:
@ -512,35 +545,42 @@ class OKXWebSocketClient:
self._connection_state = ConnectionState.DISCONNECTED self._connection_state = ConnectionState.DISCONNECTED
# Use lock to prevent concurrent reconnection attempts # Use lock to prevent concurrent reconnection attempts
async with self._reconnection_lock: try:
# Double-check we still need to reconnect # Use asyncio.wait_for to prevent hanging on lock acquisition
if (self._connection_state == ConnectionState.DISCONNECTED and async with asyncio.wait_for(self._reconnection_lock.acquire(), timeout=5.0):
self._reconnect_attempts < self.max_reconnect_attempts and try:
not self._tasks_stopping): # Double-check we still need to reconnect
if (self._connection_state == ConnectionState.DISCONNECTED and
self._reconnect_attempts += 1 self._reconnect_attempts < self.max_reconnect_attempts and
if self.logger: not self._tasks_stopping):
self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
self._reconnect_attempts += 1
# Stop current tasks properly if self.logger:
await self._stop_background_tasks() self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
# Attempt reconnection with stored subscriptions # Attempt reconnection (this will handle task cleanup)
stored_subscriptions = list(self._subscriptions.values()) if await self.reconnect():
if self.logger:
if await self.reconnect(): self.logger.info(f"{self.component_name}: Automatic reconnection successful")
if self.logger: # Exit this handler as reconnect will start new tasks
self.logger.info(f"{self.component_name}: Automatic reconnection successful") break
# The reconnect method will restart tasks, so we exit this handler else:
break if self.logger:
else: self.logger.error(f"{self.component_name}: Automatic reconnection failed")
if self.logger: break
self.logger.error(f"{self.component_name}: Automatic reconnection failed") else:
break if self.logger:
else: self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress")
if self.logger: break
self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress") finally:
break self._reconnection_lock.release()
except asyncio.TimeoutError:
if self.logger:
self.logger.warning(f"{self.component_name}: Timeout acquiring reconnection lock")
break
except asyncio.CancelledError:
# Exit immediately on cancellation
break
except asyncio.CancelledError: except asyncio.CancelledError:
if self.logger: if self.logger:

155
tests/test_recursion_fix.py Normal file
View File

@ -0,0 +1,155 @@
#!/usr/bin/env python3
"""
Simple test to verify recursion fix in WebSocket task management.
"""
import asyncio
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType
from utils.logger import get_logger
async def test_rapid_connection_cycles():
"""Test rapid connect/disconnect cycles to verify no recursion errors."""
logger = get_logger("recursion_test", verbose=False)
print("🧪 Testing WebSocket Recursion Fix")
print("=" * 40)
for cycle in range(5):
print(f"\n🔄 Cycle {cycle + 1}/5: Rapid connect/disconnect")
ws_client = OKXWebSocketClient(
component_name=f"test_client_{cycle}",
max_reconnect_attempts=2,
logger=logger
)
try:
# Connect
success = await ws_client.connect()
if not success:
print(f" ❌ Connection failed in cycle {cycle + 1}")
continue
# Subscribe
subscriptions = [
OKXSubscription(OKXChannelType.TRADES.value, "BTC-USDT")
]
await ws_client.subscribe(subscriptions)
# Quick activity
await asyncio.sleep(0.5)
# Disconnect (this should not cause recursion)
await ws_client.disconnect()
print(f" ✅ Cycle {cycle + 1} completed successfully")
except RecursionError as e:
print(f" ❌ Recursion error in cycle {cycle + 1}: {e}")
return False
except Exception as e:
print(f" ⚠️ Other error in cycle {cycle + 1}: {e}")
# Continue with other cycles
# Small delay between cycles
await asyncio.sleep(0.2)
print("\n✅ All cycles completed without recursion errors")
return True
async def test_concurrent_shutdowns():
"""Test concurrent client shutdowns to verify no recursion."""
logger = get_logger("concurrent_shutdown_test", verbose=False)
print("\n🔄 Testing Concurrent Shutdowns")
print("=" * 40)
# Create multiple clients
clients = []
for i in range(3):
client = OKXWebSocketClient(
component_name=f"concurrent_client_{i}",
logger=logger
)
clients.append(client)
try:
# Connect all clients
connect_tasks = [client.connect() for client in clients]
results = await asyncio.gather(*connect_tasks, return_exceptions=True)
successful_connections = sum(1 for r in results if r is True)
print(f"📡 Connected {successful_connections}/3 clients")
# Let them run briefly
await asyncio.sleep(1)
# Shutdown all concurrently (this is where recursion might occur)
print("🛑 Shutting down all clients concurrently...")
shutdown_tasks = [client.disconnect() for client in clients]
# Use wait_for to prevent hanging
try:
await asyncio.wait_for(
asyncio.gather(*shutdown_tasks, return_exceptions=True),
timeout=10.0
)
print("✅ All clients shut down successfully")
return True
except asyncio.TimeoutError:
print("⚠️ Shutdown timeout - but no recursion errors")
return True # Timeout is better than recursion
except RecursionError as e:
print(f"❌ Recursion error during concurrent shutdown: {e}")
return False
except Exception as e:
print(f"⚠️ Other error during test: {e}")
return True # Other errors are acceptable for this test
async def main():
"""Run recursion fix tests."""
print("🚀 WebSocket Recursion Fix Test Suite")
print("=" * 50)
try:
# Test 1: Rapid cycles
test1_success = await test_rapid_connection_cycles()
# Test 2: Concurrent shutdowns
test2_success = await test_concurrent_shutdowns()
# Summary
print("\n" + "=" * 50)
print("📋 Test Summary:")
print(f" Rapid Cycles: {'✅ PASS' if test1_success else '❌ FAIL'}")
print(f" Concurrent Shutdowns: {'✅ PASS' if test2_success else '❌ FAIL'}")
if test1_success and test2_success:
print("\n🎉 All tests passed! Recursion issue fixed.")
return 0
else:
print("\n❌ Some tests failed.")
return 1
except KeyboardInterrupt:
print("\n⏹️ Tests interrupted")
return 1
except Exception as e:
print(f"\n💥 Test suite failed: {e}")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)