1818import random
1919import time
2020import warnings
21+ from pathlib import Path
2122from typing import Any
2223
2324import numpy as np
6465from modelopt .torch .quantization .config import _default_disabled_quantizer_cfg , need_calibration
6566from modelopt .torch .quantization .plugins .accelerate import init_quantized_weights
6667from modelopt .torch .quantization .utils import is_quantized
68+ from modelopt .torch .speculative .eagle .utils import (
69+ EagleOfflineDataCollator ,
70+ OfflineSupervisedDataset ,
71+ )
6772from modelopt .torch .utils .dataset_utils import (
6873 create_forward_loop ,
6974 get_dataset_dataloader ,
@@ -163,17 +168,63 @@ def extract_and_prepare_language_model_from_vl(full_model):
163168 return None , None
164169
165170
171+ class _DeviceDataLoader :
172+ """Wrapper around a DataLoader that moves each batch to a target device."""
173+
174+ def __init__ (self , dataloader : DataLoader , device : torch .device ):
175+ self .dataloader = dataloader
176+ self .device = device
177+
178+ def __iter__ (self ):
179+ for batch in self .dataloader :
180+ yield _move_batch_to_device (batch , self .device )
181+
182+ def __len__ (self ):
183+ return len (self .dataloader )
184+
185+
186+ def _move_batch_to_device (batch : dict , device : torch .device ) -> dict :
187+ """Recursively move all tensors in a batch dict to the given device."""
188+
189+ def _to_device (value ):
190+ if isinstance (value , torch .Tensor ):
191+ return value .to (device )
192+ if isinstance (value , dict ):
193+ return {k : _to_device (v ) for k , v in value .items ()}
194+ return value
195+
196+ return {k : _to_device (v ) for k , v in batch .items ()}
197+
198+
166199def make_calib_dataloader (
167200 args : argparse .Namespace ,
168201 language_model : torch .nn .Module ,
169202 processor : BaseImageProcessor | ProcessorMixin | None ,
170203 tokenizer : PreTrainedTokenizerBase | None ,
171204 device : torch .device ,
172205 model_type : str | None ,
173- ) -> tuple [DataLoader , str | None ]:
206+ ) -> tuple [DataLoader | _DeviceDataLoader , str | None ]:
174207 calib_dataloader = None
175208 first_text_speech_dataset = None
176- if args .calib_with_images :
209+ if args .specdec_offline_dataset is not None :
210+ offline_data_path = Path (args .specdec_offline_dataset )
211+ dumped_files = sorted (str (p ) for p in offline_data_path .glob ("*.pt" ))
212+ if not dumped_files :
213+ raise ValueError (f"No .pt files found in { args .specdec_offline_dataset } " )
214+ if args .calib_size [0 ] > 0 :
215+ dumped_files = dumped_files [: args .calib_size [0 ]]
216+ dataset = OfflineSupervisedDataset (dumped_files )
217+ collator = EagleOfflineDataCollator (train_len = args .calib_seq )
218+ raw_loader = DataLoader (
219+ dataset ,
220+ batch_size = args .batch_size ,
221+ shuffle = False ,
222+ collate_fn = collator ,
223+ )
224+ # Wrap to move batches to the target device; device-transfer logic is kept
225+ # out of the data collator to avoid interference with dataloader prefetching.
226+ calib_dataloader = _DeviceDataLoader (raw_loader , device )
227+ elif args .calib_with_images :
177228 # VLM image-text calibration path: assume Nemotron VLM dataset by default.
178229 assert processor is not None , (
179230 "Please provide a processor (e.g., AutoProcessor) for image calibration."
@@ -358,7 +409,7 @@ def forward_step(model, batch):
358409def load_model (args : argparse .Namespace ):
359410 # If low memory mode is enabled, we compress the model while loading the HF checkpoint.
360411 calibration_only = False
361- if not args .low_memory_mode :
412+ if args . specdec_offline_dataset is not None or not args .low_memory_mode :
362413 full_model = get_model (
363414 args .pyt_ckpt_path ,
364415 args .device ,
@@ -459,28 +510,34 @@ def load_model(args: argparse.Namespace):
459510 language_model = extracted_lm
460511 model_type = extracted_model_type
461512 else :
462- if args .dataset is None :
463- args .dataset = ["cnn_dailymail" , "nemotron-post-training-dataset-v2" ]
464- warnings .warn (
465- "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
513+ if args .specdec_offline_dataset is not None :
514+ language_model = full_model
515+ else :
516+ if args .dataset is None :
517+ args .dataset = ["cnn_dailymail" , "nemotron-post-training-dataset-v2" ]
518+ warnings .warn (
519+ "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
520+ )
521+ # Adjust calib_size to match dataset length by extending or truncating as needed
522+ args .calib_size = (args .calib_size + [args .calib_size [- 1 ]] * len (args .dataset ))[
523+ : len (args .dataset )
524+ ]
525+
526+ # We only quantize the language model for VLMs other than the type supported above.
527+ extracted_lm , extracted_model_type = extract_and_prepare_language_model_from_vl (
528+ full_model
466529 )
467- # Adjust calib_size to match dataset length by extending or truncating as needed
468- args . calib_size = ( args . calib_size + [ args . calib_size [ - 1 ]] * len ( args . dataset ))[
469- : len ( args . dataset )
470- ]
530+ if extracted_lm is not None :
531+ language_model = extracted_lm
532+ model_type = extracted_model_type
533+
471534 tokenizer = get_tokenizer (args .pyt_ckpt_path , trust_remote_code = args .trust_remote_code )
472535
473536 default_padding_side = tokenizer .padding_side
474537 default_pad_token = tokenizer .pad_token
475538 # Left padding usually provides better calibration result.
476539 tokenizer .padding_side = "left"
477540
478- # We only quantize the language model for VLMs other than the type supported above.
479- extracted_lm , extracted_model_type = extract_and_prepare_language_model_from_vl (full_model )
480- if extracted_lm is not None :
481- language_model = extracted_lm
482- model_type = extracted_model_type
483-
484541 if model_type == "phi4mm" :
485542 warnings .warn ("Please set the default input_mode to InputMode.LANGUAGE before quantizing." )
486543
@@ -581,7 +638,12 @@ def mono_quantize(
581638 if args .calib_with_images and is_nemotron_vl_model :
582639 calibrate_loop = create_vlm_calibration_loop (full_model , calib_dataloader )
583640 else :
584- calibrate_loop = create_forward_loop (dataloader = calib_dataloader )
641+ calibrate_loop = create_forward_loop (
642+ dataloader = calib_dataloader ,
643+ allowed_non_tensor_keys = {"base_model_outputs" }
644+ if args .specdec_offline_dataset is not None
645+ else None ,
646+ )
585647
586648 if calibration_only :
587649 language_model = mtq .calibrate (
@@ -736,7 +798,7 @@ def pre_quantize(
736798 full_model : torch .nn .Module ,
737799 model_type : str | None ,
738800 tokenizer : PreTrainedTokenizerBase | None ,
739- calib_dataloader : DataLoader ,
801+ calib_dataloader : DataLoader | None ,
740802 is_nemotron_vl_model : bool ,
741803):
742804 """
@@ -746,7 +808,12 @@ def pre_quantize(
746808 post-quantize generation.
747809
748810 """
811+ # Offline specdec models skip pre-quantize preview (no tokenizer or standard dataloader)
812+ if args .specdec_offline_dataset is not None :
813+ return None , None
814+
749815 # Only run single sample for preview
816+ assert calib_dataloader is not None , "calib_dataloader is required for pre-quantize preview"
750817 preview_input_ids = next (iter (calib_dataloader ))[
751818 "input_features" if model_type == "whisper" else "input_ids"
752819 ][0 :1 ]
@@ -781,21 +848,39 @@ def pre_quantize(
781848def post_quantize (
782849 args : argparse .Namespace ,
783850 full_model : torch .nn .Module ,
851+ language_model : torch .nn .Module ,
784852 model_type : str | None ,
785853 tokenizer : PreTrainedTokenizerBase | None ,
786854 processor : BaseImageProcessor | ProcessorMixin | None ,
787855 preview_input_ids ,
788856 generated_ids_before_ptq ,
789857 is_nemotron_vl_model ,
790858 first_text_speech_dataset ,
859+ default_padding_side ,
860+ default_pad_token ,
861+ calib_dataloader : DataLoader ,
791862):
792863 """
793- Processing after the quantization.
864+ Processing after the quantization, then export .
794865
795- Currently we run one round of generation using the quantized model for a sample prompt,
796- and compare it with pre-quantize generation.
866+ For offline speculative decoding models, skip generation comparison and proceed
867+ directly to export. For standard models, run one round of generation using the
868+ quantized model for a sample prompt and compare it with pre-quantize generation.
797869
798870 """
871+ # Early exit for offline speculative decoding: skip generation comparison and export directly.
872+ # The model's get_dummy_inputs() provides the right input format for the export forward pass.
873+ if args .specdec_offline_dataset is not None :
874+ export_quantized (
875+ args ,
876+ full_model ,
877+ language_model ,
878+ model_type ,
879+ tokenizer ,
880+ default_padding_side ,
881+ default_pad_token ,
882+ )
883+ return
799884
800885 if args .verbose :
801886 try :
@@ -873,6 +958,16 @@ def output_decode(generated_ids, input_shape):
873958 f"example outputs after ptq: { output_decode (generated_ids_after_ptq , preview_input_ids .shape [1 ])} "
874959 )
875960
961+ export_quantized (
962+ args ,
963+ full_model ,
964+ language_model ,
965+ model_type ,
966+ tokenizer ,
967+ default_padding_side ,
968+ default_pad_token ,
969+ )
970+
876971
877972def quantize_main (
878973 args : argparse .Namespace ,
@@ -892,6 +987,13 @@ def quantize_main(
892987 if args .calib_with_images :
893988 print ("Image-text calibration enabled. Using default batch_size=1 for calibration." )
894989 args .batch_size = 1
990+ # Speculative decoding offline model dost not support get_max_batch_size() because of
991+ # the customized dataloader, so we set batch_size to 1 to avoid OOM.
992+ elif args .specdec_offline_dataset is not None :
993+ print (
994+ "Offline speculative decoding calibration enabled. Using default batch_size=1 for calibration."
995+ )
996+ args .batch_size = 1
895997 else :
896998 # Calibration/sparsification will actually take much more memory than regular inference
897999 # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
@@ -1020,22 +1122,17 @@ def quantize_main(
10201122 post_quantize (
10211123 args ,
10221124 full_model ,
1125+ language_model ,
10231126 model_type ,
10241127 tokenizer ,
10251128 processor ,
10261129 preview_input_ids ,
10271130 generated_ids_before_ptq ,
10281131 is_nemotron_vl_model ,
10291132 first_text_speech_dataset ,
1030- )
1031- export_quantized (
1032- args ,
1033- full_model ,
1034- language_model ,
1035- model_type ,
1036- tokenizer ,
10371133 default_padding_side ,
10381134 default_pad_token ,
1135+ calib_dataloader ,
10391136 )
10401137
10411138
@@ -1099,6 +1196,14 @@ def parse_args() -> argparse.Namespace:
10991196 type = str ,
11001197 default = None ,
11011198 )
1199+ parser .add_argument (
1200+ "--specdec_offline_dataset" ,
1201+ help = (
1202+ "If set, the model is a speculative decoding model,"
1203+ "which uses offline dataset for calibration. "
1204+ ),
1205+ default = None ,
1206+ )
11021207 parser .add_argument (
11031208 "--calib_with_images" ,
11041209 action = "store_true" ,
@@ -1256,6 +1361,12 @@ def parse_args() -> argparse.Namespace:
12561361 if args .moe_calib_experts_ratio is not None and not (0.0 < args .moe_calib_experts_ratio <= 1.0 ):
12571362 parser .error ("--moe_calib_experts_ratio must be in the range (0.0, 1.0]." )
12581363
1364+ if args .specdec_offline_dataset is not None and args .sparsity_fmt != "dense" :
1365+ parser .error ("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ)." )
1366+
1367+ if args .specdec_offline_dataset is not None and args .low_memory_mode :
1368+ parser .error ("--specdec_offline_dataset is not compatible with --low_memory_mode." )
1369+
12591370 return args
12601371
12611372
@@ -1311,4 +1422,10 @@ def main(args: argparse.Namespace):
13111422
13121423 args .dataset = args .dataset .split ("," ) if isinstance (args .dataset , str ) else args .dataset
13131424 args .calib_size = [int (num_sample ) for num_sample in args .calib_size .split ("," )]
1425+
1426+ if args .specdec_offline_dataset is not None and len (args .calib_size ) != 1 :
1427+ raise ValueError (
1428+ "--specdec_offline_dataset expects a single --calib value, not a comma-separated list."
1429+ )
1430+
13141431 main (args )
0 commit comments