295 lines
9.8 KiB
Python
295 lines
9.8 KiB
Python
|
|
"""
|
||
|
|
Async Task Manager for managing and tracking asyncio.Task instances.
|
||
|
|
|
||
|
|
This module provides a centralized way to manage asyncio tasks with proper lifecycle
|
||
|
|
tracking, cleanup, and memory leak prevention.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import logging
|
||
|
|
from typing import Set, Optional, Dict, Any, Callable, Awaitable
|
||
|
|
from weakref import WeakSet
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
|
||
|
|
|
||
|
|
class TaskManager:
|
||
|
|
"""
|
||
|
|
Manages asyncio.Task instances with proper lifecycle tracking and cleanup.
|
||
|
|
|
||
|
|
Features:
|
||
|
|
- Automatic task cleanup when tasks complete
|
||
|
|
- Graceful shutdown with timeout handling
|
||
|
|
- Task statistics and monitoring
|
||
|
|
- Memory leak prevention through weak references where appropriate
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self,
|
||
|
|
name: str = "task_manager",
|
||
|
|
logger: Optional[logging.Logger] = None,
|
||
|
|
cleanup_timeout: float = 5.0):
|
||
|
|
"""
|
||
|
|
Initialize the task manager.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: Name identifier for this task manager
|
||
|
|
logger: Optional logger for logging operations
|
||
|
|
cleanup_timeout: Timeout for graceful task cleanup in seconds
|
||
|
|
"""
|
||
|
|
self.name = name
|
||
|
|
self.logger = logger
|
||
|
|
self.cleanup_timeout = cleanup_timeout
|
||
|
|
|
||
|
|
# Active task tracking
|
||
|
|
self._tasks: Set[asyncio.Task] = set()
|
||
|
|
self._task_names: Dict[asyncio.Task, str] = {}
|
||
|
|
self._task_created_at: Dict[asyncio.Task, datetime] = {}
|
||
|
|
|
||
|
|
# Statistics
|
||
|
|
self._stats = {
|
||
|
|
'tasks_created': 0,
|
||
|
|
'tasks_completed': 0,
|
||
|
|
'tasks_cancelled': 0,
|
||
|
|
'tasks_failed': 0,
|
||
|
|
'active_tasks': 0
|
||
|
|
}
|
||
|
|
|
||
|
|
# State
|
||
|
|
self._shutdown = False
|
||
|
|
|
||
|
|
if self.logger:
|
||
|
|
self.logger.debug(f"TaskManager '{self.name}' initialized")
|
||
|
|
|
||
|
|
def create_task(self,
|
||
|
|
coro: Awaitable,
|
||
|
|
name: Optional[str] = None,
|
||
|
|
auto_cleanup: bool = True) -> asyncio.Task:
|
||
|
|
"""
|
||
|
|
Create and track an asyncio task.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
coro: Coroutine to run as a task
|
||
|
|
name: Optional name for the task (for debugging/monitoring)
|
||
|
|
auto_cleanup: Whether to automatically remove task when done
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The created asyncio.Task
|
||
|
|
"""
|
||
|
|
if self._shutdown:
|
||
|
|
raise RuntimeError(f"TaskManager '{self.name}' is shutdown")
|
||
|
|
|
||
|
|
# Create the task
|
||
|
|
task = asyncio.create_task(coro)
|
||
|
|
|
||
|
|
# Track the task
|
||
|
|
self._tasks.add(task)
|
||
|
|
if name:
|
||
|
|
self._task_names[task] = name
|
||
|
|
if hasattr(task, 'set_name'): # Python 3.8+
|
||
|
|
task.set_name(f"{self.name}:{name}")
|
||
|
|
|
||
|
|
self._task_created_at[task] = datetime.now(timezone.utc)
|
||
|
|
self._stats['tasks_created'] += 1
|
||
|
|
self._stats['active_tasks'] = len(self._tasks)
|
||
|
|
|
||
|
|
# Add callback for automatic cleanup if requested
|
||
|
|
if auto_cleanup:
|
||
|
|
task.add_done_callback(self._task_done_callback)
|
||
|
|
|
||
|
|
if self.logger:
|
||
|
|
task_name = name or f"task-{id(task)}"
|
||
|
|
self.logger.debug(f"Created task '{task_name}' (total active: {len(self._tasks)})")
|
||
|
|
|
||
|
|
return task
|
||
|
|
|
||
|
|
def _task_done_callback(self, task: asyncio.Task) -> None:
|
||
|
|
"""Callback called when a tracked task completes."""
|
||
|
|
self._remove_task(task)
|
||
|
|
|
||
|
|
# Update statistics based on task result
|
||
|
|
if task.cancelled():
|
||
|
|
self._stats['tasks_cancelled'] += 1
|
||
|
|
elif task.exception() is not None:
|
||
|
|
self._stats['tasks_failed'] += 1
|
||
|
|
if self.logger:
|
||
|
|
task_name = self._task_names.get(task, f"task-{id(task)}")
|
||
|
|
self.logger.warning(f"Task '{task_name}' failed: {task.exception()}")
|
||
|
|
else:
|
||
|
|
self._stats['tasks_completed'] += 1
|
||
|
|
|
||
|
|
self._stats['active_tasks'] = len(self._tasks)
|
||
|
|
|
||
|
|
def _remove_task(self, task: asyncio.Task) -> None:
|
||
|
|
"""Remove a task from tracking."""
|
||
|
|
self._tasks.discard(task)
|
||
|
|
self._task_names.pop(task, None)
|
||
|
|
self._task_created_at.pop(task, None)
|
||
|
|
|
||
|
|
def cancel_task(self, task: asyncio.Task, reason: str = "Manual cancellation") -> bool:
|
||
|
|
"""
|
||
|
|
Cancel a specific task.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
task: The task to cancel
|
||
|
|
reason: Reason for cancellation (for logging)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if task was cancelled, False if already done
|
||
|
|
"""
|
||
|
|
if task.done():
|
||
|
|
return False
|
||
|
|
|
||
|
|
task.cancel()
|
||
|
|
|
||
|
|
if self.logger:
|
||
|
|
task_name = self._task_names.get(task, f"task-{id(task)}")
|
||
|
|
self.logger.debug(f"Cancelled task '{task_name}': {reason}")
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def cancel_all(self, reason: str = "Shutdown") -> int:
|
||
|
|
"""
|
||
|
|
Cancel all tracked tasks.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
reason: Reason for cancellation (for logging)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Number of tasks cancelled
|
||
|
|
"""
|
||
|
|
tasks_to_cancel = [task for task in self._tasks if not task.done()]
|
||
|
|
|
||
|
|
for task in tasks_to_cancel:
|
||
|
|
self.cancel_task(task, reason)
|
||
|
|
|
||
|
|
if self.logger and tasks_to_cancel:
|
||
|
|
self.logger.debug(f"Cancelled {len(tasks_to_cancel)} tasks: {reason}")
|
||
|
|
|
||
|
|
return len(tasks_to_cancel)
|
||
|
|
|
||
|
|
async def wait_for_completion(self, timeout: Optional[float] = None) -> bool:
|
||
|
|
"""
|
||
|
|
Wait for all tracked tasks to complete.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
timeout: Optional timeout in seconds
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if all tasks completed, False if timeout occurred
|
||
|
|
"""
|
||
|
|
if not self._tasks:
|
||
|
|
return True
|
||
|
|
|
||
|
|
pending_tasks = [task for task in self._tasks if not task.done()]
|
||
|
|
|
||
|
|
if not pending_tasks:
|
||
|
|
return True
|
||
|
|
|
||
|
|
try:
|
||
|
|
await asyncio.wait_for(
|
||
|
|
asyncio.gather(*pending_tasks, return_exceptions=True),
|
||
|
|
timeout=timeout
|
||
|
|
)
|
||
|
|
return True
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
if self.logger:
|
||
|
|
self.logger.warning(f"Timeout waiting for {len(pending_tasks)} tasks to complete")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def shutdown(self, graceful: bool = True) -> None:
|
||
|
|
"""
|
||
|
|
Shutdown the task manager and cleanup all tasks.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
graceful: Whether to wait for tasks to complete gracefully
|
||
|
|
"""
|
||
|
|
if self._shutdown:
|
||
|
|
return
|
||
|
|
|
||
|
|
self._shutdown = True
|
||
|
|
|
||
|
|
if self.logger:
|
||
|
|
self.logger.debug(f"Shutting down TaskManager '{self.name}' ({len(self._tasks)} active tasks)")
|
||
|
|
|
||
|
|
if graceful and self._tasks:
|
||
|
|
# Try graceful shutdown first
|
||
|
|
completed = await self.wait_for_completion(timeout=self.cleanup_timeout)
|
||
|
|
|
||
|
|
if not completed:
|
||
|
|
# Force cancellation if graceful shutdown failed
|
||
|
|
cancelled_count = self.cancel_all("Forced shutdown after timeout")
|
||
|
|
if self.logger:
|
||
|
|
self.logger.warning(f"Force cancelled {cancelled_count} tasks after timeout")
|
||
|
|
else:
|
||
|
|
# Immediate cancellation
|
||
|
|
self.cancel_all("Immediate shutdown")
|
||
|
|
|
||
|
|
# Wait for cancelled tasks to complete
|
||
|
|
if self._tasks:
|
||
|
|
try:
|
||
|
|
await asyncio.wait_for(
|
||
|
|
asyncio.gather(*list(self._tasks), return_exceptions=True),
|
||
|
|
timeout=2.0
|
||
|
|
)
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
if self.logger:
|
||
|
|
self.logger.warning("Some tasks did not complete after cancellation")
|
||
|
|
|
||
|
|
# Clear all tracking
|
||
|
|
self._tasks.clear()
|
||
|
|
self._task_names.clear()
|
||
|
|
self._task_created_at.clear()
|
||
|
|
|
||
|
|
if self.logger:
|
||
|
|
self.logger.debug(f"TaskManager '{self.name}' shutdown complete")
|
||
|
|
|
||
|
|
def get_stats(self) -> Dict[str, Any]:
|
||
|
|
"""Get task management statistics."""
|
||
|
|
return {
|
||
|
|
'name': self.name,
|
||
|
|
'active_tasks': len(self._tasks),
|
||
|
|
'tasks_created': self._stats['tasks_created'],
|
||
|
|
'tasks_completed': self._stats['tasks_completed'],
|
||
|
|
'tasks_cancelled': self._stats['tasks_cancelled'],
|
||
|
|
'tasks_failed': self._stats['tasks_failed'],
|
||
|
|
'is_shutdown': self._shutdown
|
||
|
|
}
|
||
|
|
|
||
|
|
def get_active_tasks(self) -> Dict[str, Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Get information about currently active tasks.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary mapping task names to task information
|
||
|
|
"""
|
||
|
|
active_tasks = {}
|
||
|
|
current_time = datetime.now(timezone.utc)
|
||
|
|
|
||
|
|
for task in self._tasks:
|
||
|
|
if task.done():
|
||
|
|
continue
|
||
|
|
|
||
|
|
task_name = self._task_names.get(task, f"task-{id(task)}")
|
||
|
|
created_at = self._task_created_at.get(task)
|
||
|
|
age_seconds = (current_time - created_at).total_seconds() if created_at else None
|
||
|
|
|
||
|
|
active_tasks[task_name] = {
|
||
|
|
'task_id': id(task),
|
||
|
|
'created_at': created_at,
|
||
|
|
'age_seconds': age_seconds,
|
||
|
|
'done': task.done(),
|
||
|
|
'cancelled': task.cancelled()
|
||
|
|
}
|
||
|
|
|
||
|
|
return active_tasks
|
||
|
|
|
||
|
|
def __len__(self) -> int:
|
||
|
|
"""Return the number of active tasks."""
|
||
|
|
return len(self._tasks)
|
||
|
|
|
||
|
|
def __bool__(self) -> bool:
|
||
|
|
"""Return True if there are active tasks."""
|
||
|
|
return bool(self._tasks)
|
||
|
|
|
||
|
|
def __repr__(self) -> str:
|
||
|
|
"""String representation of the task manager."""
|
||
|
|
return f"TaskManager(name='{self.name}', active_tasks={len(self._tasks)}, shutdown={self._shutdown})"
|