-
Notifications
You must be signed in to change notification settings - Fork 449
fix: layerwise calibration backward-compat, recipe split, batch-size guard #1310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -988,6 +988,25 @@ def quantize_main( | |
| default_pad_token, | ||
| device: torch.device, | ||
| ): | ||
| # Load the recipe up front so we can detect layerwise calibration before batch-size probing. | ||
| recipe = None | ||
| if args.recipe is not None and not args.auto_quantize_bits: | ||
| print(f"Use recipe {args.recipe} for quantization") | ||
| recipe = load_recipe(args.recipe) | ||
| if not isinstance(recipe, ModelOptPTQRecipe): | ||
| raise TypeError( | ||
| f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" | ||
| ) | ||
|
|
||
| def _is_layerwise(obj): | ||
| if isinstance(obj, ModelOptPTQRecipe): | ||
| return _is_layerwise(obj.quantize.algorithm) | ||
| if isinstance(obj, list): | ||
| return any(_is_layerwise(a) for a in obj) | ||
| return bool(getattr(obj, "layerwise", False)) | ||
|
Comment on lines
+1001
to
+1006
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle dict-form algorithms in At Line 973, Suggested fix def _is_layerwise(obj):
if isinstance(obj, ModelOptPTQRecipe):
return _is_layerwise(obj.quantize.algorithm)
+ if isinstance(obj, dict):
+ if "layerwise" in obj:
+ return bool(obj["layerwise"])
+ if "algorithm" in obj:
+ return _is_layerwise(obj["algorithm"])
+ return False
if isinstance(obj, list):
return any(_is_layerwise(a) for a in obj)
return bool(getattr(obj, "layerwise", False))🤖 Prompt for AI Agents |
||
|
|
||
| is_layerwise = _is_layerwise(recipe) | ||
|
|
||
| if args.batch_size == 0: | ||
| # For VL models with image-text calibration, skip automatic batch size detection | ||
| # since get_max_batch_size can't handle multimodal inputs | ||
|
|
@@ -1001,6 +1020,11 @@ def quantize_main( | |
| "Offline speculative decoding calibration enabled. Using default batch_size=1 for calibration." | ||
| ) | ||
| args.batch_size = 1 | ||
| # Layerwise calibration processes one layer at a time; auto batch-size probing runs a | ||
| # full-model forward which defeats the point and can OOM on very large models. | ||
| elif is_layerwise: | ||
| print("Layerwise calibration enabled. Using default batch_size=1 for calibration.") | ||
| args.batch_size = 1 | ||
| else: | ||
| # Calibration/sparsification will actually take much more memory than regular inference | ||
| # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio | ||
|
|
@@ -1064,12 +1088,7 @@ def quantize_main( | |
| else: | ||
| # mono quantization | ||
|
|
||
| if args.recipe is not None: | ||
| print(f"Use recipe {args.recipe} for quantization") | ||
| recipe = load_recipe(args.recipe) | ||
| assert isinstance(recipe, ModelOptPTQRecipe), ( | ||
| f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" | ||
| ) | ||
| if recipe is not None: | ||
| quant_cfg = recipe.quantize.model_dump() | ||
|
|
||
| else: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 45 additions & 0 deletions
45
modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| imports: | ||
| base_disable_all: configs/ptq/units/base_disable_all | ||
| default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers | ||
| nvfp4: configs/numerics/nvfp4 | ||
| kv_fp8: configs/ptq/units/kv_fp8 | ||
|
|
||
| metadata: | ||
| recipe_type: ptq | ||
| description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max layerwise calibration. | ||
| quantize: | ||
| algorithm: | ||
| method: max | ||
| # Max calibration is fast and does not typically need checkpointing. | ||
| layerwise: true | ||
| quant_cfg: | ||
| - $import: base_disable_all | ||
| - quantizer_name: '*mlp.experts*weight_quantizer' | ||
| cfg: | ||
| $import: nvfp4 | ||
| - quantizer_name: '*mlp.experts*input_quantizer' | ||
| cfg: | ||
| $import: nvfp4 | ||
| - quantizer_name: '*block_sparse_moe*weight_quantizer' | ||
| cfg: | ||
| $import: nvfp4 | ||
| - quantizer_name: '*block_sparse_moe*input_quantizer' | ||
| cfg: | ||
| $import: nvfp4 | ||
| - $import: kv_fp8 | ||
| - $import: default_disabled_quantizers |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.