-
Notifications
You must be signed in to change notification settings - Fork 362
Add support for offline speculative decoding model PTQ #883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9717c73
163f909
60d4c36
77a73c0
98c4042
9117489
b56c7a1
3995a2f
6c7acca
0badde0
278ec14
4514e70
003e0fb
fa3cd6b
971be07
67e2b9b
b0c0454
6318ec5
edf0512
3e7839c
a80a9b5
e012208
05957a1
f62953c
8e986bf
f2cd43f
dca82d3
38d8705
10af262
cceae6a
65a94d6
aef4749
ac3794f
a0f2507
e342300
48436de
91779f7
db545f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| import random | ||
| import time | ||
| import warnings | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
|
|
@@ -64,6 +65,10 @@ | |
| from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration | ||
| from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights | ||
| from modelopt.torch.quantization.utils import is_quantized | ||
| from modelopt.torch.speculative.eagle.utils import ( | ||
| EagleOfflineDataCollator, | ||
| OfflineSupervisedDataset, | ||
| ) | ||
| from modelopt.torch.utils.dataset_utils import ( | ||
| create_forward_loop, | ||
| get_dataset_dataloader, | ||
|
|
@@ -163,17 +168,63 @@ def extract_and_prepare_language_model_from_vl(full_model): | |
| return None, None | ||
|
|
||
|
|
||
| class _DeviceDataLoader: | ||
| """Wrapper around a DataLoader that moves each batch to a target device.""" | ||
|
|
||
| def __init__(self, dataloader: DataLoader, device: torch.device): | ||
| self.dataloader = dataloader | ||
| self.device = device | ||
|
|
||
| def __iter__(self): | ||
| for batch in self.dataloader: | ||
| yield _move_batch_to_device(batch, self.device) | ||
|
|
||
| def __len__(self): | ||
| return len(self.dataloader) | ||
|
|
||
|
|
||
| def _move_batch_to_device(batch: dict, device: torch.device) -> dict: | ||
| """Recursively move all tensors in a batch dict to the given device.""" | ||
|
|
||
| def _to_device(value): | ||
| if isinstance(value, torch.Tensor): | ||
| return value.to(device) | ||
| if isinstance(value, dict): | ||
| return {k: _to_device(v) for k, v in value.items()} | ||
| return value | ||
|
|
||
| return {k: _to_device(v) for k, v in batch.items()} | ||
|
|
||
|
|
||
| def make_calib_dataloader( | ||
| args: argparse.Namespace, | ||
| language_model: torch.nn.Module, | ||
| processor: BaseImageProcessor | ProcessorMixin | None, | ||
| tokenizer: PreTrainedTokenizerBase | None, | ||
| device: torch.device, | ||
| model_type: str | None, | ||
| ) -> tuple[DataLoader, str | None]: | ||
| ) -> tuple[DataLoader | _DeviceDataLoader, str | None]: | ||
| calib_dataloader = None | ||
| first_text_speech_dataset = None | ||
| if args.calib_with_images: | ||
| if args.specdec_offline_dataset is not None: | ||
| offline_data_path = Path(args.specdec_offline_dataset) | ||
| dumped_files = sorted(str(p) for p in offline_data_path.glob("*.pt")) | ||
| if not dumped_files: | ||
| raise ValueError(f"No .pt files found in {args.specdec_offline_dataset}") | ||
| if args.calib_size[0] > 0: | ||
| dumped_files = dumped_files[: args.calib_size[0]] | ||
| dataset = OfflineSupervisedDataset(dumped_files) | ||
| collator = EagleOfflineDataCollator(train_len=args.calib_seq) | ||
| raw_loader = DataLoader( | ||
| dataset, | ||
| batch_size=args.batch_size, | ||
| shuffle=False, | ||
| collate_fn=collator, | ||
| ) | ||
| # Wrap to move batches to the target device; device-transfer logic is kept | ||
| # out of the data collator to avoid interference with dataloader prefetching. | ||
| calib_dataloader = _DeviceDataLoader(raw_loader, device) | ||
| elif args.calib_with_images: | ||
| # VLM image-text calibration path: assume Nemotron VLM dataset by default. | ||
| assert processor is not None, ( | ||
| "Please provide a processor (e.g., AutoProcessor) for image calibration." | ||
|
|
@@ -358,7 +409,7 @@ def forward_step(model, batch): | |
| def load_model(args: argparse.Namespace): | ||
| # If low memory mode is enabled, we compress the model while loading the HF checkpoint. | ||
| calibration_only = False | ||
| if not args.low_memory_mode: | ||
| if args.specdec_offline_dataset is not None or not args.low_memory_mode: | ||
| full_model = get_model( | ||
| args.pyt_ckpt_path, | ||
| args.device, | ||
|
|
@@ -459,28 +510,34 @@ def load_model(args: argparse.Namespace): | |
| language_model = extracted_lm | ||
| model_type = extracted_model_type | ||
| else: | ||
| if args.dataset is None: | ||
| args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] | ||
| warnings.warn( | ||
| "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." | ||
| if args.specdec_offline_dataset is not None: | ||
| language_model = full_model | ||
| else: | ||
| if args.dataset is None: | ||
| args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] | ||
| warnings.warn( | ||
| "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." | ||
| ) | ||
| # Adjust calib_size to match dataset length by extending or truncating as needed | ||
| args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ | ||
| : len(args.dataset) | ||
| ] | ||
|
|
||
| # We only quantize the language model for VLMs other than the type supported above. | ||
| extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl( | ||
| full_model | ||
| ) | ||
| # Adjust calib_size to match dataset length by extending or truncating as needed | ||
| args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ | ||
| : len(args.dataset) | ||
| ] | ||
| if extracted_lm is not None: | ||
| language_model = extracted_lm | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 4586572d — replaced |
||
| model_type = extracted_model_type | ||
|
|
||
| tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) | ||
|
|
||
| default_padding_side = tokenizer.padding_side | ||
| default_pad_token = tokenizer.pad_token | ||
|
ChenhanYu marked this conversation as resolved.
|
||
| # Left padding usually provides better calibration result. | ||
| tokenizer.padding_side = "left" | ||
|
|
||
| # We only quantize the language model for VLMs other than the type supported above. | ||
| extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model) | ||
| if extracted_lm is not None: | ||
| language_model = extracted_lm | ||
| model_type = extracted_model_type | ||
|
|
||
| if model_type == "phi4mm": | ||
| warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.") | ||
|
|
||
|
|
@@ -581,7 +638,12 @@ def mono_quantize( | |
| if args.calib_with_images and is_nemotron_vl_model: | ||
| calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader) | ||
| else: | ||
| calibrate_loop = create_forward_loop(dataloader=calib_dataloader) | ||
| calibrate_loop = create_forward_loop( | ||
| dataloader=calib_dataloader, | ||
| allowed_non_tensor_keys={"base_model_outputs"} | ||
| if args.specdec_offline_dataset is not None | ||
| else None, | ||
| ) | ||
|
|
||
| if calibration_only: | ||
| language_model = mtq.calibrate( | ||
|
|
@@ -736,7 +798,7 @@ def pre_quantize( | |
| full_model: torch.nn.Module, | ||
| model_type: str | None, | ||
| tokenizer: PreTrainedTokenizerBase | None, | ||
| calib_dataloader: DataLoader, | ||
| calib_dataloader: DataLoader | None, | ||
| is_nemotron_vl_model: bool, | ||
| ): | ||
| """ | ||
|
|
@@ -746,7 +808,12 @@ def pre_quantize( | |
| post-quantize generation. | ||
|
|
||
| """ | ||
| # Offline specdec models skip pre-quantize preview (no tokenizer or standard dataloader) | ||
| if args.specdec_offline_dataset is not None: | ||
| return None, None | ||
|
|
||
| # Only run single sample for preview | ||
| assert calib_dataloader is not None, "calib_dataloader is required for pre-quantize preview" | ||
| preview_input_ids = next(iter(calib_dataloader))[ | ||
| "input_features" if model_type == "whisper" else "input_ids" | ||
| ][0:1] | ||
|
|
@@ -781,21 +848,39 @@ def pre_quantize( | |
| def post_quantize( | ||
| args: argparse.Namespace, | ||
| full_model: torch.nn.Module, | ||
| language_model: torch.nn.Module, | ||
| model_type: str | None, | ||
| tokenizer: PreTrainedTokenizerBase | None, | ||
| processor: BaseImageProcessor | ProcessorMixin | None, | ||
| preview_input_ids, | ||
| generated_ids_before_ptq, | ||
| is_nemotron_vl_model, | ||
| first_text_speech_dataset, | ||
| default_padding_side, | ||
| default_pad_token, | ||
| calib_dataloader: DataLoader, | ||
| ): | ||
| """ | ||
| Processing after the quantization. | ||
| Processing after the quantization, then export. | ||
|
|
||
| Currently we run one round of generation using the quantized model for a sample prompt, | ||
| and compare it with pre-quantize generation. | ||
| For offline speculative decoding models, skip generation comparison and proceed | ||
| directly to export. For standard models, run one round of generation using the | ||
| quantized model for a sample prompt and compare it with pre-quantize generation. | ||
|
|
||
| """ | ||
| # Early exit for offline speculative decoding: skip generation comparison and export directly. | ||
| # The model's get_dummy_inputs() provides the right input format for the export forward pass. | ||
| if args.specdec_offline_dataset is not None: | ||
| export_quantized( | ||
| args, | ||
| full_model, | ||
| language_model, | ||
| model_type, | ||
| tokenizer, | ||
| default_padding_side, | ||
| default_pad_token, | ||
| ) | ||
| return | ||
|
|
||
| if args.verbose: | ||
| try: | ||
|
|
@@ -873,6 +958,16 @@ def output_decode(generated_ids, input_shape): | |
| f"example outputs after ptq: {output_decode(generated_ids_after_ptq, preview_input_ids.shape[1])}" | ||
| ) | ||
|
|
||
| export_quantized( | ||
| args, | ||
| full_model, | ||
| language_model, | ||
| model_type, | ||
| tokenizer, | ||
| default_padding_side, | ||
| default_pad_token, | ||
| ) | ||
|
|
||
|
|
||
| def quantize_main( | ||
| args: argparse.Namespace, | ||
|
|
@@ -892,6 +987,13 @@ def quantize_main( | |
| if args.calib_with_images: | ||
| print("Image-text calibration enabled. Using default batch_size=1 for calibration.") | ||
| args.batch_size = 1 | ||
| # Speculative decoding offline model dost not support get_max_batch_size() because of | ||
| # the customized dataloader, so we set batch_size to 1 to avoid OOM. | ||
| elif args.specdec_offline_dataset is not None: | ||
| print( | ||
| "Offline speculative decoding calibration enabled. Using default batch_size=1 for calibration." | ||
| ) | ||
| args.batch_size = 1 | ||
| else: | ||
| # Calibration/sparsification will actually take much more memory than regular inference | ||
| # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio | ||
|
|
@@ -1020,22 +1122,17 @@ def quantize_main( | |
| post_quantize( | ||
| args, | ||
| full_model, | ||
| language_model, | ||
| model_type, | ||
| tokenizer, | ||
| processor, | ||
| preview_input_ids, | ||
| generated_ids_before_ptq, | ||
| is_nemotron_vl_model, | ||
| first_text_speech_dataset, | ||
| ) | ||
| export_quantized( | ||
| args, | ||
| full_model, | ||
| language_model, | ||
| model_type, | ||
| tokenizer, | ||
| default_padding_side, | ||
| default_pad_token, | ||
| calib_dataloader, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -1099,6 +1196,14 @@ def parse_args() -> argparse.Namespace: | |
| type=str, | ||
| default=None, | ||
| ) | ||
| parser.add_argument( | ||
| "--specdec_offline_dataset", | ||
| help=( | ||
| "If set, the model is a speculative decoding model," | ||
| "which uses offline dataset for calibration. " | ||
| ), | ||
| default=None, | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| parser.add_argument( | ||
| "--calib_with_images", | ||
| action="store_true", | ||
|
|
@@ -1256,6 +1361,12 @@ def parse_args() -> argparse.Namespace: | |
| if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): | ||
| parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].") | ||
|
|
||
| if args.specdec_offline_dataset is not None and args.sparsity_fmt != "dense": | ||
| parser.error("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ).") | ||
|
|
||
| if args.specdec_offline_dataset is not None and args.low_memory_mode: | ||
| parser.error("--specdec_offline_dataset is not compatible with --low_memory_mode.") | ||
|
|
||
| return args | ||
|
|
||
|
|
||
|
|
@@ -1311,4 +1422,10 @@ def main(args: argparse.Namespace): | |
|
|
||
| args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset | ||
| args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] | ||
|
|
||
| if args.specdec_offline_dataset is not None and len(args.calib_size) != 1: | ||
| raise ValueError( | ||
| "--specdec_offline_dataset expects a single --calib value, not a comma-separated list." | ||
| ) | ||
|
|
||
| main(args) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will it be cleaner if we separate the the speculative decoding ptq to examples/speculative_decoding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depends. Our Online PTQ will use hf_ptq so I would prefer we leave all ptq code together
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding to the above — the offline PTQ code in
hf_ptq.pyis ~50 lines that reuse the existing PTQ infrastructure (model loading, calibration loop, export). Separating it toexamples/speculative_decoding/would require duplicating all of that shared code, and as mentioned, the upcoming online PTQ path will also live inhf_ptq.py. Keeping it together avoids divergence.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK