Skip to content

Commit 2a9f431

Browse files
Edwardf0t1kevalmorabia97
authored andcommitted
Add Nemotron parse PTQ support (#786)
## What does this PR do? **Type of change:** New model support <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Add PTQ support for https://huggingface.co/nvidia/NVIDIA-Nemotron-Parse-v1.1 ## Usage <!-- You can potentially add a usage example below. --> ```python python3 hf_ptq.py --pyt_ckpt_path /home/omniml_data_3/models/NVIDIA-Nemotron-Parse-v1.1 --qformat fp8 --export_path /home/omniml_data_3/zhiyuc/checkpoints/NVIDIA-Nemotron-Parse-v1.1-FP8 --trust_remote_code --kv_cache_qformat none --attn_implementation eager ``` By default, image-text data will be used in calibration for VLMs. ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Not yet <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for Nemotron-Parse multimodal models, including proper device mapping, processor loading, and generation handling. * **Improvements** * Enhanced quantization robustness with safer handling of quantization attributes and fallback logic. * Improved model loading with better device placement and encoder buffer management for vision-language models. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent d7f62d3 commit 2a9f431

File tree

6 files changed

+145
-73
lines changed

6 files changed

+145
-73
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ NVIDIA Model Optimizer Changelog (Linux)
2121
- Add LTX-2 and Wan2.2 (T2V) support in the diffusers quantization workflow.
2222
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
2323
- Add support for image-text data calibration in PTQ for Nemotron VL models.
24+
- Add PTQ support for Nemotron Parse.
2425

2526
0.41 (2026-01-19)
2627
^^^^^^^^^^^^^^^^^

examples/llm_ptq/example_utils.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from safetensors.torch import load_file
3232
from transformers import (
3333
AutoConfig,
34+
AutoModel,
3435
AutoModelForCausalLM,
3536
AutoProcessor,
3637
AutoTokenizer,
@@ -75,19 +76,18 @@ def run_nemotron_vl_preview(
7576
"eos_token_id": tokenizer.eos_token_id,
7677
}
7778

78-
# Try text-only generation
79+
# Try text-only generation (may fail for encoder-decoder models like Nemotron-Parse)
7980
text_response = run_text_only_generation(
8081
full_model, tokenizer, question, generation_config, pyt_ckpt_path
8182
)
8283

84+
generated_ids = None
8385
if text_response is not None:
8486
print(f"✅ Text-only generation successful: {text_response[:100]}...")
8587
generated_ids = text_response
8688
elif allow_fallback:
8789
print("Text-only generation failed, falling back to standard generate...")
8890
generated_ids = full_model.generate(input_ids, max_new_tokens=100)
89-
else:
90-
generated_ids = None
9191

9292
# Run additional VL test with images
9393
print(f"Running additional VL test with images ({stage_name})...")
@@ -106,6 +106,10 @@ def _is_multimodal_config(config):
106106
or (
107107
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
108108
) # Image embedding layers
109+
or getattr(config, "is_encoder_decoder", False) # Encoder-decoder VL models
110+
or any( # Architecture-based detection for custom VL models (e.g., Nemotron-Parse)
111+
"conditionalgeneration" in arch.lower() for arch in getattr(config, "architectures", [])
112+
)
109113
)
110114

111115

@@ -158,9 +162,20 @@ def calibrate_loop(_model):
158162
)
159163
allowed_keys = set(forward_params.keys())
160164

165+
# Check if model is encoder-decoder (needs decoder_input_ids instead of input_ids)
166+
is_enc_dec = getattr(full_model.config, "is_encoder_decoder", False)
167+
161168
full_model.eval()
162169
with torch.no_grad():
163170
for batch in calib_dataloader:
171+
# For encoder-decoder models, rename input_ids → decoder_input_ids
172+
# and disable KV caching to avoid tuple index errors in decoder layers
173+
if is_enc_dec and "input_ids" in batch and "pixel_values" in batch:
174+
batch["decoder_input_ids"] = batch.pop("input_ids")
175+
if "attention_mask" in batch:
176+
batch["decoder_attention_mask"] = batch.pop("attention_mask")
177+
batch["use_cache"] = False
178+
164179
# Filter batch to only include parameters the model accepts
165180
if accepts_kwargs:
166181
call_kwargs = batch
@@ -172,10 +187,8 @@ def calibrate_loop(_model):
172187
# Use safe_nemotron_vl_forward for Nemotron Nano VL (embedding-injection style)
173188
# For other VLMs (like Nemotron-Parse), use standard forward
174189
if hasattr(full_model, "img_context_token_id"):
175-
# Nemotron Nano VL style
176190
safe_nemotron_vl_forward(full_model, call_kwargs)
177191
else:
178-
# Standard encoder-decoder or other VLM architectures
179192
full_model(**call_kwargs)
180193

181194
return calibrate_loop
@@ -312,8 +325,15 @@ def get_processor(
312325
)
313326

314327
return MllamaImageProcessor(processor, device)
315-
316-
return None
328+
else:
329+
# Try to load AutoProcessor for other VL models (e.g., Nemotron-Parse)
330+
try:
331+
processor = AutoProcessor.from_pretrained(ckpt_path, **model_kwargs)
332+
print(f"Loaded AutoProcessor for model type: {model_type}")
333+
return processor
334+
except Exception as e:
335+
print(f"Could not load processor for {model_type}: {e}")
336+
return None
317337

318338

