diff --git a/auto_tune_vllm/execution/backends.py b/auto_tune_vllm/execution/backends.py index e325394..7aca5d6 100644 --- a/auto_tune_vllm/execution/backends.py +++ b/auto_tune_vllm/execution/backends.py @@ -11,7 +11,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from ..core.trial import TrialConfig, TrialResult @@ -522,12 +522,16 @@ def shutdown(self): class LocalExecutionBackend(ExecutionBackend): """Local execution backend using thread/process pool.""" + CANCELLATION_DETECTION_WAIT = 5 + GRACEFUL_CLEANUP_TIMEOUT = 30 + def __init__(self, max_concurrent: int = 1): self.max_concurrent = max_concurrent self.executor = concurrent.futures.ThreadPoolExecutor( max_workers=max_concurrent ) - self.active_futures: Dict[str, concurrent.futures.Future] = {} + self.active_futures: Dict[str, concurrent.futures.Future[TrialResult]] = {} + self.active_controllers: Dict[str, Any] = {} def submit_trial(self, trial_config: TrialConfig) -> JobHandle: """Submit trial for local execution.""" @@ -540,6 +544,7 @@ def submit_trial(self, trial_config: TrialConfig) -> JobHandle: job_id = str(id(future)) # Use future object ID as job ID self.active_futures[job_id] = future + self.active_controllers[job_id] = controller logger.info(f"Submitted trial {trial_config.trial_id} for local execution") return JobHandle(trial_config.trial_id, job_id) @@ -564,6 +569,7 @@ def poll_trials( logger.info(f"Completed local trial {handle.trial_id}") # Remove from active futures del self.active_futures[handle.backend_job_id] + self.active_controllers.pop(handle.backend_job_id, None) except Exception as e: # Trial failed - create error result from ..core.trial import ExecutionInfo, TrialResult @@ -580,30 +586,77 @@ def poll_trials( logger.error(f"Local trial {handle.trial_id} failed: {e}") # Remove from active futures del self.active_futures[handle.backend_job_id] + self.active_controllers.pop(handle.backend_job_id, None) else: remaining_handles.append(handle) return completed_results, remaining_handles def cleanup_all_trials(self): - """Force cleanup of all active local trials by cancelling running futures.""" + """Force cleanup of all active local trials.""" if not self.active_futures: logger.debug("No active local trials to cleanup") return - logger.info(f"Cancelling {len(self.active_futures)} active local trial(s)") + logger.info(f"Cleaning up {len(self.active_futures)} active local trial(s)") + + running_jobs = [ + (job_id, future) + for job_id, future in self.active_futures.items() + if not future.done() + ] + + for job_id, _future in running_jobs: + controller = self.active_controllers.get(job_id) + if controller is None: + continue - # Cancel all running futures - for job_id, future in list(self.active_futures.items()): try: - if not future.done(): - future.cancel() - logger.debug(f"Cancelled local trial {job_id}") + controller.request_cancellation() + logger.debug(f"Requested cancellation for local trial {job_id}") except Exception as e: - logger.warning(f"Failed to cancel local trial {job_id}: {e}") + logger.warning( + f"Failed to request cancellation for local trial {job_id}: {e}" + ) + + if running_jobs: + logger.info( + f"Waiting {self.CANCELLATION_DETECTION_WAIT}s for local trials " + "to detect cancellation..." + ) + done, not_done = concurrent.futures.wait( + [future for _, future in running_jobs], + timeout=self.CANCELLATION_DETECTION_WAIT, + ) + logger.info( + f"{len(done)} local trial(s) stopped after cancellation request, " + f"{len(not_done)} still running" + ) + + for job_id, future in running_jobs: + if future.done(): + continue + + controller = self.active_controllers.get(job_id) + if controller is None: + continue + + try: + controller.cleanup_resources() + logger.debug(f"Forced cleanup for local trial {job_id}") + except Exception as e: + logger.warning( + f"Failed forced cleanup for local trial {job_id}: {e}" + ) + + if any(not future.done() for _, future in running_jobs): + concurrent.futures.wait( + [future for _, future in running_jobs], + timeout=self.GRACEFUL_CLEANUP_TIMEOUT, + ) - # Clear the tracking collection self.active_futures.clear() + self.active_controllers.clear() logger.info("Completed cleanup of all active local trials") def shutdown(self):