Skip to content

Commit 52eee84

Browse files
committed
add image-text data calibration support
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent e1bd013 commit 52eee84

3 files changed

Lines changed: 222 additions & 72 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,33 @@ def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs) -> PreTrainedTok
221221
if "vila" in ckpt_path.lower():
222222
ckpt_path += "/llm"
223223

224-
tokenizer = AutoTokenizer.from_pretrained(
225-
ckpt_path, trust_remote_code=trust_remote_code, **kwargs
226-
)
224+
# Suppress verbose tokenizer output (e.g., printing all special tokens)
225+
import contextlib
226+
import io
227+
import logging
228+
import os
229+
230+
# Save current settings
231+
old_verbosity = os.environ.get("TOKENIZERS_PARALLELISM", None)
232+
transformers_log_level = logging.getLogger("transformers").level
233+
234+
# Suppress output
235+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
236+
logging.getLogger("transformers").setLevel(logging.ERROR)
237+
238+
# Also capture stdout to suppress verbose tokenizer printing
239+
with contextlib.redirect_stdout(io.StringIO()):
240+
try:
241+
tokenizer = AutoTokenizer.from_pretrained(
242+
ckpt_path, trust_remote_code=trust_remote_code, **kwargs
243+
)
244+
finally:
245+
# Restore original settings
246+
if old_verbosity is not None:
247+
os.environ["TOKENIZERS_PARALLELISM"] = old_verbosity
248+
else:
249+
os.environ.pop("TOKENIZERS_PARALLELISM", None)
250+
logging.getLogger("transformers").setLevel(transformers_log_level)
227251

