Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions auto_tune_vllm/execution/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from ..core.trial import TrialConfig, TrialResult

Expand Down Expand Up @@ -522,12 +522,16 @@ def shutdown(self):
class LocalExecutionBackend(ExecutionBackend):
"""Local execution backend using thread/process pool."""

CANCELLATION_DETECTION_WAIT = 5
GRACEFUL_CLEANUP_TIMEOUT = 30

def __init__(self, max_concurrent: int = 1):
self.max_concurrent = max_concurrent
self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=max_concurrent
)
self.active_futures: Dict[str, concurrent.futures.Future] = {}
self.active_futures: Dict[str, concurrent.futures.Future[TrialResult]] = {}
self.active_controllers: Dict[str, Any] = {}

def submit_trial(self, trial_config: TrialConfig) -> JobHandle:
"""Submit trial for local execution."""
Expand All @@ -540,6 +544,7 @@ def submit_trial(self, trial_config: TrialConfig) -> JobHandle:

job_id = str(id(future)) # Use future object ID as job ID
self.active_futures[job_id] = future
self.active_controllers[job_id] = controller

logger.info(f"Submitted trial {trial_config.trial_id} for local execution")
return JobHandle(trial_config.trial_id, job_id)
Expand All @@ -564,6 +569,7 @@ def poll_trials(
logger.info(f"Completed local trial {handle.trial_id}")
# Remove from active futures
del self.active_futures[handle.backend_job_id]
self.active_controllers.pop(handle.backend_job_id, None)
except Exception as e:
# Trial failed - create error result
from ..core.trial import ExecutionInfo, TrialResult
Expand All @@ -580,30 +586,77 @@ def poll_trials(
logger.error(f"Local trial {handle.trial_id} failed: {e}")
# Remove from active futures
del self.active_futures[handle.backend_job_id]
self.active_controllers.pop(handle.backend_job_id, None)
else:
remaining_handles.append(handle)

return completed_results, remaining_handles

def cleanup_all_trials(self):
"""Force cleanup of all active local trials by cancelling running futures."""
"""Force cleanup of all active local trials."""
if not self.active_futures:
logger.debug("No active local trials to cleanup")
return

logger.info(f"Cancelling {len(self.active_futures)} active local trial(s)")
logger.info(f"Cleaning up {len(self.active_futures)} active local trial(s)")

running_jobs = [
(job_id, future)
for job_id, future in self.active_futures.items()
if not future.done()
]

for job_id, _future in running_jobs:
controller = self.active_controllers.get(job_id)
if controller is None:
continue

# Cancel all running futures
for job_id, future in list(self.active_futures.items()):
try:
if not future.done():
future.cancel()
logger.debug(f"Cancelled local trial {job_id}")
controller.request_cancellation()
logger.debug(f"Requested cancellation for local trial {job_id}")
except Exception as e:
logger.warning(f"Failed to cancel local trial {job_id}: {e}")
logger.warning(
f"Failed to request cancellation for local trial {job_id}: {e}"
)

if running_jobs:
logger.info(
f"Waiting {self.CANCELLATION_DETECTION_WAIT}s for local trials "
"to detect cancellation..."
)
done, not_done = concurrent.futures.wait(
[future for _, future in running_jobs],
timeout=self.CANCELLATION_DETECTION_WAIT,
)
logger.info(
f"{len(done)} local trial(s) stopped after cancellation request, "
f"{len(not_done)} still running"
)

for job_id, future in running_jobs:
if future.done():
continue

controller = self.active_controllers.get(job_id)
if controller is None:
continue

try:
controller.cleanup_resources()
logger.debug(f"Forced cleanup for local trial {job_id}")
except Exception as e:
logger.warning(
f"Failed forced cleanup for local trial {job_id}: {e}"
)

if any(not future.done() for _, future in running_jobs):
concurrent.futures.wait(
[future for _, future in running_jobs],
timeout=self.GRACEFUL_CLEANUP_TIMEOUT,
)

# Clear the tracking collection
self.active_futures.clear()
self.active_controllers.clear()
logger.info("Completed cleanup of all active local trials")

def shutdown(self):
Expand Down
Loading