Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9717c73
add offline specdec ptq support in hf_ptq.py; export is not supported…
yeyu-nvidia Feb 11, 2026
163f909
debug
yeyu-nvidia Feb 12, 2026
60d4c36
add support for offline specdec export
yeyu-nvidia Feb 12, 2026
77a73c0
formatting
yeyu-nvidia Feb 12, 2026
98c4042
formatting
yeyu-nvidia Mar 9, 2026
9117489
debug
yeyu-nvidia Mar 10, 2026
b56c7a1
debug
yeyu-nvidia Mar 10, 2026
3995a2f
debug
yeyu-nvidia Mar 10, 2026
6c7acca
debug
yeyu-nvidia Mar 10, 2026
0badde0
move offline_specdec_input to EagleExporter
yeyu-nvidia Mar 10, 2026
278ec14
debug
yeyu-nvidia Mar 10, 2026
4514e70
formatting
yeyu-nvidia Mar 10, 2026
003e0fb
Make eagle_utils import lazy to avoid transformers.Trainer at module …
yeyu-nvidia Mar 18, 2026
fa3cd6b
Fix ruff formatting in hf_ptq.py
yeyu-nvidia Mar 18, 2026
971be07
Address PR review feedback for offline speculative decoding PTQ
yeyu-nvidia Mar 19, 2026
67e2b9b
Add unit tests for offline speculative decoding PTQ features
yeyu-nvidia Mar 19, 2026
b0c0454
Add CI test for offline speculative decoding PTQ pipeline
yeyu-nvidia Mar 19, 2026
6318ec5
Add fail-fast validation for --specdec_offline_dataset flag
yeyu-nvidia Mar 19, 2026
edf0512
Fix device placement, input validation, and GPU transfer for offline …
yeyu-nvidia Apr 6, 2026
3e7839c
Add Qwen3-8B EAGLE3 PTQ launcher example with offline dataset
yeyu-nvidia Apr 6, 2026
a80a9b5
refactor: address PR #883 review comments for offline quant
yeyu-nvidia Apr 7, 2026
e012208
refactor: move offline specdec guards into pre_quantize and add low_m…
yeyu-nvidia Apr 7, 2026
05957a1
refactor: replace offline_specdec_input threading with model.get_dumm…
yeyu-nvidia Apr 7, 2026
f62953c
fix: add weights_only=True to torch.load in OfflineSupervisedDataset
yeyu-nvidia Apr 7, 2026
8e986bf
fix: address PR review - revert formatting, improve dataset docstring
yeyu-nvidia Apr 7, 2026
f2cd43f
fix: use device_map="auto" for offline specdec PTQ to support large m…
yeyu-nvidia Apr 7, 2026
dca82d3
refactor: use get_model() for offline specdec PTQ model loading
yeyu-nvidia Apr 7, 2026
38d8705
fix: resolve code-quality CI failures after rebase
yeyu-nvidia Apr 7, 2026
10af262
fix: correct test_sample_size_zero to expect ValueError
yeyu-nvidia Apr 7, 2026
cceae6a
fix: handle models without num_hidden_layers in config
yeyu-nvidia Apr 7, 2026
65a94d6
fix: handle apply_chat_template returning tensor or dict
yeyu-nvidia Apr 8, 2026
aef4749
test: add unit tests for OfflineSupervisedDataset and EagleOfflineDat…
yeyu-nvidia Apr 8, 2026
ac3794f
style: fix import line length in test file
yeyu-nvidia Apr 8, 2026
a0f2507
style: ruff format long assert line
yeyu-nvidia Apr 8, 2026
e342300
fix: update offline PTQ test to use YAML config CLI
yeyu-nvidia Apr 8, 2026
48436de
fix: call hf_ptq.py directly for offline PTQ test
yeyu-nvidia Apr 8, 2026
91779f7
fix: match model dtype in offline EAGLE get_dummy_inputs
yeyu-nvidia Apr 8, 2026
db545f7
fix: include aux_hidden_states in offline EAGLE dummy inputs
yeyu-nvidia Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 146 additions & 29 deletions examples/llm_ptq/hf_ptq.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be cleaner if we separate the the speculative decoding ptq to examples/speculative_decoding?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends. Our Online PTQ will use hf_ptq so I would prefer we leave all ptq code together

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding to the above — the offline PTQ code in hf_ptq.py is ~50 lines that reuse the existing PTQ infrastructure (model loading, calibration loop, export). Separating it to examples/speculative_decoding/ would require duplicating all of that shared code, and as mentioned, the upcoming online PTQ path will also live in hf_ptq.py. Keeping it together avoids divergence.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import random
import time
import warnings
from pathlib import Path
from typing import Any

