fix recursion error on reconnection
This commit is contained in:
parent
01cea1d5e5
commit
d508616677
@ -388,56 +388,83 @@ class OKXWebSocketClient:
|
|||||||
self.logger.warning(f"{self.component_name}: Cannot start tasks while stopping is in progress")
|
self.logger.warning(f"{self.component_name}: Cannot start tasks while stopping is in progress")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Cancel any existing tasks first
|
# Check if tasks are already running
|
||||||
|
if (self._ping_task and not self._ping_task.done() and
|
||||||
|
self._message_handler_task and not self._message_handler_task.done()):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: Background tasks already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cancel any existing tasks first (safety measure)
|
||||||
await self._stop_background_tasks()
|
await self._stop_background_tasks()
|
||||||
|
|
||||||
# Start ping task
|
# Ensure we're still supposed to start tasks after stopping
|
||||||
self._ping_task = asyncio.create_task(self._ping_loop())
|
if self._tasks_stopping or not self.is_connected:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: Aborting task start - stopping or disconnected")
|
||||||
|
return
|
||||||
|
|
||||||
# Start message handler task
|
try:
|
||||||
self._message_handler_task = asyncio.create_task(self._message_handler())
|
# Start ping task
|
||||||
|
self._ping_task = asyncio.create_task(self._ping_loop())
|
||||||
if self.logger:
|
|
||||||
self.logger.debug(f"{self.component_name}: Started background tasks")
|
# Start message handler task
|
||||||
|
self._message_handler_task = asyncio.create_task(self._message_handler())
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: Started background tasks")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"{self.component_name}: Error starting background tasks: {e}")
|
||||||
|
# Clean up on failure
|
||||||
|
await self._stop_background_tasks()
|
||||||
|
|
||||||
async def _stop_background_tasks(self) -> None:
|
async def _stop_background_tasks(self) -> None:
|
||||||
"""Stop background tasks with proper synchronization."""
|
"""Stop background tasks with proper synchronization - simplified approach."""
|
||||||
self._tasks_stopping = True
|
self._tasks_stopping = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tasks = []
|
|
||||||
|
|
||||||
# Collect tasks to cancel
|
# Collect tasks to cancel
|
||||||
if self._ping_task and not self._ping_task.done():
|
tasks_to_cancel = []
|
||||||
tasks.append(self._ping_task)
|
|
||||||
if self._message_handler_task and not self._message_handler_task.done():
|
|
||||||
tasks.append(self._message_handler_task)
|
|
||||||
|
|
||||||
if not tasks:
|
if self._ping_task and not self._ping_task.done():
|
||||||
|
tasks_to_cancel.append(('ping_task', self._ping_task))
|
||||||
|
if self._message_handler_task and not self._message_handler_task.done():
|
||||||
|
tasks_to_cancel.append(('message_handler_task', self._message_handler_task))
|
||||||
|
|
||||||
|
if not tasks_to_cancel:
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.debug(f"{self.component_name}: No background tasks to stop")
|
self.logger.debug(f"{self.component_name}: No background tasks to stop")
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.debug(f"{self.component_name}: Stopping {len(tasks)} background tasks")
|
self.logger.debug(f"{self.component_name}: Stopping {len(tasks_to_cancel)} background tasks")
|
||||||
|
|
||||||
# Cancel all tasks
|
# Cancel tasks individually to avoid recursion
|
||||||
for task in tasks:
|
for task_name, task in tasks_to_cancel:
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
# Wait for all tasks to complete with timeout
|
|
||||||
if tasks:
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
if not task.done():
|
||||||
asyncio.gather(*tasks, return_exceptions=True),
|
task.cancel()
|
||||||
timeout=5.0
|
if self.logger:
|
||||||
)
|
self.logger.debug(f"{self.component_name}: Cancelled {task_name}")
|
||||||
except asyncio.TimeoutError:
|
|
||||||
if self.logger:
|
|
||||||
self.logger.warning(f"{self.component_name}: Task shutdown timeout - some tasks may still be running")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.debug(f"{self.component_name}: Expected exception during task shutdown: {e}")
|
self.logger.debug(f"{self.component_name}: Error cancelling {task_name}: {e}")
|
||||||
|
|
||||||
|
# Wait for tasks to complete individually with shorter timeouts
|
||||||
|
for task_name, task in tasks_to_cancel:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(task, timeout=2.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.component_name}: {task_name} shutdown timeout")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Expected when task is cancelled
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"{self.component_name}: {task_name} shutdown exception: {e}")
|
||||||
|
|
||||||
# Clear task references
|
# Clear task references
|
||||||
self._ping_task = None
|
self._ping_task = None
|
||||||
@ -446,6 +473,9 @@ class OKXWebSocketClient:
|
|||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.debug(f"{self.component_name}: Background tasks stopped successfully")
|
self.logger.debug(f"{self.component_name}: Background tasks stopped successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"{self.component_name}: Error in _stop_background_tasks: {e}")
|
||||||
finally:
|
finally:
|
||||||
self._tasks_stopping = False
|
self._tasks_stopping = False
|
||||||
|
|
||||||
@ -495,6 +525,9 @@ class OKXWebSocketClient:
|
|||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue # No message received, continue loop
|
continue # No message received, continue loop
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Exit immediately on cancellation
|
||||||
|
break
|
||||||
|
|
||||||
# Check if we're still supposed to be running
|
# Check if we're still supposed to be running
|
||||||
if self._tasks_stopping:
|
if self._tasks_stopping:
|
||||||
@ -512,35 +545,42 @@ class OKXWebSocketClient:
|
|||||||
self._connection_state = ConnectionState.DISCONNECTED
|
self._connection_state = ConnectionState.DISCONNECTED
|
||||||
|
|
||||||
# Use lock to prevent concurrent reconnection attempts
|
# Use lock to prevent concurrent reconnection attempts
|
||||||
async with self._reconnection_lock:
|
try:
|
||||||
# Double-check we still need to reconnect
|
# Use asyncio.wait_for to prevent hanging on lock acquisition
|
||||||
if (self._connection_state == ConnectionState.DISCONNECTED and
|
async with asyncio.wait_for(self._reconnection_lock.acquire(), timeout=5.0):
|
||||||
self._reconnect_attempts < self.max_reconnect_attempts and
|
try:
|
||||||
not self._tasks_stopping):
|
# Double-check we still need to reconnect
|
||||||
|
if (self._connection_state == ConnectionState.DISCONNECTED and
|
||||||
self._reconnect_attempts += 1
|
self._reconnect_attempts < self.max_reconnect_attempts and
|
||||||
if self.logger:
|
not self._tasks_stopping):
|
||||||
self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
|
|
||||||
|
self._reconnect_attempts += 1
|
||||||
# Stop current tasks properly
|
if self.logger:
|
||||||
await self._stop_background_tasks()
|
self.logger.info(f"{self.component_name}: Attempting automatic reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})")
|
||||||
|
|
||||||
# Attempt reconnection with stored subscriptions
|
# Attempt reconnection (this will handle task cleanup)
|
||||||
stored_subscriptions = list(self._subscriptions.values())
|
if await self.reconnect():
|
||||||
|
if self.logger:
|
||||||
if await self.reconnect():
|
self.logger.info(f"{self.component_name}: Automatic reconnection successful")
|
||||||
if self.logger:
|
# Exit this handler as reconnect will start new tasks
|
||||||
self.logger.info(f"{self.component_name}: Automatic reconnection successful")
|
break
|
||||||
# The reconnect method will restart tasks, so we exit this handler
|
else:
|
||||||
break
|
if self.logger:
|
||||||
else:
|
self.logger.error(f"{self.component_name}: Automatic reconnection failed")
|
||||||
if self.logger:
|
break
|
||||||
self.logger.error(f"{self.component_name}: Automatic reconnection failed")
|
else:
|
||||||
break
|
if self.logger:
|
||||||
else:
|
self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress")
|
||||||
if self.logger:
|
break
|
||||||
self.logger.error(f"{self.component_name}: Max reconnection attempts exceeded or shutdown in progress")
|
finally:
|
||||||
break
|
self._reconnection_lock.release()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"{self.component_name}: Timeout acquiring reconnection lock")
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Exit immediately on cancellation
|
||||||
|
break
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
if self.logger:
|
if self.logger:
|
||||||
|
|||||||
155
tests/test_recursion_fix.py
Normal file
155
tests/test_recursion_fix.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple test to verify recursion fix in WebSocket task management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType
|
||||||
|
from utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
async def test_rapid_connection_cycles():
|
||||||
|
"""Test rapid connect/disconnect cycles to verify no recursion errors."""
|
||||||
|
logger = get_logger("recursion_test", verbose=False)
|
||||||
|
|
||||||
|
print("🧪 Testing WebSocket Recursion Fix")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
for cycle in range(5):
|
||||||
|
print(f"\n🔄 Cycle {cycle + 1}/5: Rapid connect/disconnect")
|
||||||
|
|
||||||
|
ws_client = OKXWebSocketClient(
|
||||||
|
component_name=f"test_client_{cycle}",
|
||||||
|
max_reconnect_attempts=2,
|
||||||
|
logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect
|
||||||
|
success = await ws_client.connect()
|
||||||
|
if not success:
|
||||||
|
print(f" ❌ Connection failed in cycle {cycle + 1}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Subscribe
|
||||||
|
subscriptions = [
|
||||||
|
OKXSubscription(OKXChannelType.TRADES.value, "BTC-USDT")
|
||||||
|
]
|
||||||
|
await ws_client.subscribe(subscriptions)
|
||||||
|
|
||||||
|
# Quick activity
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Disconnect (this should not cause recursion)
|
||||||
|
await ws_client.disconnect()
|
||||||
|
print(f" ✅ Cycle {cycle + 1} completed successfully")
|
||||||
|
|
||||||
|
except RecursionError as e:
|
||||||
|
print(f" ❌ Recursion error in cycle {cycle + 1}: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Other error in cycle {cycle + 1}: {e}")
|
||||||
|
# Continue with other cycles
|
||||||
|
|
||||||
|
# Small delay between cycles
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
print("\n✅ All cycles completed without recursion errors")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_concurrent_shutdowns():
|
||||||
|
"""Test concurrent client shutdowns to verify no recursion."""
|
||||||
|
logger = get_logger("concurrent_shutdown_test", verbose=False)
|
||||||
|
|
||||||
|
print("\n🔄 Testing Concurrent Shutdowns")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
# Create multiple clients
|
||||||
|
clients = []
|
||||||
|
for i in range(3):
|
||||||
|
client = OKXWebSocketClient(
|
||||||
|
component_name=f"concurrent_client_{i}",
|
||||||
|
logger=logger
|
||||||
|
)
|
||||||
|
clients.append(client)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect all clients
|
||||||
|
connect_tasks = [client.connect() for client in clients]
|
||||||
|
results = await asyncio.gather(*connect_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
successful_connections = sum(1 for r in results if r is True)
|
||||||
|
print(f"📡 Connected {successful_connections}/3 clients")
|
||||||
|
|
||||||
|
# Let them run briefly
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Shutdown all concurrently (this is where recursion might occur)
|
||||||
|
print("🛑 Shutting down all clients concurrently...")
|
||||||
|
shutdown_tasks = [client.disconnect() for client in clients]
|
||||||
|
|
||||||
|
# Use wait_for to prevent hanging
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.gather(*shutdown_tasks, return_exceptions=True),
|
||||||
|
timeout=10.0
|
||||||
|
)
|
||||||
|
print("✅ All clients shut down successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("⚠️ Shutdown timeout - but no recursion errors")
|
||||||
|
return True # Timeout is better than recursion
|
||||||
|
|
||||||
|
except RecursionError as e:
|
||||||
|
print(f"❌ Recursion error during concurrent shutdown: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Other error during test: {e}")
|
||||||
|
return True # Other errors are acceptable for this test
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run recursion fix tests."""
|
||||||
|
print("🚀 WebSocket Recursion Fix Test Suite")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test 1: Rapid cycles
|
||||||
|
test1_success = await test_rapid_connection_cycles()
|
||||||
|
|
||||||
|
# Test 2: Concurrent shutdowns
|
||||||
|
test2_success = await test_concurrent_shutdowns()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("📋 Test Summary:")
|
||||||
|
print(f" Rapid Cycles: {'✅ PASS' if test1_success else '❌ FAIL'}")
|
||||||
|
print(f" Concurrent Shutdowns: {'✅ PASS' if test2_success else '❌ FAIL'}")
|
||||||
|
|
||||||
|
if test1_success and test2_success:
|
||||||
|
print("\n🎉 All tests passed! Recursion issue fixed.")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
print("\n❌ Some tests failed.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⏹️ Tests interrupted")
|
||||||
|
return 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n💥 Test suite failed: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = asyncio.run(main())
|
||||||
|
sys.exit(exit_code)
|
||||||
Loading…
x
Reference in New Issue
Block a user