Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,9 @@ def pre_quantize(
][0:1]

# Generate preview before quantization
if model_type == "deepseek":
if args.skip_generate:
generated_ids_before_ptq = None
elif model_type == "deepseek":
Comment on lines +693 to +695
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Short-circuit preview input extraction when --skip_generate is set.

Even with --skip_generate, the code still fetches preview_input_ids at Line 684 before this branch. That extra batch fetch is unnecessary and can still fail on edge dataloader schemas while generation is intentionally disabled.

Proposed fix
 def pre_quantize(
@@
-    # Only run single sample for preview
-    preview_input_ids = next(iter(calib_dataloader))[
-        "input_features" if model_type == "whisper" else "input_ids"
-    ][0:1]
-
-    # Generate preview before quantization
-    if args.skip_generate:
-        generated_ids_before_ptq = None
+    preview_input_ids = None
+    # Generate preview before quantization
+    if args.skip_generate:
+        generated_ids_before_ptq = None
     elif model_type == "deepseek":
+        preview_input_ids = next(iter(calib_dataloader))[
+            "input_features" if model_type == "whisper" else "input_ids"
+        ][0:1]
         # DeepSeek generation may go OOM, so we skip it
         generated_ids_before_ptq = None
     elif is_nemotron_vl_model and tokenizer is not None:
+        preview_input_ids = next(iter(calib_dataloader))[
+            "input_features" if model_type == "whisper" else "input_ids"
+        ][0:1]
         generated_ids_before_ptq = run_nemotron_vl_preview(
@@
     else:
+        preview_input_ids = next(iter(calib_dataloader))[
+            "input_features" if model_type == "whisper" else "input_ids"
+        ][0:1]
         generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 689 - 691, The preview input
extraction (preview_input_ids) is still executed even when args.skip_generate is
true, causing an unnecessary dataloader fetch that can fail; update the logic
around args.skip_generate and generated_ids_before_ptq so that when
args.skip_generate is set you short-circuit before any preview_input_ids or
dataloader reads (set generated_ids_before_ptq and preview_input_ids to None or
skip their assignment), i.e., move or guard the preview_input_ids extraction
behind the `if not args.skip_generate` path (the branch that currently tests
model_type == "deepseek" and subsequent generation logic) so no
preview/dataloader work runs when generation is disabled.

# DeepSeek generation may go OOM, so we skip it
generated_ids_before_ptq = None
elif is_nemotron_vl_model and tokenizer is not None:
Expand All @@ -703,7 +705,6 @@ def pre_quantize(
allow_fallback=False,
)
else:
# Standard generation for non-Nemotron VL models
generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
print("Applying nvfp4 quantization (MoE only) for gpt-oss")
Expand Down Expand Up @@ -1084,6 +1085,16 @@ def parse_args() -> argparse.Namespace:
default=True,
action=argparse.BooleanOptionalAction,
)
parser.add_argument(
"--skip_generate",
help=(
"Skip pre/post-quantization preview calls that invoke model.generate(). "
"Note: this does not skip calibration or batch-size probing. "
"For very large models, pair with --batch_size 1 to avoid max-batch probing."
),
default=False,
action="store_true",
)
Comment thread
arendu marked this conversation as resolved.
parser.add_argument(
"--low_memory_mode",
help=(
Expand Down