import numpy as np
Expand Down Expand Up @@ -64,6 +65,10 @@
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
from modelopt.torch.quantization.utils import is_quantized
from modelopt.torch.speculative.eagle.utils import (
EagleOfflineDataCollator,
OfflineSupervisedDataset,
)
from modelopt.torch.utils.dataset_utils import (
create_forward_loop,
get_dataset_dataloader,
Expand Down Expand Up @@ -163,17 +168,63 @@ def extract_and_prepare_language_model_from_vl(full_model):
return None, None


class _DeviceDataLoader:
"""Wrapper around a DataLoader that moves each batch to a target device."""

def __init__(self, dataloader: DataLoader, device: torch.device):
self.dataloader = dataloader
self.device = device

def __iter__(self):
for batch in self.dataloader:
yield _move_batch_to_device(batch, self.device)

def __len__(self):
return len(self.dataloader)


def _move_batch_to_device(batch: dict, device: torch.device) -> dict:
"""Recursively move all tensors in a batch dict to the given device."""

def _to_device(value):
if isinstance(value, torch.Tensor):
return value.to(device)
if isinstance(value, dict):
return {k: _to_device(v) for k, v in value.items()}
return value

return {k: _to_device(v) for k, v in batch.items()}


def make_calib_dataloader(
args: argparse.Namespace,
language_model: torch.nn.Module,
processor: BaseImageProcessor | ProcessorMixin | None,
tokenizer: PreTrainedTokenizerBase | None,
device: torch.device,
model_type: str | None,
) -> tuple[DataLoader, str | None]:
) -> tuple[DataLoader | _DeviceDataLoader, str | None]:
calib_dataloader = None
first_text_speech_dataset = None
if args.calib_with_images:
if args.specdec_offline_dataset is not None:
offline_data_path = Path(args.specdec_offline_dataset)
dumped_files = sorted(str(p) for p in offline_data_path.glob("*.pt"))
if not dumped_files:
raise ValueError(f"No .pt files found in {args.specdec_offline_dataset}")
if args.calib_size[0] > 0:
dumped_files = dumped_files[: args.calib_size[0]]
dataset = OfflineSupervisedDataset(dumped_files)
collator = EagleOfflineDataCollator(train_len=args.calib_seq)
raw_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=collator,
)
# Wrap to move batches to the target device; device-transfer logic is kept
# out of the data collator to avoid interference with dataloader prefetching.
calib_dataloader = _DeviceDataLoader(raw_loader, device)
elif args.calib_with_images:
# VLM image-text calibration path: assume Nemotron VLM dataset by default.
assert processor is not None, (
"Please provide a processor (e.g., AutoProcessor) for image calibration."
Expand Down Expand Up @@ -358,7 +409,7 @@ def forward_step(model, batch):
def load_model(args: argparse.Namespace):
# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
calibration_only = False
if not args.low_memory_mode:
if args.specdec_offline_dataset is not None or not args.low_memory_mode:
full_model = get_model(
args.pyt_ckpt_path,
args.device,
Expand Down Expand Up @@ -459,28 +510,34 @@ def load_model(args: argparse.Namespace):
language_model = extracted_lm
model_type = extracted_model_type
else:
if args.dataset is None:
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
warnings.warn(
"No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
if args.specdec_offline_dataset is not None:
language_model = full_model
else:
if args.dataset is None:
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
warnings.warn(
"No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
)
# Adjust calib_size to match dataset length by extending or truncating as needed
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
: len(args.dataset)
]

# We only quantize the language model for VLMs other than the type supported above.
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
full_model
)
# Adjust calib_size to match dataset length by extending or truncating as needed
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
: len(args.dataset)
]
if extracted_lm is not None:
language_model = extracted_lm
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sys.path.append to import from a sibling example directory is fragile — breaks if __file__ is unset or the relative path changes. Consider making eagle_utils a proper importable module, or at minimum wrap in try/except with a helpful error message.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 4586572d — replaced sys.path.append with importlib.util.spec_from_file_location to load eagle_utils directly by path without polluting sys.path. Since this is an examples directory (not a package), making it a proper module isn't appropriate.

model_type = extracted_model_type

tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)

default_padding_side = tokenizer.padding_side
default_pad_token = tokenizer.pad_token
Comment thread
ChenhanYu marked this conversation as resolved.
# Left padding usually provides better calibration result.
tokenizer.padding_side = "left"

# We only quantize the language model for VLMs other than the type supported above.
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model)
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type

if model_type == "phi4mm":
warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.")

Expand Down Expand Up @@ -581,7 +638,12 @@ def mono_quantize(
if args.calib_with_images and is_nemotron_vl_model:
calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader)
else:
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
calibrate_loop = create_forward_loop(
dataloader=calib_dataloader,
allowed_non_tensor_keys={"base_model_outputs"}
if args.specdec_offline_dataset is not None
else None,
)

