""" Async Task Manager for managing and tracking asyncio.Task instances. This module provides a centralized way to manage asyncio tasks with proper lifecycle tracking, cleanup, and memory leak prevention. """ import asyncio import logging from typing import Set, Optional, Dict, Any, Callable, Awaitable from weakref import WeakSet from datetime import datetime, timezone class TaskManager: """ Manages asyncio.Task instances with proper lifecycle tracking and cleanup. Features: - Automatic task cleanup when tasks complete - Graceful shutdown with timeout handling - Task statistics and monitoring - Memory leak prevention through weak references where appropriate """ def __init__(self, name: str = "task_manager", logger: Optional[logging.Logger] = None, cleanup_timeout: float = 5.0): """ Initialize the task manager. Args: name: Name identifier for this task manager logger: Optional logger for logging operations cleanup_timeout: Timeout for graceful task cleanup in seconds """ self.name = name self.logger = logger self.cleanup_timeout = cleanup_timeout # Active task tracking self._tasks: Set[asyncio.Task] = set() self._task_names: Dict[asyncio.Task, str] = {} self._task_created_at: Dict[asyncio.Task, datetime] = {} # Statistics self._stats = { 'tasks_created': 0, 'tasks_completed': 0, 'tasks_cancelled': 0, 'tasks_failed': 0, 'active_tasks': 0 } # State self._shutdown = False if self.logger: self.logger.debug(f"TaskManager '{self.name}' initialized") def create_task(self, coro: Awaitable, name: Optional[str] = None, auto_cleanup: bool = True) -> asyncio.Task: """ Create and track an asyncio task. Args: coro: Coroutine to run as a task name: Optional name for the task (for debugging/monitoring) auto_cleanup: Whether to automatically remove task when done Returns: The created asyncio.Task """ if self._shutdown: raise RuntimeError(f"TaskManager '{self.name}' is shutdown") # Create the task task = asyncio.create_task(coro) # Track the task self._tasks.add(task) if name: self._task_names[task] = name if hasattr(task, 'set_name'): # Python 3.8+ task.set_name(f"{self.name}:{name}") self._task_created_at[task] = datetime.now(timezone.utc) self._stats['tasks_created'] += 1 self._stats['active_tasks'] = len(self._tasks) # Add callback for automatic cleanup if requested if auto_cleanup: task.add_done_callback(self._task_done_callback) if self.logger: task_name = name or f"task-{id(task)}" self.logger.debug(f"Created task '{task_name}' (total active: {len(self._tasks)})") return task def _task_done_callback(self, task: asyncio.Task) -> None: """Callback called when a tracked task completes.""" self._remove_task(task) # Update statistics based on task result if task.cancelled(): self._stats['tasks_cancelled'] += 1 elif task.exception() is not None: self._stats['tasks_failed'] += 1 if self.logger: task_name = self._task_names.get(task, f"task-{id(task)}") self.logger.warning(f"Task '{task_name}' failed: {task.exception()}") else: self._stats['tasks_completed'] += 1 self._stats['active_tasks'] = len(self._tasks) def _remove_task(self, task: asyncio.Task) -> None: """Remove a task from tracking.""" self._tasks.discard(task) self._task_names.pop(task, None) self._task_created_at.pop(task, None) def cancel_task(self, task: asyncio.Task, reason: str = "Manual cancellation") -> bool: """ Cancel a specific task. Args: task: The task to cancel reason: Reason for cancellation (for logging) Returns: True if task was cancelled, False if already done """ if task.done(): return False task.cancel() if self.logger: task_name = self._task_names.get(task, f"task-{id(task)}") self.logger.debug(f"Cancelled task '{task_name}': {reason}") return True def cancel_all(self, reason: str = "Shutdown") -> int: """ Cancel all tracked tasks. Args: reason: Reason for cancellation (for logging) Returns: Number of tasks cancelled """ tasks_to_cancel = [task for task in self._tasks if not task.done()] for task in tasks_to_cancel: self.cancel_task(task, reason) if self.logger and tasks_to_cancel: self.logger.debug(f"Cancelled {len(tasks_to_cancel)} tasks: {reason}") return len(tasks_to_cancel) async def wait_for_completion(self, timeout: Optional[float] = None) -> bool: """ Wait for all tracked tasks to complete. Args: timeout: Optional timeout in seconds Returns: True if all tasks completed, False if timeout occurred """ if not self._tasks: return True pending_tasks = [task for task in self._tasks if not task.done()] if not pending_tasks: return True try: await asyncio.wait_for( asyncio.gather(*pending_tasks, return_exceptions=True), timeout=timeout ) return True except asyncio.TimeoutError: if self.logger: self.logger.warning(f"Timeout waiting for {len(pending_tasks)} tasks to complete") return False async def shutdown(self, graceful: bool = True) -> None: """ Shutdown the task manager and cleanup all tasks. Args: graceful: Whether to wait for tasks to complete gracefully """ if self._shutdown: return self._shutdown = True if self.logger: self.logger.debug(f"Shutting down TaskManager '{self.name}' ({len(self._tasks)} active tasks)") if graceful and self._tasks: # Try graceful shutdown first completed = await self.wait_for_completion(timeout=self.cleanup_timeout) if not completed: # Force cancellation if graceful shutdown failed cancelled_count = self.cancel_all("Forced shutdown after timeout") if self.logger: self.logger.warning(f"Force cancelled {cancelled_count} tasks after timeout") else: # Immediate cancellation self.cancel_all("Immediate shutdown") # Wait for cancelled tasks to complete if self._tasks: try: await asyncio.wait_for( asyncio.gather(*list(self._tasks), return_exceptions=True), timeout=2.0 ) except asyncio.TimeoutError: if self.logger: self.logger.warning("Some tasks did not complete after cancellation") # Clear all tracking self._tasks.clear() self._task_names.clear() self._task_created_at.clear() if self.logger: self.logger.debug(f"TaskManager '{self.name}' shutdown complete") def get_stats(self) -> Dict[str, Any]: """Get task management statistics.""" return { 'name': self.name, 'active_tasks': len(self._tasks), 'tasks_created': self._stats['tasks_created'], 'tasks_completed': self._stats['tasks_completed'], 'tasks_cancelled': self._stats['tasks_cancelled'], 'tasks_failed': self._stats['tasks_failed'], 'is_shutdown': self._shutdown } def get_active_tasks(self) -> Dict[str, Dict[str, Any]]: """ Get information about currently active tasks. Returns: Dictionary mapping task names to task information """ active_tasks = {} current_time = datetime.now(timezone.utc) for task in self._tasks: if task.done(): continue task_name = self._task_names.get(task, f"task-{id(task)}") created_at = self._task_created_at.get(task) age_seconds = (current_time - created_at).total_seconds() if created_at else None active_tasks[task_name] = { 'task_id': id(task), 'created_at': created_at, 'age_seconds': age_seconds, 'done': task.done(), 'cancelled': task.cancelled() } return active_tasks def __len__(self) -> int: """Return the number of active tasks.""" return len(self._tasks) def __bool__(self) -> bool: """Return True if there are active tasks.""" return bool(self._tasks) def __repr__(self) -> str: """String representation of the task manager.""" return f"TaskManager(name='{self.name}', active_tasks={len(self._tasks)}, shutdown={self._shutdown})"