Skip to content

Commit cccfded

Browse files
yeyu-nvidiaclaude
andauthored
Add support for offline speculative decoding model PTQ (#883)
## What does this PR do? **Type of change:** new feature **Overview:** This PR enables loading in a ModelOpt pretrained offline speculative decoding model (e.g., EAGLE3) and performs PTQ on it and export. ## Usage Follow the speculative_decoding examples to train an offline speculative decoding model first. Then follow the command below to quantize and export it: ```bash python hf_ptq.py --pyt_ckpt_path <dir_of_offline_specdec_model> --specdec_offline_dataset <dir_of_dataset> ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Offline speculative decoding workflow: support loading a local dataset for calibration, generation, and export; new CLI option to specify the offline dataset. * **Improvements** * Export and quantization paths now accept and propagate offline speculative-decoding inputs. * Offline data loading honors a sample-size limit and enforces safe batch sizing for calibration. * **Bug Fixes** * Better handling of model/config mismatches and varied batch types in offline flows. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Ye Yu <yeyu@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8428f06 commit cccfded

File tree

14 files changed

+963
-133
lines changed

14 files changed

+963
-133
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 146 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import random
1919
import time
2020
import warnings
21+
from pathlib import Path
2122
from typing import Any
2223

2324
import numpy as np
@@ -64,6 +65,10 @@
6465
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
6566
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
6667
from modelopt.torch.quantization.utils import is_quantized
68+
from modelopt.torch.speculative.eagle.utils import (
69+
EagleOfflineDataCollator,
70+
OfflineSupervisedDataset,
71+
)
6772
from modelopt.torch.utils.dataset_utils import (
6873
create_forward_loop,
6974
get_dataset_dataloader,
@@ -163,17 +168,63 @@ def extract_and_prepare_language_model_from_vl(full_model):
163168
return None, None
164169

165170

171+
class _DeviceDataLoader:
172+
"""Wrapper around a DataLoader that moves each batch to a target device."""
173+
174+
def __init__(self, dataloader: DataLoader, device: torch.device):
175+
self.dataloader = dataloader
176+
self.device = device
177+
178+
def __iter__(self):
179+
for batch in self.dataloader:
180+
yield _move_batch_to_device(batch, self.device)
181+
182+
def __len__(self):
183+
return len(self.dataloader)
184+
185+
186+
def _move_batch_to_device(batch: dict, device: torch.device) -> dict:
187+
"""Recursively move all tensors in a batch dict to the given device."""
188+
189+
def _to_device(value):
190+
if isinstance(value, torch.Tensor):
191+
return value.to(device)
192+
if isinstance(value, dict):
193+
return {k: _to_device(v) for k, v in value.items()}
194+
return value
195+
196+
return {k: _to_device(v) for k, v in batch.items()}
197+
198+
166199
def make_calib_dataloader(
167200
args: argparse.Namespace,
168201
language_model: torch.nn.Module,
169202
processor: BaseImageProcessor | ProcessorMixin | None,
170203
tokenizer: PreTrainedTokenizerBase | None,
171204
device: torch.device,
172205
model_type: str | None,
173-
) -> tuple[DataLoader, str | None]:
206+
) -> tuple[DataLoader | _DeviceDataLoader, str | None]:
174207
calib_dataloader = None
175208
first_text_speech_dataset = None
176-
if args.calib_with_images:
209+
if args.specdec_offline_dataset is not None:
210+
offline_data_path = Path(args.specdec_offline_dataset)
211+
dumped_files = sorted(str(p) for p in offline_data_path.glob("*.pt"))
212+
if not dumped_files:
213+
raise ValueError(f"No .pt files found in {args.specdec_offline_dataset}")
214+
if args.calib_size[0] > 0:
215+
dumped_files = dumped_files[: args.calib_size[0]]
216+
dataset = OfflineSupervisedDataset(dumped_files)
217+
collator = EagleOfflineDataCollator(train_len=args.calib_seq)
218+
raw_loader = DataLoader(
219+
dataset,
220+
batch_size=args.batch_size,
221+
shuffle=False,
222+
collate_fn=collator,
223+
)
224+
# Wrap to move batches to the target device; device-transfer logic is kept
225+
# out of the data collator to avoid interference with dataloader prefetching.
226+
calib_dataloader = _DeviceDataLoader(raw_loader, device)
227+
elif args.calib_with_images:
177228
# VLM image-text calibration path: assume Nemotron VLM dataset by default.
178229
assert processor is not None, (
179230
"Please provide a processor (e.g., AutoProcessor) for image calibration."
@@ -358,7 +409,7 @@ def forward_step(model, batch):
358409
def load_model(args: argparse.Namespace):
359410
# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
360411
calibration_only = False
361-
if not args.low_memory_mode:
412+
if args.specdec_offline_dataset is not None or not args.low_memory_mode:
362413
full_model = get_model(
363414
args.pyt_ckpt_path,
364415
args.device,
@@ -459,28 +510,34 @@ def load_model(args: argparse.Namespace):
459510
language_model = extracted_lm
460511
model_type = extracted_model_type
461512
else:
462-
if args.dataset is None:
463-
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
464-
warnings.warn(
465-
"No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
513+
if args.specdec_offline_dataset is not None:
514+
language_model = full_model
515+
else:
516+
if args.dataset is None:
517+
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
518+
warnings.warn(
519+
"No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
520+
)
521+
# Adjust calib_size to match dataset length by extending or truncating as needed
522+
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
523+
: len(args.dataset)
524+
]
525+
526+
# We only quantize the language model for VLMs other than the type supported above.
527+
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
528+
full_model
466529
)
467-
# Adjust calib_size to match dataset length by extending or truncating as needed
468-
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
469-
: len(args.dataset)
470-
]
530+
if extracted_lm is not None:
531+
language_model = extracted_lm
532+
model_type = extracted_model_type
533+
471534
tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)
472535