228252
# can't set attribute 'pad_token' for "<unk>"
229253
# We skip this step for Nemo models
@@ -279,10 +303,23 @@ def get_processor(
279303
# Try to load AutoProcessor for other VL models (e.g., Nemotron-Parse)
280304
# This will only work if the model has a processor config
281305
try:
282-
processor = AutoProcessor.from_pretrained(
283-
ckpt_path,
284-
**model_kwargs,
285-
)
306+
import contextlib
307+
import io
308+
import logging
309+
310+
# Suppress verbose output from processor/tokenizer loading
311+
transformers_log_level = logging.getLogger("transformers").level
312+
logging.getLogger("transformers").setLevel(logging.ERROR)
313+
314+
with contextlib.redirect_stdout(io.StringIO()):
315+
processor = AutoProcessor.from_pretrained(
316+
ckpt_path,
317+
**model_kwargs,
318+
)
319+
320+
# Restore logging
321+
logging.getLogger("transformers").setLevel(transformers_log_level)
322+
286323
print(f"Loaded AutoProcessor for model type: {model_type}")
287324
return processor
288325
except Exception as e:
@@ -330,12 +367,26 @@ def get_model(
330367
# Load config once and handle VL model detection
331368
try:
332369
hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs)
370+
371+
# Check specifically for Nemotron-Parse
372+
architectures = getattr(hf_config, "architectures", [])
373+
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
374+
333375
if is_nemotron_vl(hf_config):
334-
print(
335-
"Detected Nemotron VL model from config. "
336-
"Disabling automatic device mapping for compatibility."
337-
)
338-
device_map = None
376+
if is_nemotron_parse:
377+
# Nemotron-Parse works fine with device_map="auto"
378+
# Keep device_map="auto" to ensure proper device placement
379+
print(
380+
"Detected Nemotron-Parse model from config. "
381+
"Using automatic device mapping."
382+
)
383+
else:
384+
# For other Nemotron VL models, disable device_map for compatibility
385+
print(
386+
"Detected Nemotron VL model from config. "
387+
"Disabling automatic device mapping for compatibility."
388+
)
389+
device_map = None
339390
except Exception as e:
340391
print(f"Error: Could not load config from {ckpt_path}: {e}")
341392
raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e
@@ -433,6 +484,21 @@ def get_model(
433484
print(f"Moving model to {device} device...")
434485
model = model.to(device)
435486

487+
# For Nemotron-Parse, ensure the encoder (including RADIO) is fully on device
488+
# The RADIO encoder has buffers that might not be properly moved even with device_map="auto"
489+
# This is because custom RADIO modules might not fully support accelerate's device_map
490+
if device != "cpu" and hasattr(model, "encoder"):
491+
# Check if encoder has any buffers on CPU
492+
cpu_buffers = []
493+
for name, buffer in model.encoder.named_buffers():
494+
if buffer.device.type == "cpu":
495+
cpu_buffers.append(name)
496+
497+
if cpu_buffers:
498+
print(f"Found {len(cpu_buffers)} encoder buffers on CPU. Moving encoder to {device}...")
499+
model.encoder = model.encoder.to(device)
500+
print(f"Encoder moved to {device}")
501+
436502
if device == "cuda" and not is_model_on_gpu(model):
437503
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")
438504

examples/llm_ptq/hf_ptq.py

Lines changed: 139 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
)
6565
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor
6666
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
67+
from modelopt.torch.utils.nemotron_vlm_dataset_utils import get_nemotron_vlm_dataset_dataloader
6768
from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader
6869
from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader
6970

@@ -173,9 +174,50 @@ def make_calib_dataloader(
173174
tokenizer: PreTrainedTokenizerBase | None,
174175
device: torch.device,
175176
model_type: str | None,
177+
full_model: torch.nn.Module | None = None,
176178
) -> tuple[DataLoader, str | None]:
177179
calib_dataloader = None
178180
first_text_speech_dataset = None
181+
182+
# Check if this is Nemotron-Parse - use image-text data for better calibration
183+
if full_model is not None:
184+
config = full_model.config
185+
architectures = getattr(config, "architectures", [])
186+
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
187+
188+
if is_nemotron_parse and processor is not None:
189+
print(
190+
"✓ Detected Nemotron-Parse model. Using image-text dataset for calibration "
191+
"to provide realistic visual embeddings to the decoder."
192+
)
193+
194+
# Override dataset to use image-text dataset if not specified
195+
supported_datasets = ["nemotron_vlm_v2", "chartqa", "scienceqa"]
196+
if not args.dataset or args.dataset[0] not in supported_datasets:
197+
print(
198+
f"[INFO] Dataset '{args.dataset}' is not a supported image-text dataset. "
199+
f"Automatically using 'nemotron_vlm_v2' for Nemotron-Parse calibration."
200+
)
201+
dataset_to_use = "nemotron_vlm_v2"
202+
else:
203+
dataset_to_use = args.dataset[0]
204+
205+
# Nemotron-Parse needs single dataset for now
206+
if len(args.calib_size) > 1:
207+
print(f"[INFO] Using first calib_size value: {args.calib_size[0]}")
208+
calib_size_to_use = args.calib_size[0]
209+
else:
210+
calib_size_to_use = args.calib_size[0] if args.calib_size else 512
211+
212+
calib_dataloader = get_nemotron_vlm_dataset_dataloader(
213+
dataset_name=dataset_to_use,
214+
processor=processor,
215+
batch_size=args.batch_size,
216+
num_samples=calib_size_to_use,
217+
device=device, # Move data to model's device
218+
)
219+
return calib_dataloader, first_text_speech_dataset
220+
179221
if model_type == "mllama":
180222
assert processor is not None and isinstance(processor, MllamaImageProcessor), (
181223
"The MllamaImageProcessor must be set."
@@ -377,18 +419,35 @@ def load_model(args: argparse.Namespace):
377419
trust_remote_code=args.trust_remote_code,
378420
)
379421
else:
422+
# Check if this is a Nemotron VL model that needs a processor
423+
# Do this BEFORE setting default datasets so we can use image-text data for Nemotron-Parse
424+
is_nemotron_vl_model = is_nemotron_vl(full_model)
425+
426+
# Check specifically for Nemotron-Parse to set appropriate dataset defaults
427+
config = full_model.config
428+
architectures = getattr(config, "architectures", [])
429+
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
430+
380431
if args.dataset is None:
381-
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
382-
warnings.warn(
383-
"No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
384-
)
432+
if is_nemotron_parse:
433+
# For Nemotron-Parse, default to Nemotron VLM Dataset v2
434+
args.dataset = ["nemotron_vlm_v2"]
435+
print(
436+
"No dataset specified. Defaulting to 'nemotron_vlm_v2' for Nemotron-Parse "
437+
"(NVIDIA's image-text dataset for better calibration)."
438+
)
439+
else:
440+
# For other models, use text-only datasets
441+
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
442+
warnings.warn(
443+
"No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
444+
)
445+
385446
# Adjust calib_size to match dataset length by extending or truncating as needed
386447
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
387448
: len(args.dataset)
388449
]
389450