if calibration_only:
language_model = mtq.calibrate(
Expand Down Expand Up @@ -736,7 +798,7 @@ def pre_quantize(
full_model: torch.nn.Module,
model_type: str | None,
tokenizer: PreTrainedTokenizerBase | None,
calib_dataloader: DataLoader,
calib_dataloader: DataLoader | None,
is_nemotron_vl_model: bool,
):
"""
Expand All @@ -746,7 +808,12 @@ def pre_quantize(
post-quantize generation.

"""
# Offline specdec models skip pre-quantize preview (no tokenizer or standard dataloader)
if args.specdec_offline_dataset is not None:
return None, None

# Only run single sample for preview
assert calib_dataloader is not None, "calib_dataloader is required for pre-quantize preview"
preview_input_ids = next(iter(calib_dataloader))[
"input_features" if model_type == "whisper" else "input_ids"
][0:1]
Expand Down Expand Up @@ -781,21 +848,39 @@ def pre_quantize(
def post_quantize(
args: argparse.Namespace,
full_model: torch.nn.Module,
language_model: torch.nn.Module,
model_type: str | None,
tokenizer: PreTrainedTokenizerBase | None,
processor: BaseImageProcessor | ProcessorMixin | None,
preview_input_ids,
generated_ids_before_ptq,
is_nemotron_vl_model,
first_text_speech_dataset,
default_padding_side,
default_pad_token,
calib_dataloader: DataLoader,
):
"""
Processing after the quantization.
Processing after the quantization, then export.

Currently we run one round of generation using the quantized model for a sample prompt,
and compare it with pre-quantize generation.
For offline speculative decoding models, skip generation comparison and proceed
directly to export. For standard models, run one round of generation using the
quantized model for a sample prompt and compare it with pre-quantize generation.

"""
# Early exit for offline speculative decoding: skip generation comparison and export directly.
# The model's get_dummy_inputs() provides the right input format for the export forward pass.
if args.specdec_offline_dataset is not None:
export_quantized(
args,
full_model,
language_model,
model_type,
tokenizer,
default_padding_side,
default_pad_token,
)
return

if args.verbose:
try:
Expand Down Expand Up @@ -873,6 +958,16 @@ def output_decode(generated_ids, input_shape):
f"example outputs after ptq: {output_decode(generated_ids_after_ptq, preview_input_ids.shape[1])}"
)

export_quantized(
args,
full_model,
language_model,
model_type,
tokenizer,
default_padding_side,
default_pad_token,
)


def quantize_main(
args: argparse.Namespace,
Expand All @@ -892,6 +987,13 @@ def quantize_main(
if args.calib_with_images:
print("Image-text calibration enabled. Using default batch_size=1 for calibration.")
args.batch_size = 1
# Speculative decoding offline model dost not support get_max_batch_size() because of
# the customized dataloader, so we set batch_size to 1 to avoid OOM.
elif args.specdec_offline_dataset is not None:
print(
"Offline speculative decoding calibration enabled. Using default batch_size=1 for calibration."
)
args.batch_size = 1
else:
# Calibration/sparsification will actually take much more memory than regular inference
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
Expand Down Expand Up @@ -1020,22 +1122,17 @@ def quantize_main(
post_quantize(
args,
full_model,
language_model,
model_type,
tokenizer,
processor,
preview_input_ids,
generated_ids_before_ptq,
is_nemotron_vl_model,
first_text_speech_dataset,
)
export_quantized(
args,
full_model,
language_model,
model_type,
tokenizer,
default_padding_side,
default_pad_token,
calib_dataloader,
)


Expand Down Expand Up @@ -1099,6 +1196,14 @@ def parse_args() -> argparse.Namespace:
type=str,
default=None,
)
parser.add_argument(
"--specdec_offline_dataset",
help=(
"If set, the model is a speculative decoding model,"
"which uses offline dataset for calibration. "
),
default=None,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
parser.add_argument(
"--calib_with_images",
action="store_true",
Expand Down Expand Up @@ -1256,6 +1361,12 @@ def parse_args() -> argparse.Namespace:
if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0):
parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")

if args.specdec_offline_dataset is not None and args.sparsity_fmt != "dense":
parser.error("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ).")

if args.specdec_offline_dataset is not None and args.low_memory_mode:
parser.error("--specdec_offline_dataset is not compatible with --low_memory_mode.")

return args


Expand Down Expand Up @@ -1311,4 +1422,10 @@ def main(args: argparse.Namespace):

args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset
args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")]

if args.specdec_offline_dataset is not None and len(args.calib_size) != 1:
raise ValueError(
"--specdec_offline_dataset expects a single --calib value, not a comma-separated list."
)

main(args)
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,10 @@ async def submit_generates():
continue

# Tokenize and check length
input_ids = tokenizer.apply_chat_template(
tokenized = tokenizer.apply_chat_template(
conversations, return_tensors="pt", add_generation_template=False
)["input_ids"]
)
input_ids = tokenized["input_ids"] if isinstance(tokenized, dict) else tokenized
num_input_tokens = input_ids.shape[1]
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
num_skipped_too_long += 1
Expand Down
Loading
Loading