Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions auto_tune_vllm/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..core.study_controller import StudyController
from ..execution.backends import RayExecutionBackend
from ..logging.manager import CentralizedLogger, LogStreamer
from ..utils.grid_cardinality import get_parameter_grid_cardinality

# Setup rich console and app
console = Console()
Expand Down Expand Up @@ -369,7 +370,40 @@ def run_optimization_sync(
create_db: bool = False,
):
"""Synchronous optimization runner with progress display."""
# Create study controller
total_trials = n_trials or config.optimization.n_trials
cardinality = get_parameter_grid_cardinality(config)
if total_trials >= cardinality:
requested = total_trials
config.optimization.sampler = "grid"
config.optimization.n_trials = cardinality
config.optimization.n_startup_trials = min(
config.optimization.n_startup_trials, max(0, cardinality - 1)
)
n_trials = cardinality
total_trials = cardinality
console.print(
f"[yellow]n_trials ({requested}) meets or exceeds grid cardinality "
f"({cardinality}). "
"Search set to grid mode with n_trials = cardinality.[/yellow]"
)
elif (
config.optimization.sampler.lower() in ("tpe", "gp", "botorch")
and total_trials <= config.optimization.n_startup_trials
):
startup_before = config.optimization.n_startup_trials
prev_sampler = config.optimization.sampler
config.optimization.sampler = "random"
config.optimization.n_trials = total_trials
config.optimization.n_startup_trials = min(
startup_before, max(0, total_trials - 1)
)
console.print(
f"[yellow]Auto-switched sampler from '{prev_sampler}' to 'random': "
f"n_trials ({total_trials}) is <= n_startup_trials ({startup_before}), "
"so startup sampling would consume the full trial budget. "
f"n_startup_trials is now {config.optimization.n_startup_trials}.[/yellow]"
)
# Create study controller (uses config with possibly updated sampler/n_trials)
controller = StudyController.create_from_config(
backend, config, create_db=create_db
)
Expand All @@ -381,8 +415,6 @@ def run_optimization_sync(
_display_log_viewing_instructions(config)
console.print() # Add blank line for better readability

total_trials = n_trials or config.optimization.n_trials

with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
Expand Down Expand Up @@ -1086,6 +1118,11 @@ def validate_command(
"Optimization", f"{opt_summary} ({study_config.optimization.sampler})"
)
table.add_row("Trials", str(study_config.optimization.n_trials))
cardinality = get_parameter_grid_cardinality(study_config)
table.add_row(
"Possible combinations (grid cardinality)",
str(cardinality),
)
table.add_row("Model", study_config.benchmark.model)
table.add_row(
"Parameters",
Expand Down
2 changes: 2 additions & 0 deletions auto_tune_vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Utilities for auto-tune-vllm package.
"""

from .grid_cardinality import get_parameter_grid_cardinality
from .version_manager import VLLMDefaultsVersion, VLLMVersionManager
from .vllm_cli_parser import ArgumentType, CLIArgument, VLLMCLIParser

Expand All @@ -11,4 +12,5 @@
"ArgumentType",
"VLLMVersionManager",
"VLLMDefaultsVersion",
"get_parameter_grid_cardinality",
]
61 changes: 61 additions & 0 deletions auto_tune_vllm/utils/grid_cardinality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Compute the total cardinality of the parameter grid from a study config.
Uses the same logic as StudyController._create_search_space for consistency.
"""

from pathlib import Path
from typing import Union

from auto_tune_vllm.core.config import StudyConfig
from auto_tune_vllm.core.parameters import (
BooleanParameter,
EnvironmentParameter,
ListParameter,
ParameterConfig,
RangeParameter,
)

_MAX_GRID_SIZE = 10000


def _count_parameter_values(param: ParameterConfig) -> int:
"""Count distinct values for one parameter (mirrors _create_search_space)."""
if isinstance(param, (ListParameter, EnvironmentParameter)):
return len(param.options)
if isinstance(param, RangeParameter):
min_val = param.min_value
max_val = param.max_value
if param.data_type is float:
step = param.step
if step is None:
return _MAX_GRID_SIZE
n_steps = int(round((max_val - min_val) / step)) + 1
return min(n_steps, _MAX_GRID_SIZE)
step = param.step or 1
current = min_val
count = 0
while current <= max_val and count < _MAX_GRID_SIZE:
count += 1
current += step
return count
if isinstance(param, BooleanParameter):
return 2
raise ValueError(f"Unknown parameter type: {type(param)}")


def get_parameter_grid_cardinality(
config: Union[str, Path, StudyConfig],
vllm_version: str | None = None,
) -> int:
"""
Return the total number of points in the parameter grid (product of enabled params).

config: Path to YAML, or an already-loaded StudyConfig.
"""
if isinstance(config, (str, Path)):
config = StudyConfig.from_file(str(config), vllm_version=vllm_version)
total = 1
for param_config in config.parameters.values():
if param_config.enabled:
total *= _count_parameter_values(param_config)
return total