390-
# Check if this is a Nemotron VL model that needs a processor
391-
is_nemotron_vl_model = is_nemotron_vl(full_model)
392451
if is_nemotron_vl_model:
393452
# Load processor for Nemotron VL models (like Nemotron-Parse)
394453
processor = get_processor(
@@ -404,26 +463,41 @@ def load_model(args: argparse.Namespace):
404463
# Left padding usually provides better calibration result.
405464
tokenizer.padding_side = "left"
406465

407-
# We only quantize the language model for VLMs other than the type supported above.
408-
language_model_lineage = get_language_model_from_vl(full_model)
409-
if language_model_lineage is not None:
410-
language_model = language_model_lineage.pop(-1)
411-
ancestors = language_model_lineage
412-
# Apply disabled quant to all modules that are not part of language_model so we can exclude them during
413-
# HF export.
414-
disabled_quant_cfg = {
415-
"quant_cfg": {"default": {"enable": False}},
416-
"algorithm": "max",
417-
}
418-
419-
memo = set(ancestors) | {language_model}
420-
for ancestor in ancestors:
421-
for _, module in ancestor.named_children():
422-
if module not in memo:
423-
mtq.quantize(module, disabled_quant_cfg, forward_loop=None)
424-
memo.add(module)
425-
426-
model_type = get_model_type(language_model)
466+
# Check if this is Nemotron-Parse
467+
config = full_model.config
468+
architectures = getattr(config, "architectures", [])
469+
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
470+
471+
# For Nemotron-Parse, DON'T extract the decoder
472+
# We want to calibrate the full model so the decoder sees realistic visual embeddings
473+
# The vision encoder won't be quantized (disabled via quant_cfg in mono_quantize)
474+
if is_nemotron_parse:
475+
print(
476+
"Nemotron-Parse detected: Keeping full encoder-decoder model for calibration "
477+
"with image-text data. Vision encoder will be disabled from quantization."
478+
)
479+
# language_model = full_model (already set above)
480+
else:
481+
# For other VLMs, extract the language model for quantization
482+
language_model_lineage = get_language_model_from_vl(full_model)
483+
if language_model_lineage is not None:
484+
language_model = language_model_lineage.pop(-1)
485+
ancestors = language_model_lineage
486+
# Apply disabled quant to all modules that are not part of language_model so we can exclude them during
487+
# HF export.
488+
disabled_quant_cfg = {
489+
"quant_cfg": {"default": {"enable": False}},
490+
"algorithm": "max",
491+
}
492+
493+
memo = set(ancestors) | {language_model}
494+
for ancestor in ancestors:
495+
for _, module in ancestor.named_children():
496+
if module not in memo:
497+
mtq.quantize(module, disabled_quant_cfg, forward_loop=None)
498+
memo.add(module)
499+
500+
model_type = get_model_type(language_model)
427501

428502
if model_type == "phi4mm":
429503
warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.")
@@ -494,14 +568,23 @@ def mono_quantize(
494568
"Consider reducing calib_size to reduce calibration time.\n####\n"
495569
)
496570

571+
# Check if this is Nemotron-Parse
572+
config = full_model.config
573+
architectures = getattr(config, "architectures", [])
574+
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
575+
original_forward = None # Track original forward method if we wrap it
576+
497577
# For Nemotron VL models, disable quantization of vision components
498578
if is_nemotron_vl_model:
499579
print("Disabling quantization for vision components in Nemotron VL model")
500580
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
501581
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
502-
# Also disable radio model components specifically
582+
# Also disable radio model components specifically (for Nemotron-Parse)
503583
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
504584
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
585+
quant_cfg["quant_cfg"]["*encoder*"] = {"enable": False} # Disable encoder
586+
quant_cfg["quant_cfg"]["*model_encoder*"] = {"enable": False} # Nemotron-Parse specific
587+
print("Quantization will only be applied to the decoder (text generation) component")
505588

506589
if not model_is_already_quantized or calibration_only:
507590
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
@@ -513,9 +596,25 @@ def mono_quantize(
513596

514597
if not use_calibration:
515598
warnings.warn("Dynamic quantization. Calibration skipped.")
516-
calibrate_loop = (
517-
create_forward_loop(dataloader=calib_dataloader) if use_calibration else None
518-
)
599+
600+
# Create calibration loop
601+
if use_calibration:
602+
if is_nemotron_parse:
603+
# For Nemotron-Parse, wrap the model to force use_cache=False
604+
print("Wrapping Nemotron-Parse model for calibration (use_cache=False)")
605+
original_forward = language_model.forward
606+
607+
def wrapped_forward(*args, **kwargs):
608+
kwargs["use_cache"] = False
609+
return original_forward(*args, **kwargs)
610+
611+
# Temporarily replace forward method
612+
language_model.forward = wrapped_forward
613+
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
614+
else:
615+
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
616+
else:
617+
calibrate_loop = None
519618

520619
if calibration_only:
521620
language_model = mtq.calibrate(
@@ -524,8 +623,15 @@ def mono_quantize(
524623
else:
525624
language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop)
526625

527-
# For VL models, update full_model to use the quantized language model
528-
if is_nemotron_vl_model:
626+
# Restore original forward method if we wrapped it for Nemotron-Parse
627+
if is_nemotron_parse and original_forward is not None:
628+
print("Restoring original forward method after calibration")
629+
language_model.forward = original_forward
630+
original_forward = None
631+
632+
# For VL models (except Nemotron-Parse), update full_model to use the quantized language model
633+
# For Nemotron-Parse, language_model IS full_model, so no update needed
634+
if is_nemotron_vl_model and language_model is not full_model:
529635
language_model_lineage = get_language_model_from_vl(full_model)
530636
if language_model_lineage is not None:
531637
print("Updating full_model with quantized language_model...")
@@ -828,38 +934,12 @@ def quantize_main(
828934
print(f"Use calib batch_size {args.batch_size}")
829935

830936
calib_dataloader, first_text_speech_dataset = make_calib_dataloader(
831-
args, language_model, processor, tokenizer, device, model_type
937+
args, language_model, processor, tokenizer, device, model_type, full_model
832938
)
833939

834940
# Detect if this is a Nemotron VL model using architecture-based detection
835941
is_nemotron_vl_model = is_nemotron_vl(full_model)
836942

837-
# For Nemotron-Parse, wrap the text-only dataloader to add dummy images
838-
# Nemotron-Parse is an encoder-decoder model that requires pixel_values
839-
if is_nemotron_vl_model and processor is not None:
840-
config = full_model.config
841-
architectures = getattr(config, "architectures", [])
842-
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
843-
844-
if is_nemotron_parse:
845-
# Check if we're quantizing just the decoder or the full model
846-
decoder_only = language_model is not full_model
847-
848-
if decoder_only:
849-
print(
850-
"Calibration will use text-only inputs for Nemotron-Parse decoder. "
851-
"Vision encoder is excluded from quantization."
852-
)
853-
else:
854-
print(
855-
"Wrapping calibration dataloader for Nemotron-Parse to add dummy images. "
856-
"Nemotron-Parse requires pixel_values for full model calibration."
857-
)
858-
859-
calib_dataloader = create_nemotron_parse_calib_wrapper(
860-
calib_dataloader, processor, device, decoder_only=decoder_only
861-
)
862-
863943
preview_input_ids, generated_ids_before_ptq = pre_quantize(
864944
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
865945
)

0 commit comments

Comments
 (0)