319339
def load_mtp_weights(
@@ -447,6 +467,7 @@ def get_model(
447467
# Load config once and handle VL model detection
448468
try:
449469
hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs)
470+
450471
if is_nemotron_vl(hf_config):
451472
print(
452473
"Detected Nemotron VL model from config. "
@@ -466,8 +487,6 @@ def get_model(
466487
model_kwargs.setdefault("torch_dtype", "auto")
467488

468489
if "vila" in ckpt_path.lower():
469-
from transformers import AutoModel
470-
471490
hf_vila = AutoModel.from_pretrained(
472491
ckpt_path,
473492
device_map=device_map,
@@ -510,13 +529,17 @@ def get_model(
510529
if not hasattr(transformers, architecture):
511530
warnings.warn(
512531
f"Architecture {architecture} not found in transformers: {transformers.__version__}. "
513-
"Falling back to AutoModelForCausalLM."
532+
"Falling back to AutoModelForCausalLM (or AutoModel for non-causal architectures)."
514533
)
515534
assert trust_remote_code, (
516535
"Please set trust_remote_code to True if you want to use this architecture"
517536
)
518537

519-
auto_model_module = AutoModelForCausalLM
538+
# Use AutoModelForCausalLM for causal LMs, AutoModel for encoder-decoder models
539+
if getattr(hf_config, "is_encoder_decoder", False):
540+
auto_model_module = AutoModel
541+
else:
542+
auto_model_module = AutoModelForCausalLM
520543
from_config = auto_model_module.from_config
521544
else:
522545
auto_model_module = getattr(transformers, architecture)
@@ -527,7 +550,7 @@ def get_model(
527550
# unless specified by the hf_config.
528551
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
529552
model_kwargs2 = model_kwargs.copy()
530-
if auto_model_module != AutoModelForCausalLM:
553+
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
531554
model_kwargs2.pop("trust_remote_code", None)
532555
model_kwargs2["torch_dtype"] = torch_dtype
533556
model_kwargs2.pop("max_memory", None)

examples/llm_ptq/hf_ptq.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,12 @@ def load_model(args: argparse.Namespace):
361361
default_pad_token = None
362362

363363
is_nemotron_vl_model = is_nemotron_vl(full_model)
364+
365+
# Default to image-text calibration for VLM models
366+
if is_nemotron_vl_model and not args.calib_with_images:
367+
print("Nemotron VL model detected. Enabling image-text calibration by default.")
368+
args.calib_with_images = True
369+
364370
if model_type == "mllama":
365371
processor = get_processor(
366372
args.pyt_ckpt_path,
@@ -499,9 +505,12 @@ def mono_quantize(
499505
print("Disabling quantization for vision components in Nemotron VL model")
500506
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
501507
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
502-
# Also disable radio model components specifically
508+
# Also disable radio model components specifically (for Nemotron-Parse)
503509
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
504510
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
511+
quant_cfg["quant_cfg"]["*encoder*"] = {"enable": False} # Disable encoder
512+
quant_cfg["quant_cfg"]["*model_encoder*"] = {"enable": False} # Nemotron-Parse specific
513+
print("Quantization will only be applied to the decoder (text generation) component")
505514

506515
if not model_is_already_quantized or calibration_only:
507516
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
@@ -686,7 +695,7 @@ def pre_quantize(
686695
preview_input_ids,
687696
args.pyt_ckpt_path,
688697
"before quantization",
689-
allow_fallback=True,
698+
allow_fallback=False,
690699
)
691700
else:
692701
# Standard generation for non-Nemotron VL models
@@ -800,36 +809,42 @@ def quantize_main(
800809
device: torch.device,
801810
):
802811
if args.batch_size == 0:
803-
# Calibration/sparsification will actually take much more memory than regular inference
804-
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
805-
# to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
806-
sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1
807-
# Whisper model expects mel-spectrogram input features of length 3000
808-
# Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
809-
# As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
810-
# For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
811-
if model_type == "whisper":
812-
max_sample_length = 3000
813-
num_mel_bins = language_model.config.num_mel_bins
814-
sample_input_single_batch = (
815-
torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to(
816-
language_model.device
817-
)
818-
* 100
819-
)
812+
# For VL models with image-text calibration, skip automatic batch size detection
813+
# since get_max_batch_size can't handle multimodal inputs
814+
if args.calib_with_images:
815+
print("Image-text calibration enabled. Using default batch_size=1 for calibration.")
816+
args.batch_size = 1
820817
else:
821-
sample_input_single_batch = None
818+
# Calibration/sparsification will actually take much more memory than regular inference
819+
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
820+
# to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
821+
sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1
822+
# Whisper model expects mel-spectrogram input features of length 3000
823+
# Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
824+
# As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
825+
# For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
826+
if model_type == "whisper":
827+
max_sample_length = 3000
828+
num_mel_bins = language_model.config.num_mel_bins
829+
sample_input_single_batch = (
830+
torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to(
831+
language_model.device
832+
)
833+
* 100
834+
)
835+
else:
836+
sample_input_single_batch = None
822837

823-
run_auto_quant = args.auto_quantize_bits is not None
838+
run_auto_quant = args.auto_quantize_bits is not None
824839

825-
args.batch_size = get_max_batch_size(
826-
language_model,
827-
max_sample_length=args.calib_seq,
828-
sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0,
829-
sample_input_single_batch=sample_input_single_batch,
830-
enable_grad=run_auto_quant,
831-
)
832-
args.batch_size = min(args.batch_size, sum(args.calib_size))
840+
args.batch_size = get_max_batch_size(
841+
language_model,
842+
max_sample_length=args.calib_seq,
843+
sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0,
844+
sample_input_single_batch=sample_input_single_batch,
845+
enable_grad=run_auto_quant,
846+
)
847+
args.batch_size = min(args.batch_size, sum(args.calib_size))
833848

834849
print(f"Use calib batch_size {args.batch_size}")
835850

examples/llm_ptq/vlm_utils.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -105,27 +105,31 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name):
105105
else:
106106
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
107107

108-
messages = [
109-
{"role": "system", "content": "/no_think"},
110-
{
111-
"role": "user",
112-
"content": [
113-
{
114-
"type": "image",
115-
"image": "",
116-
},
117-
{
118-
"type": "text",
119-
"text": question,
120-
},
121-
],
122-
},
123-
]
124-
125-
# Apply chat template
126-
prompt = tokenizer.apply_chat_template(
127-
messages, tokenize=False, add_generation_prompt=True
128-
)
108+
# Use chat template if available, otherwise fall back to default task prompt
109+
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None:
110+
messages = [
111+
{"role": "system", "content": "/no_think"},
112+
{
113+
"role": "user",
114+
"content": [
115+
{
116+
"type": "image",
117+
"image": "",
118+
},
119+
{
120+
"type": "text",
121+
"text": question,
122+
},
123+
],
124+
},
125+
]
126+
prompt = tokenizer.apply_chat_template(
127+
messages, tokenize=False, add_generation_prompt=True
128+
)
129+
else:
130+
# For models without chat templates (e.g., encoder-decoder VL models),
131+
# use the tokenizer's bos/eos tokens as a minimal prompt
132+
prompt = (tokenizer.bos_token or "") + question
129133

130134
# Process inputs using the processor with single image
131135
inputs = processor(
@@ -139,6 +143,12 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name):
139143
inputs = inputs.to(model_device)
140144
print(f" Moved inputs to {model_device}")
141145

146+
# Verify we have pixel_values for the vision encoder
147+
if not hasattr(inputs, "pixel_values") or inputs.pixel_values is None:
148+
raise ValueError(
149+
"Processor did not generate pixel_values. Check processor configuration."
150+
)
151+
142152
# Generate response using model.generate
143153
generated_ids = model.generate(
144154
pixel_values=inputs.pixel_values,
@@ -148,12 +158,23 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name):
148158
)
149159

150160
# Decode the response (trim input tokens like in the working example)
161+
if generated_ids is None:
162+
raise ValueError("Model generate returned None")
163+
151164
generated_ids_trimmed = [
152165
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
153166
]
154-
output_text = processor.batch_decode(
155-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
167+
# Use processor.batch_decode if available, otherwise fall back to tokenizer
168+
decoder = processor if hasattr(processor, "batch_decode") else tokenizer
169+
output_text = decoder.batch_decode(
170+
generated_ids_trimmed,
171+
skip_special_tokens=True,
172+
clean_up_tokenization_spaces=False,
156173
)
174+
175+
if output_text is None or len(output_text) == 0:
176+
raise ValueError("Decoding returned empty output")
177+
157178
response = output_text[0]
158179

159180
print(f"✅ VL generation {stage_name} successful!")

modelopt/torch/export/model_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def is_multimodal_model(model):
8585
- Vision LoRA configurations
8686
- Audio processing capabilities
8787
- Image embedding layers
88+
- Nemotron-Parse conditional generation models
8889
8990
Args:
9091
model: The HuggingFace model instance to check
@@ -103,6 +104,10 @@ def is_multimodal_model(model):
103104
"""
104105
config = model.config
105106

107+
# Check for Nemotron-Parse encoder-decoder architecture
108+
architectures = getattr(config, "architectures", [])
109+
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
110+
106111
return (
107112
hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL)
108113
or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA)
@@ -112,6 +117,7 @@ def is_multimodal_model(model):
112117
or (
113118
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
114119
) # Image embedding layers
120+
or is_nemotron_parse # Nemotron-Parse conditional generation model
115121
)
116122

117123

@@ -141,5 +147,11 @@ def get_language_model_from_vl(model) -> list[nn.Module] | None:
141147
if hasattr(model, "language_model"):
142148
return [model, model.language_model]
143149

144-
# Pattern 3: No language_model found
150+
# Pattern 3: For encoder-decoder VL models (e.g., Nemotron-Parse), the decoder is the language model.
151+
# Only match if the model is detected as multimodal to avoid matching non-VLM encoder-decoder
152+
# models like T5, Bart, Whisper which also have .decoder.
153+
if hasattr(model, "decoder") and is_multimodal_model(model):
154+
return [model, model.decoder]
155+
156+
# Pattern 4: No language_model found
145157
return None

0 commit comments

Comments
 (0)