diff --git a/CHANGELOG.rst b/CHANGELOG.rst index be2210a33f2..5fd1341ddd0 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,12 +24,16 @@ Changelog - Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md `_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model. - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. +- Support Megatron-Core checkpoint restore and export for MSE ``NVFP4StaticQuantizer``. +- Add mixed-precision FP8 + NVFP4 export for Megatron-Core: per-layer ``quant_algo`` recorded under ``quantized_layers`` in ``hf_quant_config.json``, PP-aware ``kv_cache_dtype`` gather, fused-QKV exclude split into per-HF-name ``q/k/v_proj`` entries. +- Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. - Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. - Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection `_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback. - The ``nemotron-sft-agentic-v2`` registered dataset (added in #1498) now uses only the ``search`` split. The previously configured ``interactive_agent`` and ``tool_calling`` splits contain content-level defects (heterogeneous schema and a malformed JSON row, respectively) that cause pyarrow's streaming JSON reader to fail deterministically. **Bug Fixes** +- In Megatron-Core only do EP amax sync for routed expert weights if ``sync_expert_weight_amax=True``. Previously EP amax sync would sync routed expert weights across EP ranks even when ``sync_expert_weight_amax`` was False - Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance. 0.44 (2026-05-14) diff --git a/examples/specdec_bench/specdec_bench/datasets/speed.py b/examples/specdec_bench/specdec_bench/datasets/speed.py index 3552d71a1ad..651ee6b1e73 100644 --- a/examples/specdec_bench/specdec_bench/datasets/speed.py +++ b/examples/specdec_bench/specdec_bench/datasets/speed.py @@ -730,11 +730,7 @@ def _load_dataset(self, config_name_or_dataset_path: config_type | str) -> "Data # Strip HF metadata from the schema to avoid Feature parsing errors schema = table.schema if schema.metadata and b"huggingface" in schema.metadata: - new_meta = { - k: v - for k, v in schema.metadata.items() - if k != b"huggingface" - } + new_meta = {k: v for k, v in schema.metadata.items() if k != b"huggingface"} table = table.replace_schema_metadata(new_meta or None) dataset = HFDataset(table) if self.num_samples is not None and self.num_samples < len(dataset): diff --git a/modelopt/torch/export/plugins/hf_checkpoint_utils.py b/modelopt/torch/export/plugins/hf_checkpoint_utils.py index 4d9bc6fc292..63ac3393486 100644 --- a/modelopt/torch/export/plugins/hf_checkpoint_utils.py +++ b/modelopt/torch/export/plugins/hf_checkpoint_utils.py @@ -22,9 +22,21 @@ import torch from huggingface_hub import snapshot_download +from huggingface_hub.errors import LocalEntryNotFoundError from safetensors.torch import safe_open from tqdm import tqdm +_HF_HUB_OFFLINE_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} + + +def _is_hf_hub_offline() -> bool: + return os.environ.get("HF_HUB_OFFLINE", "").strip().upper() in _HF_HUB_OFFLINE_TRUE_VALUES + + +def _copy_python_files(source_dir: Path, save_dir: Path) -> None: + for py_file in source_dir.glob("*.py"): + shutil.copy2(py_file, save_dir / py_file.name) + def copy_hf_ckpt_remote_code( pretrained_model_path: str | os.PathLike, save_directory: str | os.PathLike @@ -36,7 +48,10 @@ def copy_hf_ckpt_remote_code( frameworks. If ``pretrained_model_path`` is a local directory, Python files are copied directly. - If it's a HF Hub model ID (e.g. ``nvidia/NVIDIA-Nemotron-Nano-12B-v2``), files are downloaded from the Hub. + If it's a HF Hub model ID (e.g. ``nvidia/NVIDIA-Nemotron-Nano-12B-v2``), the Hub + snapshot is resolved first and Python files are copied from that snapshot. When + ``HF_HUB_OFFLINE`` is set, the snapshot must already be available in the local + Hugging Face cache. Args: pretrained_model_path: Local path to the pretrained model or HuggingFace Hub model ID. @@ -47,14 +62,28 @@ def copy_hf_ckpt_remote_code( save_dir.mkdir(parents=True, exist_ok=True) if hf_checkpoint_path.is_dir(): - for py_file in hf_checkpoint_path.glob("*.py"): - shutil.copy2(py_file, save_dir / py_file.name) + _copy_python_files(hf_checkpoint_path, save_dir) else: - snapshot_download( - repo_id=str(pretrained_model_path), - local_dir=str(save_dir), - allow_patterns=["*.py"], - ) + local_files_only = _is_hf_hub_offline() + try: + source_dir = Path( + snapshot_download( + repo_id=str(pretrained_model_path), + allow_patterns=["*.py"], + local_files_only=local_files_only, + ) + ) + except LocalEntryNotFoundError as exc: + if local_files_only: + raise RuntimeError( + f"Could not copy Python sidecar files for {pretrained_model_path!r} because " + "HF_HUB_OFFLINE is enabled and the files are not available in the local " + "Hugging Face cache. Populate the cache with the model's *.py files or pass " + "a local pretrained model directory." + ) from exc + raise + + _copy_python_files(source_dir, save_dir) def load_multimodal_components( diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index e9d1b4e1e9b..24bd8144055 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -131,7 +131,10 @@ "input_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_qkv": QKVSlicing("backbone.layers.{}.mixer."), "linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."), - "core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."), + "core_attention": SelfAttentionScaling( + "backbone.layers.{}.mixer.", + func_kwargs={"k_scale_name": "k_proj.k_scale", "v_scale_name": "v_proj.v_scale"}, + ), # MLP "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."), diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index cf9f26d51a7..b3173706b44 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -288,9 +288,25 @@ def _ensure_weight_quantizer_calibrated( module_name: Optional module name for better warning messages """ if isinstance(weight_quantizer, NVFP4StaticQuantizer): - need_per_block = not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None + + def _amax_is_invalid(t: torch.Tensor | None) -> bool: + # MCore distcp may register but not fill amax — treat missing/non-finite/negative as recompute. + if t is None: + return True + t = t.detach() + if not torch.is_floating_point(t): + return False + return bool(torch.any(~torch.isfinite(t)).item() or torch.any(t < 0).item()) + + need_per_block = ( + not hasattr(weight_quantizer, "_amax") + or weight_quantizer._amax is None + or _amax_is_invalid(weight_quantizer._amax) + ) need_global = ( - not hasattr(weight_quantizer, "_global_amax") or weight_quantizer.global_amax is None + not hasattr(weight_quantizer, "_global_amax") + or weight_quantizer.global_amax is None + or _amax_is_invalid(weight_quantizer.global_amax) ) if not (need_per_block or need_global): return diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 44529ba0fad..183dd4cb8bc 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -44,6 +44,7 @@ QUANTIZATION_FP8_PB_WO, QUANTIZATION_NONE, QUANTIZATION_NVFP4, + QUANTIZATION_W4A16_NVFP4, ) from .plugins.hf_checkpoint_utils import copy_hf_ckpt_remote_code, load_multimodal_components from .plugins.mcore_common import all_mcore_hf_export_mapping @@ -61,6 +62,7 @@ get_weight_block_size, get_weight_scaling_factor, get_weight_scaling_factor_2, + process_layer_quant_config, to_quantized_weight, ) @@ -169,6 +171,7 @@ def __init__( self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] self.exclude_modules = [] + self.layer_config_dict = {} if not hasattr(model, "_modelopt_state"): return @@ -287,6 +290,8 @@ def save_pretrained( quantization = "FP8" elif quantization_format == QUANTIZATION_NVFP4: quantization = "NVFP4" + elif quantization_format == QUANTIZATION_W4A16_NVFP4: + quantization = "W4A16_NVFP4" # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. @@ -324,22 +329,32 @@ def save_pretrained( print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors") combined_exclude_modules = self._gather_exclude_modules() + combined_layer_config_dict = self._gather_layer_config_dict() + # kv_cache_dtype is only set on attention-owning ranks; writer rank may not be one. + gathered_kv_cache_dtype = self._gather_kv_cache_dtype() if is_last_stage_main_rank and quantization is not None: + if combined_layer_config_dict: + quantization_config = process_layer_quant_config(combined_layer_config_dict) + quantization_config["exclude_modules"] = combined_exclude_modules + else: + quantization_config = { + "quant_algo": quantization, + "exclude_modules": combined_exclude_modules, + } + if quantization in ("NVFP4", "W4A16_NVFP4"): # update block size + quantization_config["group_size"] = 16 + + if gathered_kv_cache_dtype is not None: + quantization_config["kv_cache_quant_algo"] = gathered_kv_cache_dtype + self._hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, }, - "quantization": { - "quant_algo": quantization, - "exclude_modules": combined_exclude_modules, - }, + "quantization": quantization_config, } - if quantization == "NVFP4": # update block size - self._hf_quant_config["quantization"]["group_size"] = 16 - if hasattr(self, "kv_cache_dtype"): - self._hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(self._hf_quant_config, f, indent=4) @@ -359,10 +374,11 @@ def save_pretrained( # Newer versions of VLLM expect config.json with hf_quant_config config_json_file = save_directory + "/config.json" if self._hf_quant_config and os.path.exists(config_json_file): - converted_quant_config = convert_hf_quant_config_format(self._hf_quant_config) with open(config_json_file) as f: config_dict = json.load(f) - config_dict["quantization_config"] = converted_quant_config + config_dict["quantization_config"] = convert_hf_quant_config_format( + self._hf_quant_config + ) with open(config_json_file, "w") as f: json.dump(config_dict, f, indent=4) @@ -814,9 +830,7 @@ def _get_quantized_state( name_to_value = {} qformat: str = self._get_quantization_format(module) if qformat is None and "norm" not in prefix: - # Add exclude layers for hf_quant_config. Note that if the prefix is not an empty - # string then it usually ends with "." which needs to be removed. - self.exclude_modules.append(prefix.removesuffix(".")) + self._record_excluded_module(prefix) block_size = get_weight_block_size(module) name_to_value = self._get_weight_bias(module, dtype, name_to_value) @@ -861,6 +875,27 @@ def _get_weight_scales(self, quantized_state: dict[str, Any], qformat: str): return weight_scale, weight_scale_2 + def _record_layer_quant_config(self, prefix: str, qformat: str | None, block_size: int): + """Record per-HF-layer quantization metadata for mixed precision exports.""" + if qformat in (None, QUANTIZATION_NONE): + return + + layer_name = prefix.removesuffix(".") + if "{" in layer_name or not layer_name: + return + + self.layer_config_dict[layer_name + ".quantization"] = qformat + self.layer_config_dict[layer_name + ".awq_block_size"] = block_size + + def _record_excluded_module(self, prefix: str): + """Record an unquantized HF module prefix for hf_quant_config.""" + layer_name = prefix.removesuffix(".") + if "{" in layer_name or not layer_name: + return + + if layer_name not in self.exclude_modules: + self.exclude_modules.append(layer_name) + def _name_remapping( self, module: torch.nn.Module | torch.Tensor, @@ -877,6 +912,7 @@ def _name_remapping( return name_to_value, qformat, block_size = self._get_quantized_state(module, dtype, prefix=prefix) + self._record_layer_quant_config(prefix, qformat, block_size) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -917,6 +953,8 @@ def _gated_mlp_slicing( gate_proj_prefix = prefix + gate_proj_name + "." up_proj_prefix = prefix + up_proj_name + "." + self._record_layer_quant_config(gate_proj_prefix, qformat, block_size) + self._record_layer_quant_config(up_proj_prefix, qformat, block_size) ffn_hidden_size = module.config.ffn_hidden_size gate_proj_weight = weight[:ffn_hidden_size, :] @@ -997,6 +1035,7 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None): for expert_id in range(num_experts): expert_prefix = prefix.format(expert_id) + "." + self._record_layer_quant_config(expert_prefix, qformat, block_size) weight_key = f"weight{expert_id}" if weight_key not in state_dict: @@ -1041,6 +1080,16 @@ def _qkv_slicing( q_proj_prefix = prefix + q_proj_name + "." k_proj_prefix = prefix + k_proj_name + "." v_proj_prefix = prefix + v_proj_name + "." + self._record_layer_quant_config(q_proj_prefix, qformat, block_size) + self._record_layer_quant_config(k_proj_prefix, qformat, block_size) + self._record_layer_quant_config(v_proj_prefix, qformat, block_size) + if qformat in (None, QUANTIZATION_NONE): + # Split fused linear_qkv exclude into per-HF-name q/k/v_proj entries. + fused_prefix = prefix.removesuffix(".") + self.exclude_modules = [m for m in self.exclude_modules if m != fused_prefix] + self._record_excluded_module(q_proj_prefix) + self._record_excluded_module(k_proj_prefix) + self._record_excluded_module(v_proj_prefix) config = module.config hidden_size = config.hidden_size @@ -1190,6 +1239,7 @@ def _pack_name_remapping(self, module, prefix, layer_type=None): weight_scale_list.append(weight_scale) weight_scale_2_list.append(weight_scale_2) input_scale_list.append(input_scale) + self._record_layer_quant_config(prefix, qformat, block_size) merged_weight = torch.stack(weight_list, dim=0) @@ -1258,6 +1308,7 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): weight_scale_2_list.append(weight_scale_2) input_scale_list.append(input_scale) bias_list.append(bias) + self._record_layer_quant_config(prefix, qformat, block_size) merged_weight = torch.stack(weight_list, dim=0) @@ -1360,6 +1411,31 @@ def _gather_exclude_modules(self): combined_exclude_modules.update(modules) return sorted(combined_exclude_modules) + def _gather_layer_config_dict(self): + """Get per-layer quantization metadata from all ranks for hf_quant_config.""" + if not torch.distributed.is_initialized(): + return dict(sorted(self.layer_config_dict.items())) + + all_layer_config_dicts = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_layer_config_dicts, self.layer_config_dict) + combined_layer_config_dict = {} + for layer_config_dict in all_layer_config_dicts: + if layer_config_dict: + combined_layer_config_dict.update(layer_config_dict) + return dict(sorted(combined_layer_config_dict.items())) + + def _gather_kv_cache_dtype(self): + """Return first non-None kv_cache_dtype across ranks (only attention ranks set it).""" + local = getattr(self, "kv_cache_dtype", None) + if not torch.distributed.is_initialized(): + return local + all_dtypes = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_dtypes, local) + for dt in all_dtypes: + if dt is not None: + return dt + return None + def export_mcore_gpt_to_hf( model: torch.nn.Module, diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index e4e633e36ae..d96ef4593d2 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -293,13 +293,14 @@ def get_score(self, recipe: QuantRecipe) -> float: total_score += importance.cpu().item() continue - if parallel_state.expert_model_parallel_group.is_initialized(): - # TODO: Support expert model parallelism for score estimation - warnings.warn("AutoQuantize does not support expert model parallelism yet.") importance = importance.cpu() importance = DistributedProcessGroup.get_dist_syncd_obj( importance, - [parallel_state.tensor_parallel_group, parallel_state.data_parallel_group], + [ + parallel_state.tensor_parallel_group, + parallel_state.data_parallel_group, + parallel_state.expert_model_parallel_group, + ], sum, ) total_score += importance.item() @@ -320,13 +321,12 @@ def get_cost(self, recipe: QuantRecipe) -> float: cost += weight_size * recipe.compression continue - if parallel_state.expert_model_parallel_group.is_initialized(): - # TODO: Support expert model parallelism - warnings.warn("AutoQuantize does not support expert model parallelism yet.") - weight_size = DistributedProcessGroup.get_dist_syncd_obj( weight_size, - [parallel_state.tensor_parallel_group], + [ + parallel_state.tensor_parallel_group, + parallel_state.expert_model_parallel_group, + ], sum, ) @@ -364,6 +364,8 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC): # gate_proj, up_proj, down_proj for Qwen3 like MoE models r"^(.*?\.mlp\.experts)\.\d+\.(gate_proj|up_proj|down_proj)$", r"^(.*?\.mixer\.experts)\.\d+\.(up_proj|down_proj)$", # NemotronH MoE experts + # NemotronH MoE experts in MCore naming (linear_fc1=gate+up fused, linear_fc2=down) + r"^(.*?\.mlp\.experts\.local_experts)\.\d+\.(linear_fc1|linear_fc2)$", r"^(.*?)\.(gate_proj|up_proj)$", # gate_proj, up_proj for llama like models r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts @@ -722,6 +724,15 @@ def _get_total_weight_size(modules): for module in modules ) + @staticmethod + def _get_total_weight_size_from_candidate_stats(candidate_stats): + no_quant_recipe = QuantRecipe(quant_cfg=None) + total_weight_size = 0 + for candidate_stat in candidate_stats.values(): + no_quant_idx = candidate_stat["formats"].index(no_quant_recipe) + total_weight_size += candidate_stat["costs"][no_quant_idx] + return total_weight_size + def _get_constraints_for_search(self, max_weight_size, lower_bound=None): constraints = { "weight_size_after_compression": ( @@ -744,7 +755,7 @@ def run_search(self): ) compression = self._get_formatted_weight_compression_constraint() - total_weight_size = self._get_total_weight_size(self.model.modules()) + total_weight_size = self._get_total_weight_size_from_candidate_stats(self.candidate_stats) max_weight_size = total_weight_size * compression # Run the search with stats to get the best recipe and whether the constraints are satisfied @@ -754,12 +765,16 @@ def run_search(self): best_recipe = {} best_constraints, best_scores = 0, 0 for name, best_hparam_recipe_info in best_recipe_info.items(): - # Solvers could give different solutions for the same layer across DP/TP groups even though - # the scores and costs are the same. Lets make sure the same recipe is selected across DP/TP + # Solvers could give different solutions for the same layer across DP/TP/EP groups even though + # the scores and costs are the same. Lets make sure the same recipe is selected across DP/TP/EP _ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state best_format = DistributedProcessGroup.get_dist_syncd_obj( best_hparam_recipe_info["format"], - [_ps.data_parallel_group, _ps.tensor_parallel_group], + [ + _ps.data_parallel_group, + _ps.tensor_parallel_group, + _ps.expert_model_parallel_group, + ], lambda a: a[0], ) @@ -1379,7 +1394,9 @@ def _resolve_best_recipe(search_state, constraints, verbose=False): effective_bits = constraints["effective_bits"] compression = effective_bits / 16.0 candidate_stats = search_state["candidate_stats"] - total_weight_size = sum(s["costs"][-1] for s in candidate_stats.values()) + total_weight_size = _AutoQuantizeBaseSearcher._get_total_weight_size_from_candidate_stats( + candidate_stats + ) max_weight_size = total_weight_size * compression method = search_state["method"] diff --git a/modelopt/torch/quantization/backends/utils.py b/modelopt/torch/quantization/backends/utils.py index 838951bef9d..6ed133f5d6c 100644 --- a/modelopt/torch/quantization/backends/utils.py +++ b/modelopt/torch/quantization/backends/utils.py @@ -20,9 +20,13 @@ def fp8_compatible(): """Check if the current device supports FP8.""" + if not torch.cuda.is_available(): + return False return torch.cuda.get_device_capability(0) >= (8, 9) def fp4_compatible(): """Check if the current device supports FP4.""" + if not torch.cuda.is_available(): + return False return torch.cuda.get_device_capability(0) >= (10, 0) diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 40c2b8dbc7e..4a79d1a6507 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -98,6 +98,13 @@ def restore_quantized_model( return restore_quantizer_state(model, config, metadata) +def maybe_promote_nvfp4_static_quantizer(module: nn.Module, quantizer_state: dict) -> None: + if quantizer_state.get("_is_nvfp4_static_quantizer") and not isinstance( + module, NVFP4StaticQuantizer + ): + NVFP4StaticQuantizer.from_tensor_quantizer(module) + + def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: MetadataDict): """Restore the quantizer states from the given state dict. @@ -131,12 +138,8 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: if isinstance(module, TensorQuantizer): name = get_unwrapped_name(name, model) state = quantizer_state_dict[name] - # TODO: Add a registry for TensorQuantizers and avoid this manual conversion. - if state.get("_is_nvfp4_static_quantizer") and not isinstance( - module, NVFP4StaticQuantizer - ): - NVFP4StaticQuantizer.from_tensor_quantizer(module) - module.set_from_modelopt_state(quantizer_state_dict[name]) + maybe_promote_nvfp4_static_quantizer(module, state) + module.set_from_modelopt_state(state) for name, module in model.named_modules(): if isinstance(module, QuantModule): diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 78b237847b1..fa7ca1cedec 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -55,6 +55,14 @@ ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper +try: + from .plugins.megatron import _check_static_block_tp_supported +except ImportError: + + def _check_static_block_tp_supported(model: nn.Module) -> None: # no-op without megatron + return + + __all__ = [ "CalibratorFactory", "awq", @@ -99,7 +107,7 @@ def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: @torch.no_grad() -def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: +def _bootstrap_uncalibrated_static_weight_quantizers(model: nn.Module) -> int: """Re-run weight calibration on the weight tensor for quantizers missing ``_amax``. Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE @@ -210,26 +218,51 @@ def _has_expert_parallelism(module: nn.Module) -> bool: return ps is not None and ps.expert_model_parallel_group.is_initialized() -def _check_moe_calibration_complete(quantizer, parallel_state): - """Raise error if MoE calibration is incomplete (some ranks have amax, others don't).""" +def _iter_leaf_quantizers(quantizer): if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - _check_moe_calibration_complete(_q, parallel_state) + yield from _iter_leaf_quantizers(_q) return - for group in [ - parallel_state.data_parallel_group, - parallel_state.expert_model_parallel_group, - parallel_state.tensor_parallel_group, - ]: - if not group.is_initialized(): - continue - has_amax = getattr(quantizer, "_amax", None) is not None - amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs) - if any(amax_states) and not all(amax_states): - raise RuntimeError( - "MoE calibration incomplete: some experts received no tokens during calibration. " - "Increase --calib-size to ensure all experts see calibration data." + yield quantizer + + +def _check_moe_calibration_complete(quantizer, parallel_state): + """Raise error if MoE calibration is incomplete across distributed MoE ranks.""" + for leaf_quantizer in _iter_leaf_quantizers(quantizer): + has_amax = getattr(leaf_quantizer, "_amax", None) is not None + for group in [ + parallel_state.data_parallel_group, + parallel_state.expert_model_parallel_group, + parallel_state.tensor_parallel_group, + ]: + if not group.is_initialized(): + continue + amax_states = DistributedProcessGroup.get_dist_syncd_obj( + has_amax, group, lambda objs: objs ) + if any(amax_states) and not all(amax_states): + raise RuntimeError( + "MoE calibration incomplete: some experts received no tokens during " + "calibration. Increase --calib-size to ensure all experts see calibration " + "data." + ) + + +def _is_routed_expert(parent_name: str) -> bool: + """Routed-expert FQN contains ``experts`` but not ``shared_experts`` (covers SequentialMLP and TEGroupedMLP).""" + return "experts" in parent_name and "shared_experts" not in parent_name + + +def _should_sync_amax_across_ep( + parent_name: str, child_name: str, sync_expert_weight_amax: bool +) -> bool: + """Skip EP sync for routed-expert weights (per-rank shards differ). + + SequentialMLP opts in via sync_expert_weight_amax. + """ + if "weight_quantizer" in child_name and _is_routed_expert(parent_name): + return sync_expert_weight_amax + return True @torch.no_grad() @@ -246,7 +279,8 @@ def max_calibrate( forward_loop: A callable which takes the model as argument and forwards calibration data through the model. distributed_sync: Whether to sync input_quantizer amax across distributed processes. - sync_expert_weight_amax: Whether to sync weight quantizer amax across MoE experts. + sync_expert_weight_amax: SequentialMLP only — share one weight amax across all experts + in a MoE layer (within-rank sync + EP all-reduce when EP>1). See :class:`MaxCalibConfig ` for details on the remaining arguments. @@ -263,13 +297,14 @@ def max_calibrate( if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax(sync_weight_amax=sync_expert_weight_amax) - # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer - # so the static blockwise fake-quant path is used in forward and the export - # picks up the two-level (per-block + global) scaling. Run before the - # ``distributed_sync`` early return so single-process callers also get the - # promotion. ``promote_nvfp4_static_quantizers`` only promotes when - # ``is_static_block_quant`` is True and the per-block ``_amax`` buffer is - # populated, so it's a no-op for dynamic-block / non-NVFP4 configs. + # Fail fast on static-block under TP>1 (sharded_state_dict treats _amax as replicated). + _check_static_block_tp_supported(model) + + # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer so + # the static blockwise fake-quant path is used in forward and export picks up the + # two-level (per-block + global) scaling. Run before the ``distributed_sync`` early + # return so single-process callers also get the promotion. No-op for dynamic-block + # / non-NVFP4 configs. promote_nvfp4_static_quantizers(model) if not distributed_sync: @@ -282,23 +317,24 @@ def max_calibrate( if isinstance(child, (TensorQuantizer, SequentialQuantizer)): _check_moe_calibration_complete(child, module.parallel_state) - def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): - """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state, parent_name, child_name): + """Sync amax across DP (always) and EP (filtered — see _should_sync_amax_across_ep).""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp_ep(_q, parallel_state) + sync_quantizer_amax_across_dp_ep(_q, parallel_state, parent_name, child_name) + return + if getattr(quantizer, "_amax", None) is None: return - if getattr(quantizer, "_amax", None) is not None: - quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + if _should_sync_amax_across_ep(parent_name, child_name, sync_expert_weight_amax): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) - # TODO: create sync_bias_across_distributed_group # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): - for child in module.children(): + for child_name, child in module.named_children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp_ep(child, module.parallel_state) + sync_quantizer_amax_across_dp_ep(child, module.parallel_state, name, child_name) # Step 3: TP sync # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -321,7 +357,6 @@ def sync_quantizer_amax_across_tp( # Syncing amax across TP for sequential quantizer if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - # Syncing amax across TP for sequential quantizer sync_quantizer_amax_across_tp( _q, linear_name, quantizer_type, axes_for_sync, parallel_state ) @@ -449,13 +484,23 @@ def mse_calibrate( # Step 1: max calibrate, bootstrap dead-expert weight quantizers, # unify grouped NVFP4 global_amax so MSE sees a consistent FP8 grid. max_calibrate(model, forward_loop, distributed_sync) - _bootstrap_uncalibrated_weight_quantizers(model) + _bootstrap_uncalibrated_static_weight_quantizers(model) _sync_grouped_weight_global_amax(model) # Step 2: replace calibrators with MseCalibrator for enabled quantizers. + skipped_non_nvfp4: dict[str, int] = {} for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + bs = getattr(module, "block_sizes", None) + if not ( + getattr(module, "_num_bits", None) == (2, 1) + and bs is not None + and bs.get("scale_bits") == (4, 3) + ): + fmt = f"num_bits={module._num_bits} block_sizes={bs}" + skipped_non_nvfp4[fmt] = skipped_non_nvfp4.get(fmt, 0) + 1 + continue initial_amax = module._amax.clone().detach() is_nvfp4_static = module.is_nvfp4_static @@ -503,6 +548,14 @@ def mse_calibrate( quant_func=partial(_mse_quant_func, quantizer=module), ) + if skipped_non_nvfp4: + formats = ", ".join(f"{n}x [{fmt}]" for fmt, n in skipped_non_nvfp4.items()) + warnings.warn( + f"MSE calibration only meaningful for NVFP4; skipped {sum(skipped_non_nvfp4.values())} " + f"non-NVFP4 quantizer(s) — keeping max-calibrated amax: {formats}", + stacklevel=2, + ) + # Step 3: calibrate weight quantizers via iter_weights_for_calibration. name_to_module = dict(model.named_modules()) seen_modules: set[int] = set() diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 5e3cea44c2a..0dff5d79aed 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -547,15 +547,25 @@ def is_mxfp(self, bits): else: raise NotImplementedError() + @property + def is_block_quant(self): + """Check if is block quantization (static or dynamic).""" + return self.block_sizes is not None + @property def is_static_block_quant(self): """Check if is static block quantization.""" return ( - self.block_sizes is not None + self.is_block_quant and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ) + @property + def is_dynamic_block_quant(self): + """Check if is dynamic block quantization.""" + return self.is_block_quant and self.block_sizes.get("type", None) == "dynamic" + @property def rotate_is_enabled(self): """Check if rotate is enabled in quant config.""" diff --git a/modelopt/torch/quantization/plugins/custom.py b/modelopt/torch/quantization/plugins/custom.py index 4200aadc73a..f480d245daa 100644 --- a/modelopt/torch/quantization/plugins/custom.py +++ b/modelopt/torch/quantization/plugins/custom.py @@ -24,7 +24,7 @@ from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModule, SequentialQuantizer, TensorQuantizer +from ..nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from ..nn.modules.quant_linear import _QuantLinear from ..utils import multi_context, replace_function @@ -126,7 +126,7 @@ def modelopt_post_restore(self, prefix: str = ""): def _check_unsupported_states(quantizer: TensorQuantizer): for k in quantizer.state_dict(): - if k not in ["_amax", "_pre_quant_scale"]: + if k not in ["_amax", "_pre_quant_scale", "_global_amax"]: warnings.warn( f"Restore of {k} for {prefix} is not supported. The restore of this layer might be " f"incorrect. Please implement a custom restore for {k}." @@ -137,6 +137,21 @@ def _has_state(quantizer, name): quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer return hasattr(quantizer, name) + def _has_complete_static_nvfp4_weight_state(quantizer, weight): + quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer + if not isinstance(quantizer, NVFP4StaticQuantizer): + return False + amax = getattr(quantizer, "_amax", None) + global_amax = getattr(quantizer, "global_amax", None) + if amax is None or global_amax is None: + return False + block_sizes = getattr(quantizer, "block_sizes", None) + block_size = block_sizes.get(-1) if isinstance(block_sizes, dict) else None + if block_size is None or weight.shape[-1] % block_size != 0: + return False + expected_blocks = weight.numel() // block_size + return amax.numel() == expected_blocks and global_amax.numel() == 1 + if self.weight is None: return @@ -144,7 +159,10 @@ def _has_state(quantizer, name): _check_unsupported_states( quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0] ) - if _has_state(self.weight_quantizer, "_amax"): + # Skip max_calibrate when saved static NVFP4 state is intact; else MSE scales get overwritten. + if _has_state( + self.weight_quantizer, "_amax" + ) and not _has_complete_static_nvfp4_weight_state(self.weight_quantizer, self.weight): self.weight_quantizer.reset_amax() max_calibrate(self.weight_quantizer, lambda wq: wq(self.weight), distributed_sync=False) if _has_state(self.input_quantizer, "_pre_quant_scale"): diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 0b50fd937ae..bec9ab8e081 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -18,6 +18,7 @@ import logging import types import warnings +from contextlib import contextmanager from typing import Any import megatron.core.parallel_state as mcore_parallel @@ -40,7 +41,8 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer +from ..conversion import maybe_promote_nvfp4_static_quantizer +from ..nn import QuantModule, QuantModuleRegistry, SequentialQuantizer, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from ..utils import sync_moe_expert_amax @@ -68,6 +70,35 @@ __all__ = [] +def _check_static_block_tp_supported(model: torch.nn.Module) -> None: + """Raise under TP>1: static-block _amax is shard-local but sharded_state_dict treats it as replicated.""" + offending = [] + for name, module in model.named_modules(): + if not isinstance(module, QuantModule): + continue + parallel_state = getattr(module, "parallel_state", None) + if parallel_state is None: + continue + tp_group = getattr(parallel_state, "tensor_parallel_group", None) + if tp_group is None or not tp_group.is_initialized() or tp_group.world_size() <= 1: + continue + weight_quantizer = getattr(module, "weight_quantizer", None) + if weight_quantizer is None: + continue + leaves = ( + list(weight_quantizer) + if isinstance(weight_quantizer, SequentialQuantizer) + else [weight_quantizer] + ) + if any(leaf.is_static_block_quant for leaf in leaves): + offending.append((name, tp_group.world_size())) + if offending: + raise NotImplementedError( + "Static-block NVFP4 weight quantization (e.g. MSE) is not supported with TP > 1. Please re-run with TP=1. " + f"Offending modules (showing first 5 of {len(offending)}): {offending[:5]}" + ) + + def real_quant_module_get_extra_state(self) -> dict: """Populating real_quantizer_state and q_tensor_state.""" extra_state = {} @@ -190,7 +221,9 @@ def quant_module_set_extra_state(self, state: Any): if quantizer_state is not None: for name, module in self.named_modules(): if isinstance(module, TensorQuantizer): - module.set_from_modelopt_state(quantizer_state[name], properties_only=False) + quantizer_substate = quantizer_state[name] + maybe_promote_nvfp4_static_quantizer(module, quantizer_substate) + module.set_from_modelopt_state(quantizer_substate, properties_only=False) self.modelopt_post_restore() # Handle real_quantizer_state and q_tensor_state @@ -250,11 +283,6 @@ def _configure_attention_for_kv_cache_quant(module: Attention): def _register_extra_state_callbacks(model: torch.nn.Module): for name, module in model.named_modules(): - if name.endswith("output_layer"): - # output_layer is not quantized, - # hence we don't need to register extra state callbacks for it - continue - if type(module) in QuantModuleRegistry: # This module will be replaced as a QuantModule register_modelopt_extra_state_callbacks( @@ -344,8 +372,13 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): # output_layer.input_quantizer._amax but TP-only does not. This lead to # state_dict mismatch. if prefix.endswith("output_layer."): - # assert not any("_quantizer" in k for k in self.state_dict()), "quantized output_layer" - return super().sharded_state_dict(prefix, sharded_offsets, metadata) + try: + from megatron.training import get_args as _mlm_get_args + _untied = bool(getattr(_mlm_get_args(), "untie_embeddings_and_output_weights", False)) + except Exception: + _untied = False + if not _untied: + return super().sharded_state_dict(prefix, sharded_offsets, metadata) quantizer_state_dict = {} for k, v in self.state_dict(prefix="", keep_vars=True).items(): @@ -399,6 +432,9 @@ def _get_shard_axis_dict(self, state_dict): """ shard_axis_dict = {} for k in state_dict: + # Static NVFP4 _global_amax is a replicated scalar; only per-block _amax shards. + if k.endswith("_global_amax"): + continue if "weight_quantizer." in k: weight_quantizer_axis = self.get_submodule(k.rsplit(".", 1)[0]).axis if weight_quantizer_axis is not None: @@ -427,6 +463,9 @@ def _get_shard_axis_dict(self, state_dict): """ shard_axis_dict = {} for k in state_dict: + # Static NVFP4 _global_amax is a replicated scalar; only per-block _amax shards. + if k.endswith("_global_amax"): + continue if "weight_quantizer." in k: weight_quantizer_axis = None if isinstance(self.weight_quantizer, TensorQuantizer): @@ -737,3 +776,33 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): # Affine KVCache Quant bias vector. state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets) + + +def _is_supported_megatron_model(model: torch.nn.Module) -> bool: + return isinstance(model, MegatronModule) + + +@contextmanager +def _megatron_grad_ckpt_context(model: torch.nn.Module): + # Megatron configures activation recompute at model build time via TransformerConfig, + # so there is no runtime flag to flip here. + yield + + +def _is_param_grad_enabled_for_megatron(pname: str, model: torch.nn.Module) -> bool: + return "embed" in pname + + +def _register_auto_quantize_support() -> None: + # Local import breaks the circular path where algorithms imports model_calib, + # which imports _check_static_block_tp_supported from this plugin. + from ..algorithms import AutoQuantizeGradientSearcher + + AutoQuantizeGradientSearcher.register_custom_support( + _is_supported_megatron_model, + _megatron_grad_ckpt_context, + _is_param_grad_enabled_for_megatron, + ) + + +_register_auto_quantize_support() diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 32332bf1f6b..0f84cdea3c1 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -71,7 +71,16 @@ def get_e2m1_bounds(cls, device): @classmethod def _is_static_quantizer(cls, weight_quantizer) -> bool: """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax.""" - return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None + global_amax = cls._get_static_global_amax(weight_quantizer) + return global_amax is not None + + @classmethod + def _get_static_global_amax(cls, weight_quantizer): + """Return global amax from live or restored static NVFP4 quantizers.""" + global_amax = getattr(weight_quantizer, "global_amax", None) + if global_amax is None: + global_amax = getattr(weight_quantizer, "_global_amax", None) + return global_amax @classmethod def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): @@ -86,8 +95,9 @@ def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): Returns: The global scaling factor as a float tensor. """ - if cls._is_static_quantizer(weight_quantizer): - return weight_quantizer.global_amax.float() / (6.0 * 448.0) + global_amax = cls._get_static_global_amax(weight_quantizer) + if global_amax is not None: + return global_amax.float() / (6.0 * 448.0) else: assert hasattr(weight_quantizer, "_amax"), ( "Weight quantizer does not have attribute amax" @@ -125,7 +135,7 @@ def get_weights_scaling_factor_from_quantizer( if cls._is_static_quantizer(weight_quantizer): # Static path: use pre-computed per-block amax values from quantizer - global_amax = weight_quantizer.global_amax.float() + global_amax = cls._get_static_global_amax(weight_quantizer).float() per_block_amax = weight_quantizer._amax.float() # Compute scales in float diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 80ed8f9abdd..800140acd99 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -557,6 +557,22 @@ def __len__(self): return len(next(iter(self.encodings.values()))) +def get_dataloader_from_dataset( + dataset, + batch_size: int = 1, + distributed: bool = False, + sampler_kwargs: dict | None = None, + shuffle: bool = False, +) -> DataLoader: + """Wrap a pre-tokenized torch Dataset in a DataLoader, with optional DistributedSampler.""" + if distributed: + from torch.utils.data.distributed import DistributedSampler + + sampler = DistributedSampler(dataset, **(sampler_kwargs or {})) + return DataLoader(dataset, batch_size=batch_size, sampler=sampler) + return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) + + def get_dataset_dataloader( dataset_name: str | list[str] = "cnn_dailymail", tokenizer: "PreTrainedTokenizerBase | None" = None, diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml new file mode 100644 index 00000000000..df2b30b8ed8 --- /dev/null +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# but with ONE major difference: use max calibration instead of MSE +# - MoE routed experts: NVFP4 W4A4 weight, group_size 16 +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# HF names: mixer.shared_experts.{up,down}_proj +# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} +# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor +# - KV cache: FP8 +# - Attention linears ({q,k,v}_proj): BF16 (not quantized) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) +# +# Calibration: amax/max calibration comparison variant +metadata: + recipe_type: ptq + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj FP8 per-tensor; FP8 KV cache; + everything else(lm_head/MTP/Latent MOE) stay BF16. Amax calibration comparison variant. +quantize: + algorithm: + method: max + quant_cfg: + # Disable all layers by default so that these layers stay in original BF16 precision: + # lm_head, output projection, MoE routers/gates, Latent MOE, MTP head, mamba conv1d. + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml new file mode 100644 index 00000000000..729dcd12d52 --- /dev/null +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# - MoE routed experts: NVFP4 W4A4 weight MSE, group_size 16 +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# HF names: mixer.shared_experts.{up,down}_proj +# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} +# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor +# - KV cache: FP8 +# - Attention linears ({q,k,v}_proj): BF16 (not quantized) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) +# +# Calibration: weight MSE with FP8-scale sweep over the 128 e4m3 scale values +# (NVFP4 weights use static block scales selected by MSE; FP8 per-tensor scales +# are also chosen via MSE search instead of plain amax). +metadata: + recipe_type: ptq + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj FP8 per-tensor; FP8 KV cache; + everything else(lm_head/MTP/latent MOE) stay BF16. Weight-MSE calibration with FP8 scale sweep. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + quant_cfg: + # Disable all layers by default so that these layers stay in original BF16 precision: + # lm_head, output projection, MoE routers/gates, Latent MOE, MTP head, mamba conv1d. + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. + # Weight uses static block scales (chosen by MSE); activations stay dynamic. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 diff --git a/tests/_test_utils/torch/quantization/quantize_common.py b/tests/_test_utils/torch/quantization/quantize_common.py index 7af33fa599f..46259203b24 100644 --- a/tests/_test_utils/torch/quantization/quantize_common.py +++ b/tests/_test_utils/torch/quantization/quantize_common.py @@ -249,14 +249,28 @@ def forward_loop(model): ) -def auto_quantize_helper(model): +def auto_quantize_helper( + model, + data_loader=None, + forward_step=None, + forward_backward_step=None, + quantization_formats=None, +): + if data_loader is None: + data_loader = [model.get_dummy_input().cuda() for _ in range(2)] + if forward_step is None: + forward_step = lambda model, batch: model(batch) # noqa: E731 + if forward_backward_step is None: + forward_backward_step = lambda model, batch: model(batch).sum().backward() # noqa: E731 + if quantization_formats is None: + quantization_formats = [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG] model, search_state = mtq.auto_quantize( model, constraints={"effective_bits": 8.0}, - quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], - data_loader=[model.get_dummy_input().cuda() for _ in range(2)], - forward_step=lambda model, batch: model(batch), - forward_backward_step=lambda model, batch: model(batch).sum().backward(), + quantization_formats=quantization_formats, + data_loader=data_loader, + forward_step=forward_step, + forward_backward_step=forward_backward_step, num_calib_steps=2, num_score_steps=2, verbose=True, diff --git a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py index 2b5caea16f0..773a2a47405 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py +++ b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py @@ -138,6 +138,23 @@ def test_fake_quantize_with_both_amaxs(self, device): assert torch.allclose(output, expected) + def test_static_export_clamps_overflowing_fp8_block_scales(self, device): + """Static export should match fake quant clamping and never write NaN FP8 scales.""" + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "static", "scale_bits": (4, 3)}, + ) + quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + quantizer.amax = torch.tensor([8.0], device=device) + quantizer.global_amax = torch.tensor(1.0, device=device) + weight = torch.ones(1, 16, device=device) + + weight_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(quantizer, weight) + + assert weight_scale.dtype == torch.float8_e4m3fn + assert torch.isfinite(weight_scale.float()).all() + assert torch.equal(weight_scale.float(), torch.full_like(weight_scale.float(), 448.0)) + @pytest.mark.parametrize("device", ["cuda"]) class TestNVFP4MSECalibrator: diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index 3fac8269ccd..8dfbc0323c2 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -29,7 +29,7 @@ import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp -from modelopt.torch.export import KV_CACHE_FP8, export_mcore_gpt_to_hf, import_mcore_gpt_from_hf +from modelopt.torch.export import export_mcore_gpt_to_hf, import_mcore_gpt_from_hf from modelopt.torch.export.unified_export_megatron import GPTModelExporter from modelopt.torch.speculative.eagle.default_config import default_eagle_config from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel @@ -42,15 +42,8 @@ def _verify_model_quant_config( """Verify config.json and hf_quant_config.json""" config_dict = json.load(open(export_dir / "config.json")) hf_quant_config_dict = json.load(open(export_dir / "hf_quant_config.json")) - # Make sure config.json and hf_quant_config.json are consistent - assert ( - config_dict["quantization_config"]["quant_algo"] - == hf_quant_config_dict["quantization"]["quant_algo"] - ) - assert ( - config_dict["quantization_config"]["ignore"] - == hf_quant_config_dict["quantization"]["exclude_modules"] - ) + # Make sure config.json and hf_quant_config.json use the same serving config. + assert config_dict["quantization_config"] == hf_quant_config_dict # Verify config.json if kv_cache_quant_cfg: @@ -58,17 +51,17 @@ def _verify_model_quant_config( # Verify hf_quant_config.json if quant_config: - quant_config_dict = hf_quant_config_dict["quantization"] + quant_config_dict = hf_quant_config_dict quant_type = quant_config_dict["quant_algo"] assert ( quant_type in quant_config ) # quant config str is subset of quant config e.g. NVFP4 -> NVFP4_DEFAULT_CFG - assert len(quant_config_dict["exclude_modules"]) > 1 # Dynamically added exclude modules + assert len(quant_config_dict["ignore"]) > 1 # Dynamically added exclude modules if quant_type == "NVFP4": - assert quant_config_dict["group_size"] == 16 + assert quant_config_dict["config_groups"]["group_0"]["weights"]["group_size"] == 16 if kv_cache_quant_cfg: - assert quant_config_dict["kv_cache_quant_algo"] == KV_CACHE_FP8 + assert quant_config_dict["kv_cache_scheme"]["num_bits"] == 8 def _test_unified_export_megatron( @@ -295,6 +288,44 @@ def test_qkv_slicing_gqa_tp2(dist_workers_size_2, tmp_path): dist_workers_size_2.run(partial(_test_qkv_slicing_gqa_tp2, tmp_path)) +def test_qkv_slicing_records_hf_excludes_for_unquantized_fused_qkv(): + """Unquantized fused MCore linear_qkv should become HF q/k/v excludes.""" + exporter = object.__new__(GPTModelExporter) + exporter.dtype = torch.bfloat16 + exporter.exclude_modules = ["backbone.layers.0.mixer"] + exporter.layer_config_dict = {} + exporter._state_dict = {} + + hidden_size = 8 + head_size = 4 + num_attention_heads = 2 + num_query_groups = 1 + qkv_dim = num_attention_heads + 2 * num_query_groups + weight = torch.arange(qkv_dim * head_size * hidden_size, dtype=torch.bfloat16).reshape( + qkv_dim * head_size, hidden_size + ) + + module = torch.nn.Module() + module.config = type( + "Config", + (), + { + "hidden_size": hidden_size, + "num_query_groups": num_query_groups, + "num_attention_heads": num_attention_heads, + "kv_channels": head_size, + }, + )() + exporter._get_quantized_state = lambda *args, **kwargs: ({"weight": weight}, None, 0) + + exporter._qkv_slicing(module, "backbone.layers.0.mixer.") + + assert "backbone.layers.0.mixer" not in exporter.exclude_modules + assert "backbone.layers.0.mixer.q_proj" in exporter.exclude_modules + assert "backbone.layers.0.mixer.k_proj" in exporter.exclude_modules + assert "backbone.layers.0.mixer.v_proj" in exporter.exclude_modules + + def _make_exporter_for_mtp(model_dir: Path) -> GPTModelExporter: """Create a minimal GPTModelExporter instance for testing _get_mtp_state_dict.""" exporter = object.__new__(GPTModelExporter) diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index c34fb2df376..3b7985c817b 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -26,6 +26,7 @@ from _test_utils.torch.megatron.utils import ( compare_amax_sync_across_expert_parallel, copy_weights_from_grouped_to_non_grouped, + get_batch, get_forward, initialize_for_megatron, run_mcore_inference, @@ -694,6 +695,45 @@ def test_te_grouped_vs_sequential_quantize(dist_workers_size_4, quant_cfg): ) +def _test_auto_quantize_moe_ep_helper(rank, size): + initialize_for_megatron( + tensor_model_parallel_size=1, + expert_model_parallel_size=size, + seed=SEED, + ) + model = _gpt_model_provider( + tp_size=1, + ep_size=size, + hidden_size=32, + num_moe_experts=4, + moe_grouped_gemm=False, + transformer_impl="modelopt", + ) + + def forward_step(model, batch): + input_ids, labels, position_ids, attention_mask, loss_mask = batch + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + + auto_quantize_helper( + model, + data_loader=[get_batch(model, batch_size=2) for _ in range(2)], + forward_step=forward_step, + forward_backward_step=lambda m, b: forward_step(m, b).mean().backward(), + quantization_formats=[mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG], + ) + + +def test_auto_quantize_moe_ep(dist_workers_size_2): + """auto_quantize must sum score/cost across EP ranks and pick a consistent recipe.""" + dist_workers_size_2.run(_test_auto_quantize_moe_ep_helper) + + @pytest.mark.parametrize("ep_size", [1, 2]) @pytest.mark.parametrize("sync_weight_amax", [True, False]) def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, sync_weight_amax): diff --git a/tests/unit/torch/export/test_hf_checkpoint_utils.py b/tests/unit/torch/export/test_hf_checkpoint_utils.py index f83cb355749..33d17eebb3d 100644 --- a/tests/unit/torch/export/test_hf_checkpoint_utils.py +++ b/tests/unit/torch/export/test_hf_checkpoint_utils.py @@ -20,6 +20,8 @@ import pytest pytest.importorskip("huggingface_hub") +hf_hub_errors = pytest.importorskip("huggingface_hub.errors") +LocalEntryNotFoundError = hf_hub_errors.LocalEntryNotFoundError from modelopt.torch.export import copy_hf_ckpt_remote_code @@ -59,15 +61,60 @@ def test_copy_hf_ckpt_remote_code_local_dir_no_py_files(tmp_path): assert list(dst_dir.iterdir()) == [], "no files should be copied" -def test_copy_hf_ckpt_remote_code_hub_id(tmp_path): - """copy_hf_ckpt_remote_code delegates to snapshot_download for a Hub model ID.""" +def test_copy_hf_ckpt_remote_code_hub_id(tmp_path, monkeypatch): + """copy_hf_ckpt_remote_code copies .py files from the resolved Hub snapshot.""" dst_dir = tmp_path / "dst" - - with patch("modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download") as mock_sd: + snapshot_dir = tmp_path / "snapshot" + snapshot_dir.mkdir() + (snapshot_dir / "modeling_custom.py").write_text("# custom model") + (snapshot_dir / "not_python.txt").write_text("not python") + + monkeypatch.delenv("HF_HUB_OFFLINE", raising=False) + with patch( + "modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download", + return_value=str(snapshot_dir), + ) as mock_sd: copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-Nano-12B-v2", dst_dir) mock_sd.assert_called_once_with( repo_id="nvidia/NVIDIA-Nemotron-Nano-12B-v2", - local_dir=str(dst_dir), allow_patterns=["*.py"], + local_files_only=False, + ) + assert (dst_dir / "modeling_custom.py").read_text() == "# custom model" + assert not (dst_dir / "not_python.txt").exists(), "non-.py files should not be copied" + + +def test_copy_hf_ckpt_remote_code_hub_id_offline_uses_cache(tmp_path, monkeypatch): + """copy_hf_ckpt_remote_code resolves cached Hub snapshots when HF_HUB_OFFLINE is set.""" + dst_dir = tmp_path / "dst" + snapshot_dir = tmp_path / "snapshot" + snapshot_dir.mkdir() + (snapshot_dir / "nemotron_reasoning_parser.py").write_text("# parser") + + monkeypatch.setenv("HF_HUB_OFFLINE", "1") + with patch( + "modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download", + return_value=str(snapshot_dir), + ) as mock_sd: + copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", dst_dir) + + mock_sd.assert_called_once_with( + repo_id="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", + allow_patterns=["*.py"], + local_files_only=True, ) + assert (dst_dir / "nemotron_reasoning_parser.py").read_text() == "# parser" + + +def test_copy_hf_ckpt_remote_code_hub_id_offline_missing_cache_raises(tmp_path, monkeypatch): + """copy_hf_ckpt_remote_code raises a clear error when offline cache is missing.""" + monkeypatch.setenv("HF_HUB_OFFLINE", "1") + with ( + patch( + "modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download", + side_effect=LocalEntryNotFoundError("missing"), + ), + pytest.raises(RuntimeError, match="HF_HUB_OFFLINE"), + ): + copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", tmp_path / "dst") diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index 4a759172b6f..833ee277211 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -668,7 +668,7 @@ def test_bootstrap_populates_dead_expert_quantizers(self): """ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.model_calib import ( - _bootstrap_uncalibrated_weight_quantizers, + _bootstrap_uncalibrated_static_weight_quantizers, ) model = _TinyMoEModel() @@ -722,7 +722,7 @@ def partial_forward(m): f"Dead expert {idx} down_proj should be uncalibrated pre-bootstrap" ) - n_bootstrapped = _bootstrap_uncalibrated_weight_quantizers(model) + n_bootstrapped = _bootstrap_uncalibrated_static_weight_quantizers(model) assert n_bootstrapped >= 2 * len(dead), ( f"Expected ≥{2 * len(dead)} bootstrapped (gate_up + down per dead expert), " f"got {n_bootstrapped}" diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 87ec73291e7..f7b4965cf8f 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -24,8 +24,10 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.quantization.algorithms import ( + AutoQuantizeGradientSearcher, QuantRecipe, QuantRecipeHparam, + _AutoQuantizeBaseSearcher, estimate_quant_compression, ) from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg @@ -305,6 +307,39 @@ def test_data_parallel_auto_quantize(skip_on_windows): spawn_multiprocess_job(4, _test_data_parallel_auto_quantize, backend="gloo") +def test_auto_quantize_budget_uses_no_quant_candidate_cost(monkeypatch): + class _BudgetCaptureSearcher(AutoQuantizeGradientSearcher): + def run_search_with_stats(self, max_weight_size, verbose=False): + self.max_weight_size = max_weight_size + return {}, True + + def _raise_local_total_weight_size(modules): + pytest.fail("run_search should derive total weight size from candidate costs") + + monkeypatch.setattr( + _AutoQuantizeBaseSearcher, + "_get_total_weight_size", + staticmethod(_raise_local_total_weight_size), + ) + + searcher = _BudgetCaptureSearcher() + searcher.reset_search() + searcher.model = torch.nn.Module() + searcher.config = {"verbose": False} + searcher.constraints = {"effective_bits": 8.0} + searcher.candidate_stats = { + "local_expert.quant_recipe": { + "formats": [QuantRecipe(mtq.NVFP4_DEFAULT_CFG), QuantRecipe(None)], + "scores": [1.0, 0.0], + "costs": [25.0, 100.0], + } + } + + searcher.run_search() + + assert searcher.max_weight_size == 50.0 + + def test_estimate_quant_compression(): nvfp4_affine_kv_cfg = mtq.config.QuantizeConfig(**mtq.NVFP4_AFFINE_KV_CFG) assert estimate_quant_compression(nvfp4_affine_kv_cfg) == 0.25 diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 4332b093861..04266447a61 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -558,14 +558,19 @@ def _quantize_and_calibrate(self, backend_name, fp8_scale_sweep=True): from modelopt.torch.quantization.nn.modules.tensor_quantizer import register_quant_backend register_quant_backend(backend_name, lambda x, tq: x) - model = torch.nn.Linear(8, 8, bias=False) - inputs = torch.randn(1, 8) + model = torch.nn.Linear(16, 8, bias=False) + inputs = torch.randn(1, 16) config = { "quant_cfg": [ {"quantizer_name": "*", "enable": False}, { "quantizer_name": "*weight_quantizer", - "cfg": {"num_bits": 8, "axis": None, "backend": backend_name}, + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "backend": backend_name, + }, }, ], "algorithm": "max", diff --git a/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py b/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py new file mode 100644 index 00000000000..dfb776a0484 --- /dev/null +++ b/tests/unit/torch/quantization/test_nvfp4_static_export_cpu.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU round-trip tests for NVFP4 static export with extreme per-block amax (underflow/overflow).""" + +from __future__ import annotations + +import pytest +import torch + +from modelopt.torch.export.quant_utils import QUANTIZATION_NVFP4, to_quantized_weight +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer +from modelopt.torch.quantization.qtensor import NVFP4QTensor + +BLOCK_SIZE = 16 +FP4_VALUES = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6]) + + +def _make_static_quantizer( + per_block_amax: torch.Tensor, global_amax: torch.Tensor +) -> NVFP4StaticQuantizer: + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: BLOCK_SIZE, "type": "static", "scale_bits": (4, 3)}, + ) + q = NVFP4StaticQuantizer(quant_attribute_cfg=cfg) + q.amax = per_block_amax.clone() + q.global_amax = global_amax.clone() + return q + + +def _export_round_trip( + weight: torch.Tensor, quantizer: NVFP4StaticQuantizer +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run the full export path and dequantize it back, mimicking vLLM serving. + + Returns (weight_scale_fp8, weight_scale_2_fp32, dequantized_weight_bf16). + """ + weight_scale_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(quantizer) + weight_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + quantizer, weight, weight_scale_2.to(weight.device) + ) + packed = to_quantized_weight( + weight, + weight_scale, + QUANTIZATION_NVFP4, + weight_scale_2, + BLOCK_SIZE, + ) + qtensor = NVFP4QTensor(weight.shape, weight.dtype, packed) + dequant = qtensor.dequantize( + scale=weight_scale, + double_scale=weight_scale_2, + block_sizes={-1: BLOCK_SIZE}, + dtype=weight.dtype, + ) + return weight_scale, weight_scale_2, dequant + + +def _layer1_routed_expert_like( + out_dim: int, in_dim: int, *, n_outliers: int, seed: int +) -> torch.Tensor: + """Synthesize a tensor whose block-amax distribution matches the failure case. + + The vast majority of blocks have tiny absolute values (~1e-7), and a handful + of rows carry an outlier magnitude (~1e-1) that drives the global tensor + amax to the FP8-normal regime. This is the distribution that produces 81% + raw-1 FP8 block scales in the broken Ultra V3 MSE export. + """ + g = torch.Generator().manual_seed(seed) + weight = torch.randn(out_dim, in_dim, generator=g, dtype=torch.bfloat16) * 1e-7 + # Inject outliers in n_outliers distinct (row, col) positions. + n = max(1, n_outliers) + rows = torch.randint(0, out_dim, (n,), generator=g) + cols = torch.randint(0, in_dim, (n,), generator=g) + weight[rows, cols] = torch.randn(n, generator=g, dtype=torch.bfloat16) * 0.1 + return weight + + +def _per_block_max(weight: torch.Tensor) -> torch.Tensor: + """Per-(out, num_blocks) absolute max, mirroring NVFP4 block-amax reduction.""" + blocks = weight.float().view(*weight.shape[:-1], -1, BLOCK_SIZE) + return blocks.abs().amax(dim=-1) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestNVFP4StaticExportFiniteAndBounded: + """The static export path must never emit NaN/Inf scales or dequant values. + + These tests fail fast if any saved bit pattern (weight_scale, packed weight, + dequantized result) contains a NaN or Inf, no matter what amax the caller + set on the static quantizer. + """ + + def test_layer1_like_distribution_no_nan(self): + weight = _layer1_routed_expert_like(64, 256, n_outliers=4, seed=0) + block_max = _per_block_max(weight) + global_amax = block_max.max() + # MSE-style: a mix of "shrunk" (multiplier 0.5) and "expanded" + # (multiplier 2.5) per-block amax around the actual block maxima. + mult = torch.where( + torch.arange(block_max.numel()) % 2 == 0, + torch.tensor(0.5), + torch.tensor(2.5), + ).view_as(block_max) + amax = (block_max * mult).clamp(min=1e-30) + q = _make_static_quantizer(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + assert torch.isfinite(ws.float()).all(), "weight_scale (FP8) must be finite" + assert torch.isfinite(ws2).all(), "weight_scale_2 (FP32) must be finite" + assert torch.isfinite(deq.float()).all(), "dequantized weight must be finite" + + def test_overflow_amax_saturates_no_nan(self): + """When _amax > _global_amax (MSE multiplier > 1 on the global-max + block), the FP8 cast must saturate to 448, not produce NaN.""" + weight = torch.randn(8, 16, dtype=torch.bfloat16) * 1e-2 + block_max = _per_block_max(weight) + global_amax = block_max.max() + # Force the first block to have amax 4x the global max. + amax = block_max.clone() + amax[0] = global_amax * 4.0 + q = _make_static_quantizer(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + assert torch.isfinite(ws.float()).all(), "FP8 weight_scale must saturate, not NaN" + # FP8 e4m3fn max is 448. The byte for the overflowing block should be 448. + assert ws[0].float().max().item() == pytest.approx(448.0) + assert torch.isfinite(deq.float()).all() + + def test_underflow_amax_no_inf_in_dequant(self): + """When per_block_amax / global_amax is below FP8 representable range, + the static export must not emit Inf or NaN in the *dequantized* tensor. + Whether the FP8 byte is 0 (natural underflow) or a clamped subnormal, + the dequant of every weight in the affected blocks must be finite.""" + weight = torch.randn(8, 16, dtype=torch.bfloat16) * 1e-7 + block_max = _per_block_max(weight) + global_amax = block_max.max() * 1e6 # _global_amax much larger than blocks + amax = block_max.clamp(min=1e-30) + q = _make_static_quantizer(amax, torch.tensor(global_amax)) + + ws, ws2, deq = _export_round_trip(weight, q) + + assert torch.isfinite(ws.float()).all() + assert torch.isfinite(ws2).all() + assert torch.isfinite(deq.float()).all(), ( + "dequantized weight has NaN/Inf in underflow regime — this is the " + "failure pattern that breaks vLLM serving" + ) + + +class TestNVFP4StaticExportRoundTripBound: + """Round-trip dequant magnitude must stay bounded by the encoded amax. + + For every block, ``|dequant| <= 6 * weight_scale_FP8 * weight_scale_2``, + and that product must be bounded above by ``max(_amax_block, 448 * scale_2)`` + (FP8 saturation). If any block's dequant exceeds that bound, an + out-of-distribution outlier was synthesized by the export path itself. + """ + + def test_dequant_magnitude_within_amax(self): + weight = _layer1_routed_expert_like(32, 128, n_outliers=4, seed=1) + block_max = _per_block_max(weight) + global_amax = block_max.max() + # Use _amax = block_max (no shrinking, no expansion). + amax = block_max.clamp(min=1e-30) + q = _make_static_quantizer(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + # The maximum representable magnitude per block is 6 * dequant_scale. + # Allow a small relative tolerance for the bf16 quantization that + # NVFP4QTensor.dequantize applies to its output (~0.4% per element). + dequant_scale_per_block = ws.float() * ws2.float() + expected_block_bound = 6.0 * dequant_scale_per_block # shape (out, num_blocks) + deq_block_max = deq.float().view(*deq.shape[:-1], -1, BLOCK_SIZE).abs().amax(dim=-1) + violation = (deq_block_max - expected_block_bound).clamp(min=0) + # Reject any per-block dequant magnitude that exceeds the FP4 saturation + # bound by more than 1% (well above bf16 round-up noise) — that would + # indicate the export synthesized out-of-distribution outliers. + relative = violation / expected_block_bound.clamp(min=1e-30) + max_relative = relative.max().item() + assert max_relative <= 1e-2, ( + f"dequant block max exceeds the FP4 saturation bound by " + f"{max_relative:.2%}. Worst block index: " + f"{tuple(int(i) for i in (relative == relative.max()).nonzero()[0].tolist())}" + ) + + +class TestNVFP4StaticVsDynamicEquivalence: + """When _amax = per_block_amax (no MSE shrink/expand), the static and + dynamic export paths must produce bit-identical FP8 weight_scale bytes. + Both paths apply the same lower clamp at the fp8 subnormal min (2**-9) + so tiny-amax blocks land on 0x01 instead of underflowing to 0x00.""" + + def test_static_matches_dynamic_when_amax_is_block_max(self): + weight = _layer1_routed_expert_like(16, 64, n_outliers=2, seed=2) + block_max = _per_block_max(weight) + global_amax = block_max.max() + + # Static path + q = _make_static_quantizer(block_max.clamp(min=1e-30), global_amax) + static_ws_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(q) + static_ws, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + q, weight, static_ws_2 + ) + + # Dynamic path + dynamic_ws, dynamic_ws_2 = NVFP4QTensor.get_weights_scaling_factor( + weight, BLOCK_SIZE, weights_scaling_factor_2=static_ws_2.clone() + ) + + # Both paths should yield identical FP8 byte patterns when amax matches. + assert torch.equal(static_ws.view(torch.uint8), dynamic_ws.view(torch.uint8)), ( + "static and dynamic export paths produced different FP8 " + "weight_scale bytes for the same per-block amax — this means " + "the static path's scale computation diverges from the dynamic path" + ) + assert torch.allclose(static_ws_2, dynamic_ws_2) + + +class TestNVFP4StaticManualRoundTrip: + """Cross-check the export path against a manual per-block computation. + + For each block: ``dequant_scale = FP8(amax * 448 / global_amax) * (global_amax / (6 * 448))``, + and each FP4 value lies in {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}. The + dequantized weight for any element should be the FP4 value at the rounded + ordinal, multiplied by the block's dequant_scale. + """ + + def test_single_block_matches_manual(self): + # One block of 16 elements, mid-magnitude (FP8-normal regime). + torch.manual_seed(3) + weight = (torch.rand(1, BLOCK_SIZE, dtype=torch.bfloat16) - 0.5) * 0.05 + block_max = _per_block_max(weight) + global_amax = block_max.max() + amax = block_max.clamp(min=1e-30) + q = _make_static_quantizer(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + # Manual per-block dequant scale: FP8-quantized ratio * scale_2. + dequant_scale = ws.float()[0, 0].item() * ws2.float().item() + + # Every dequant element must be FP4 value * dequant_scale, allowing for + # the bf16 round-trip applied by NVFP4QTensor.dequantize on its output + # (~0.4% relative). Use a tolerance that's loose enough for bf16 and + # tight enough to catch a real off-grid value. + deq_vals = deq.float().reshape(-1) + grid = torch.tensor( + [v.item() * dequant_scale for v in FP4_VALUES.float()], + dtype=torch.float32, + ) + for v in deq_vals.tolist(): + distance = (grid - v).abs().min().item() + tolerance = max(abs(v) * 1e-2, 1e-12) + assert distance <= tolerance, ( + f"dequant value {v} is not on the FP4 grid (min distance {distance:g}, " + f"tolerance {tolerance:g}); grid = {sorted(grid.tolist())}" + ) + + +class TestNVFP4StaticCornerCases: + """Edge cases that have historically caused trouble in MSE static export.""" + + def test_zero_amax_block_does_not_explode(self): + """If MSE selects amax=0 for a block (e.g., dead expert), the export + must not emit NaN/Inf or amplify dequant magnitude.""" + weight = torch.zeros(2, BLOCK_SIZE, dtype=torch.bfloat16) + weight[1, :] = 0.05 # one block with real values + block_max = _per_block_max(weight) + global_amax = block_max.max() + amax = block_max.clone() + amax[0] = 0.0 # explicit zero amax for the dead block + q = _make_static_quantizer(amax, global_amax) + + ws, ws2, deq = _export_round_trip(weight, q) + + assert torch.isfinite(ws.float()).all() + assert torch.isfinite(deq.float()).all() + # The dead block must dequantize to all zeros (no leakage from the + # special amax==0 substitution). + assert torch.equal(deq[0].float(), torch.zeros_like(deq[0].float())), ( + "amax==0 block leaked nonzero dequant values" + ) + + def test_ultra_v3_layer1_distribution_byte_distribution_sane(self): + """Sanity: in the Ultra-V3 layer-1-like regime, the export's FP8 byte + distribution does not contain raw NaN bytes (0x7F or 0xFF).""" + weight = _layer1_routed_expert_like(128, 512, n_outliers=8, seed=4) + block_max = _per_block_max(weight) + global_amax = block_max.max() + amax = block_max.clamp(min=1e-30) + q = _make_static_quantizer(amax, global_amax) + + ws, _, _ = _export_round_trip(weight, q) + ws_bytes = ws.view(torch.uint8).reshape(-1) + + # FP8 e4m3fn NaN bytes are 0x7F (127) and 0xFF (255). + nan_count = int(((ws_bytes == 127) | (ws_bytes == 255)).sum().item()) + assert nan_count == 0, f"static export emitted {nan_count} NaN FP8 weight_scale bytes"