diff --git a/auto_tune_vllm/cli/main.py b/auto_tune_vllm/cli/main.py index cbc5d23..0935ea0 100644 --- a/auto_tune_vllm/cli/main.py +++ b/auto_tune_vllm/cli/main.py @@ -18,6 +18,7 @@ from ..core.study_controller import StudyController from ..execution.backends import RayExecutionBackend from ..logging.manager import CentralizedLogger, LogStreamer +from ..utils.grid_cardinality import get_parameter_grid_cardinality # Setup rich console and app console = Console() @@ -369,7 +370,40 @@ def run_optimization_sync( create_db: bool = False, ): """Synchronous optimization runner with progress display.""" - # Create study controller + total_trials = n_trials or config.optimization.n_trials + cardinality = get_parameter_grid_cardinality(config) + if total_trials >= cardinality: + requested = total_trials + config.optimization.sampler = "grid" + config.optimization.n_trials = cardinality + config.optimization.n_startup_trials = min( + config.optimization.n_startup_trials, max(0, cardinality - 1) + ) + n_trials = cardinality + total_trials = cardinality + console.print( + f"[yellow]n_trials ({requested}) meets or exceeds grid cardinality " + f"({cardinality}). " + "Search set to grid mode with n_trials = cardinality.[/yellow]" + ) + elif ( + config.optimization.sampler.lower() in ("tpe", "gp", "botorch") + and total_trials <= config.optimization.n_startup_trials + ): + startup_before = config.optimization.n_startup_trials + prev_sampler = config.optimization.sampler + config.optimization.sampler = "random" + config.optimization.n_trials = total_trials + config.optimization.n_startup_trials = min( + startup_before, max(0, total_trials - 1) + ) + console.print( + f"[yellow]Auto-switched sampler from '{prev_sampler}' to 'random': " + f"n_trials ({total_trials}) is <= n_startup_trials ({startup_before}), " + "so startup sampling would consume the full trial budget. " + f"n_startup_trials is now {config.optimization.n_startup_trials}.[/yellow]" + ) + # Create study controller (uses config with possibly updated sampler/n_trials) controller = StudyController.create_from_config( backend, config, create_db=create_db ) @@ -381,8 +415,6 @@ def run_optimization_sync( _display_log_viewing_instructions(config) console.print() # Add blank line for better readability - total_trials = n_trials or config.optimization.n_trials - with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -1086,6 +1118,11 @@ def validate_command( "Optimization", f"{opt_summary} ({study_config.optimization.sampler})" ) table.add_row("Trials", str(study_config.optimization.n_trials)) + cardinality = get_parameter_grid_cardinality(study_config) + table.add_row( + "Possible combinations (grid cardinality)", + str(cardinality), + ) table.add_row("Model", study_config.benchmark.model) table.add_row( "Parameters", diff --git a/auto_tune_vllm/utils/__init__.py b/auto_tune_vllm/utils/__init__.py index 7a5a58f..3d2a3ca 100644 --- a/auto_tune_vllm/utils/__init__.py +++ b/auto_tune_vllm/utils/__init__.py @@ -2,6 +2,7 @@ Utilities for auto-tune-vllm package. """ +from .grid_cardinality import get_parameter_grid_cardinality from .version_manager import VLLMDefaultsVersion, VLLMVersionManager from .vllm_cli_parser import ArgumentType, CLIArgument, VLLMCLIParser @@ -11,4 +12,5 @@ "ArgumentType", "VLLMVersionManager", "VLLMDefaultsVersion", + "get_parameter_grid_cardinality", ] diff --git a/auto_tune_vllm/utils/grid_cardinality.py b/auto_tune_vllm/utils/grid_cardinality.py new file mode 100644 index 0000000..1f829d0 --- /dev/null +++ b/auto_tune_vllm/utils/grid_cardinality.py @@ -0,0 +1,61 @@ +""" +Compute the total cardinality of the parameter grid from a study config. +Uses the same logic as StudyController._create_search_space for consistency. +""" + +from pathlib import Path +from typing import Union + +from auto_tune_vllm.core.config import StudyConfig +from auto_tune_vllm.core.parameters import ( + BooleanParameter, + EnvironmentParameter, + ListParameter, + ParameterConfig, + RangeParameter, +) + +_MAX_GRID_SIZE = 10000 + + +def _count_parameter_values(param: ParameterConfig) -> int: + """Count distinct values for one parameter (mirrors _create_search_space).""" + if isinstance(param, (ListParameter, EnvironmentParameter)): + return len(param.options) + if isinstance(param, RangeParameter): + min_val = param.min_value + max_val = param.max_value + if param.data_type is float: + step = param.step + if step is None: + return _MAX_GRID_SIZE + n_steps = int(round((max_val - min_val) / step)) + 1 + return min(n_steps, _MAX_GRID_SIZE) + step = param.step or 1 + current = min_val + count = 0 + while current <= max_val and count < _MAX_GRID_SIZE: + count += 1 + current += step + return count + if isinstance(param, BooleanParameter): + return 2 + raise ValueError(f"Unknown parameter type: {type(param)}") + + +def get_parameter_grid_cardinality( + config: Union[str, Path, StudyConfig], + vllm_version: str | None = None, +) -> int: + """ + Return the total number of points in the parameter grid (product of enabled params). + + config: Path to YAML, or an already-loaded StudyConfig. + """ + if isinstance(config, (str, Path)): + config = StudyConfig.from_file(str(config), vllm_version=vllm_version) + total = 1 + for param_config in config.parameters.values(): + if param_config.enabled: + total *= _count_parameter_values(param_config) + return total