From cd44b8811476296e411ed36a6099408fcb136af6 Mon Sep 17 00:00:00 2001 From: Vincent Gimenes Date: Fri, 13 Mar 2026 18:44:23 +0100 Subject: [PATCH 1/3] add grid-cardinality-and-auto-switch Signed-off-by: Vincent Gimenes --- auto_tune_vllm/cli/main.py | 25 ++++++++-- auto_tune_vllm/utils/__init__.py | 2 + auto_tune_vllm/utils/grid_cardinality.py | 58 ++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 auto_tune_vllm/utils/grid_cardinality.py diff --git a/auto_tune_vllm/cli/main.py b/auto_tune_vllm/cli/main.py index cbc5d23..dc688c1 100644 --- a/auto_tune_vllm/cli/main.py +++ b/auto_tune_vllm/cli/main.py @@ -16,6 +16,7 @@ from ..core.config import StudyConfig from ..core.storage.postgres_utils import clear_study_data, verify_database_connection from ..core.study_controller import StudyController +from ..utils.grid_cardinality import get_parameter_grid_cardinality from ..execution.backends import RayExecutionBackend from ..logging.manager import CentralizedLogger, LogStreamer @@ -369,7 +370,22 @@ def run_optimization_sync( create_db: bool = False, ): """Synchronous optimization runner with progress display.""" - # Create study controller + total_trials = n_trials or config.optimization.n_trials + cardinality = get_parameter_grid_cardinality(config) + if total_trials > cardinality: + requested = total_trials + config.optimization.sampler = "grid" + config.optimization.n_trials = cardinality + config.optimization.n_startup_trials = min( + config.optimization.n_startup_trials, max(0, cardinality - 1) + ) + n_trials = cardinality + total_trials = cardinality + console.print( + f"[yellow]n_trials ({requested}) exceeds grid cardinality ({cardinality}). " + "Search set to grid mode with n_trials = cardinality.[/yellow]" + ) + # Create study controller (uses config with possibly updated sampler/n_trials) controller = StudyController.create_from_config( backend, config, create_db=create_db ) @@ -381,8 +397,6 @@ def run_optimization_sync( _display_log_viewing_instructions(config) console.print() # Add blank line for better readability - total_trials = n_trials or config.optimization.n_trials - with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -1086,6 +1100,11 @@ def validate_command( "Optimization", f"{opt_summary} ({study_config.optimization.sampler})" ) table.add_row("Trials", str(study_config.optimization.n_trials)) + cardinality = get_parameter_grid_cardinality(study_config) + table.add_row( + "Possible combinations (grid cardinality)", + str(cardinality), + ) table.add_row("Model", study_config.benchmark.model) table.add_row( "Parameters", diff --git a/auto_tune_vllm/utils/__init__.py b/auto_tune_vllm/utils/__init__.py index 7a5a58f..3d2a3ca 100644 --- a/auto_tune_vllm/utils/__init__.py +++ b/auto_tune_vllm/utils/__init__.py @@ -2,6 +2,7 @@ Utilities for auto-tune-vllm package. """ +from .grid_cardinality import get_parameter_grid_cardinality from .version_manager import VLLMDefaultsVersion, VLLMVersionManager from .vllm_cli_parser import ArgumentType, CLIArgument, VLLMCLIParser @@ -11,4 +12,5 @@ "ArgumentType", "VLLMVersionManager", "VLLMDefaultsVersion", + "get_parameter_grid_cardinality", ] diff --git a/auto_tune_vllm/utils/grid_cardinality.py b/auto_tune_vllm/utils/grid_cardinality.py new file mode 100644 index 0000000..1535fba --- /dev/null +++ b/auto_tune_vllm/utils/grid_cardinality.py @@ -0,0 +1,58 @@ +""" +Compute the total cardinality of the parameter grid from a study config. +Uses the same logic as StudyController._create_search_space for consistency. +""" + +from pathlib import Path +from typing import Union + +from auto_tune_vllm.core.config import StudyConfig +from auto_tune_vllm.core.parameters import ( + BooleanParameter, + EnvironmentParameter, + ListParameter, + ParameterConfig, + RangeParameter, +) + +_MAX_GRID_SIZE = 10000 + + +def _count_parameter_values(param: ParameterConfig) -> int: + """Count distinct values for one parameter (mirrors _create_search_space).""" + if isinstance(param, (ListParameter, EnvironmentParameter)): + return len(param.options) + if isinstance(param, RangeParameter): + min_val = param.min_value + max_val = param.max_value + step = param.step or 1 + if param.data_type == float: + n_steps = int((max_val - min_val) / step) + 1 + return min(n_steps, _MAX_GRID_SIZE) + current = min_val + count = 0 + while current <= max_val and count < _MAX_GRID_SIZE: + count += 1 + current += step + return count + if isinstance(param, BooleanParameter): + return 2 + return 0 + + +def get_parameter_grid_cardinality( + config: Union[str, Path, StudyConfig], + vllm_version: str | None = None, +) -> int: + """ + Return the total number of points in the parameter grid (product of enabled params). + + config: Path to YAML, or an already-loaded StudyConfig. + """ + if isinstance(config, (str, Path)): + config = StudyConfig.from_file(str(config), vllm_version=vllm_version) + total = 1 + for param_config in config.parameters.values(): + if param_config.enabled: + total *= _count_parameter_values(param_config) + return total From 3e8d3e5dd08f0c6d86185a8826c965abc48e16a9 Mon Sep 17 00:00:00 2001 From: Vincent Gimenes Date: Sat, 21 Mar 2026 19:24:04 +0100 Subject: [PATCH 2/3] fix ruff Signed-off-by: Vincent Gimenes --- auto_tune_vllm/cli/main.py | 2 +- auto_tune_vllm/utils/grid_cardinality.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_tune_vllm/cli/main.py b/auto_tune_vllm/cli/main.py index dc688c1..fa64d70 100644 --- a/auto_tune_vllm/cli/main.py +++ b/auto_tune_vllm/cli/main.py @@ -16,9 +16,9 @@ from ..core.config import StudyConfig from ..core.storage.postgres_utils import clear_study_data, verify_database_connection from ..core.study_controller import StudyController -from ..utils.grid_cardinality import get_parameter_grid_cardinality from ..execution.backends import RayExecutionBackend from ..logging.manager import CentralizedLogger, LogStreamer +from ..utils.grid_cardinality import get_parameter_grid_cardinality # Setup rich console and app console = Console() diff --git a/auto_tune_vllm/utils/grid_cardinality.py b/auto_tune_vllm/utils/grid_cardinality.py index 1535fba..5e3fe55 100644 --- a/auto_tune_vllm/utils/grid_cardinality.py +++ b/auto_tune_vllm/utils/grid_cardinality.py @@ -26,7 +26,7 @@ def _count_parameter_values(param: ParameterConfig) -> int: min_val = param.min_value max_val = param.max_value step = param.step or 1 - if param.data_type == float: + if param.data_type is float: n_steps = int((max_val - min_val) / step) + 1 return min(n_steps, _MAX_GRID_SIZE) current = min_val From 213675f3626872221f0679e85c749690375d460c Mon Sep 17 00:00:00 2001 From: Vincent Gimenes Date: Tue, 24 Mar 2026 14:56:54 +0100 Subject: [PATCH 3/3] apply rabbit ai feedbacks Signed-off-by: Vincent Gimenes --- auto_tune_vllm/cli/main.py | 22 ++++++++++++++++++++-- auto_tune_vllm/utils/grid_cardinality.py | 9 ++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/auto_tune_vllm/cli/main.py b/auto_tune_vllm/cli/main.py index fa64d70..0935ea0 100644 --- a/auto_tune_vllm/cli/main.py +++ b/auto_tune_vllm/cli/main.py @@ -372,7 +372,7 @@ def run_optimization_sync( """Synchronous optimization runner with progress display.""" total_trials = n_trials or config.optimization.n_trials cardinality = get_parameter_grid_cardinality(config) - if total_trials > cardinality: + if total_trials >= cardinality: requested = total_trials config.optimization.sampler = "grid" config.optimization.n_trials = cardinality @@ -382,9 +382,27 @@ def run_optimization_sync( n_trials = cardinality total_trials = cardinality console.print( - f"[yellow]n_trials ({requested}) exceeds grid cardinality ({cardinality}). " + f"[yellow]n_trials ({requested}) meets or exceeds grid cardinality " + f"({cardinality}). " "Search set to grid mode with n_trials = cardinality.[/yellow]" ) + elif ( + config.optimization.sampler.lower() in ("tpe", "gp", "botorch") + and total_trials <= config.optimization.n_startup_trials + ): + startup_before = config.optimization.n_startup_trials + prev_sampler = config.optimization.sampler + config.optimization.sampler = "random" + config.optimization.n_trials = total_trials + config.optimization.n_startup_trials = min( + startup_before, max(0, total_trials - 1) + ) + console.print( + f"[yellow]Auto-switched sampler from '{prev_sampler}' to 'random': " + f"n_trials ({total_trials}) is <= n_startup_trials ({startup_before}), " + "so startup sampling would consume the full trial budget. " + f"n_startup_trials is now {config.optimization.n_startup_trials}.[/yellow]" + ) # Create study controller (uses config with possibly updated sampler/n_trials) controller = StudyController.create_from_config( backend, config, create_db=create_db diff --git a/auto_tune_vllm/utils/grid_cardinality.py b/auto_tune_vllm/utils/grid_cardinality.py index 5e3fe55..1f829d0 100644 --- a/auto_tune_vllm/utils/grid_cardinality.py +++ b/auto_tune_vllm/utils/grid_cardinality.py @@ -25,10 +25,13 @@ def _count_parameter_values(param: ParameterConfig) -> int: if isinstance(param, RangeParameter): min_val = param.min_value max_val = param.max_value - step = param.step or 1 if param.data_type is float: - n_steps = int((max_val - min_val) / step) + 1 + step = param.step + if step is None: + return _MAX_GRID_SIZE + n_steps = int(round((max_val - min_val) / step)) + 1 return min(n_steps, _MAX_GRID_SIZE) + step = param.step or 1 current = min_val count = 0 while current <= max_val and count < _MAX_GRID_SIZE: @@ -37,7 +40,7 @@ def _count_parameter_values(param: ParameterConfig) -> int: return count if isinstance(param, BooleanParameter): return 2 - return 0 + raise ValueError(f"Unknown parameter type: {type(param)}") def get_parameter_grid_cardinality(