6464)
6565from modelopt .torch .utils .image_processor import BaseImageProcessor , MllamaImageProcessor
6666from modelopt .torch .utils .memory_monitor import launch_memory_monitor
67+ from modelopt .torch .utils .nemotron_vlm_dataset_utils import get_nemotron_vlm_dataset_dataloader
6768from modelopt .torch .utils .speech_dataset_utils import get_speech_dataset_dataloader
6869from modelopt .torch .utils .vlm_dataset_utils import get_vlm_dataset_dataloader
6970
@@ -173,9 +174,50 @@ def make_calib_dataloader(
173174 tokenizer : PreTrainedTokenizerBase | None ,
174175 device : torch .device ,
175176 model_type : str | None ,
177+ full_model : torch .nn .Module | None = None ,
176178) -> tuple [DataLoader , str | None ]:
177179 calib_dataloader = None
178180 first_text_speech_dataset = None
181+
182+ # Check if this is Nemotron-Parse - use image-text data for better calibration
183+ if full_model is not None :
184+ config = full_model .config
185+ architectures = getattr (config , "architectures" , [])
186+ is_nemotron_parse = any ("nemotronparse" in arch .lower () for arch in architectures )
187+
188+ if is_nemotron_parse and processor is not None :
189+ print (
190+ "✓ Detected Nemotron-Parse model. Using image-text dataset for calibration "
191+ "to provide realistic visual embeddings to the decoder."
192+ )
193+
194+ # Override dataset to use image-text dataset if not specified
195+ supported_datasets = ["nemotron_vlm_v2" , "chartqa" , "scienceqa" ]
196+ if not args .dataset or args .dataset [0 ] not in supported_datasets :
197+ print (
198+ f"[INFO] Dataset '{ args .dataset } ' is not a supported image-text dataset. "
199+ f"Automatically using 'nemotron_vlm_v2' for Nemotron-Parse calibration."
200+ )
201+ dataset_to_use = "nemotron_vlm_v2"
202+ else :
203+ dataset_to_use = args .dataset [0 ]
204+
205+ # Nemotron-Parse needs single dataset for now
206+ if len (args .calib_size ) > 1 :
207+ print (f"[INFO] Using first calib_size value: { args .calib_size [0 ]} " )
208+ calib_size_to_use = args .calib_size [0 ]
209+ else :
210+ calib_size_to_use = args .calib_size [0 ] if args .calib_size else 512
211+
212+ calib_dataloader = get_nemotron_vlm_dataset_dataloader (
213+ dataset_name = dataset_to_use ,
214+ processor = processor ,
215+ batch_size = args .batch_size ,
216+ num_samples = calib_size_to_use ,
217+ device = device , # Move data to model's device
218+ )
219+ return calib_dataloader , first_text_speech_dataset
220+
179221 if model_type == "mllama" :
180222 assert processor is not None and isinstance (processor , MllamaImageProcessor ), (
181223 "The MllamaImageProcessor must be set."
@@ -377,18 +419,35 @@ def load_model(args: argparse.Namespace):
377419 trust_remote_code = args .trust_remote_code ,
378420 )
379421 else :
422+ # Check if this is a Nemotron VL model that needs a processor
423+ # Do this BEFORE setting default datasets so we can use image-text data for Nemotron-Parse
424+ is_nemotron_vl_model = is_nemotron_vl (full_model )
425+
426+ # Check specifically for Nemotron-Parse to set appropriate dataset defaults
427+ config = full_model .config
428+ architectures = getattr (config , "architectures" , [])
429+ is_nemotron_parse = any ("nemotronparse" in arch .lower () for arch in architectures )
430+
380431 if args .dataset is None :
381- args .dataset = ["cnn_dailymail" , "nemotron-post-training-dataset-v2" ]
382- warnings .warn (
383- "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
384- )
432+ if is_nemotron_parse :
433+ # For Nemotron-Parse, default to Nemotron VLM Dataset v2
434+ args .dataset = ["nemotron_vlm_v2" ]
435+ print (
436+ "No dataset specified. Defaulting to 'nemotron_vlm_v2' for Nemotron-Parse "
437+ "(NVIDIA's image-text dataset for better calibration)."
438+ )
439+ else :
440+ # For other models, use text-only datasets
441+ args .dataset = ["cnn_dailymail" , "nemotron-post-training-dataset-v2" ]
442+ warnings .warn (
443+ "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
444+ )
445+
385446 # Adjust calib_size to match dataset length by extending or truncating as needed
386447 args .calib_size = (args .calib_size + [args .calib_size [- 1 ]] * len (args .dataset ))[
387448 : len (args .dataset )
388449 ]
389450
390- # Check if this is a Nemotron VL model that needs a processor
391- is_nemotron_vl_model = is_nemotron_vl (full_model )
392451 if is_nemotron_vl_model :
393452 # Load processor for Nemotron VL models (like Nemotron-Parse)
394453 processor = get_processor (
@@ -404,26 +463,41 @@ def load_model(args: argparse.Namespace):
404463 # Left padding usually provides better calibration result.
405464 tokenizer .padding_side = "left"
406465
407- # We only quantize the language model for VLMs other than the type supported above.
408- language_model_lineage = get_language_model_from_vl (full_model )
409- if language_model_lineage is not None :
410- language_model = language_model_lineage .pop (- 1 )
411- ancestors = language_model_lineage
412- # Apply disabled quant to all modules that are not part of language_model so we can exclude them during
413- # HF export.
414- disabled_quant_cfg = {
415- "quant_cfg" : {"default" : {"enable" : False }},
416- "algorithm" : "max" ,
417- }
418-
419- memo = set (ancestors ) | {language_model }
420- for ancestor in ancestors :
421- for _ , module in ancestor .named_children ():
422- if module not in memo :
423- mtq .quantize (module , disabled_quant_cfg , forward_loop = None )
424- memo .add (module )
425-
426- model_type = get_model_type (language_model )
466+ # Check if this is Nemotron-Parse
467+ config = full_model .config
468+ architectures = getattr (config , "architectures" , [])
469+ is_nemotron_parse = any ("nemotronparse" in arch .lower () for arch in architectures )
470+
471+ # For Nemotron-Parse, DON'T extract the decoder
472+ # We want to calibrate the full model so the decoder sees realistic visual embeddings
473+ # The vision encoder won't be quantized (disabled via quant_cfg in mono_quantize)
474+ if is_nemotron_parse :
475+ print (
476+ "Nemotron-Parse detected: Keeping full encoder-decoder model for calibration "
477+ "with image-text data. Vision encoder will be disabled from quantization."
478+ )
479+ # language_model = full_model (already set above)
480+ else :
481+ # For other VLMs, extract the language model for quantization
482+ language_model_lineage = get_language_model_from_vl (full_model )
483+ if language_model_lineage is not None :
484+ language_model = language_model_lineage .pop (- 1 )
485+ ancestors = language_model_lineage
486+ # Apply disabled quant to all modules that are not part of language_model so we can exclude them during
487+ # HF export.
488+ disabled_quant_cfg = {
489+ "quant_cfg" : {"default" : {"enable" : False }},
490+ "algorithm" : "max" ,
491+ }
492+
493+ memo = set (ancestors ) | {language_model }
494+ for ancestor in ancestors :
495+ for _ , module in ancestor .named_children ():
496+ if module not in memo :
497+ mtq .quantize (module , disabled_quant_cfg , forward_loop = None )
498+ memo .add (module )
499+
500+ model_type = get_model_type (language_model )
427501
428502 if model_type == "phi4mm" :
429503 warnings .warn ("Please set the default input_mode to InputMode.LANGUAGE before quantizing." )
@@ -494,14 +568,23 @@ def mono_quantize(
494568 "Consider reducing calib_size to reduce calibration time.\n ####\n "
495569 )
496570
571+ # Check if this is Nemotron-Parse
572+ config = full_model .config
573+ architectures = getattr (config , "architectures" , [])
574+ is_nemotron_parse = any ("nemotronparse" in arch .lower () for arch in architectures )
575+ original_forward = None # Track original forward method if we wrap it
576+
497577 # For Nemotron VL models, disable quantization of vision components
498578 if is_nemotron_vl_model :
499579 print ("Disabling quantization for vision components in Nemotron VL model" )
500580 quant_cfg ["quant_cfg" ]["*vision*" ] = {"enable" : False }
501581 quant_cfg ["quant_cfg" ]["*image*" ] = {"enable" : False }
502- # Also disable radio model components specifically
582+ # Also disable radio model components specifically (for Nemotron-Parse)
503583 quant_cfg ["quant_cfg" ]["*radio*" ] = {"enable" : False }
504584 quant_cfg ["quant_cfg" ]["*visual*" ] = {"enable" : False }
585+ quant_cfg ["quant_cfg" ]["*encoder*" ] = {"enable" : False } # Disable encoder
586+ quant_cfg ["quant_cfg" ]["*model_encoder*" ] = {"enable" : False } # Nemotron-Parse specific
587+ print ("Quantization will only be applied to the decoder (text generation) component" )
505588
506589 if not model_is_already_quantized or calibration_only :
507590 if model_type == "gptoss" and args .qformat == "nvfp4_mlp_only" :
@@ -513,9 +596,25 @@ def mono_quantize(
513596
514597 if not use_calibration :
515598 warnings .warn ("Dynamic quantization. Calibration skipped." )
516- calibrate_loop = (
517- create_forward_loop (dataloader = calib_dataloader ) if use_calibration else None
518- )
599+
600+ # Create calibration loop
601+ if use_calibration :
602+ if is_nemotron_parse :
603+ # For Nemotron-Parse, wrap the model to force use_cache=False
604+ print ("Wrapping Nemotron-Parse model for calibration (use_cache=False)" )
605+ original_forward = language_model .forward
606+
607+ def wrapped_forward (* args , ** kwargs ):
608+ kwargs ["use_cache" ] = False
609+ return original_forward (* args , ** kwargs )
610+
611+ # Temporarily replace forward method
612+ language_model .forward = wrapped_forward
613+ calibrate_loop = create_forward_loop (dataloader = calib_dataloader )
614+ else :
615+ calibrate_loop = create_forward_loop (dataloader = calib_dataloader )
616+ else :
617+ calibrate_loop = None
519618
520619 if calibration_only :
521620 language_model = mtq .calibrate (
@@ -524,8 +623,15 @@ def mono_quantize(
524623 else :
525624 language_model = mtq .quantize (language_model , quant_cfg , forward_loop = calibrate_loop )
526625
527- # For VL models, update full_model to use the quantized language model
528- if is_nemotron_vl_model :
626+ # Restore original forward method if we wrapped it for Nemotron-Parse
627+ if is_nemotron_parse and original_forward is not None :
628+ print ("Restoring original forward method after calibration" )
629+ language_model .forward = original_forward
630+ original_forward = None
631+
632+ # For VL models (except Nemotron-Parse), update full_model to use the quantized language model
633+ # For Nemotron-Parse, language_model IS full_model, so no update needed
634+ if is_nemotron_vl_model and language_model is not full_model :
529635 language_model_lineage = get_language_model_from_vl (full_model )
530636 if language_model_lineage is not None :
531637 print ("Updating full_model with quantized language_model..." )
@@ -828,38 +934,12 @@ def quantize_main(
828934 print (f"Use calib batch_size { args .batch_size } " )
829935
830936 calib_dataloader , first_text_speech_dataset = make_calib_dataloader (
831- args , language_model , processor , tokenizer , device , model_type
937+ args , language_model , processor , tokenizer , device , model_type , full_model
832938 )
833939
834940 # Detect if this is a Nemotron VL model using architecture-based detection
835941 is_nemotron_vl_model = is_nemotron_vl (full_model )
836942
837- # For Nemotron-Parse, wrap the text-only dataloader to add dummy images
838- # Nemotron-Parse is an encoder-decoder model that requires pixel_values
839- if is_nemotron_vl_model and processor is not None :
840- config = full_model .config
841- architectures = getattr (config , "architectures" , [])
842- is_nemotron_parse = any ("nemotronparse" in arch .lower () for arch in architectures )
843-
844- if is_nemotron_parse :
845- # Check if we're quantizing just the decoder or the full model
846- decoder_only = language_model is not full_model
847-
848- if decoder_only :
849- print (
850- "Calibration will use text-only inputs for Nemotron-Parse decoder. "
851- "Vision encoder is excluded from quantization."
852- )
853- else :
854- print (
855- "Wrapping calibration dataloader for Nemotron-Parse to add dummy images. "
856- "Nemotron-Parse requires pixel_values for full model calibration."
857- )
858-
859- calib_dataloader = create_nemotron_parse_calib_wrapper (
860- calib_dataloader , processor , device , decoder_only = decoder_only
861- )
862-
863943 preview_input_ids , generated_ids_before_ptq = pre_quantize (
864944 args , full_model , model_type , tokenizer , calib_dataloader , is_nemotron_vl_model
865945 )
0 commit comments