Skip to content

Commit 0ebcd70

Browse files
authored
Support VLM calibration with image-text data (#755)
## What does this PR do? **Type of change:** New feature **Overview:** The primary goal of this PR is to allow the model optimizer to use image-text pair data during the calibration phase of quantization, which is likely help improve accuracy of quantized VLMs like Nemotron VL on visual understanding tasks particularly, compared to text-only calibration data. - New Feature: Adds support for VLM calibration specifically using image-text data. - Dataset Integration: Introduces support for sampling from the `Nemotron-VLM-Dataset-v2`. - Refactoring: Created a separate utility for VLM datasets to keep the main Hugging Face PTQ script (`hf_ptq.py`) clean. - Simplified logic for handling multimodal inputs. - Addressed specific issues encountered when calibrating the `Nemotron-Nano-VL-12B-V2` model with image data. - Documentation: Updated the README to include instructions and examples for VLM calibration. This PR complements #347 and we will consolidate llm_ptq and vlm_ptq examples in follow-up PRs. ## Usage <!-- You can potentially add a usage example below. --> ```python python3 hf_ptq.py --pyt_ckpt_path /home/scratch.omniml_data_2/models/Nemotron-Nano-VL-12B-V2 --qformat nvfp4 --export_path /home/omniml_data_3/zhiyuc/checkpoints/Nemotron-Nano-VL-12B-V2-NVFP4-doccalib --trust_remote_code --kv_cache_qformat none --calib_with_images --vlm_dataset nemotron_vlm_dataset_v2 --vlm_subsets sparsetables,plotqa_cot --calib_size 512 ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Not yet <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added Vision-Language Model (VLM) calibration support with image-text pair data, specifically for Nemotron VL models. * Added new `--calib_with_images` CLI flag to enable image-based calibration workflows. * Integrated Nemotron VLM dataset v2 for streaming multimodal calibration data. * **Documentation** * Added VLM calibration guidance in the PTQ README with usage examples and dataset information. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent 8c36f5a commit 0ebcd70

File tree

6 files changed

+904
-44
lines changed

6 files changed

+904
-44
lines changed

examples/llm_ptq/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,23 @@ scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|int8_sq|int4_
162162

163163
[PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM.
164164

165+
#### VLM calibration with image-text pairs (e.g., Nemotron VL)
166+
167+
For vision-language models, calibration quality can likely improve by using image-text pairs instead of text-only data, especially on visual understanding tasks:
168+
169+
```bash
170+
python hf_ptq.py \
171+
--pyt_ckpt_path <huggingface_model_card> \
172+
--qformat nvfp4 \
173+
--export_path <quantized_ckpt_path> \
174+
--trust_remote_code \
175+
--calib_with_images \
176+
--calib_size 512
177+
```
178+
179+
> Note: when `--calib_with_images` is set, `--calib_size` must be a single value, and the calibration dataset is nvidia/nemotron_vlm_dataset_v2.
180+
This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`.
181+
165182
### NeMo Example Script
166183

167184
NeMo 2.0 framework PTQ and TensorRT-LLM deployment examples are maintained in the NeMo GitHub repo. Please refer to the [NeMo PTQ documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/quantization/quantization.html) for more details.

examples/llm_ptq/example_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import copy
1717
import glob
18+
import inspect
1819
import os
1920
import shutil
2021
import sys
@@ -131,6 +132,53 @@ def is_nemotron_vl(model_or_config):
131132
return any("nemotron" in arch.lower() for arch in architectures)
132133

133134

135+
def create_vlm_calibration_loop(full_model, calib_dataloader):
136+
"""Create a calibration loop for VLM models that handles multimodal inputs.
137+
138+
This function inspects the model's forward signature and filters batch kwargs
139+
to only include supported parameters, then calls the appropriate forward method.
140+
141+
Args:
142+
full_model: The full VLM model
143+
calib_dataloader: DataLoader yielding multimodal batches
144+
145+
Returns:
146+
A calibration function that can be passed to mtq.quantize()
147+
"""
148+
# Import here to avoid circular dependency
149+
from nemotron_vl_calib import safe_nemotron_vl_forward
150+
151+
def calibrate_loop(_model):
152+
# Inspect model's forward signature to determine what parameters it accepts
153+
forward_params = inspect.signature(full_model.forward).parameters
154+
accepts_kwargs = any(
155+
p.kind == inspect.Parameter.VAR_KEYWORD for p in forward_params.values()
156+
)
157+
allowed_keys = set(forward_params.keys())
158+
159+
full_model.eval()
160+
with torch.no_grad():
161+
for batch in calib_dataloader:
162+
# Filter batch to only include parameters the model accepts
163+
if accepts_kwargs:
164+
call_kwargs = batch
165+
else:
166+
call_kwargs = {k: v for k, v in batch.items() if k in allowed_keys}
167+
# Remove None values
168+
call_kwargs = {k: v for k, v in call_kwargs.items() if v is not None}
169+
170+
# Use safe_nemotron_vl_forward for Nemotron Nano VL (embedding-injection style)
171+
# For other VLMs (like Nemotron-Parse), use standard forward
172+
if hasattr(full_model, "img_context_token_id"):
173+
# Nemotron Nano VL style
174+
safe_nemotron_vl_forward(full_model, call_kwargs)
175+
else:
176+
# Standard encoder-decoder or other VLM architectures
177+
full_model(**call_kwargs)
178+
179+
return calibrate_loop
180+
181+
134182
def build_quant_cfg(
135183
qformat,
136184
kv_cache_qformat,

examples/llm_ptq/hf_ptq.py

Lines changed: 129 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from example_utils import (
2626
build_quant_cfg,
2727
copy_custom_model_files,
28+
create_vlm_calibration_loop,
2829
get_model,
2930
get_processor,
3031
get_tokenizer,
@@ -98,6 +99,39 @@
9899
mto.enable_huggingface_checkpointing()
99100

100101

102+
def extract_and_prepare_language_model_from_vl(full_model):
103+
"""Extract language model from VL model and disable quantization for non-language components.
104+
105+
Args:
106+
full_model: The full VLM model
107+
108+
Returns:
109+
tuple: (language_model, model_type) or (None, None) if not a VLM
110+
"""
111+
language_model_lineage = get_language_model_from_vl(full_model)
112+
if language_model_lineage is not None:
113+
language_model = language_model_lineage.pop(-1)
114+
ancestors = language_model_lineage
115+
# Apply disabled quant to all modules that are not part of language_model
116+
# This excludes them during HF export
117+
disabled_quant_cfg = {
118+
"quant_cfg": {"default": {"enable": False}},
119+
"algorithm": "max",
120+
}
121+
122+
memo = set(ancestors) | {language_model}
123+
for ancestor in ancestors:
124+
for _, module in ancestor.named_children():
125+
if module not in memo:
126+
mtq.quantize(module, disabled_quant_cfg, forward_loop=None)
127+
memo.add(module)
128+
129+
model_type = get_model_type(language_model)
130+
return language_model, model_type
131+
132+
return None, None
133+
134+
101135
def make_calib_dataloader(
102136
args: argparse.Namespace,
103137
language_model: torch.nn.Module,
@@ -108,7 +142,30 @@ def make_calib_dataloader(
108142
) -> tuple[DataLoader, str | None]:
109143
calib_dataloader = None
110144
first_text_speech_dataset = None
111-
if model_type == "mllama":
145+
if args.calib_with_images:
146+
# VLM image-text calibration path: assume Nemotron VLM dataset by default.
147+
assert processor is not None, (
148+
"Please provide a processor (e.g., AutoProcessor) for image calibration."
149+
)
150+
assert len(args.calib_size) == 1, (
151+
"Image calibration currently supports a single dataset. "
152+
"Please pass --calib_size with one value (e.g., --calib_size 256)."
153+
)
154+
calib_dataloader = get_vlm_dataset_dataloader(
155+
dataset_name="nemotron_vlm_dataset_v2",
156+
processor=processor,
157+
batch_size=args.batch_size,
158+
num_samples=args.calib_size[0],
159+
device=device,
160+
max_length=args.calib_seq,
161+
require_image=True,
162+
subsets=["sparsetables", "plotqa_cot", "wiki_en"],
163+
shuffle_buffer_size=10_000,
164+
seed=42,
165+
use_media_shards=True,
166+
max_shards=1,
167+
)
168+
elif model_type == "mllama":
112169
assert processor is not None and isinstance(processor, MllamaImageProcessor), (
113170
"The MllamaImageProcessor must be set."
114171
)
@@ -165,6 +222,12 @@ def auto_quantize(
165222
):
166223
"""Auto search quantization of multiple formats."""
167224

225+
if args.calib_with_images:
226+
raise NotImplementedError(
227+
"AutoQuantize with image-text calibration is not supported yet. "
228+
"Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images."
229+
)
230+
168231
assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), (
169232
"Auto Quantization is not supported for pipeline parallel size > 1"
170233
)
@@ -292,7 +355,9 @@ def load_model(args: argparse.Namespace):
292355
tokenizer = None
293356
language_model = full_model
294357
default_padding_side = None
358+
default_pad_token = None
295359

360+
is_nemotron_vl_model = is_nemotron_vl(full_model)
296361
if model_type == "mllama":
297362
processor = get_processor(
298363
args.pyt_ckpt_path,
@@ -308,6 +373,31 @@ def load_model(args: argparse.Namespace):
308373
device,
309374
trust_remote_code=args.trust_remote_code,
310375
)
376+
elif is_nemotron_vl_model and args.calib_with_images:
377+
# For Nemotron VL image calibration, we need an AutoProcessor to build multimodal inputs.
378+
processor = AutoProcessor.from_pretrained(
379+
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code, padding_side="left"
380+
)
381+
382+
if hasattr(processor, "tokenizer") and processor.tokenizer is not None:
383+
tokenizer = processor.tokenizer
384+
else:
385+
tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)
386+
387+
default_pad_token = tokenizer.pad_token
388+
# Some Nemotron tokenizers may not define pad_token by default; but we use padding=True during calibration.
389+
if tokenizer.pad_token is None:
390+
tokenizer.pad_token = tokenizer.eos_token
391+
assert tokenizer.pad_token is not None, f"Pad token for {args.pyt_ckpt_path} cannot be set!"
392+
393+
default_padding_side = tokenizer.padding_side
394+
tokenizer.padding_side = "left"
395+
396+
# Quantize only the language model, but keep the full_model for calibration forward.
397+
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model)
398+
if extracted_lm is not None:
399+
language_model = extracted_lm
400+
model_type = extracted_model_type
311401
else:
312402
if args.dataset is None:
313403
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
@@ -321,29 +411,15 @@ def load_model(args: argparse.Namespace):
321411
tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)
322412

323413
default_padding_side = tokenizer.padding_side
414+
default_pad_token = tokenizer.pad_token
324415
# Left padding usually provides better calibration result.
325416
tokenizer.padding_side = "left"
326417

327418
# We only quantize the language model for VLMs other than the type supported above.
328-
language_model_lineage = get_language_model_from_vl(full_model)
329-
if language_model_lineage is not None:
330-
language_model = language_model_lineage.pop(-1)
331-
ancestors = language_model_lineage
332-
# Apply disabled quant to all modules that are not part of language_model so we can exclude them during
333-
# HF export.
334-
disabled_quant_cfg = {
335-
"quant_cfg": {"default": {"enable": False}},
336-
"algorithm": "max",
337-
}
338-
339-
memo = set(ancestors) | {language_model}
340-
for ancestor in ancestors:
341-
for _, module in ancestor.named_children():
342-
if module not in memo:
343-
mtq.quantize(module, disabled_quant_cfg, forward_loop=None)
344-
memo.add(module)
345-
346-
model_type = get_model_type(language_model)
419+
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model)
420+
if extracted_lm is not None:
421+
language_model = extracted_lm
422+
model_type = extracted_model_type
347423

348424
if model_type == "phi4mm":
349425
warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.")
@@ -356,6 +432,7 @@ def load_model(args: argparse.Namespace):
356432
processor,
357433
tokenizer,
358434
default_padding_side,
435+
default_pad_token,
359436
device,
360437
)
361438

@@ -433,9 +510,14 @@ def mono_quantize(
433510

434511
if not use_calibration:
435512
warnings.warn("Dynamic quantization. Calibration skipped.")
436-
calibrate_loop = (
437-
create_forward_loop(dataloader=calib_dataloader) if use_calibration else None
438-
)
513+
calibrate_loop = None
514+
if use_calibration:
515+
# For Nemotron VL image calibration, the dataloader yields multimodal kwargs (e.g., pixel_values).
516+
# Those kwargs must be consumed by the *full* VLM model, not the extracted language_model.
517+
if args.calib_with_images and is_nemotron_vl_model:
518+
calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader)
519+
else:
520+
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
439521

440522
if calibration_only:
441523
language_model = mtq.calibrate(
@@ -462,6 +544,7 @@ def export_quantized(
462544
model_type: str | None,
463545
tokenizer: PreTrainedTokenizerBase | None,
464546
default_padding_side,
547+
default_pad_token,
465548
):
466549
with torch.inference_mode():
467550
if model_type is None:
@@ -551,6 +634,8 @@ def export_quantized(
551634
# Restore default padding and export the tokenizer as well.
552635
if tokenizer is not None:
553636
tokenizer.padding_side = default_padding_side
637+
if default_pad_token is not None:
638+
tokenizer.pad_token = default_pad_token
554639
tokenizer.save_pretrained(export_path)
555640

556641
end_time = time.time()
@@ -700,6 +785,7 @@ def quantize_main(
700785
processor: BaseImageProcessor | ProcessorMixin | None,
701786
tokenizer: PreTrainedTokenizerBase | None,
702787
default_padding_side,
788+
default_pad_token,
703789
device: torch.device,
704790
):
705791
if args.batch_size == 0:
@@ -815,7 +901,15 @@ def quantize_main(
815901
is_nemotron_vl_model,
816902
first_text_speech_dataset,
817903
)
818-
export_quantized(args, full_model, language_model, model_type, tokenizer, default_padding_side)
904+
export_quantized(
905+
args,
906+
full_model,
907+
language_model,
908+
model_type,
909+
tokenizer,
910+
default_padding_side,
911+
default_pad_token,
912+
)
819913

820914

821915
def parse_args() -> argparse.Namespace:
@@ -866,6 +960,14 @@ def parse_args() -> argparse.Namespace:
866960
type=str,
867961
default=None,
868962
)
963+
parser.add_argument(
964+
"--calib_with_images",
965+
action="store_true",
966+
help=(
967+
"Calibrate with image-text pairs (for VLMs). "
968+
"This uses nemotron_vlm_dataset_v2 with default subsets (sparsetables, plotqa_cot, wiki_en)."
969+
),
970+
)
869971
parser.add_argument("--inference_tensor_parallel", type=int, default=1)
870972
parser.add_argument("--inference_pipeline_parallel", type=int, default=1)
871973
parser.add_argument("--awq_block_size", default=0, type=int)
@@ -1003,6 +1105,7 @@ def main(args: argparse.Namespace):
10031105
processor,
10041106
tokenizer,
10051107
default_padding_side,
1108+
default_pad_token,
10061109
device,
10071110
) = load_model(args)
10081111

@@ -1020,6 +1123,7 @@ def main(args: argparse.Namespace):
10201123
processor,
10211124
tokenizer,
10221125
default_padding_side,
1126+
default_pad_token,
10231127
device,
10241128
)
10251129

@@ -1030,6 +1134,6 @@ def main(args: argparse.Namespace):
10301134
if args.export_fmt != "hf":
10311135
warnings.warn("Deprecated. --export_fmt forced to hf.")
10321136

1033-
args.dataset = args.dataset.split(",") if args.dataset else None
1137+
args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset
10341138
args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")]
10351139
main(args)

0 commit comments

Comments
 (0)