473536
default_padding_side = tokenizer.padding_side
474537
default_pad_token = tokenizer.pad_token
475538
# Left padding usually provides better calibration result.
476539
tokenizer.padding_side = "left"
477540

478-
# We only quantize the language model for VLMs other than the type supported above.
479-
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model)
480-
if extracted_lm is not None:
481-
language_model = extracted_lm
482-
model_type = extracted_model_type
483-
484541
if model_type == "phi4mm":
485542
warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.")
486543

@@ -581,7 +638,12 @@ def mono_quantize(
581638
if args.calib_with_images and is_nemotron_vl_model:
582639
calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader)
583640
else:
584-
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
641+
calibrate_loop = create_forward_loop(
642+
dataloader=calib_dataloader,
643+
allowed_non_tensor_keys={"base_model_outputs"}
644+
if args.specdec_offline_dataset is not None
645+
else None,
646+
)
585647

586648
if calibration_only:
587649
language_model = mtq.calibrate(
@@ -736,7 +798,7 @@ def pre_quantize(
736798
full_model: torch.nn.Module,
737799
model_type: str | None,
738800
tokenizer: PreTrainedTokenizerBase | None,
739-
calib_dataloader: DataLoader,
801+
calib_dataloader: DataLoader | None,
740802
is_nemotron_vl_model: bool,
741803
):
742804
"""
@@ -746,7 +808,12 @@ def pre_quantize(
746808
post-quantize generation.
747809
748810
"""
811+
# Offline specdec models skip pre-quantize preview (no tokenizer or standard dataloader)
812+
if args.specdec_offline_dataset is not None:
813+
return None, None
814+
749815
# Only run single sample for preview
816+
assert calib_dataloader is not None, "calib_dataloader is required for pre-quantize preview"
750817
preview_input_ids = next(iter(calib_dataloader))[
751818
"input_features" if model_type == "whisper" else "input_ids"
752819
][0:1]
@@ -781,21 +848,39 @@ def pre_quantize(
781848
def post_quantize(
782849
args: argparse.Namespace,
783850
full_model: torch.nn.Module,
851+
language_model: torch.nn.Module,
784852
model_type: str | None,
785853
tokenizer: PreTrainedTokenizerBase | None,
786854
processor: BaseImageProcessor | ProcessorMixin | None,
787855
preview_input_ids,
788856
generated_ids_before_ptq,
789857
is_nemotron_vl_model,
790858
first_text_speech_dataset,
859+
default_padding_side,
860+
default_pad_token,
861+
calib_dataloader: DataLoader,
791862
):
792863
"""
793-
Processing after the quantization.
864+
Processing after the quantization, then export.
794865
795-
Currently we run one round of generation using the quantized model for a sample prompt,
796-
and compare it with pre-quantize generation.
866+
For offline speculative decoding models, skip generation comparison and proceed
867+
directly to export. For standard models, run one round of generation using the
868+
quantized model for a sample prompt and compare it with pre-quantize generation.
797869
798870
"""
871+
# Early exit for offline speculative decoding: skip generation comparison and export directly.
872+
# The model's get_dummy_inputs() provides the right input format for the export forward pass.
873+
if args.specdec_offline_dataset is not None:
874+
export_quantized(
875+
args,
876+
full_model,
877+
language_model,
878+
model_type,
879+
tokenizer,
880+
default_padding_side,
881+
default_pad_token,
882+
)
883+
return
799884

800885
if args.verbose:
801886
try:
@@ -873,6 +958,16 @@ def output_decode(generated_ids, input_shape):
873958
f"example outputs after ptq: {output_decode(generated_ids_after_ptq, preview_input_ids.shape[1])}"
874959
)
875960

961+
export_quantized(
962+
args,
963+
full_model,
964+
language_model,
965+
model_type,
966+
tokenizer,
967+
default_padding_side,
968+
default_pad_token,
969+
)
970+
876971

877972
def quantize_main(
878973
args: argparse.Namespace,
@@ -892,6 +987,13 @@ def quantize_main(
892987
if args.calib_with_images:
893988
print("Image-text calibration enabled. Using default batch_size=1 for calibration.")
894989
args.batch_size = 1
990+
# Speculative decoding offline model dost not support get_max_batch_size() because of
991+
# the customized dataloader, so we set batch_size to 1 to avoid OOM.
992+
elif args.specdec_offline_dataset is not None:
993+
print(
994+
"Offline speculative decoding calibration enabled. Using default batch_size=1 for calibration."
995+
)
996+
args.batch_size = 1
895997
else:
896998
# Calibration/sparsification will actually take much more memory than regular inference
897999
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
@@ -1020,22 +1122,17 @@ def quantize_main(
10201122
post_quantize(
10211123
args,
10221124
full_model,
1125+
language_model,
10231126
model_type,
10241127
tokenizer,
10251128
processor,
10261129
preview_input_ids,
10271130
generated_ids_before_ptq,
10281131
is_nemotron_vl_model,
10291132
first_text_speech_dataset,
1030-
)
1031-
export_quantized(
1032-
args,
1033-
full_model,
1034-
language_model,
1035-
model_type,
1036-
tokenizer,
10371133
default_padding_side,
10381134
default_pad_token,
1135+
calib_dataloader,
10391136
)
10401137

10411138

@@ -1099,6 +1196,14 @@ def parse_args() -> argparse.Namespace:
10991196
type=str,
11001197
default=None,
11011198
)
1199+
parser.add_argument(
1200+
"--specdec_offline_dataset",
1201+
help=(
1202+
"If set, the model is a speculative decoding model,"
1203+
"which uses offline dataset for calibration. "
1204+
),
1205+
default=None,
1206+
)
11021207
parser.add_argument(
11031208
"--calib_with_images",
11041209
action="store_true",
@@ -1256,6 +1361,12 @@ def parse_args() -> argparse.Namespace:
12561361
if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0):
12571362
parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")
12581363

1364+
if args.specdec_offline_dataset is not None and args.sparsity_fmt != "dense":
1365+
parser.error("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ).")
1366+
1367+
if args.specdec_offline_dataset is not None and args.low_memory_mode:
1368+
parser.error("--specdec_offline_dataset is not compatible with --low_memory_mode.")
1369+
12591370
return args
12601371

12611372

@@ -1311,4 +1422,10 @@ def main(args: argparse.Namespace):
13111422

13121423
args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset
13131424
args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")]
1425+
1426+
if args.specdec_offline_dataset is not None and len(args.calib_size) != 1:
1427+
raise ValueError(
1428+
"--specdec_offline_dataset expects a single --calib value, not a comma-separated list."
1429+
)
1430+
13141431
main(args)

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,10 @@ async def submit_generates():
206206
continue
207207

208208
# Tokenize and check length
209-
input_ids = tokenizer.apply_chat_template(
209+
tokenized = tokenizer.apply_chat_template(
210210
conversations, return_tensors="pt", add_generation_template=False
211-
)["input_ids"]
211+
)
212+
input_ids = tokenized["input_ids"] if isinstance(tokenized, dict) else tokenized
212213
num_input_tokens = input_ids.shape[1]
213214
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
214215
num_skipped_too_long += 1

0 commit comments

Comments
 (0)