Skip to content

Commit f26b9c3

Browse files
skip generate option for large models and mxfp8 (#942)
## What does this PR do? **Type of change:** New feature **Overview:** Adds a `--skip_generate` flag to `hf_ptq.py` that skips the pre/post-quantization generation preview calls. These calls run `model.generate()` which crashes for very large models (500B+) that are split across GPU and CPU via `device_map="auto"` (e.g., models with Mamba/Triton kernels that cannot handle CPU-offloaded tensors). ## Usage ``` python examples/llm_ptq/hf_ptq.py \ --pyt_ckpt_path /path/to/model \ --export_path /path/to/output \ --qformat mxfp8 \ --trust_remote_code \ --export_fmt hf \ --batch_size 1 \ --skip_generate \ --kv_cache_qformat none ``` ## Testing Tested with a 500B parameter NemotronH hybrid Mamba/attention model on 4x GB200 GPUs. Without --skip_generate, the script crashes at model.generate() due to Mamba Triton kernels failing on CPU-offloaded tensors. With --skip_generate, the generation preview is skipped and quantization proceeds normally. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information The --skip_generate flag sets generated_ids_before_ptq = None early, which also causes the post-quantization generate to be skipped via the existing if generated_ids_before_ptq is None: pass guard. Combined with --batch_size 1 (to skip the get_max_batch_size forward-pass probe), this eliminates all forward passes that can crash for device-map-split models. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Introduced `--skip_generate` CLI option to skip pre-quantization text and image generation, reducing processing time for very large models. Useful when generation previews are computationally expensive. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: adithyare <adithyare@nvidia.com> Signed-off-by: Adi Renduchintala <adithya.r@gmail.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent ba29ad7 commit f26b9c3

1 file changed

Lines changed: 13 additions & 2 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,9 @@ def pre_quantize(
690690
][0:1]
691691

692692
# Generate preview before quantization
693-
if model_type == "deepseek":
693+
if args.skip_generate:
694+
generated_ids_before_ptq = None
695+
elif model_type == "deepseek":
694696
# DeepSeek generation may go OOM, so we skip it
695697
generated_ids_before_ptq = None
696698
elif is_nemotron_vl_model and tokenizer is not None:
@@ -703,7 +705,6 @@ def pre_quantize(
703705
allow_fallback=False,
704706
)
705707
else:
706-
# Standard generation for non-Nemotron VL models
707708
generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
708709
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
709710
print("Applying nvfp4 quantization (MoE only) for gpt-oss")
@@ -1084,6 +1085,16 @@ def parse_args() -> argparse.Namespace:
10841085
default=True,
10851086
action=argparse.BooleanOptionalAction,
10861087
)
1088+
parser.add_argument(
1089+
"--skip_generate",
1090+
help=(
1091+
"Skip pre/post-quantization preview calls that invoke model.generate(). "
1092+
"Note: this does not skip calibration or batch-size probing. "
1093+
"For very large models, pair with --batch_size 1 to avoid max-batch probing."
1094+
),
1095+
default=False,
1096+
action="store_true",
1097+
)
10871098
parser.add_argument(
10881099
"--low_memory_mode",
10891100
help=(

0 commit comments

Comments
 (0)