diff --git a/docs/source/conf.py b/docs/source/conf.py index 6fe7a860024..47f997a0113 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -124,7 +124,7 @@ # Mock imports for autodoc -autodoc_mock_imports = ["mpi4py", "tensorrt_llm", "triton"] +autodoc_mock_imports = ["mpi4py", "tensorrt_llm", "triton", "vllm"] autosummary_generate = True autosummary_imported_members = False diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index 571b40ca499..dce76866d6d 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -343,6 +343,46 @@ See [Megatron-Bridge distillation](../megatron_bridge/README.md#distillation) fo For distillation results on Puzzletron-compressed models, see [examples/pruning/puzzletron/](../pruning/puzzletron/README.md). +## Runtime-Based Latency Optimization + +You can enable **runtime stats** to measure actual inference latency via vLLM, which unlocks latency-based MIP constraints. + +A ready-to-run example config is included at [`configs/llama-3_1-8B_pruneffn_runtime/`](./configs/llama-3_1-8B_pruneffn_runtime/llama-3_1-8B_pruneffn_runtime.yaml). The following key fields enable and control execution of the runtime statistics in the `llama-3_1-8B_pruneffn_runtime.yaml` config file: + +```yaml +calc_subblock_stats: + runtime_stats: + enabled: true + num_warmup_iters: 2 + num_iters: 10 +``` + +The runtime constraint is specified in the `human_constraints` section of the config `Llama-3_1-8B.yaml`: + +```yaml +human_constraints: + target_latency_seconds: 21 +``` + +Run the pipeline against this config the same way as the memory-constrained variant: + +```bash +torchrun --nproc_per_node 2 examples/puzzletron/main.py \ + --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_runtime/llama-3_1-8B_pruneffn_runtime.yaml 2>&1 | tee ./log.txt | grep "Puzzletron Progress" +``` + +The MIP solver will now search for a heterogeneous architecture whose measured end-to-end latency is at or below `target_latency_seconds`, instead of optimizing for a memory budget. + +Because vLLM startup adds substantial overhead during stats collection, extend the distributed process group timeout accordingly (already included in the example config): + +```yaml +nccl_timeout_minutes: 90 # default is 10 if omitted +``` + +This field is supported in any Puzzletron YAML config and overrides the default 10-minute distributed timeout. + +Due to non-linear extension of the runtime stats of single subblocks to the total runtime of the model, the `target_latency_seconds` value should be set to a value that is slightly lower than the desired latency. For example, in our experiments, the `target_latency_seconds` value of 5 resulted in a final model latency of 5.4 seconds. + ## Advanced Usage Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios. diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml index 21903db1623..1c302fd4c30 100644 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml @@ -42,7 +42,7 @@ scoring: teacher_dir: ${to_path:${teacher_dir}} output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation - eval_samples: 128 + eval_samples: 8 micro_batch_size: 1 seed: 42 shuffle_seed: 444 diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml index ce1749d9698..6b36142a3a8 100644 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml @@ -3,7 +3,7 @@ autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model block_size: 8192 bos_rate: 0.5 data_column: messages -val_dataset_name: valid +val_dataset_name: validation shuffle_seed: 81436 seed: 42 fim_rate: 0 diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_runtime/Llama-3_1-8B.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_runtime/Llama-3_1-8B.yaml new file mode 100644 index 00000000000..b4adbb82add --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_runtime/Llama-3_1-8B.yaml @@ -0,0 +1,103 @@ +defaults: + - ../llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning@pruning + - ../llama-3_1-8B_pruneffn_memory/validate_solutions_defaults@scoring + - ../llama-3_1-8B_pruneffn_memory/validate_solutions_defaults@realize_model + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: llama +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # ppath to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [1, 4] + prefill_seq_len: 1024 + generation_seq_len: 1024 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 1 + weights_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_latency_seconds: 5 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:120} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_runtime/llama-3_1-8B_pruneffn_runtime.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_runtime/llama-3_1-8B_pruneffn_runtime.yaml new file mode 100644 index 00000000000..588df25f27d --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_runtime/llama-3_1-8B_pruneffn_runtime.yaml @@ -0,0 +1,22 @@ +defaults: + - Llama-3_1-8B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for puzzletron outputs +puzzle_dir: /workspace/puzzle_dir + +calc_subblock_stats: + runtime_stats: + enabled: true + num_warmup_iters: 2 + num_iters: 10 + +# FFN intermediate sizes to search over (heterogeneous architecture) +pruning: + intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 diff --git a/examples/puzzletron/main.py b/examples/puzzletron/main.py index 8ceed378318..f093e5b7e68 100644 --- a/examples/puzzletron/main.py +++ b/examples/puzzletron/main.py @@ -68,7 +68,6 @@ def run_full_puzzletron(hydra_config_path: str): config_path: Path to the YAML configuration file """ mtpz.tools.mprint("Puzzletron Progress 1/8: starting puzzletron pipeline") - dist.setup(timeout=timedelta(minutes=10)) # Register Hydra custom resolvers (needed for config resolution) mtpz.tools.register_hydra_resolvers() @@ -84,6 +83,14 @@ def run_full_puzzletron(hydra_config_path: str): overrides=[], ) + # Default timeout: 10 minutes, or extended to nccl_timeout_minutes if set in config + if hasattr(hydra_cfg, "nccl_timeout_minutes"): + timeout_minutes = hydra_cfg.nccl_timeout_minutes + else: + timeout_minutes = timedelta(minutes=10) + + dist.setup(timeout=timeout_minutes) + # Convert model (convert from HF to DeciLM, score pruning activations, # prune the model and save pruned checkpoints) input_model = mtpz.puzzletron_nas_plugin.PuzzletronModel() diff --git a/modelopt/torch/kernels/sparsity/attention/calibrate.py b/modelopt/torch/kernels/sparsity/attention/calibrate.py index 971c423f711..61707f63013 100644 --- a/modelopt/torch/kernels/sparsity/attention/calibrate.py +++ b/modelopt/torch/kernels/sparsity/attention/calibrate.py @@ -200,17 +200,18 @@ def attention_calibrate( measuring how many KV tiles would be skipped at each threshold in ``threshold_trials``. No autograd — forward only. + All arguments except ``threshold_trials`` match + :func:`modelopt.torch.kernels.common.attention.attention`. + Args: - q, k, v, b_start_loc, b_seq_len, max_input_len, is_causal, - softmax_scale, b_start_loc_k, b_seq_len_k, max_input_len_k: - Same as :func:`modelopt.torch.kernels.common.attention.attention`. threshold_trials: List of threshold values to measure sparsity for. Each value is converted to log2-scaled space for the kernel. Returns: - Tuple of (output, sparsity_counters): - - output: ``[total_q_tokens, num_q_heads, head_dim]`` - - sparsity_counters: ``[num_thresholds, 2]`` int64 tensor where + Tuple of ``(output, sparsity_counters)``: + + - ``output``: ``[total_q_tokens, num_q_heads, head_dim]`` + - ``sparsity_counters``: ``[num_thresholds, 2]`` int64 tensor where ``[:, 0]`` = total tile evaluations, ``[:, 1]`` = skipped tiles. Sparsity per threshold = ``counters[:, 1] / counters[:, 0]``. """ diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py index 761534f6df9..22c8b471546 100644 --- a/modelopt/torch/puzzletron/mip/run_puzzle.py +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -79,7 +79,7 @@ class Type(enum.Enum): _ALLOWED_HUMAN_CONSTRAINTS = { "target_memory", "target_throughput", - "target_latency", + "target_latency_seconds", "target_time_to_first_token", "num_params", "stats.has_attention", @@ -175,8 +175,8 @@ def to_mip_constraints(self, subblock_stats_args) -> dict[str, Any]: throughput_constraints.append( batch_size * generation_seq_len / self.constraints["target_throughput"] ) - if "target_latency" in self.constraints: - throughput_constraints.append(self.constraints["target_latency"]) + if "target_latency_seconds" in self.constraints: + throughput_constraints.append(self.constraints["target_latency_seconds"]) if throughput_constraints: mip_constraints["stats.runtime_ms"] = 1000 * min(throughput_constraints) diff --git a/modelopt/torch/puzzletron/subblock_stats/__init__.py b/modelopt/torch/puzzletron/subblock_stats/__init__.py index fbbeb3ff709..4964dba0cfa 100644 --- a/modelopt/torch/puzzletron/subblock_stats/__init__.py +++ b/modelopt/torch/puzzletron/subblock_stats/__init__.py @@ -15,5 +15,4 @@ """Subblock statistics collection for Puzzletron.""" -from .calc_subblock_params_and_memory import * from .calc_subblock_stats import * diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_runtime_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_runtime_stats.py new file mode 100644 index 00000000000..6e4821936e7 --- /dev/null +++ b/modelopt/torch/puzzletron/subblock_stats/calc_runtime_stats.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Runtime statistics calculation for NAS subblock benchmarking via vLLM.""" + +import tempfile +from dataclasses import replace +from functools import cache +from pathlib import Path + +from omegaconf import DictConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM + +from ..anymodel.models.llama import LlamaModelDescriptor +from ..anymodel.puzzformer import deci_x_patcher +from ..block_config import AttentionConfig, BlockConfig, FFNConfig, SubblockConfig +from .runtime_utils import RuntimeConfig, save_model +from .runtime_vllm import run_vllm_latency_benchmark + + +def _make_standard_block_config(num_key_value_heads: int) -> BlockConfig: + return BlockConfig( + attention=AttentionConfig(no_op=False, num_key_value_heads=num_key_value_heads), + ffn=FFNConfig(no_op=False, intermediate_size=256, moe=None), + ) + + +def create_benchmark_model( + vocab_size: int, + hidden_size: int, + num_key_value_heads: int, + num_attention_heads: int, + prefill_seq_len: int, + generation_seq_len: int, + block_config: BlockConfig | None, + repeat_block_n_times: int = 10, +) -> LlamaForCausalLM: + """Build a small Llama model with repeated subblocks for latency benchmarking.""" + block_configs = [_make_standard_block_config(num_key_value_heads)] + + if block_config: + block_configs.extend([block_config] * repeat_block_n_times) + + model_config = LlamaConfig( + max_position_embeddings=prefill_seq_len + generation_seq_len, + vocab_size=vocab_size, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_hidden_layers=len(block_configs), + head_dim=None, # Compute from hidden_size // num_attention_heads instead of using default 128 + # this is required for trt-llm convertion to know which model classes to use to the checkpoint + auto_map={ + "AutoConfig": "transformers.models.llama.configuration_llama.LlamaConfig", + "AutoModelForCausalLM": "transformers.models.llama.modeling_llama.LlamaForCausalLM", + }, + ) + + for idx, bc in enumerate(block_configs): + block_configs[idx] = bc.to_dict() + model_config.block_configs = block_configs + + with deci_x_patcher(LlamaModelDescriptor, block_configs): + model = AutoModelForCausalLM.from_config(model_config) + + model.config.architectures = ["AnyModel"] + model.config.base_architecture = "LlamaForCausalLM" + + return model + + +def calc_model_runtime(model: LlamaForCausalLM, runtime_config: RuntimeConfig) -> float: + """Measure total runtime of a model via vLLM latency benchmark.""" + with tempfile.TemporaryDirectory() as model_tmpdir: + save_model(model, Path(runtime_config.tokenizer_path), Path(model_tmpdir)) + model_total_runtime_ms = run_vllm_latency_benchmark(Path(model_tmpdir), runtime_config) + return model_total_runtime_ms + + +@cache +def calc_subblock_runtime( + runtime_config: RuntimeConfig, + subblock_config: SubblockConfig | None, +) -> float: + """Measure total runtime of a repeated subblock via vLLM latency benchmark.""" + block_config: BlockConfig | None = None + + if subblock_config is not None: + if isinstance(subblock_config, BlockConfig): + block_config = subblock_config + elif isinstance(subblock_config, (AttentionConfig, FFNConfig)): + if isinstance(subblock_config, FFNConfig): + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=runtime_config.num_key_value_heads + ), + ffn=subblock_config, + ) + else: + block_config = subblock_config.to_blockconfig() + else: + raise Exception(f"Runtime stats: Not supported subblock type: {subblock_config}") + + model = create_benchmark_model( + runtime_config.vocab_size, + runtime_config.hidden_size, + runtime_config.num_key_value_heads, + runtime_config.num_attention_heads, + runtime_config.prefill_seq_len, + runtime_config.generation_seq_len, + block_config=block_config, + repeat_block_n_times=runtime_config.repeat_block_n_times, + ) + return calc_model_runtime(model, runtime_config) + + +@cache +def calc_base_runtime(runtime_config: RuntimeConfig, subblock_config: SubblockConfig) -> float: + """Calculate the base runtime of a model with no subblocks.""" + base_runtime_ms = None + if isinstance(subblock_config, AttentionConfig): + base_runtime_ms = calc_subblock_runtime(runtime_config, None) + elif isinstance(subblock_config, FFNConfig): + attn_block_config = AttentionConfig( + no_op=False, num_key_value_heads=runtime_config.num_key_value_heads + ).to_blockconfig() + base_runtime_ms = calc_subblock_runtime(runtime_config, attn_block_config) + else: + raise ValueError(f"Unsupported subblock type: {type(subblock_config)}") + + return base_runtime_ms + + +@cache +def calc_no_block_runtime(runtime_config: RuntimeConfig) -> float: + """Estimate the overhead runtime (embedding + LM head) with no decoder blocks.""" + runtime_cfg_ten_blocks = replace(runtime_config, repeat_block_n_times=9) + + block_config = _make_standard_block_config(runtime_config.num_key_value_heads) + + runtime_ms_one_block = calc_subblock_runtime(runtime_config, None) # only one base block + runtime_ms_ten_blocks = calc_subblock_runtime( + runtime_cfg_ten_blocks, block_config + ) # one base block + 9 repeated blocks + + no_block_runtime_ms = runtime_ms_one_block - (runtime_ms_ten_blocks - runtime_ms_one_block) / 9 + + return no_block_runtime_ms + + +def calc_runtime_for_subblocks( + subblock_config_set: set[SubblockConfig], + runtime_stats_config: DictConfig, + vocab_size: int, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + tokenizer_path: str, + prefill_seq_len: int, + generation_seq_len: int, + batch_size: int, +) -> tuple[dict[SubblockConfig, float], float]: + """Benchmark each unique subblock and return per-subblock runtimes and no-block overhead.""" + repeat_block_n_times = 10 + + runtime_config = RuntimeConfig( + vocab_size, + hidden_size, + num_attention_heads, + num_key_value_heads, + tokenizer_path, + repeat_block_n_times, + prefill_seq_len, + generation_seq_len, + batch_size, + runtime_stats_config.get("num_iters", 30), + runtime_stats_config.get("num_warmup_iters", 10), + ) + + runtime_by_subblock_dict = {} + + for subblock_config in tqdm( + sorted(subblock_config_set), + desc=(f"Computing runtime for {len(subblock_config_set)} subblocks\n"), + ): + baseline_runtime_ms = calc_base_runtime(runtime_config, subblock_config) + + if subblock_config.no_op: + total_runtime_ms = 0.0 + else: + subblock_total_runtime_ms = calc_subblock_runtime(runtime_config, subblock_config) + total_runtime_ms = ( + subblock_total_runtime_ms - baseline_runtime_ms + ) / repeat_block_n_times + + runtime_by_subblock_dict[subblock_config] = total_runtime_ms + + no_block_runtime_ms = calc_no_block_runtime(runtime_config) + + return runtime_by_subblock_dict, no_block_runtime_ms diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py index d893eb55bb3..531f7a3f0a1 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py @@ -25,7 +25,6 @@ import json import math from pathlib import Path -from typing import Type import numpy as np import torch @@ -48,16 +47,16 @@ ) __all__ = [ - "calculate_subblock_memory", - "calculate_subblock_params", "calc_subblock_active_params", - "load_moe_stats", - "estimate_num_active_experts", + "calculate_ffn_memory", "calculate_mamba_memory", "calculate_mamba_state_size", - "calculate_ffn_memory", "calculate_non_block_memory", "calculate_non_block_params", + "calculate_subblock_memory", + "calculate_subblock_params", + "estimate_num_active_experts", + "load_moe_stats", ] @@ -73,9 +72,29 @@ def calculate_subblock_memory( kv_cache_dtype: torch.dtype, allocate_prefill_query: bool, model_config: PretrainedConfig, - descriptor: Type[ModelDescriptor], + descriptor: type[ModelDescriptor], ) -> float | dict[str, float]: - """``model_config`` / ``descriptor`` are required (puzzletron-style); FFN uses them for meta init.""" + """Calculate the memory usage of a single subblock (FFN or Attention). + + Given its configuration and runtime dimensions, returns bytes or a detailed dict. + + Args: + subblock_config: Subblock configuration dataclass. + batch_size: Batch size for memory estimate. + prefill_seq_len: Sequence length for prefill phase. + generation_seq_len: Sequence length for generation phase (token-by-token). + prefill_queue_size: Token queue size for prefill attention memory allocation. + n_embd: Embedding (hidden) dimension. + n_head: Number of attention heads (used for non-FFN). + weights_dtype: PyTorch dtype for model weights. + kv_cache_dtype: PyTorch dtype for KV cache. + allocate_prefill_query: Whether to allocate query cache for prefill tokens. + model_config: HuggingFace-style config instance describing the model. + descriptor: Model descriptor type (for puzzletron model types). + + Returns: + Memory usage in bytes (float), or a dictionary by memory type. + """ if subblock_config.no_op: return 0 if isinstance(subblock_config, FFNConfig): @@ -116,7 +135,7 @@ def calculate_subblock_memory( def calculate_subblock_params( config: PretrainedConfig, layer_config: BlockConfig | FFNConfig | AttentionConfig, - descriptor: Type[ModelDescriptor], + descriptor: type[ModelDescriptor], ) -> int: """Count parameters on one meta decoder layer. @@ -124,9 +143,7 @@ def calculate_subblock_params( ``hybrid_override_pattern``) before passing ``config``; see ``ModelDescriptor.truncate_pattern_for_subblock``. """ - if isinstance(layer_config, FFNConfig): - block_config = layer_config.to_blockconfig() - elif isinstance(layer_config, AttentionConfig): + if isinstance(layer_config, (FFNConfig, AttentionConfig)): block_config = layer_config.to_blockconfig() else: block_config = layer_config @@ -189,12 +206,31 @@ def calculate_subblock_params( def calc_subblock_active_params( sublayer_config: FFNConfig | AttentionConfig, model_config: PretrainedConfig, - descriptor: Type[ModelDescriptor], + descriptor: type[ModelDescriptor], n_embd: int, moe_stats_file: str, batch_size: int, block_idx: int, ) -> int: + """Calculate the number of "active" parameters for a subblock (FFN, Attention, or MoE). + + For non-MoE subblocks, simply calls `calculate_subblock_params` to count all parameters. + For MoE (Mixture-of-Experts) FFN subblocks, estimates the expected number of active parameters + per batch by leveraging expert activation statistics (from a given stats file) and calculating + the expected number of active experts, then multiplies by the number of parameters per expert. + + Args: + sublayer_config: The subblock configuration (either FFNConfig or AttentionConfig). + model_config: The Hugging Face model configuration. + descriptor: The ModelDescriptor class corresponding to this model family. + n_embd: The embedding size (hidden dimension). + moe_stats_file: Path to file containing expert activation probabilities. + batch_size: The batch size used for the estimate. + block_idx: The index of the block/subblock within the network, used to index into the stats. + + Returns: + The expected number of "active" parameters for the given subblock. + """ if not (isinstance(sublayer_config, FFNConfig) and sublayer_config.is_moe): return calculate_subblock_params(model_config, sublayer_config, descriptor) return estimate_moe_active_params( @@ -203,14 +239,45 @@ def calc_subblock_active_params( def load_moe_stats(stats_file: str) -> dict: + """Load MoE (Mixture-of-Experts) routing statistics from a file. + + This function reads a JSON file containing expert activation probabilities or counts for each MoE block. + It returns the normalized probability distributions over experts for each block, as a list of numpy arrays. + + Args: + stats_file: Path to the JSON file containing expert routing statistics for each block. + + Returns: + A list where each element is a numpy array containing the normalized probability + distribution over experts for the corresponding block. If a block's expert list is empty, + its entry is 0. + """ with open(stats_file) as f: stats = json.load(f) - return [np.array(l) / np.sum(l) if len(l) > 0 else 0 for l in stats] + return [ + np.array(expert_probs) / np.sum(expert_probs) if len(expert_probs) > 0 else 0 + for expert_probs in stats + ] def estimate_num_active_experts( dist_over_experts: np.ndarray, batch_size: int, num_experts: int ) -> int: + """Estimate the expected number of active experts in a Mixture-of-Experts (MoE) layer. + + This function computes the expected number of unique experts that are selected at least once when performing + inference with a given batch size. It assumes, for each input in the batch, an expert is chosen with probability + given by `dist_over_experts` (typically a vector of probabilities for each expert). For a batch of size B, the + expected number of active (i.e., selected at least once) experts is computed. + + Args: + dist_over_experts: A 1D array of probabilities for each expert. + batch_size: The number of samples in the batch. + num_experts: The maximum number of experts to consider (fewer if `dist_over_experts` is shorter). + + Returns: + The expected number of experts selected at least once across the batch. + """ # cut the tail and renormalize dist_over_experts = np.sort(dist_over_experts)[::-1][:num_experts] dist_over_experts = dist_over_experts / (dist_over_experts.sum()) @@ -226,6 +293,18 @@ def estimate_moe_active_params( batch_size: int, block_idx: int, ) -> int: + """Estimate the expected number of active (used) parameters for a Mixture-of-Experts (MoE) FFN subblock. + + Args: + subblock_config: The FFNConfig for the MoE subblock (with .moe field configured). + n_embd: The embedding dimension (input and output size per expert). + moe_stats_file: Path to the JSON file containing routing/selection probabilities for experts. + batch_size: Batch size to simulate/extrapolate expected expert use. + block_idx: The index of the block/layer whose expert routing statistics should be used. + + Returns: + Estimated number of parameters actively used for the current batch and expert selection statistics. + """ assert Path(moe_stats_file).exists() # if not Path(moe_stats_file).exists(): # if path is not provided, should we assume uniform distribution? # return calculate_subblock_params(subblock_config, n_embd, n_head=None) @@ -255,7 +334,7 @@ def estimate_moe_active_params( def calculate_attention_memory( attention_config: AttentionConfig, model_config: PretrainedConfig, - descriptor: Type[ModelDescriptor], + descriptor: type[ModelDescriptor], batch_size: int, prefill_seq_len: int, generation_seq_len: int, @@ -267,6 +346,7 @@ def calculate_attention_memory( allocate_prefill_query: bool, ) -> dict[str, float]: """allocate_prefill_query: infery-llm style. + Infery used a unified Wqkv matrix, so before extracting the kv-cache, the query also had to be kept in-memory, once per layer. """ @@ -294,11 +374,24 @@ def calculate_attention_memory( def calculate_mamba_memory( attention_config: AttentionConfig, model_config: PretrainedConfig, - descriptor: Type[ModelDescriptor], + descriptor: type[ModelDescriptor], batch_size: int, weights_dtype: torch.dtype, kv_cache_dtype: torch.dtype, ) -> int: + """Calculate memory usage (MiB) for a Mamba attention subblock. + + Args: + attention_config: Mamba attention configuration, including Mamba-specific settings. + model_config: Model configuration. + descriptor: Model descriptor class. + batch_size: Batch size for memory estimate. + weights_dtype: Data type for model weights. + kv_cache_dtype: Data type for state/kv-cache. + + Returns: + Estimated memory usage in mebibytes (MiB) for the Mamba subblock. + """ assert attention_config.mamba is not None mamba_config = attention_config.mamba num_params = calculate_subblock_params(model_config, attention_config, descriptor) @@ -312,7 +405,16 @@ def calculate_mamba_state_size( mamba_config: MambaConfig, batch_size: int, ) -> int: - d_inner, in_proj_dim, conv_dim, kernel_size = _calculate_mamba_intermediates(mamba_config) + """Calculate the total state size for a Mamba attention subblock. + + Args: + mamba_config: Configuration object containing Mamba subblock parameters. + batch_size: Batch size to estimate the memory/state requirements for. + + Returns: + Total state size (number of elements) required for the Mamba subblock, including convolution and SSM state. + """ + _, _, conv_dim, kernel_size = _calculate_mamba_intermediates(mamba_config) conv_state_size = math.prod((batch_size, conv_dim, kernel_size)) ssm_state_size = math.prod( (batch_size, mamba_config.num_heads, mamba_config.head_dim, mamba_config.state_dim) @@ -333,10 +435,22 @@ def _calculate_mamba_intermediates(mamba_config: MambaConfig) -> tuple[int, ...] def calculate_ffn_memory( ffn_config: FFNConfig, model_config: PretrainedConfig, - descriptor: Type[ModelDescriptor], + descriptor: type[ModelDescriptor], weights_dtype: torch.dtype | str, experts_dtype: torch.dtype | str | None = None, ) -> float: + """Estimate the memory usage in MiB of a feed-forward network (FFN) subblock. + + Args: + ffn_config: FFN configuration for the block. + model_config: The parent model configuration. + descriptor: Model descriptor class. + weights_dtype: Data type for FFN weights. + experts_dtype: Data type for expert weights (for MoE layers, if present). + + Returns: + Estimated FFN memory usage in mebibytes (MiB). + """ # TODO: How to separate between expert weights and the rest for any model (same as puzzletron). num_params = calculate_subblock_params(model_config, ffn_config, descriptor) return num_params * sizeof_dtype(weights_dtype) / 2**20 @@ -347,6 +461,7 @@ def calculate_non_block_memory( vocab_size: int, weight_dtype: torch.dtype, ) -> float: + """Estimate the memory usage in MiB of non-subblock components (e.g., embeddings, output projection).""" return calculate_non_block_params(n_embd, vocab_size) * sizeof_dtype(weight_dtype) / 2**20 @@ -354,4 +469,5 @@ def calculate_non_block_params( n_embd: int, vocab_size: int, ) -> int: + """Calculate the number of parameters for non-subblock components (e.g., embeddings, output projection).""" return vocab_size * n_embd * 2 + n_embd diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py index dc89a1f6450..1d04cc01add 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -19,12 +19,11 @@ import copy import dataclasses import json -import os import warnings from functools import partial from itertools import product from pathlib import Path -from typing import Iterable, Optional, Type, TypeVar +from typing import Iterable, Type, TypeVar import pandas as pd import torch @@ -52,7 +51,6 @@ __all__ = [ "calculate_subblock_stats", "launch_calc_subblock_stats", - "add_int8_runtime_estimates", ] # Type variable for dataclasses @@ -60,10 +58,10 @@ """ Usage: -python -m modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats PUZZLE_DIR [ --benchmark_iterations 1000 ] +python -m modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats PUZZLE_DIR [ --runtime_stats ] ---benchmark_iterations=None (the default) means that the code won't use infery to benchmark runtime, - only memory stats will be calculated. If you want to benchmark runtime, run inside an infery-llm docker. +--runtime_stats_enabled=False (the default) means that the code won't benchmark runtime, + only memory stats will be calculated. If you want to benchmark runtime, run inside an trtllm docker. """ @@ -82,7 +80,7 @@ def calculate_subblock_stats( n_embd: int, n_head: int, vocab_size: int, - benchmark_iterations: Optional[int], + runtime_stats_enabled: bool, use_cuda_graph: bool, weights_dtype: torch.dtype, activations_dtype: torch.dtype, @@ -90,14 +88,14 @@ def calculate_subblock_stats( allocate_prefill_query: bool, moe_stats_file: str | Path | None = None, ) -> dict: - is_calc_runtime = benchmark_iterations is not None - if is_calc_runtime: - raise NotImplementedError("Runtime stats calculation is not implemented yet") + if runtime_stats_enabled: + from modelopt.torch.puzzletron.subblock_stats.calc_runtime_stats import ( + calc_runtime_for_subblocks, + ) gpu = None if not torch.cuda.is_available() else torch.cuda.get_device_name() subblock_stats = { "args": dict( - is_calc_runtime=is_calc_runtime, gpu=gpu, batch_size=batch_size, prefill_seq_len=prefill_seq_len, @@ -106,7 +104,7 @@ def calculate_subblock_stats( n_embd=n_embd, n_head=n_head, vocab_size=vocab_size, - benchmark_iterations=benchmark_iterations, + runtime_stats=runtime_stats_enabled, use_cuda_graph=use_cuda_graph, weights_dtype=str(weights_dtype), activations_dtype=str(activations_dtype), @@ -116,27 +114,24 @@ def calculate_subblock_stats( "subblocks": list(), } # Compute runtime stats for unique subblocks only - if is_calc_runtime: - raise NotImplementedError("Runtime stats calculation is not implemented yet") + if runtime_stats_enabled: subblock_configs_nolayerindex = set( [subblock_config["subblock_config"] for subblock_config in subblock_configs] ) - # dict[SubblockConfig, float], float - # TODO: Manage default values for calc_subblock_stats_config in one place, e.g. within a dataclass for hydra config. - synth_dataset_num_requests = calc_subblock_stats_config.get("runtime_stats", {}).get( - "synth_dataset_num_requests", 200 - ) - backend = calc_subblock_stats_config.get("runtime_stats", {}).get("backend", "trt_torch") - runtime_by_subblock_dict, non_block_runtime_ms = calc_runtime_ms_for_subblocks( - subblock_configs_nolayerindex, - vocab_size, - n_embd, - n_head, - master_puzzle_dir, - teacher_dir, - synth_dataset_num_requests, - backend, + runtime_stats_config = calc_subblock_stats_config.get("runtime_stats", {}) + + runtime_by_subblock_dict, non_block_runtime_ms = calc_runtime_for_subblocks( + subblock_config_set=subblock_configs_nolayerindex, + runtime_stats_config=runtime_stats_config, + vocab_size=vocab_size, + hidden_size=n_embd, + num_attention_heads=n_head, + num_key_value_heads=model_config.num_key_value_heads, + tokenizer_path=teacher_dir, + prefill_seq_len=prefill_seq_len, + generation_seq_len=generation_seq_len, + batch_size=batch_size, ) sorted_subblock_config = sorted( @@ -144,7 +139,7 @@ def calculate_subblock_stats( ) it = ( tqdm(sorted_subblock_config, desc="Measuring subblock runtimes") - if is_calc_runtime + if runtime_stats_enabled else sorted_subblock_config ) for subblock_config_indexed in it: @@ -156,7 +151,7 @@ def calculate_subblock_stats( descriptor.get_language_model_config(layer_model_config), parent_layer_indices[0] ) - if is_calc_runtime: + if runtime_stats_enabled: total_runtime_ms = runtime_by_subblock_dict[subblock_config] prefill_runtime_ms = None decode_runtime_ms = None @@ -207,25 +202,13 @@ def calculate_subblock_stats( } ) - if is_calc_runtime: - # TODO: fix - # from puzzle_tools.calc_subblock_runtime import measure_non_block_runtime_ms - # non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = \ - # measure_non_block_runtime_ms(batch_size, prefill_seq_len, generation_seq_len, n_embd, vocab_size, - # benchmark_iterations, use_cuda_graph) - embedding_runtime_ms, lm_head_runtime_ms = None, None - else: - non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = None, None, None + if not runtime_stats_enabled: + non_block_runtime_ms = None non_block_memory = calculate_non_block_memory(n_embd, vocab_size, weights_dtype) non_block_params = calculate_non_block_params(n_embd, vocab_size) - # TODO - # the semantics here is wrong why do we refer, prefill_runtime_ms as embedding_runtime_ms and lm_head_runtime_ms as decode_runtime_ms ? - # Prefill is the first the user prompt inference, and Decode refer to the next generation process. both processes use all the model layers. subblock_stats["non_block"] = { "runtime_ms": non_block_runtime_ms, - "prefill_runtime_ms": embedding_runtime_ms, - "decode_runtime_ms": lm_head_runtime_ms, "memory_mib": non_block_memory, "num_params": non_block_params, } @@ -256,7 +239,9 @@ def launch_calc_subblock_stats(cfg: DictConfig) -> None: num_active_tokens_override=cfg.calc_subblock_stats.get("num_active_tokens_override", None), prefill_queue_size=cfg.calc_subblock_stats.prefill_queue_size, allocate_prefill_query=cfg.calc_subblock_stats.get("allocate_prefill_query", False), - benchmark_iterations=cfg.calc_subblock_stats.get("benchmark_iterations", None), + runtime_stats_enabled=cfg.calc_subblock_stats.get("runtime_stats", {}).get( + "enabled", False + ), merge_with_existing_stats=cfg.calc_subblock_stats.merge_with_existing_stats, subblock_stats_filename=cfg.calc_subblock_stats.subblock_stats_filename, moe_stats_filename=cfg.calc_subblock_stats.moe_stats_filename, @@ -276,9 +261,7 @@ def calculate_subblock_stats_for_puzzle_dir( num_active_tokens_override: int | None = None, prefill_queue_size: int = 0, # it's an infery-llm thing allocate_prefill_query: bool = False, - benchmark_iterations: ( - int | None - ) = None, # If set then compute runtime performance statistics. TODO: recommend default value, is 1000 good? + runtime_stats_enabled: bool = False, # Compute runtime statistics. merge_with_existing_stats: bool = False, subblock_stats_filename: str = "subblock_stats.json", moe_stats_filename: str = "moe_stats.json", @@ -344,8 +327,8 @@ def calculate_subblock_stats_for_puzzle_dir( if num_active_tokens_override is not None: prefill_seq_len = generation_seq_len = int(num_active_tokens_override / batch_size / 2) - curr_benchmark_iterations = ( - benchmark_iterations if weights_dtype == torch.bfloat16 else None + curr_runtime_stats_enabled = ( + runtime_stats_enabled if weights_dtype == torch.bfloat16 else False ) curr_subblock_stats = calculate_subblock_stats( @@ -362,7 +345,7 @@ def calculate_subblock_stats_for_puzzle_dir( n_embd=model_hidden_size, n_head=lm_config.num_attention_heads, vocab_size=lm_config.vocab_size, - benchmark_iterations=curr_benchmark_iterations, + runtime_stats_enabled=curr_runtime_stats_enabled, use_cuda_graph=True, weights_dtype=weights_dtype, activations_dtype=activations_dtype, @@ -378,8 +361,6 @@ def calculate_subblock_stats_for_puzzle_dir( subblock_stats.append(curr_subblock_stats) - # TODO fix: add_int8_runtime_estimates(subblock_stats) - json_dump(subblock_stats, subblock_stats_file) mprint(subblock_stats_file) @@ -433,10 +414,7 @@ def _load_subblock_configs_from_replacement_library( 4 intermediate_size + teacher_intermediate_size + ffn_noop + att_op (teacher) + att_noop. Args: - master_puzzle_dir (Path): Directory with "replacement_library.json" file - - Returns: - list[SubblockConfig]: + master_puzzle_dir: Directory with "replacement_library.json" file """ replacement_library = json.loads((master_puzzle_dir / "replacement_library.json").read_text()) subblock_configs = set() @@ -501,67 +479,3 @@ def _dataclass_from_dict( if pd.isna(d): return None raise ValueError(f"_dataclass_from_dict: unrecognized {type(d)=} {d=}") - - -def add_int8_runtime_estimates(subblock_stats: list[dict]) -> None: - for curr_subblock_stats in subblock_stats: - args = curr_subblock_stats["args"] - if args["weights_dtype"] == "torch.int8": - assert args["activations_dtype"] == "torch.int8" - ffn_factor = 0.5 - attention_factor = 0.5 if args["kv_cache_dtype"] == "torch.int8" else 0.8 - - bf16_stats = _find_corresponding_bf16_stats(args, subblock_stats) - if bf16_stats is not None: - curr_subblocks = curr_subblock_stats["subblocks"] + [ - curr_subblock_stats["non_block"] - ] - bf16_subblocks = bf16_stats["subblocks"] + [bf16_stats["non_block"]] - for curr_subblock, bf16_subblock in zip(curr_subblocks, bf16_subblocks): - assert curr_subblock.get("subblock_config", None) == bf16_subblock.get( - "subblock_config", None - ) - is_attention = False - if (subblock_config := curr_subblock.get("subblock_config")) is not None: - if hasattr(subblock_config, "__dataclass_fields__"): - subblock_config = dataclasses.asdict(subblock_config) - is_attention = subblock_config.get("num_key_value_heads", None) is not None - runtime_factor = attention_factor if is_attention else ffn_factor - for stat_name, stat_value in bf16_subblock.items(): - if "runtime" in stat_name: - curr_subblock[stat_name] = stat_value * runtime_factor - - -def _find_corresponding_bf16_stats(args: dict, subblock_stats: list[dict]) -> dict | None: - scenario_keys = [ - "batch_size", - "prefill_seq_len", - "generation_seq_len", - "prefill_queue_size", - "gpu", - "n_embd", - "n_head", - "vocab_size", - ] - corresponding_bf16_args = { - **{k: v for k, v in args.items() if k in scenario_keys}, - "is_calc_runtime": True, - "weights_dtype": "torch.bfloat16", - "activations_dtype": "torch.bfloat16", - "kv_cache_dtype": "torch.bfloat16", - } - matching_bf16_stats = [ - stats - for stats in subblock_stats - if all( - [ - stats["args"][key] == corresponding_bf16_args[key] - for key in corresponding_bf16_args.keys() - ] - ) - ] - if len(matching_bf16_stats) == 0: - return None - if len(matching_bf16_stats) == 1: - return matching_bf16_stats[0] - raise ValueError(f"Found more than 1 matching bf16 stats for {args=}") diff --git a/modelopt/torch/puzzletron/subblock_stats/runtime_utils.py b/modelopt/torch/puzzletron/subblock_stats/runtime_utils.py new file mode 100644 index 00000000000..3259e706c73 --- /dev/null +++ b/modelopt/torch/puzzletron/subblock_stats/runtime_utils.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for runtime benchmarking and model saving in ModelOpt NAS. + +This module provides classes and utility functions used for empirical runtime +estimation of Transformer subblocks and for saving models and tokenizers in +formats suitable for benchmarking (e.g., vLLM latency benchmark) or the +AnyModel subblock-safetensors format. It defines the configuration dataclass +used to parameterize runtime benchmarks, as well as model checkpointing helpers +to ensure compatibility with downstream evaluation pipelines. +""" + +import json +from dataclasses import dataclass +from pathlib import Path + +import torch +from transformers import AutoTokenizer, LlamaForCausalLM + +from ..anymodel.converter import Converter +from ..anymodel.models.llama import LlamaModelDescriptor + + +@dataclass(frozen=True) +class RuntimeConfig: + """Configuration for a vLLM latency benchmark run.""" + + vocab_size: int + hidden_size: int + num_attention_heads: int + num_key_value_heads: int + tokenizer_path: str + repeat_block_n_times: int + prefill_seq_len: int + generation_seq_len: int + batch_size: int + num_iters: int + num_warmup_iters: int + + +def save_model(model: LlamaForCausalLM, tokenizer_path: Path, output_path: Path) -> None: + """Save model weights as AnyModel and copy the tokenizer to ``output_path``.""" + model = model.to(dtype=torch.bfloat16) + save_model_as_anymodel(model, output_path, LlamaModelDescriptor) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + tokenizer.save_pretrained(output_path) + + +def save_model_as_anymodel(model, output_dir: Path, descriptor): + """Save a model checkpoint in AnyModel subblock-safetensors format.""" + # Save standard model checkpoint (as safetensors, HF format) + model.save_pretrained(output_dir, safe_serialization=True) + + # Convert/slice weights into AnyModel subblock_safetensors format + Converter.convert_model_weights( + input_dir=output_dir, + output_dir=output_dir, + descriptor=descriptor, + num_hidden_layers=model.config.num_hidden_layers, + ) + # Load the model config.json, update "architectures" to ["AnyModel"], and write back to disk. + + config_path = output_dir / "config.json" + if config_path.exists(): + with open(config_path) as f: + config_data = json.load(f) + config_data["architectures"] = ["AnyModel"] + with open(config_path, "w") as f: + json.dump(config_data, f, indent=2) diff --git a/modelopt/torch/puzzletron/subblock_stats/runtime_vllm.py b/modelopt/torch/puzzletron/subblock_stats/runtime_vllm.py new file mode 100644 index 00000000000..14eb337b707 --- /dev/null +++ b/modelopt/torch/puzzletron/subblock_stats/runtime_vllm.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""vLLM Runtime Benchmark Integration for ModelOpt NAS Subblocks. + +This module provides the integration logic to empirically benchmark subblock +runtime statistics within transformer architectures using the vLLM latency +benchmark. Each invocation is launched in a dedicated subprocess so that GPU +memory and CUDA state are fully reclaimed when the subprocess exits, allowing +many sequential benchmarks to run in a single Python session without leaking. + +Usage: + - Call `run_vllm_latency_benchmark` with a model path and a + `RuntimeConfig` instance to run a latency benchmark and + return the average latency for the configuration (in milliseconds). +""" + +import json +import subprocess # nosec B404 +from pathlib import Path + +from ..tools.logger import mprint +from ..utils.vllm_adapter import convert_block_configs_to_per_layer_config +from .runtime_utils import RuntimeConfig + + +def run_vllm_latency_benchmark(model_path: Path, runtime_config: RuntimeConfig) -> float: + """Run ``vllm bench latency`` in a fresh subprocess and return avg latency in ms. + + Spawning a subprocess per call gives OS-level isolation: GPU memory, CUDA + context, and vLLM engine state are fully released on subprocess exit, so + many calls in one parent process do not accumulate. + """ + output_json_path = model_path / "vllm_latency_benchmark.json" + max_model_len = runtime_config.prefill_seq_len + runtime_config.generation_seq_len + + with open(model_path / "config.json") as f: + config = json.load(f) + + if convert_block_configs_to_per_layer_config(config): + mprint("Converted block configs to per-layer config") + with open(model_path / "config.json", "w") as f: + json.dump(config, f, indent=2) + else: + mprint("No block configs to convert") + + cmd = [ + "vllm", + "bench", + "latency", + "--model", + str(model_path), + "--input-len", + str(runtime_config.prefill_seq_len), + "--output-len", + str(runtime_config.generation_seq_len), + "--batch-size", + str(runtime_config.batch_size), + "--output-json", + str(output_json_path), + "--max-model-len", + str(max_model_len), + "--num-iters-warmup", + str(runtime_config.num_warmup_iters), + "--num-iters", + str(runtime_config.num_iters), + "--max-num-seqs", + "1", + "--tensor-parallel-size", + "1", + "--pipeline-parallel-size", + "1", + "--distributed-executor-backend", + "external_launcher", + # Required for accurate per-block runtime stats. + "--optimization-level", + "0", + ] + + # cmd is a fixed list of strings (no shell, no untrusted input). + try: + subprocess.run( + cmd, + check=True, + capture_output=True, + text=True, + timeout=1800, # 30 minutes + ) # nosec B603 + except subprocess.TimeoutExpired as exc: + raise TimeoutError("vLLM latency benchmark timed out") from exc + except subprocess.CalledProcessError as exc: + raise RuntimeError(exc.stderr or exc.stdout or "vLLM latency benchmark failed") from exc + + if output_json_path.exists(): + with open(output_json_path) as f: + vllm_results = json.load(f) + if "avg_latency" in vllm_results: + return vllm_results["avg_latency"] * 1000 # seconds -> milliseconds + + raise RuntimeError(f"vLLM benchmark output not found at {output_json_path}") diff --git a/modelopt/torch/puzzletron/utils/vllm_adapter.py b/modelopt/torch/puzzletron/utils/vllm_adapter.py new file mode 100644 index 00000000000..ae8409a1de7 --- /dev/null +++ b/modelopt/torch/puzzletron/utils/vllm_adapter.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ModelOpt/AnyModel -> vLLM/AnyModel config adapter. + +ModelOpt/AnyModel checkpoints describe per-layer overrides via a dense +``block_configs`` list with nested ``attention`` / ``ffn`` sub-sections. +AnyModel in vLLM now consumes the HuggingFace heterogeneity schema: a sparse +``per_layer_config`` dict mapping ``layer_idx -> {flat HF keys + optional +"skip" list}``. + +This module rewrites the Puzzletron schema in-place so vLLM only +ever sees ``per_layer_config``. It is invoked from +``AnyModelConfig.verify_and_update_model_config`` before the arch +convertor or layer-patching runs. +""" + +from __future__ import annotations + +from typing import Any + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +# (num_experts_field, moe_intermediate_size_field) per base architecture. +# ModelOpt always writes ``moe.num_local_experts`` and +# ``moe.expert_intermediate_{size,dim}``; the adapter rewrites them into the +# field names the base HF config actually reads. +_MOE_FIELDS_BY_ARCH: dict[str, tuple[str, str]] = { + "Qwen2MoeForCausalLM": ("num_experts", "moe_intermediate_size"), + "Qwen3MoeForCausalLM": ("num_experts", "moe_intermediate_size"), + "MixtralForCausalLM": ("num_local_experts", "intermediate_size"), + "GptOssForCausalLM": ("num_local_experts", "intermediate_size"), + "NemotronHForCausalLM": ("n_routed_experts", "moe_intermediate_size"), + "DeepseekV3ForCausalLM": ("n_routed_experts", "moe_intermediate_size"), + "DeepseekV2ForCausalLM": ("n_routed_experts", "moe_intermediate_size"), +} + +_DEFAULT_MOE_FIELDS: tuple[str, str] = ("num_local_experts", "intermediate_size") + + +def _get(obj: Any, key: str, default: Any = None) -> Any: + if obj is None: + return default + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + +def _convert_block_entry( + block: Any, + *, + global_kv: int | None, + global_isize: int | None, + global_hact: str | None, + global_moe_num: int | None, + global_moe_size: int | None, + moe_num_field: str, + moe_size_field: str, +) -> dict[str, Any]: + """Translate a single ModelOpt ``block_configs`` entry into a flat + per-layer override dict. Only attributes that differ from the global + config are emitted; sub-module no-ops become a ``"skip"`` list.""" + attn = _get(block, "attention") or {} + ffn = _get(block, "ffn") or {} + a_noop = bool(_get(attn, "no_op", False)) + f_noop = bool(_get(ffn, "no_op", False)) + + entry: dict[str, Any] = {} + skip: list[str] = [] + if a_noop: + skip.append("attention") + if f_noop: + skip.append("mlp") + if skip: + entry["skip"] = skip + + if not a_noop: + kv = _get(attn, "num_key_value_heads") + if kv is not None and kv != global_kv: + entry["num_key_value_heads"] = kv + + if not f_noop: + isize = _get(ffn, "intermediate_size") + if isize is not None and isize != global_isize: + entry["intermediate_size"] = isize + + hact = _get(ffn, "hidden_act") + if hact is not None and hact != global_hact: + entry["hidden_act"] = hact + + moe = _get(ffn, "moe") + if moe: + n_exp = _get(moe, "num_local_experts") + if n_exp is None: + n_exp = _get(moe, "num_experts") + if n_exp is None: + n_exp = _get(moe, "n_routed_experts") + if n_exp is not None and n_exp != global_moe_num: + entry[moe_num_field] = n_exp + + exp_size = _get( + moe, + "expert_intermediate_size", + _get(moe, "expert_intermediate_dim"), + ) + if exp_size is not None and exp_size != global_moe_size: + entry[moe_size_field] = exp_size + + return entry + + +def convert_block_configs_to_per_layer_config(hf_config: Any) -> bool: + """In-place: convert legacy ``block_configs`` on ``hf_config`` to + ``per_layer_config`` on its text config. + + Returns ``True`` if a conversion happened, ``False`` if there was + nothing to convert. If ``per_layer_config`` is already set, the legacy + field is dropped and a warning emitted (the new schema wins). + """ + block_configs = getattr(hf_config, "block_configs", None) + if not block_configs: + return False + + text_config = ( + hf_config.get_text_config() if hasattr(hf_config, "get_text_config") else hf_config + ) + + existing = getattr(text_config, "per_layer_config", None) + if existing: + logger.warning_once( + "AnyModel config has both legacy 'block_configs' and new " + "'per_layer_config'; using per_layer_config and ignoring " + "block_configs." + ) + if hasattr(hf_config, "block_configs"): + try: + delattr(hf_config, "block_configs") + except AttributeError: + pass + return False + + base_architecture = getattr(hf_config, "base_architecture", None) or "" + moe_num_field, moe_size_field = _MOE_FIELDS_BY_ARCH.get(base_architecture, _DEFAULT_MOE_FIELDS) + + global_kv = getattr(text_config, "num_key_value_heads", None) + global_isize = getattr(text_config, "intermediate_size", None) + global_hact = getattr(text_config, "hidden_act", None) + global_moe_num = getattr(text_config, moe_num_field, None) + global_moe_size = getattr(text_config, moe_size_field, None) + + per_layer_config: dict[str, dict[str, Any]] = {} + for idx, block in enumerate(block_configs): + entry = _convert_block_entry( + block, + global_kv=global_kv, + global_isize=global_isize, + global_hact=global_hact, + global_moe_num=global_moe_num, + global_moe_size=global_moe_size, + moe_num_field=moe_num_field, + moe_size_field=moe_size_field, + ) + if entry: + per_layer_config[str(idx)] = entry + + n_layers = getattr(text_config, "num_hidden_layers", None) + if n_layers is not None and len(block_configs) != n_layers: + logger.warning( + "block_configs length (%d) does not match num_hidden_layers " + "(%d); converted entries beyond num_hidden_layers will fail " + "AnyModel validation.", + len(block_configs), + n_layers, + ) + + setattr(text_config, "per_layer_config", per_layer_config) + try: + delattr(hf_config, "block_configs") + except AttributeError: + pass + + logger.info( + "Converted ModelOpt block_configs (%d entries) to AnyModel " + "per_layer_config (%d non-empty entries) for base_architecture=%r.", + len(block_configs), + len(per_layer_config), + base_architecture or "", + ) + return True diff --git a/noxfile.py b/noxfile.py index 059f351b7f9..1a28321bbd7 100644 --- a/noxfile.py +++ b/noxfile.py @@ -142,7 +142,7 @@ def gpu_trtllm(session): # Pin must stay in sync with examples/vllm_serve/Dockerfile. @nox.session(venv_backend="none") def gpu_vllm(session): - session.run("python3", "-m", "pip", "install", "-e", ".[hf,dev-test]") + session.run("python3", "-m", "pip", "install", "-e", ".[hf,puzzletron,dev-test]") session.run("python3", "-m", "pytest", "tests/gpu_vllm", *_cov_args()) diff --git a/tests/gpu_vllm/torch/puzzletron/test_calc_runtime_stats.py b/tests/gpu_vllm/torch/puzzletron/test_calc_runtime_stats.py new file mode 100644 index 00000000000..377a2ffed19 --- /dev/null +++ b/tests/gpu_vllm/torch/puzzletron/test_calc_runtime_stats.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU test for ``calc_runtime_for_subblocks``. + +Exercises the end-to-end vLLM latency benchmarking pipeline on a tiny model: +constructs a small subblock set, runs the benchmark for each candidate, and +checks the returned per-subblock runtime dict and no-block overhead. +""" + +import math +from pathlib import Path + +import pytest +from _test_utils.torch.transformers_models import get_tiny_tokenizer +from omegaconf import OmegaConf + +from modelopt.torch.puzzletron.block_config import AttentionConfig, FFNConfig +from modelopt.torch.puzzletron.subblock_stats.calc_runtime_stats import calc_runtime_for_subblocks + + +@pytest.mark.skip(reason="AnyModel is not supported in vLLM yet") +def test_calc_runtime_for_subblocks(tmp_path: Path): + """End-to-end: a tiny subblock set yields a runtime dict + positive no-block overhead.""" + tokenizer = get_tiny_tokenizer() + tokenizer_dir = tmp_path / "tokenizer" + tokenizer.save_pretrained(str(tokenizer_dir)) + + attn = AttentionConfig(no_op=False, num_key_value_heads=2) + ffn = FFNConfig(no_op=False, intermediate_size=256, moe=None) + attn_noop = AttentionConfig(no_op=True) + subblock_set = {attn, ffn, attn_noop} + + # vLLM's bench latency samples input ids in [0, 10000) (see + # vllm/benchmarks/latency.py), and its input validator accepts an id when + # it fits in max(tokenizer.max_token_id, model_vocab_size - 1). The tiny + # tokenizer's vocab is ~200, so we size the model vocab past 10000 to + # cover the sampled range. + runtime_by_subblock, no_block_runtime_ms = calc_runtime_for_subblocks( + subblock_config_set=subblock_set, + runtime_stats_config=OmegaConf.create({"num_iters": 1, "num_warmup_iters": 1}), + vocab_size=10016, + hidden_size=256, + num_attention_heads=4, + num_key_value_heads=2, + tokenizer_path=str(tokenizer_dir), + prefill_seq_len=8, + generation_seq_len=4, + batch_size=1, + ) + + assert set(runtime_by_subblock) == subblock_set + assert runtime_by_subblock[attn_noop] == 0.0 + assert math.isfinite(runtime_by_subblock[attn]) + assert math.isfinite(runtime_by_subblock[ffn]) + # The 1-block model is always slower than the per-block extrapolation from + # the 10-block model, so the (embedding + LM-head) overhead is positive. + assert no_block_runtime_ms > 0