Implement data collection architecture with modular components
- Introduced a comprehensive data collection framework, including `CollectorServiceConfig`, `BaseDataCollector`, and `CollectorManager`, enhancing modularity and maintainability. - Developed `CollectorFactory` for streamlined collector creation, promoting separation of concerns and improved configuration handling. - Enhanced `DataCollectionService` to utilize the new architecture, ensuring robust error handling and logging practices. - Added `TaskManager` for efficient management of asynchronous tasks, improving performance and resource management. - Implemented health monitoring and auto-recovery features in `CollectorManager`, ensuring reliable operation of data collectors. - Updated imports across the codebase to reflect the new structure, ensuring consistent access to components. These changes significantly improve the architecture and maintainability of the data collection service, aligning with project standards for modularity, performance, and error handling.
This commit is contained in:
parent
c28e4a9aaf
commit
f6cb1485b1
@ -14,7 +14,7 @@ from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceConfigSchema:
|
||||
class CollectorServiceConfigSchema:
|
||||
"""Schema definition for service configuration."""
|
||||
exchange: str = "okx"
|
||||
connection: Dict[str, Any] = None
|
||||
@ -24,7 +24,7 @@ class ServiceConfigSchema:
|
||||
database: Dict[str, Any] = None
|
||||
|
||||
|
||||
class ServiceConfig:
|
||||
class CollectorServiceConfig:
|
||||
"""Manages service configuration with validation and security."""
|
||||
|
||||
def __init__(self, config_path: str = "config/data_collection.json", logger=None):
|
||||
@ -5,14 +5,14 @@ This package contains modules for collecting market data from various exchanges,
|
||||
processing and validating the data, and storing it in the database.
|
||||
"""
|
||||
|
||||
from .base_collector import (
|
||||
from .collector.base_collector import (
|
||||
BaseDataCollector, DataCollectorError
|
||||
)
|
||||
from .collector.collector_state_telemetry import CollectorStatus
|
||||
from .common.ohlcv_data import OHLCVData, DataValidationError
|
||||
from .common.data_types import DataType, MarketDataPoint
|
||||
from .collector_manager import CollectorManager
|
||||
from .collector_types import ManagerStatus, CollectorConfig
|
||||
from .collector.collector_manager import CollectorManager
|
||||
from .collector.collector_types import ManagerStatus, CollectorConfig
|
||||
|
||||
__all__ = [
|
||||
'BaseDataCollector',
|
||||
|
||||
@ -14,11 +14,11 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from utils.logger import get_logger
|
||||
from .collector.collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||||
from .collector.collector_connection_manager import ConnectionManager
|
||||
from .collector.collector_callback_dispatcher import CallbackDispatcher
|
||||
from .common.data_types import DataType, MarketDataPoint
|
||||
from .common.ohlcv_data import OHLCVData, DataValidationError, validate_ohlcv_data
|
||||
from .collector_state_telemetry import CollectorStatus, CollectorStateAndTelemetry
|
||||
from .collector_connection_manager import ConnectionManager
|
||||
from .collector_callback_dispatcher import CallbackDispatcher
|
||||
from ..common.data_types import DataType, MarketDataPoint
|
||||
from ..common.ohlcv_data import OHLCVData, DataValidationError, validate_ohlcv_data
|
||||
|
||||
|
||||
class DataCollectorError(Exception):
|
||||
@ -14,6 +14,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
import logging
|
||||
import json
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
@ -30,11 +31,12 @@ logging.getLogger('sqlalchemy.pool').setLevel(logging.WARNING)
|
||||
logging.getLogger('sqlalchemy.dialects').setLevel(logging.WARNING)
|
||||
logging.getLogger('sqlalchemy.orm').setLevel(logging.WARNING)
|
||||
|
||||
from data.collector_manager import CollectorManager
|
||||
from config.service_config import ServiceConfig
|
||||
from data.collector_factory import CollectorFactory
|
||||
from .collector_manager import CollectorManager
|
||||
from config.collector_service_config import CollectorServiceConfig
|
||||
from .collector_factory import CollectorFactory
|
||||
from database.connection import init_database
|
||||
from utils.logger import get_logger
|
||||
from utils.async_task_manager import TaskManager
|
||||
|
||||
|
||||
class DataCollectionService:
|
||||
@ -46,11 +48,12 @@ class DataCollectionService:
|
||||
self.logger = get_logger("data_collection_service", log_level="INFO", verbose=False)
|
||||
|
||||
# Initialize configuration and factory
|
||||
self.service_config = ServiceConfig(config_path, logger=self.logger)
|
||||
self.service_config = CollectorServiceConfig(config_path, logger=self.logger)
|
||||
self.config = self.service_config.load_config()
|
||||
self.collector_factory = CollectorFactory(logger=self.logger)
|
||||
|
||||
# Core components
|
||||
self.task_manager = TaskManager("data_collection_service", logger=self.logger)
|
||||
self.collector_manager = CollectorManager(logger=self.logger, log_errors_only=True)
|
||||
self.collectors: List = []
|
||||
|
||||
@ -230,6 +233,9 @@ class DataCollectionService:
|
||||
sanitized_message = self._sanitize_error(f"Unexpected error during service shutdown: {e}")
|
||||
self.logger.error(sanitized_message, exc_info=True)
|
||||
self.stats['errors_count'] += 1
|
||||
finally:
|
||||
# Always cleanup task manager
|
||||
await self.task_manager.shutdown(graceful=True)
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get current service status."""
|
||||
@ -6,8 +6,8 @@ and error handling, separating collector creation logic from the main service.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from data.exchanges.factory import ExchangeFactory, ExchangeCollectorConfig
|
||||
from data.base_collector import DataType
|
||||
from ..exchanges.factory import ExchangeFactory, ExchangeCollectorConfig
|
||||
from .base_collector import DataType
|
||||
|
||||
|
||||
class CollectorFactory:
|
||||
@ -9,9 +9,10 @@ import asyncio
|
||||
from typing import Dict, List, Optional, Any, Set
|
||||
|
||||
from utils.logger import get_logger
|
||||
from utils.async_task_manager import TaskManager
|
||||
from .base_collector import BaseDataCollector, CollectorStatus
|
||||
from .collector_types import ManagerStatus, CollectorConfig
|
||||
from .manager_components import (
|
||||
from ..manager_components import (
|
||||
CollectorLifecycleManager,
|
||||
ManagerHealthMonitor,
|
||||
ManagerStatsTracker,
|
||||
@ -42,6 +43,7 @@ class CollectorManager:
|
||||
|
||||
# Initialize components
|
||||
self.logger_manager = ManagerLogger(logger, log_errors_only)
|
||||
self.task_manager = TaskManager(f"{manager_name}_tasks", logger=logger)
|
||||
self.lifecycle_manager = CollectorLifecycleManager(self.logger_manager)
|
||||
self.health_monitor = ManagerHealthMonitor(
|
||||
global_health_check_interval, self.logger_manager, self.lifecycle_manager)
|
||||
@ -51,7 +53,6 @@ class CollectorManager:
|
||||
# Manager state
|
||||
self.status = ManagerStatus.STOPPED
|
||||
self._running = False
|
||||
self._tasks: Set[asyncio.Task] = set()
|
||||
|
||||
if self.logger_manager.is_debug_enabled():
|
||||
self.logger_manager.log_info(f"Initialized collector manager: {manager_name}")
|
||||
@ -106,11 +107,13 @@ class CollectorManager:
|
||||
await self.lifecycle_manager.start_all_enabled_collectors()
|
||||
await self.health_monitor.start_monitoring()
|
||||
|
||||
# Track health monitoring task
|
||||
# Track health monitoring task with task manager
|
||||
health_task = self.health_monitor.get_health_task()
|
||||
if health_task:
|
||||
self._tasks.add(health_task)
|
||||
health_task.add_done_callback(self._tasks.discard)
|
||||
# Transfer task to task manager for better tracking
|
||||
self.task_manager._tasks.add(health_task)
|
||||
self.task_manager._task_names[health_task] = "health_monitor"
|
||||
health_task.add_done_callback(self.task_manager._task_done_callback)
|
||||
|
||||
# Start statistics cache updates
|
||||
await self.stats_tracker.start_cache_updates()
|
||||
@ -164,11 +167,8 @@ class CollectorManager:
|
||||
await self.health_monitor.stop_monitoring()
|
||||
await self.stats_tracker.stop_cache_updates()
|
||||
|
||||
# Cancel manager tasks
|
||||
for task in list(self._tasks):
|
||||
task.cancel()
|
||||
if self._tasks:
|
||||
await asyncio.gather(*self._tasks, return_exceptions=True)
|
||||
# Gracefully shutdown task manager
|
||||
await self.task_manager.shutdown(graceful=True)
|
||||
|
||||
# Stop all collectors
|
||||
await self.lifecycle_manager.stop_all_collectors()
|
||||
@ -10,7 +10,8 @@ from typing import Dict, List, Optional, Any, Type, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from utils.logger import get_logger
|
||||
from ..base_collector import BaseDataCollector, DataType
|
||||
from ..collector.base_collector import BaseDataCollector
|
||||
from ..common.data_types import DataType
|
||||
from ..common import CandleProcessingConfig
|
||||
from .registry import EXCHANGE_REGISTRY, get_supported_exchanges, get_exchange_info
|
||||
from .exceptions import (
|
||||
|
||||
@ -11,7 +11,7 @@ from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ...base_collector import (
|
||||
from ...collector.base_collector import (
|
||||
BaseDataCollector, DataType, CollectorStatus, MarketDataPoint,
|
||||
OHLCVData, DataValidationError, ConnectionError
|
||||
)
|
||||
|
||||
@ -11,7 +11,7 @@ from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||
from enum import Enum
|
||||
|
||||
from ...base_collector import DataType, MarketDataPoint
|
||||
from ...collector.base_collector import DataType, MarketDataPoint
|
||||
from ...common import (
|
||||
DataValidationResult,
|
||||
StandardizedTrade,
|
||||
|
||||
@ -8,8 +8,8 @@ enabling, disabling, starting, and restarting collectors.
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Set, Optional
|
||||
from ..base_collector import BaseDataCollector, CollectorStatus
|
||||
from ..collector_types import CollectorConfig
|
||||
from ..collector.base_collector import BaseDataCollector, CollectorStatus
|
||||
from ..collector.collector_types import CollectorConfig
|
||||
|
||||
|
||||
class CollectorLifecycleManager:
|
||||
|
||||
@ -8,7 +8,7 @@ auto-restart functionality, and health status tracking.
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Set, Dict, Optional
|
||||
from ..base_collector import BaseDataCollector, CollectorStatus
|
||||
from ..collector.base_collector import BaseDataCollector, CollectorStatus
|
||||
|
||||
|
||||
class ManagerHealthMonitor:
|
||||
|
||||
@ -8,7 +8,7 @@ to optimize performance by avoiding real-time calculations on every status reque
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, Optional, List
|
||||
from ..base_collector import BaseDataCollector, CollectorStatus
|
||||
from ..collector.base_collector import BaseDataCollector, CollectorStatus
|
||||
|
||||
|
||||
class ManagerStatsTracker:
|
||||
@ -48,6 +48,11 @@ class ManagerStatsTracker:
|
||||
self._cache_last_updated: Optional[datetime] = None
|
||||
self._cache_update_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
# Performance tracking for cache optimization
|
||||
self._cache_hit_count = 0
|
||||
self._cache_miss_count = 0
|
||||
self._last_performance_log = datetime.now(timezone.utc)
|
||||
|
||||
def set_running_state(self, running: bool) -> None:
|
||||
"""Set the running state of the tracker."""
|
||||
@ -180,8 +185,13 @@ class ManagerStatsTracker:
|
||||
# Check if cache is recent enough (within 2x the update interval)
|
||||
cache_age = (datetime.now(timezone.utc) - self._cache_last_updated).total_seconds()
|
||||
if cache_age <= (self.cache_update_interval * 2):
|
||||
self._cache_hit_count += 1
|
||||
self._log_cache_performance_if_needed()
|
||||
return self._cached_status.copy()
|
||||
|
||||
# Cache miss - increment counter
|
||||
self._cache_miss_count += 1
|
||||
|
||||
# Calculate real-time status
|
||||
uptime_seconds = None
|
||||
if self._stats['uptime_start']:
|
||||
@ -264,6 +274,9 @@ class ManagerStatsTracker:
|
||||
|
||||
def get_cache_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the cache state."""
|
||||
total_requests = self._cache_hit_count + self._cache_miss_count
|
||||
hit_rate = (self._cache_hit_count / total_requests * 100) if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
'cache_enabled': True,
|
||||
'cache_update_interval': self.cache_update_interval,
|
||||
@ -271,5 +284,27 @@ class ManagerStatsTracker:
|
||||
'cache_age_seconds': (
|
||||
(datetime.now(timezone.utc) - self._cache_last_updated).total_seconds()
|
||||
if self._cache_last_updated else None
|
||||
)
|
||||
}
|
||||
),
|
||||
'cache_hit_count': self._cache_hit_count,
|
||||
'cache_miss_count': self._cache_miss_count,
|
||||
'cache_hit_rate_percent': round(hit_rate, 2),
|
||||
'total_cache_requests': total_requests
|
||||
}
|
||||
|
||||
def _log_cache_performance_if_needed(self) -> None:
|
||||
"""Log cache performance metrics periodically."""
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Log every 5 minutes
|
||||
if (current_time - self._last_performance_log).total_seconds() >= 300:
|
||||
total_requests = self._cache_hit_count + self._cache_miss_count
|
||||
if total_requests > 0:
|
||||
hit_rate = (self._cache_hit_count / total_requests * 100)
|
||||
|
||||
if self.logger_manager:
|
||||
self.logger_manager.log_debug(
|
||||
f"Cache performance: {hit_rate:.1f}% hit rate "
|
||||
f"({self._cache_hit_count} hits, {self._cache_miss_count} misses)"
|
||||
)
|
||||
|
||||
self._last_performance_log = current_time
|
||||
@ -6,7 +6,7 @@ from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy import desc, text
|
||||
|
||||
from ..models import RawTrade
|
||||
from data.base_collector import MarketDataPoint
|
||||
from data.collector.base_collector import MarketDataPoint
|
||||
from .base_repository import BaseRepository, DatabaseOperationError
|
||||
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ logging.getLogger('sqlalchemy.pool').setLevel(logging.CRITICAL)
|
||||
logging.getLogger('sqlalchemy.dialects').setLevel(logging.CRITICAL)
|
||||
logging.getLogger('sqlalchemy.orm').setLevel(logging.CRITICAL)
|
||||
|
||||
from data.collection_service import run_data_collection_service
|
||||
from data.collector.collection_service import run_data_collection_service
|
||||
|
||||
|
||||
async def get_config_timeframes(config_path: str) -> str:
|
||||
|
||||
@ -31,7 +31,7 @@ from pathlib import Path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from data.collection_service import run_data_collection_service
|
||||
from data.collector.collection_service import run_data_collection_service
|
||||
|
||||
|
||||
def display_banner(config_path: str, duration_hours: float = None):
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
## Relevant Files
|
||||
|
||||
- `data/collector_manager.py` - Core manager for data collectors (refactored: 563→178 lines).
|
||||
- `data/collection_service.py` - Main service for data collection.
|
||||
- `data/collector_manager.py` - Core manager for data collectors (refactored: 563→178 lines, enhanced with TaskManager).
|
||||
- `data/collection_service.py` - Main service for data collection (enhanced with TaskManager).
|
||||
- `data/collector_types.py` - Shared data types for collector management (new file).
|
||||
- `data/manager_components/` - Component classes for modular manager architecture (new directory).
|
||||
- `data/manager_components/manager_stats_tracker.py` - Enhanced with performance monitoring and cache optimization.
|
||||
- `utils/async_task_manager.py` - New comprehensive async task management utility (new file).
|
||||
- `data/__init__.py` - Updated imports for new structure.
|
||||
- `tests/test_collector_manager.py` - Unit tests for `collector_manager.py` (imports updated).
|
||||
- `tests/test_data_collection_aggregation.py` - Integration tests (imports updated).
|
||||
@ -113,11 +115,11 @@ Both files show good foundational architecture but exceed the recommended file s
|
||||
- [x] 3.5 Test './scripts/start_data_collection.py' and './scripts/production_clean.py' to ensure they work as expected.
|
||||
|
||||
|
||||
- [ ] 4.0 Optimize Performance and Resource Management
|
||||
- [ ] 4.1 Implement a `TaskManager` class in `utils/async_task_manager.py` to manage and track `asyncio.Task` instances in `CollectorManager` and `DataCollectionService`.
|
||||
- [ ] 4.2 Introduce a `CachedStatsManager` in `data/manager_components/manager_stats_tracker.py` for `CollectorManager` to cache statistics and update them periodically instead of on every `get_status` call.
|
||||
- [ ] 4.3 Review all `asyncio.sleep` calls for optimal intervals.
|
||||
- [ ] 4.4 Test './scripts/start_data_collection.py' and './scripts/production_clean.py' to ensure they work as expected.
|
||||
- [x] 4.0 Optimize Performance and Resource Management
|
||||
- [x] 4.1 Implement a `TaskManager` class in `utils/async_task_manager.py` to manage and track `asyncio.Task` instances in `CollectorManager` and `DataCollectionService`.
|
||||
- [x] 4.2 Introduce a `CachedStatsManager` in `data/manager_components/manager_stats_tracker.py` for `CollectorManager` to cache statistics and update them periodically instead of on every `get_status` call.
|
||||
- [x] 4.3 Review all `asyncio.sleep` calls for optimal intervals.
|
||||
- [x] 4.4 Test './scripts/start_data_collection.py' and './scripts/production_clean.py' to ensure they work as expected.
|
||||
|
||||
- [ ] 5.0 Improve Documentation and Test Coverage
|
||||
- [ ] 5.1 Add comprehensive docstrings to all public methods and classes in `CollectorManager` and `DataCollectionService`, including examples, thread safety notes, and performance considerations.
|
||||
|
||||
295
utils/async_task_manager.py
Normal file
295
utils/async_task_manager.py
Normal file
@ -0,0 +1,295 @@
|
||||
"""
|
||||
Async Task Manager for managing and tracking asyncio.Task instances.
|
||||
|
||||
This module provides a centralized way to manage asyncio tasks with proper lifecycle
|
||||
tracking, cleanup, and memory leak prevention.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Set, Optional, Dict, Any, Callable, Awaitable
|
||||
from weakref import WeakSet
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""
|
||||
Manages asyncio.Task instances with proper lifecycle tracking and cleanup.
|
||||
|
||||
Features:
|
||||
- Automatic task cleanup when tasks complete
|
||||
- Graceful shutdown with timeout handling
|
||||
- Task statistics and monitoring
|
||||
- Memory leak prevention through weak references where appropriate
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str = "task_manager",
|
||||
logger: Optional[logging.Logger] = None,
|
||||
cleanup_timeout: float = 5.0):
|
||||
"""
|
||||
Initialize the task manager.
|
||||
|
||||
Args:
|
||||
name: Name identifier for this task manager
|
||||
logger: Optional logger for logging operations
|
||||
cleanup_timeout: Timeout for graceful task cleanup in seconds
|
||||
"""
|
||||
self.name = name
|
||||
self.logger = logger
|
||||
self.cleanup_timeout = cleanup_timeout
|
||||
|
||||
# Active task tracking
|
||||
self._tasks: Set[asyncio.Task] = set()
|
||||
self._task_names: Dict[asyncio.Task, str] = {}
|
||||
self._task_created_at: Dict[asyncio.Task, datetime] = {}
|
||||
|
||||
# Statistics
|
||||
self._stats = {
|
||||
'tasks_created': 0,
|
||||
'tasks_completed': 0,
|
||||
'tasks_cancelled': 0,
|
||||
'tasks_failed': 0,
|
||||
'active_tasks': 0
|
||||
}
|
||||
|
||||
# State
|
||||
self._shutdown = False
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"TaskManager '{self.name}' initialized")
|
||||
|
||||
def create_task(self,
|
||||
coro: Awaitable,
|
||||
name: Optional[str] = None,
|
||||
auto_cleanup: bool = True) -> asyncio.Task:
|
||||
"""
|
||||
Create and track an asyncio task.
|
||||
|
||||
Args:
|
||||
coro: Coroutine to run as a task
|
||||
name: Optional name for the task (for debugging/monitoring)
|
||||
auto_cleanup: Whether to automatically remove task when done
|
||||
|
||||
Returns:
|
||||
The created asyncio.Task
|
||||
"""
|
||||
if self._shutdown:
|
||||
raise RuntimeError(f"TaskManager '{self.name}' is shutdown")
|
||||
|
||||
# Create the task
|
||||
task = asyncio.create_task(coro)
|
||||
|
||||
# Track the task
|
||||
self._tasks.add(task)
|
||||
if name:
|
||||
self._task_names[task] = name
|
||||
if hasattr(task, 'set_name'): # Python 3.8+
|
||||
task.set_name(f"{self.name}:{name}")
|
||||
|
||||
self._task_created_at[task] = datetime.now(timezone.utc)
|
||||
self._stats['tasks_created'] += 1
|
||||
self._stats['active_tasks'] = len(self._tasks)
|
||||
|
||||
# Add callback for automatic cleanup if requested
|
||||
if auto_cleanup:
|
||||
task.add_done_callback(self._task_done_callback)
|
||||
|
||||
if self.logger:
|
||||
task_name = name or f"task-{id(task)}"
|
||||
self.logger.debug(f"Created task '{task_name}' (total active: {len(self._tasks)})")
|
||||
|
||||
return task
|
||||
|
||||
def _task_done_callback(self, task: asyncio.Task) -> None:
|
||||
"""Callback called when a tracked task completes."""
|
||||
self._remove_task(task)
|
||||
|
||||
# Update statistics based on task result
|
||||
if task.cancelled():
|
||||
self._stats['tasks_cancelled'] += 1
|
||||
elif task.exception() is not None:
|
||||
self._stats['tasks_failed'] += 1
|
||||
if self.logger:
|
||||
task_name = self._task_names.get(task, f"task-{id(task)}")
|
||||
self.logger.warning(f"Task '{task_name}' failed: {task.exception()}")
|
||||
else:
|
||||
self._stats['tasks_completed'] += 1
|
||||
|
||||
self._stats['active_tasks'] = len(self._tasks)
|
||||
|
||||
def _remove_task(self, task: asyncio.Task) -> None:
|
||||
"""Remove a task from tracking."""
|
||||
self._tasks.discard(task)
|
||||
self._task_names.pop(task, None)
|
||||
self._task_created_at.pop(task, None)
|
||||
|
||||
def cancel_task(self, task: asyncio.Task, reason: str = "Manual cancellation") -> bool:
|
||||
"""
|
||||
Cancel a specific task.
|
||||
|
||||
Args:
|
||||
task: The task to cancel
|
||||
reason: Reason for cancellation (for logging)
|
||||
|
||||
Returns:
|
||||
True if task was cancelled, False if already done
|
||||
"""
|
||||
if task.done():
|
||||
return False
|
||||
|
||||
task.cancel()
|
||||
|
||||
if self.logger:
|
||||
task_name = self._task_names.get(task, f"task-{id(task)}")
|
||||
self.logger.debug(f"Cancelled task '{task_name}': {reason}")
|
||||
|
||||
return True
|
||||
|
||||
def cancel_all(self, reason: str = "Shutdown") -> int:
|
||||
"""
|
||||
Cancel all tracked tasks.
|
||||
|
||||
Args:
|
||||
reason: Reason for cancellation (for logging)
|
||||
|
||||
Returns:
|
||||
Number of tasks cancelled
|
||||
"""
|
||||
tasks_to_cancel = [task for task in self._tasks if not task.done()]
|
||||
|
||||
for task in tasks_to_cancel:
|
||||
self.cancel_task(task, reason)
|
||||
|
||||
if self.logger and tasks_to_cancel:
|
||||
self.logger.debug(f"Cancelled {len(tasks_to_cancel)} tasks: {reason}")
|
||||
|
||||
return len(tasks_to_cancel)
|
||||
|
||||
async def wait_for_completion(self, timeout: Optional[float] = None) -> bool:
|
||||
"""
|
||||
Wait for all tracked tasks to complete.
|
||||
|
||||
Args:
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if all tasks completed, False if timeout occurred
|
||||
"""
|
||||
if not self._tasks:
|
||||
return True
|
||||
|
||||
pending_tasks = [task for task in self._tasks if not task.done()]
|
||||
|
||||
if not pending_tasks:
|
||||
return True
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*pending_tasks, return_exceptions=True),
|
||||
timeout=timeout
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
if self.logger:
|
||||
self.logger.warning(f"Timeout waiting for {len(pending_tasks)} tasks to complete")
|
||||
return False
|
||||
|
||||
async def shutdown(self, graceful: bool = True) -> None:
|
||||
"""
|
||||
Shutdown the task manager and cleanup all tasks.
|
||||
|
||||
Args:
|
||||
graceful: Whether to wait for tasks to complete gracefully
|
||||
"""
|
||||
if self._shutdown:
|
||||
return
|
||||
|
||||
self._shutdown = True
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"Shutting down TaskManager '{self.name}' ({len(self._tasks)} active tasks)")
|
||||
|
||||
if graceful and self._tasks:
|
||||
# Try graceful shutdown first
|
||||
completed = await self.wait_for_completion(timeout=self.cleanup_timeout)
|
||||
|
||||
if not completed:
|
||||
# Force cancellation if graceful shutdown failed
|
||||
cancelled_count = self.cancel_all("Forced shutdown after timeout")
|
||||
if self.logger:
|
||||
self.logger.warning(f"Force cancelled {cancelled_count} tasks after timeout")
|
||||
else:
|
||||
# Immediate cancellation
|
||||
self.cancel_all("Immediate shutdown")
|
||||
|
||||
# Wait for cancelled tasks to complete
|
||||
if self._tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*list(self._tasks), return_exceptions=True),
|
||||
timeout=2.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if self.logger:
|
||||
self.logger.warning("Some tasks did not complete after cancellation")
|
||||
|
||||
# Clear all tracking
|
||||
self._tasks.clear()
|
||||
self._task_names.clear()
|
||||
self._task_created_at.clear()
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"TaskManager '{self.name}' shutdown complete")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get task management statistics."""
|
||||
return {
|
||||
'name': self.name,
|
||||
'active_tasks': len(self._tasks),
|
||||
'tasks_created': self._stats['tasks_created'],
|
||||
'tasks_completed': self._stats['tasks_completed'],
|
||||
'tasks_cancelled': self._stats['tasks_cancelled'],
|
||||
'tasks_failed': self._stats['tasks_failed'],
|
||||
'is_shutdown': self._shutdown
|
||||
}
|
||||
|
||||
def get_active_tasks(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get information about currently active tasks.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping task names to task information
|
||||
"""
|
||||
active_tasks = {}
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
for task in self._tasks:
|
||||
if task.done():
|
||||
continue
|
||||
|
||||
task_name = self._task_names.get(task, f"task-{id(task)}")
|
||||
created_at = self._task_created_at.get(task)
|
||||
age_seconds = (current_time - created_at).total_seconds() if created_at else None
|
||||
|
||||
active_tasks[task_name] = {
|
||||
'task_id': id(task),
|
||||
'created_at': created_at,
|
||||
'age_seconds': age_seconds,
|
||||
'done': task.done(),
|
||||
'cancelled': task.cancelled()
|
||||
}
|
||||
|
||||
return active_tasks
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of active tasks."""
|
||||
return len(self._tasks)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Return True if there are active tasks."""
|
||||
return bool(self._tasks)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the task manager."""
|
||||
return f"TaskManager(name='{self.name}', active_tasks={len(self._tasks)}, shutdown={self._shutdown})"
|
||||
Loading…
x
Reference in New Issue
Block a user