Skip to content

Add support for offline speculative decoding model PTQ#883

Merged
yeyu-nvidia merged 38 commits intomainfrom
yeyu/offline_quant
Apr 8, 2026
Merged

Add support for offline speculative decoding model PTQ#883
yeyu-nvidia merged 38 commits intomainfrom
yeyu/offline_quant

Conversation

@yeyu-nvidia
Copy link
Copy Markdown
Contributor

@yeyu-nvidia yeyu-nvidia commented Feb 12, 2026

What does this PR do?

Type of change:
new feature

Overview:
This PR enables loading in a ModelOpt pretrained offline speculative decoding model (e.g., EAGLE3) and performs PTQ on it and export.

Usage

Follow the speculative_decoding examples to train an offline speculative decoding model first.
Then follow the command below to quantize and export it:

python hf_ptq.py --pyt_ckpt_path <dir_of_offline_specdec_model> --specdec_offline_dataset <dir_of_dataset>

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Offline speculative decoding workflow: support loading a local dataset for calibration, generation, and export; new CLI option to specify the offline dataset.
  • Improvements

    • Export and quantization paths now accept and propagate offline speculative-decoding inputs.
    • Offline data loading honors a sample-size limit and enforces safe batch sizing for calibration.
  • Bug Fixes

    • Better handling of model/config mismatches and varied batch types in offline flows.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Feb 12, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 12, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Introduces offline speculative decoding support via a new --specdec_offline_dataset flag/path and threads an offline_specdec_input parameter through CLI, Eagle data loading, calibration, model loading, exporter plugins, and unified HF export logic.

Changes

Cohort / File(s) Summary
Speculative Decoding CLI & PTQ flow
examples/llm_ptq/hf_ptq.py
Adds --specdec_offline_dataset CLI flag; when provided builds calib DataLoader from Eagle data module (forces batch_size=1), constructs offline_specdec_input, skips pre/post-quantize calibration steps in offline mode, loads full model from pyt_ckpt_path, and updates `export_quantized(..., offline_specdec_input: dict
SpecDec data utilities
examples/speculative_decoding/eagle_utils.py
Truncates discovered .pt dump files to sample_size when sample_size > 0 before creating OfflineSupervisedDataset and collator.
SpecDec training args & main
examples/speculative_decoding/main.py
Removes eval_data_path from DataArguments; adds sample_size: int = -1; deletes layer_types from model.config in offline training path to avoid config/model mismatches.
Speculative decoding export plugins
modelopt/torch/export/plugins/hf_spec_export.py
Adds `offline_specdec_input: dict
Unified HF export pipeline
modelopt/torch/export/unified_export_hf.py
Adds offline_specdec_input to requantize_resmooth_fused_llm_layers and export_speculative_decoding; llm_dummy_forward can call model(**offline_specdec_input) when present; propagates offline input through _export_transformers_checkpoint and export flow.
Dataset batch handling
modelopt/torch/utils/dataset_utils.py
Relaxes _process_batch type checks to allow base_model_outputs values to be tensors, None, or any type and updates the error message accordingly.

Sequence Diagram(s)

sequenceDiagram
    participant CLI as CLI (hf_ptq.py)
    participant Eagle as Eagle Data Module (eagle_utils.py)
    participant Calib as Calib DataLoader
    participant Exporter as SpecDec Exporter (hf_spec_export.py)
    participant Unified as Unified Export (unified_export_hf.py)

    CLI->>Eagle: check --specdec_offline_dataset
    alt offline_specdec_dataset provided
        Eagle->>Calib: build Eagle supervised DataLoader (respect sample_size, batch_size=1)
        Calib->>CLI: produce offline_specdec_input (batch)
        CLI->>Exporter: call export(..., offline_specdec_input)
    else standard path
        CLI->>Calib: use standard calib_dataloader
        CLI->>Exporter: call export(..., offline_specdec_input=None)
    end

    Exporter->>Unified: export(export_dir, dtype, offline_specdec_input?)
    Unified->>Unified: requantize_resmooth_fused_llm_layers(offline_specdec_input)
    alt offline_specdec_input present
        Unified->>Unified: llm_dummy_forward uses model(**offline_specdec_input)
    else
        Unified->>Unified: llm_dummy_forward uses standard dummy inputs
    end
    Unified->>CLI: export complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.82% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Security Anti-Patterns ❓ Inconclusive Security analysis of PR could not be completed due to unavailability of specified files in repository. Verify repository contains files mentioned in PR summary (examples/llm_ptq/hf_ptq.py, examples/speculative_decoding/, modelopt/torch/export/) and correct branch is checked out.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title accurately summarizes the main change: adding support for offline speculative decoding model PTQ, which is clearly reflected in the code changes across multiple files including new CLI arguments, data loading paths, and export functionality.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yeyu/offline_quant

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link
Copy Markdown

codecov bot commented Feb 12, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.46%. Comparing base (3a177f6) to head (db545f7).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #883      +/-   ##
==========================================
+ Coverage   71.65%   76.46%   +4.81%     
==========================================
  Files         353      353              
  Lines       40355    40755     +400     
==========================================
+ Hits        28915    31165    +2250     
+ Misses      11440     9590    -1850     
Flag Coverage Δ
examples 44.44% <96.42%> (+1.13%) ⬆️
gpu 56.93% <37.50%> (+9.62%) ⬆️
unit 55.17% <89.28%> (+0.13%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/offline_quant branch 2 times, most recently from b1bc618 to 9811c67 Compare March 10, 2026 20:37
@yeyu-nvidia yeyu-nvidia marked this pull request as ready for review March 10, 2026 20:37
@yeyu-nvidia yeyu-nvidia requested review from a team as code owners March 10, 2026 20:37
@yeyu-nvidia yeyu-nvidia requested review from h-guo18 and meenchen March 10, 2026 20:37
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)

173-202: ⚠️ Potential issue | 🔴 Critical

Expose trust_remote_code as a caller-configurable parameter defaulting to False.

This example hardcodes trust_remote_code=True across four transformer loader calls (lines 176, 178, 187, 201) without inline justification. Per security guidelines, this flag must be configurable by the caller, not hardcoded, to avoid executing arbitrary Python shipped with untrusted checkpoints. Add a parameter to the train() function to control this behavior with a safe default.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 173 - 202, The example
currently hardcodes trust_remote_code=True in calls to
load_vlm_or_llm_with_kwargs and transformers.AutoTokenizer.from_pretrained; add
a new parameter trust_remote_code: bool = False to the train() function
signature (or the main caller entry) and use that parameter in place of the
hardcoded True in all four calls (the calls to load_vlm_or_llm_with_kwargs and
transformers.AutoTokenizer.from_pretrained shown in the snippet). Ensure the
default remains False, propagate the parameter through any helper calls if
needed, and update any caller sites to pass the desired value when invoking
train().
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 330-335: The branch that handles args.specdec_offline_dataset
calls AutoModelForCausalLM.from_pretrained without applying device placement or
attention implementation, causing full_model to load on CPU and skip attn
optimizations; change this branch to mirror the other paths by either calling
the existing get_model(...) helper (used in the elif clause) with the same
arguments or by passing the same model_kwargs (including args.device and
attn_implementation) into from_pretrained so full_model respects device
placement and attention implementation settings; update the code around the
full_model creation (the AutoModelForCausalLM.from_pretrained call and
surrounding condition for args.specdec_offline_dataset) to reuse get_model or to
include args.device and args.attn_implementation in model_kwargs.

In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 161-162: The current check treats sample_size==0 as “use all”;
update the logic around data_args.sample_size and dumped_files so only -1 means
“use all”: keep the existing slice when data_args.sample_size > 0, explicitly
set dumped_files to an empty list when data_args.sample_size == 0, and leave
dumped_files unchanged when data_args.sample_size == -1; locate and modify the
conditional around data_args.sample_size in eagle_utils.py to implement those
three branches (referencing data_args.sample_size and dumped_files).

In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 382-384: The offline_specdec_input dict must be moved to the
model's device before calling model(**offline_specdec_input) to avoid CPU/CUDA
mismatches; obtain the model device (e.g., from next(model.parameters()) or
model.device), recursively traverse offline_specdec_input and call .to(device)
on every tensor (handling nested dicts/lists/tuples), then invoke
model(**offline_specdec_input). Ensure you update the branch handling
offline_specdec_input (the block that currently does
model(**offline_specdec_input)) to perform this device transfer first.

In `@modelopt/torch/utils/dataset_utils.py`:
- Around line 522-526: The assertion allowing arbitrary base_model_outputs is
unsafe because _process_batch expects tensor-like behavior; change the check and
logic in _process_batch so that batch_size is derived from an actual tensor
entry (scan batch_data for the first value where torch.is_tensor(value) and use
its shape[0]) and either (A) tighten the assertion to require base_model_outputs
to be tensor-like or (B) add explicit handling when splitting: when recursively
splitting for OOM, if base_model_outputs is a tensor use tensor slicing, if it
is a sequence/type that supports indexing slice per-sub-batch safely, and if it
cannot be sliced raise a clear error; update the initial assert to reflect the
new requirement (or remove base_model_outputs exception) and ensure the
split/copy logic around base_model_outputs in _process_batch (the places reading
shape[0] and doing slices) uses the derived batch_size and safe slicing/copying
logic.

---

Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 173-202: The example currently hardcodes trust_remote_code=True in
calls to load_vlm_or_llm_with_kwargs and
transformers.AutoTokenizer.from_pretrained; add a new parameter
trust_remote_code: bool = False to the train() function signature (or the main
caller entry) and use that parameter in place of the hardcoded True in all four
calls (the calls to load_vlm_or_llm_with_kwargs and
transformers.AutoTokenizer.from_pretrained shown in the snippet). Ensure the
default remains False, propagate the parameter through any helper calls if
needed, and update any caller sites to pass the desired value when invoking
train().

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 890a12b7-8066-459d-bbd1-e02d93a548cd

📥 Commits

Reviewing files that changed from the base of the PR and between fff65b0 and 9811c678ffa41cd546c6bfe97219926ce001a6b7.

📒 Files selected for processing (6)
  • examples/llm_ptq/hf_ptq.py
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/utils/dataset_utils.py

Comment thread examples/llm_ptq/hf_ptq.py Outdated
Comment on lines +161 to +162
if data_args.sample_size > 0:
dumped_files = dumped_files[: data_args.sample_size]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

sample_size=0 currently expands to the full offline dataset.

Because Line 161 only slices on > 0, the new PTQ path will treat --calib_size 0 as “use every dumped sample”, even though examples/speculative_decoding/main.py documents -1 as the only “use all” value. That can silently turn a tiny calibration run into a full-dataset pass and blow up runtime/memory.

Suggested fix
-        if data_args.sample_size > 0:
+        if data_args.sample_size == 0 or data_args.sample_size < -1:
+            raise ValueError("sample_size must be -1 or a positive integer")
+        if data_args.sample_size > 0:
             dumped_files = dumped_files[: data_args.sample_size]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if data_args.sample_size > 0:
dumped_files = dumped_files[: data_args.sample_size]
if data_args.sample_size == 0 or data_args.sample_size < -1:
raise ValueError("sample_size must be -1 or a positive integer")
if data_args.sample_size > 0:
dumped_files = dumped_files[: data_args.sample_size]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/eagle_utils.py` around lines 161 - 162, The
current check treats sample_size==0 as “use all”; update the logic around
data_args.sample_size and dumped_files so only -1 means “use all”: keep the
existing slice when data_args.sample_size > 0, explicitly set dumped_files to an
empty list when data_args.sample_size == 0, and leave dumped_files unchanged
when data_args.sample_size == -1; locate and modify the conditional around
data_args.sample_size in eagle_utils.py to implement those three branches
(referencing data_args.sample_size and dumped_files).

Comment thread modelopt/torch/export/unified_export_hf.py Outdated
Comment thread modelopt/torch/utils/dataset_utils.py Outdated
Comment on lines +522 to +526
assert all(
torch.is_tensor(data) or data is None or key == "base_model_outputs"
for key, data in batch_data.items()
), (
"batch_data values must be tensors or None, except for 'base_model_outputs' which can be any type."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Allowing arbitrary base_model_outputs is inconsistent with the rest of _process_batch.

This assertion now permits any type for base_model_outputs, but the function still assumes every non-None entry has shape[0] and supports tensor-style slicing on Lines 529, 541, 568, and 569. That means the new offline/specdec path will still fail as soon as _process_batch needs to infer batch size or split on OOM unless base_model_outputs happens to be sliceable already.

Either keep base_model_outputs restricted to tensor-like values here, or add explicit handling that (1) derives batch_size from a real tensor key and (2) slices/copies base_model_outputs safely during recursive splitting.

Possible fix
+def _get_batch_size(batch_data):
+    tensor_value = next((data for data in batch_data.values() if torch.is_tensor(data)), None)
+    if tensor_value is None:
+        raise TypeError("batch_data must contain at least one tensor value to infer batch size.")
+    return tensor_value.shape[0]
+
+
+def _slice_batch_value(value, start_idx, end_idx):
+    if value is None:
+        return None
+    if torch.is_tensor(value):
+        return value[start_idx:end_idx, ...]
+    if isinstance(value, list):
+        return value[start_idx:end_idx]
+    if isinstance(value, tuple):
+        return value[start_idx:end_idx]
+    raise TypeError(
+        f"Unsupported batch value type for recursive splitting: {type(value).__name__}"
+    )
+
+
 def _process_batch(batch_data, infer_method, max_working_batch_size=None):
     ...
-    batch_size = batch_data[next(iter(batch_data.keys()))].shape[0]
+    batch_size = _get_batch_size(batch_data)
     ...
             split_data = {}
             for key in batch_data:
-                if batch_data[key] is None:
-                    split_data[key] = None
-                else:
-                    split_data[key] = batch_data[key][i:end_idx, ...]
+                split_data[key] = _slice_batch_value(batch_data[key], i, end_idx)
     ...
-    split_data_1 = {key: batch_data[key][:mid, ...] for key in batch_data}
-    split_data_2 = {key: batch_data[key][mid:, ...] for key in batch_data}
+    split_data_1 = {key: _slice_batch_value(batch_data[key], 0, mid) for key in batch_data}
+    split_data_2 = {
+        key: _slice_batch_value(batch_data[key], mid, batch_size) for key in batch_data
+    }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/dataset_utils.py` around lines 522 - 526, The assertion
allowing arbitrary base_model_outputs is unsafe because _process_batch expects
tensor-like behavior; change the check and logic in _process_batch so that
batch_size is derived from an actual tensor entry (scan batch_data for the first
value where torch.is_tensor(value) and use its shape[0]) and either (A) tighten
the assertion to require base_model_outputs to be tensor-like or (B) add
explicit handling when splitting: when recursively splitting for OOM, if
base_model_outputs is a tensor use tensor slicing, if it is a sequence/type that
supports indexing slice per-sub-batch safely, and if it cannot be sliced raise a
clear error; update the initial assert to reflect the new requirement (or remove
base_model_outputs exception) and ensure the split/copy logic around
base_model_outputs in _process_batch (the places reading shape[0] and doing
slices) uses the derived batch_size and safe slicing/copying logic.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
examples/llm_ptq/hf_ptq.py (2)

899-919: Typo in data_args field name and potential unused attribute.

Line 904 contains devlazy_preprocessice=True which appears to be a typo (possibly lazy_preprocess?). Looking at the make_eagle_supervised_data_module function signature in the relevant snippets, this field does not appear to be used. Consider removing this line or correcting the field name if it was intended to be used.

         data_args = argparse.Namespace(
             vlm_processor=None,
             vlm_img_dir=None,
             offline_data_path=args.specdec_offline_dataset,
-            devlazy_preprocessice=True,
             sample_size=args.calib_size[0],
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 899 - 919, The data_args Namespace
passed when args.specdec_offline_dataset is set contains a typoed/unused key
devlazy_preprocessice=True; locate the Namespace construction near the
args.specdec_offline_dataset branch and either remove this bogus attribute or
rename it to the actual expected field (e.g., lazy_preprocess or the exact
parameter name used by make_eagle_supervised_data_module) after confirming the
function signature of make_eagle_supervised_data_module(tokenizer, data_args,
...); ensure only valid keys are passed so data_args matches what
make_eagle_supervised_data_module expects.

1023-1025: Consider eagerly fetching the calibration batch before dataloader iteration.

Using next(iter(calib_dataloader), None) after the dataloader has potentially been iterated during calibration may yield different results or cause issues with multi-worker dataloaders. Consider capturing the first batch earlier (e.g., at line 915 after creating the dataloader) and reusing it for export:

         calib_dataloader = DataLoader(
             data_module["train_dataset"],
             batch_size=args.batch_size,
             shuffle=False,
             collate_fn=data_module["data_collator"],
         )
+        # Capture first batch for export (used in fusion/resmoothing forward pass)
+        offline_specdec_batch = next(iter(calib_dataloader), None)

Then use offline_specdec_batch at line 1023 instead of calling next(iter(...)) again.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 1023 - 1025, The code currently
fetches the calibration batch inline with offline_specdec_input using
next(iter(calib_dataloader), None), which can produce inconsistent results for
multi-worker or previously-iterated dataloaders; instead, capture the first
calibration batch once immediately after creating calib_dataloader (e.g., assign
to a variable named offline_specdec_batch right after calib_dataloader is
constructed) and then replace the inline next(iter(...)) usage (the
offline_specdec_input argument) with that offline_specdec_batch variable when
calling the export logic; ensure the variable is defined even when
args.specdec_offline_dataset is None (set to None) so callers like
offline_specdec_input consistently reference the pre-captured batch.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 899-919: The data_args Namespace passed when
args.specdec_offline_dataset is set contains a typoed/unused key
devlazy_preprocessice=True; locate the Namespace construction near the
args.specdec_offline_dataset branch and either remove this bogus attribute or
rename it to the actual expected field (e.g., lazy_preprocess or the exact
parameter name used by make_eagle_supervised_data_module) after confirming the
function signature of make_eagle_supervised_data_module(tokenizer, data_args,
...); ensure only valid keys are passed so data_args matches what
make_eagle_supervised_data_module expects.
- Around line 1023-1025: The code currently fetches the calibration batch inline
with offline_specdec_input using next(iter(calib_dataloader), None), which can
produce inconsistent results for multi-worker or previously-iterated
dataloaders; instead, capture the first calibration batch once immediately after
creating calib_dataloader (e.g., assign to a variable named
offline_specdec_batch right after calib_dataloader is constructed) and then
replace the inline next(iter(...)) usage (the offline_specdec_input argument)
with that offline_specdec_batch variable when calling the export logic; ensure
the variable is defined even when args.specdec_offline_dataset is None (set to
None) so callers like offline_specdec_input consistently reference the
pre-captured batch.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 599e400d-1d7b-4176-b7dd-b813c0cd7b46

📥 Commits

Reviewing files that changed from the base of the PR and between 9811c678ffa41cd546c6bfe97219926ce001a6b7 and f7b8a7d.

📒 Files selected for processing (6)
  • examples/llm_ptq/hf_ptq.py
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/utils/dataset_utils.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
  • modelopt/torch/export/plugins/hf_spec_export.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)

173-201: ⚠️ Potential issue | 🔴 Critical

Remove hardcoded trust_remote_code=True from model and tokenizer loading.

This code unconditionally enables execution of repository-supplied Python when loading from user-provided checkpoint paths. Thread trust_remote_code as a configurable parameter and default it to False. Apply this across all four model/tokenizer loads in this section (lines 176, 178, 187, 201) and update load_vlm_or_llm_with_kwargs in the utility module to accept and propagate this parameter instead of hardcoding it.

Per security coding guidelines: "Do not hardcode trust_remote_code=True when loading Hugging Face Transformers models. Let the caller decide via a parameter; default to False."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 173 - 201, The code
currently hardcodes trust_remote_code=True when loading models and tokenizers;
update all four calls in this block (the two load_vlm_or_llm_with_kwargs calls
and the two transformers.AutoTokenizer.from_pretrained calls) to accept a
trust_remote_code parameter instead of the literal True, default that parameter
to False in the function signatures/call sites, and pass it through; also modify
load_vlm_or_llm_with_kwargs in the utility module to accept a trust_remote_code
argument and propagate it into whatever HF loading logic it uses (instead of
hardcoding True) so callers can opt-in to trust_remote_code when needed.
♻️ Duplicate comments (3)
examples/llm_ptq/hf_ptq.py (1)

353-357: ⚠️ Potential issue | 🟠 Major

Reuse get_model() for the offline load branch.

This path skips the normal loader settings—device placement, sequential device map, GPU memory cap, and attn_implementation—so the offline PTQ flow falls back to default from_pretrained() behavior while the rest of the script assumes the standard loader semantics.

Suggested fix
     if args.specdec_offline_dataset is not None:
-        full_model = AutoModelForCausalLM.from_pretrained(
-            args.pyt_ckpt_path,
-            trust_remote_code=args.trust_remote_code,
-        )
+        full_model = get_model(
+            args.pyt_ckpt_path,
+            args.device,
+            gpu_mem_percentage=args.gpu_max_mem_percentage,
+            trust_remote_code=args.trust_remote_code,
+            use_seq_device_map=args.use_seq_device_map,
+            attn_implementation=args.attn_implementation,
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 353 - 357, The offline branch that
currently calls AutoModelForCausalLM.from_pretrained(...) bypasses the script's
standard loader settings (device placement, sequential_device_map, GPU memory
cap, attn_implementation). Replace that direct from_pretrained call with the
common loader helper get_model(...) so the args.specdec_offline_dataset path
uses the same loader semantics; specifically, load into full_model by invoking
get_model with the checkpoint path and the same trust_remote_code and
loader-related arguments used elsewhere in the script to preserve device map and
memory caps.
modelopt/torch/utils/dataset_utils.py (1)

522-526: ⚠️ Potential issue | 🔴 Critical

Still unresolved: base_model_outputs is not actually safe here.

This assert now accepts any base_model_outputs, but Line 529 still reads .shape[0] from the first batch value and Lines 541/568-569 still slice every non-None entry like a tensor. A dict/list/tuple base_model_outputs will still break as soon as _process_batch() infers batch size or recursively splits on OOM; either keep it tensor-like here or add real batch-size inference plus safe slicing for non-tensor values.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/dataset_utils.py` around lines 522 - 526, The assert
allowing any type for "base_model_outputs" is unsafe because _process_batch()
still infers batch size via .shape[0] and slices entries as tensors; either
require base_model_outputs to be tensor-like or implement proper batch-size
inference and safe per-item slicing for non-tensors. Update the validation in
dataset_utils.py (the assert on batch_data and any callers of _process_batch) to
enforce that batch_data["base_model_outputs"] is a torch.Tensor (or None), or
alternatively add logic in _process_batch()/any slicing helpers to (1) derive
batch_size from the first torch.Tensor found, (2) validate that
base_model_outputs can be indexed/sliced for that batch_size (e.g., supports
__len__ and __getitem__ semantics), and (3) branch when slicing
base_model_outputs to use safe sequence/tuple/dict-aware splitting instead of
tensor slicing so recursive splits and OOM-handling won't raise AttributeError.
modelopt/torch/export/unified_export_hf.py (1)

383-385: ⚠️ Potential issue | 🟠 Major

Move offline_specdec_input onto the model device before this forward.

examples/llm_ptq/hf_ptq.py passes the first calibration batch straight through as offline_specdec_input, so this branch is receiving DataLoader output rather than model-placed tensors. Calling model(**offline_specdec_input) against a CUDA model will fail unless you recursively .to(target_device) every nested tensor first.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 383 - 385, The
offline_specdec_input dict/list may contain CPU tensors or nested structures
that must be moved to the model device before calling
model(**offline_specdec_input); implement a small recursive helper (e.g.,
move_to_device(obj, device)) that detects torch.Tensor and calls .to(device),
and recurses into dict, list, tuple (preserving types), then obtain the model
device (e.g., device = next(model.parameters()).device or model.device if
available) and call offline_specdec_input =
move_to_device(offline_specdec_input, device) right before
model(**offline_specdec_input) so all nested tensors are placed on the correct
device.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 924-940: The offline calibration path leaves tensors on CPU
causing a device mismatch in the calibration loop; fix by ensuring batches are
moved to the target device before model forward: either add a device parameter
to EagleOfflineDataCollator and have EagleOfflineDataCollator.__call__ transfer
all tensors (e.g., via .to(device)) and/or ensure
OfflineSupervisedDataset.__getitem__ returns tensors on the correct device, or
wrap the DataLoader produced for specdec_offline_dataset with a thin iterator
that moves each batch to device before it is consumed by
_forward_loop/_process_batch; update the dataloader creation branch that
constructs calib_dataloader so batches are guaranteed on the same device as the
model.

In `@examples/speculative_decoding/main.py`:
- Around line 87-90: The sample_size dataclass field currently allows 0 and
values < -1 which the truncation logic in eagle_utils.py ignores; add validation
after argument parsing (or in the dataclass __post_init__ if present) to ensure
args.sample_size is either -1 (use all) or a positive integer (>0), and raise a
clear ValueError (or argparse error) otherwise so accidental full-dataset runs
are prevented; reference the sample_size field in main.py and the truncation
logic in eagle_utils.py when locating where to add the check.

---

Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 173-201: The code currently hardcodes trust_remote_code=True when
loading models and tokenizers; update all four calls in this block (the two
load_vlm_or_llm_with_kwargs calls and the two
transformers.AutoTokenizer.from_pretrained calls) to accept a trust_remote_code
parameter instead of the literal True, default that parameter to False in the
function signatures/call sites, and pass it through; also modify
load_vlm_or_llm_with_kwargs in the utility module to accept a trust_remote_code
argument and propagate it into whatever HF loading logic it uses (instead of
hardcoding True) so callers can opt-in to trust_remote_code when needed.

---

Duplicate comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 353-357: The offline branch that currently calls
AutoModelForCausalLM.from_pretrained(...) bypasses the script's standard loader
settings (device placement, sequential_device_map, GPU memory cap,
attn_implementation). Replace that direct from_pretrained call with the common
loader helper get_model(...) so the args.specdec_offline_dataset path uses the
same loader semantics; specifically, load into full_model by invoking get_model
with the checkpoint path and the same trust_remote_code and loader-related
arguments used elsewhere in the script to preserve device map and memory caps.

In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 383-385: The offline_specdec_input dict/list may contain CPU
tensors or nested structures that must be moved to the model device before
calling model(**offline_specdec_input); implement a small recursive helper
(e.g., move_to_device(obj, device)) that detects torch.Tensor and calls
.to(device), and recurses into dict, list, tuple (preserving types), then obtain
the model device (e.g., device = next(model.parameters()).device or model.device
if available) and call offline_specdec_input =
move_to_device(offline_specdec_input, device) right before
model(**offline_specdec_input) so all nested tensors are placed on the correct
device.

In `@modelopt/torch/utils/dataset_utils.py`:
- Around line 522-526: The assert allowing any type for "base_model_outputs" is
unsafe because _process_batch() still infers batch size via .shape[0] and slices
entries as tensors; either require base_model_outputs to be tensor-like or
implement proper batch-size inference and safe per-item slicing for non-tensors.
Update the validation in dataset_utils.py (the assert on batch_data and any
callers of _process_batch) to enforce that batch_data["base_model_outputs"] is a
torch.Tensor (or None), or alternatively add logic in _process_batch()/any
slicing helpers to (1) derive batch_size from the first torch.Tensor found, (2)
validate that base_model_outputs can be indexed/sliced for that batch_size
(e.g., supports __len__ and __getitem__ semantics), and (3) branch when slicing
base_model_outputs to use safe sequence/tuple/dict-aware splitting instead of
tensor slicing so recursive splits and OOM-handling won't raise AttributeError.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 021ade3f-816c-45de-8bf6-2d4a5f1f4ddc

📥 Commits

Reviewing files that changed from the base of the PR and between f7b8a7d and 79c8d96d9fedc036e4962a30c3983921846a44c4.

📒 Files selected for processing (6)
  • examples/llm_ptq/hf_ptq.py
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/utils/dataset_utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/export/plugins/hf_spec_export.py
  • examples/speculative_decoding/eagle_utils.py

Comment thread examples/llm_ptq/hf_ptq.py Outdated
Comment thread examples/speculative_decoding/main.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
examples/llm_ptq/hf_ptq.py (2)

351-355: ⚠️ Potential issue | 🟠 Major

Don't bypass the normal model-loading path for offline PTQ.

Unlike the next branch, this one drops args.device, args.gpu_max_mem_percentage, args.use_seq_device_map, and args.attn_implementation. That leaves offline PTQ on the default CPU/default-attention load path and can also skip any checkpoint-specific handling centralized in get_model().

Suggested fix
-    if args.specdec_offline_dataset is not None:
-        full_model = AutoModelForCausalLM.from_pretrained(
-            args.pyt_ckpt_path,
-            trust_remote_code=args.trust_remote_code,
-        )
+    if args.specdec_offline_dataset is not None:
+        full_model = get_model(
+            args.pyt_ckpt_path,
+            args.device,
+            gpu_mem_percentage=args.gpu_max_mem_percentage,
+            trust_remote_code=args.trust_remote_code,
+            use_seq_device_map=args.use_seq_device_map,
+            attn_implementation=args.attn_implementation,
+        )
#!/bin/bash
set -euo pipefail

sed -n '348,395p' examples/llm_ptq/hf_ptq.py
echo "---- get_model helper ----"
rg -n "def get_model\\(" examples/llm_ptq -A40 -B5
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 351 - 355, The branch that handles
args.specdec_offline_dataset is bypassing your centralized loader by calling
AutoModelForCausalLM.from_pretrained directly; instead call the existing
get_model(...) helper with the same checkpoint (args.pyt_ckpt_path) so
checkpoint-specific logic and runtime options (args.device,
args.gpu_max_mem_percentage, args.use_seq_device_map, args.attn_implementation)
are honored; replace the direct AutoModelForCausalLM.from_pretrained use in the
specdec_offline_dataset branch with a call to get_model(...) passing the same
args (including trust_remote_code) so offline PTQ follows the normal
model-loading path.

922-940: ⚠️ Potential issue | 🟠 Major

Keep offline SpecDec batches on the same device as the model.

This branch replaces get_dataset_dataloader(..., device=device) with a raw DataLoader, and neither the local loader nor make_eagle_supervised_data_module() receives a device. The same loader is later used to supply offline_specdec_input for export, so please add an explicit transfer step here (or pass device into the offline collator/helper) before calibration/export.

#!/bin/bash
set -euo pipefail

sed -n '920,960p' examples/llm_ptq/hf_ptq.py
echo "---- offline helper ----"
sed -n '129,168p' examples/speculative_decoding/eagle_utils.py
echo "---- dataset utils device handling ----"
rg -n "def get_dataset_dataloader|def create_forward_loop|def _process_batch" modelopt/torch/utils/dataset_utils.py -A25 -B5
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 922 - 940, The offline SpecDec
branch constructs a raw DataLoader without ensuring batches live on the model
device, so ensure calib_dataloader yields tensors on device by either passing
device into make_eagle_supervised_data_module/collator or by moving each batch
to device after loading; specifically update the code around
make_eagle_supervised_data_module / data_module and calib_dataloader so that the
produced batches (data_module["train_dataset"] / data_module["data_collator"])
are created with the device or, if that’s not possible, wrap the DataLoader
output so offline_specdec_input is explicitly transferred to device (e.g., call
.to(device) or a small helper to move all tensors in the batch) before
calibration/export and before any use of offline_specdec_input.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1121-1128: The CLI accepts --specdec_offline_dataset but main()
only uses it in the dense/PTQ path while sparsity_main() ignores it and the
offline path expects a single calib_size; add upfront validation in the argument
handling (after converting args.calib_size in __main__ / where args are
normalized) to reject unsupported combinations: if args.specdec_offline_dataset
is not None ensure the run will use the dense/PTQ path and that
len(args.calib_size) == 1, otherwise raise a ValueError with a clear message;
update or add the same check before dispatching to sparsity_main() to fail fast
when the user supplied incompatible flags (reference symbols: main(),
sparsity_main(), args.specdec_offline_dataset, args.calib_size).

---

Duplicate comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 351-355: The branch that handles args.specdec_offline_dataset is
bypassing your centralized loader by calling
AutoModelForCausalLM.from_pretrained directly; instead call the existing
get_model(...) helper with the same checkpoint (args.pyt_ckpt_path) so
checkpoint-specific logic and runtime options (args.device,
args.gpu_max_mem_percentage, args.use_seq_device_map, args.attn_implementation)
are honored; replace the direct AutoModelForCausalLM.from_pretrained use in the
specdec_offline_dataset branch with a call to get_model(...) passing the same
args (including trust_remote_code) so offline PTQ follows the normal
model-loading path.
- Around line 922-940: The offline SpecDec branch constructs a raw DataLoader
without ensuring batches live on the model device, so ensure calib_dataloader
yields tensors on device by either passing device into
make_eagle_supervised_data_module/collator or by moving each batch to device
after loading; specifically update the code around
make_eagle_supervised_data_module / data_module and calib_dataloader so that the
produced batches (data_module["train_dataset"] / data_module["data_collator"])
are created with the device or, if that’s not possible, wrap the DataLoader
output so offline_specdec_input is explicitly transferred to device (e.g., call
.to(device) or a small helper to move all tensors in the batch) before
calibration/export and before any use of offline_specdec_input.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d17dbf99-f354-4423-bfa0-b6dc1def500e

📥 Commits

Reviewing files that changed from the base of the PR and between 79c8d96d9fedc036e4962a30c3983921846a44c4 and d544c0a33909beb3ec59ef5ebe98f47c5f9e3d2a.

📒 Files selected for processing (1)
  • examples/llm_ptq/hf_ptq.py

Comment thread examples/llm_ptq/hf_ptq.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
examples/llm_ptq/hf_ptq.py (3)

921-940: ⚠️ Potential issue | 🟠 Major

Offline calibration path does not ensure batch tensors are on the model device.

The offline branch builds a DataLoader directly from Eagle collator but does not transfer batches to device before calibration forward steps.

Suggested fix
         data_module = make_eagle_supervised_data_module(
             tokenizer, data_args, train_len=args.calib_seq
         )
+        def _to_device(x):
+            if torch.is_tensor(x):
+                return x.to(device)
+            if isinstance(x, dict):
+                return {k: _to_device(v) for k, v in x.items()}
+            if isinstance(x, (list, tuple)):
+                return type(x)(_to_device(v) for v in x)
+            return x
+
+        def _collate_and_to_device(samples):
+            return _to_device(data_module["data_collator"](samples))
+
         calib_dataloader = DataLoader(
             data_module["train_dataset"],
             batch_size=args.batch_size,
             shuffle=False,
-            collate_fn=data_module["data_collator"],
+            collate_fn=_collate_and_to_device,
         )
#!/bin/bash
# Verify offline dataloader path and device transfer behavior.
sed -n '910,955p' examples/llm_ptq/hf_ptq.py

# Verify Eagle offline collator behavior and whether it performs device moves.
sed -n '1,260p' examples/speculative_decoding/eagle_utils.py | sed -n '120,240p'
rg -n "\\.to\\(|EagleOfflineDataCollator|OfflineSupervisedDataset" examples/speculative_decoding/eagle_utils.py -A8 -B4

# Verify forward loop/batch processing does not implicitly move tensors.
rg -n "def create_forward_loop|def _forward_loop|def _process_batch" modelopt/torch/utils/dataset_utils.py -A25 -B5
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 921 - 940, The offline calibration
branch that constructs calib_dataloader using make_eagle_supervised_data_module
and data_module["data_collator"] does not ensure tensors are moved to the model
device; update the calibration data path so each batch is transferred to device
before any forward/calibration call (either by wrapping the DataLoader with a
small helper that maps batches -> {k: v.to(device) for tensors} or by modifying
the collator returned by data_module["data_collator"] to move tensors to
device), ensuring changes touch the code that creates calib_dataloader and the
forward calibration loop that consumes it (references:
args.specdec_offline_dataset, make_eagle_supervised_data_module,
calib_dataloader, data_module["data_collator"]).

1121-1128: ⚠️ Potential issue | 🟠 Major

Validate unsupported --specdec_offline_dataset flag combinations up front.

The new flag is accepted unconditionally, but sparse path dispatch and multi-value --calib_size are not compatible with this offline flow.

Suggested fix
     args = parser.parse_args()
     if not (0.0 < args.moe_calib_experts_ratio <= 1.0):
         parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")
+    if args.specdec_offline_dataset is not None and args.sparsity_fmt != "dense":
+        parser.error("--specdec_offline_dataset only supports PTQ (--sparsity_fmt dense).")

     return args
 if __name__ == "__main__":
     args = parse_args()
@@
     args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset
     args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")]
+    if args.specdec_offline_dataset is not None and len(args.calib_size) != 1:
+        raise ValueError("--specdec_offline_dataset expects a single --calib_size value.")
     main(args)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 1121 - 1128, Validate the new
--specdec_offline_dataset flag immediately after argument parsing and error out
if used with incompatible options: when args.specdec_offline_dataset is not
None, reject sparse path dispatch (e.g., args.sparse_path or any flag
controlling sparse dispatch) and multi-value --calib_size (detect
len(args.calib_size) > 1 or comma-separated values) by raising/printing a clear
error and exiting; update the validation near where parser.add_argument is
processed (the main arg handling block) so incompatible combinations are
detected up front and prevent further execution.

350-355: ⚠️ Potential issue | 🟠 Major

Offline load path ignores configured device and attention implementation.

This branch loads with from_pretrained() but does not apply args.device/args.attn_implementation, so behavior diverges from other model-loading paths.

Suggested fix
     if args.specdec_offline_dataset is not None:
-        full_model = AutoModelForCausalLM.from_pretrained(
-            args.pyt_ckpt_path,
-            trust_remote_code=args.trust_remote_code,
-        )
+        model_kwargs = {"trust_remote_code": args.trust_remote_code}
+        if args.attn_implementation is not None:
+            model_kwargs["attn_implementation"] = args.attn_implementation
+        full_model = AutoModelForCausalLM.from_pretrained(args.pyt_ckpt_path, **model_kwargs)
+        full_model = full_model.to(args.device)
#!/bin/bash
# Verify the offline load branch and compare it with other load paths.
rg -n "if args.specdec_offline_dataset is not None|AutoModelForCausalLM.from_pretrained|get_model\\(|attn_implementation|\\.to\\(args\\.device\\)" examples/llm_ptq/hf_ptq.py -A8 -B4
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 350 - 355, The offline-load branch
that runs when args.specdec_offline_dataset is not None currently calls
AutoModelForCausalLM.from_pretrained(...) but does not apply the configured
device or attention implementation; update that branch to mirror the other load
paths by either calling the same helper used elsewhere (e.g., get_model(...)
with device=args.device and attn_implementation=args.attn_implementation) or,
after from_pretrained, explicitly apply the attention implementation and move
the model to args.device (e.g., set attn implementation and call
model.to(args.device)) so the offline path behaves identically to the other
branches.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 921-940: The offline calibration branch that constructs
calib_dataloader using make_eagle_supervised_data_module and
data_module["data_collator"] does not ensure tensors are moved to the model
device; update the calibration data path so each batch is transferred to device
before any forward/calibration call (either by wrapping the DataLoader with a
small helper that maps batches -> {k: v.to(device) for tensors} or by modifying
the collator returned by data_module["data_collator"] to move tensors to
device), ensuring changes touch the code that creates calib_dataloader and the
forward calibration loop that consumes it (references:
args.specdec_offline_dataset, make_eagle_supervised_data_module,
calib_dataloader, data_module["data_collator"]).
- Around line 1121-1128: Validate the new --specdec_offline_dataset flag
immediately after argument parsing and error out if used with incompatible
options: when args.specdec_offline_dataset is not None, reject sparse path
dispatch (e.g., args.sparse_path or any flag controlling sparse dispatch) and
multi-value --calib_size (detect len(args.calib_size) > 1 or comma-separated
values) by raising/printing a clear error and exiting; update the validation
near where parser.add_argument is processed (the main arg handling block) so
incompatible combinations are detected up front and prevent further execution.
- Around line 350-355: The offline-load branch that runs when
args.specdec_offline_dataset is not None currently calls
AutoModelForCausalLM.from_pretrained(...) but does not apply the configured
device or attention implementation; update that branch to mirror the other load
paths by either calling the same helper used elsewhere (e.g., get_model(...)
with device=args.device and attn_implementation=args.attn_implementation) or,
after from_pretrained, explicitly apply the attention implementation and move
the model to args.device (e.g., set attn implementation and call
model.to(args.device)) so the offline path behaves identically to the other
branches.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bb37374d-90cc-4873-9add-6a52b3b5accf

📥 Commits

Reviewing files that changed from the base of the PR and between d544c0a33909beb3ec59ef5ebe98f47c5f9e3d2a and fba558113c6bbabd69cc29edb85df17b046c9919.

📒 Files selected for processing (1)
  • examples/llm_ptq/hf_ptq.py

Copy link
Copy Markdown
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds PTQ support for offline speculative decoding models by threading offline_specdec_input through the export pipeline and adding --specdec_offline_dataset to hf_ptq.py. The feature makes sense, but there are some implementation issues (typo in a field name, fragile cross-example import, generic utility special-cased for one feature) and no tests.

Testing: Please provide information on where the end-to-end test was run (e.g., nmm-sandbox job link, internal CI run) so reviewers can verify the workflow works. Additionally, please add unit tests for:

  • sample_size truncation logic in eagle_utils (edge cases: 0, -1, N > len)
  • _process_batch with base_model_outputs (non-tensor values accepted, other keys still validated)
  • offline_specdec_input propagation through the export path

Comment thread examples/llm_ptq/hf_ptq.py
: len(args.dataset)
]
if extracted_lm is not None:
language_model = extracted_lm
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sys.path.append to import from a sibling example directory is fragile — breaks if __file__ is unset or the relative path changes. Consider making eagle_utils a proper importable module, or at minimum wrap in try/except with a helpful error message.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 4586572d — replaced sys.path.append with importlib.util.spec_from_file_location to load eagle_utils directly by path without polluting sys.path. Since this is an examples directory (not a package), making it a proper module isn't appropriate.

Comment thread examples/llm_ptq/hf_ptq.py
@@ -519,8 +519,11 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None):
Returns:
The maximum batch size that worked successfully
"""
assert all(torch.is_tensor(data) or data is None for data in batch_data.values()), (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Special-casing base_model_outputs by key name in a generic batch processing utility is a code smell — ties this utility to a specific feature. Consider either: (a) adding a parameter to skip validation for certain keys, or (b) normalizing base_model_outputs to tensors before it reaches this function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 4586572d — refactored _process_batch, _forward_loop, and create_forward_loop to accept an allowed_non_tensor_keys: set | None parameter instead of hardcoding 'base_model_outputs'. The offline dataset call site now passes allowed_non_tensor_keys={"base_model_outputs"} explicitly.

Comment thread examples/speculative_decoding/main.py Outdated
Comment thread examples/speculative_decoding/eagle_utils.py
@yeyu-nvidia
Copy link
Copy Markdown
Contributor Author

yeyu-nvidia commented Mar 19, 2026

Addressing the testing feedback from the review (a15a804):

Unit tests added:

  • tests/unit/torch/utils/test_dataset_utils.py: 3 new tests for _process_batch with allowed_non_tensor_keys — verifies non-tensor values are accepted under the allowlist, other keys are still validated, and the default (no allowlist) rejects non-tensors.
  • tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py: 5 tests for sample_size truncation edge cases (positive N, -1, 0, N > dataset len, empty dir) and 2 tests for offline_specdec_input propagation through export_speculative_decoding.

@yeyu-nvidia
Copy link
Copy Markdown
Contributor Author

CI test added (217777c):

tests/examples/speculative_decoding/test_eagle_offline_ptq.py covers the full three-stage pipeline in a single CI job:

  1. test_collect_hidden_states — runs compute_hidden_states_hf.py with 2 conversations and verifies .pt files are produced with the correct input_ids/hidden_states keys.
  2. test_offline_eagle_training — runs launch_train.sh --offline-data with a tiny EAGLE config and verifies the output checkpoint directory is created.
  3. test_offline_ptq — runs hf_ptq.py with --specdec_offline_dataset and verifies model.safetensors + config.json are present with the expected EAGLE weight keys.

The test uses minimal data (2 conversations, --training_seq_len 64, --calib 2) to keep CI runtime short.

Copy link
Copy Markdown
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review: Offline Speculative Decoding Model PTQ Support

1. Summary

This PR adds PTQ (post-training quantization) support for offline speculative decoding models (e.g., EAGLE3) by:

  • Threading offline_specdec_input through the export pipeline to provide real calibration data for dummy forward passes
  • Adding --specdec_offline_dataset flag to hf_ptq.py to load .pt files containing pre-computed hidden states
  • Extending dataset_utils.py to permit non-tensor batch values (via allowed_non_tensor_keys) for speculative decoding's custom input format
  • Adding two new test files covering the full PTQ workflow and unit tests for sample truncation

The approach is sound: it leverages existing calibration/export infrastructure and treats offline datasets as a special case rather than overhauling the entire pipeline. However, several implementation details need attention before merge.


2. Key Areas Requiring Attention

🔴 Critical: Device Placement for Offline Calibration

Files: examples/llm_ptq/hf_ptq.py (lines ~924–940), examples/speculative_decoding/eagle_utils.py
Issue: The offline dataloader yields CPU tensors (from torch.load() and torch.stack()), but these are passed directly to calibration without moving to the target device. This will cause device mismatch errors when the model is on CUDA.
Why it matters: The calibration loop (_forward_loop_process_batch) does not automatically move batches to device. Will silently fail at runtime.
Fix: Either:

  • Add a device parameter to EagleOfflineDataCollator and apply .to(device) in its __call__ method, or
  • Wrap the offline DataLoader with a device-transfer step before passing to calibration

🔴 Critical: Model Loading Bypasses Device & Attention Config

File: examples/llm_ptq/hf_ptq.py (lines ~351–355)
Issue: When args.specdec_offline_dataset is not None, the code calls AutoModelForCausalLM.from_pretrained() directly without applying args.device or args.attn_implementation, unlike other load paths.
Why it matters: Model loads on CPU by default and misses architecture-specific optimizations, diverging from the standard PTQ flow.
Fix: Either call the centralized get_model() helper (as used in the elif branch) or explicitly apply attn_implementation and .to(args.device) after loading.


🔴 Critical: offline_specdec_input Device Mismatch in Export

File: modelopt/torch/export/unified_export_hf.py (lines ~383–385)
Issue: The calibration batch passed as offline_specdec_input is a DataLoader output (CPU tensors). Calling model(**offline_specdec_input) will fail if the model is on CUDA due to device mismatch.
Why it matters: The export forward pass will crash with a device error.
Fix: Before calling model(**offline_specdec_input), recursively move all tensors to the model's device via a helper function (suggested in existing review comments).


🟠 Important: sample_size Truncation Logic Unclear

File: examples/speculative_decoding/eagle_utils.py (line ~162)
Issue: The code only slices when sample_size > 0, so sample_size=0 and sample_size < -1 both use the full dataset. The docstring/help text suggests -1 is the sentinel for "use all," but the implementation is ambiguous.
Why it matters: Users could accidentally trigger full-dataset calibration when they intended a small sample.
Fix: Add validation in main.py to reject sample_size <= 0 and sample_size < -1 at parse time, raising a clear error. Document the sentinel value explicitly.


🟠 Important: allowed_non_tensor_keys Special-Cases Generic Utility

File: modelopt/torch/utils/dataset_utils.py (line ~522)
Issue: The new allowed_non_tensor_keys parameter ties a generic batch-processing utility to a specific feature (offline speculative decoding). This is a code smell and reduces reusability.
Why it matters: Future features with custom batch formats will need similar special-casing, making the code harder to maintain.
Context: The PR addresses this with the new parameter, which is an improvement, but consider whether this should be configurable at a higher level or normalized before reaching _process_batch.


🟠 Important: Fragile Module Loading via sys.path.append

File: examples/llm_ptq/hf_ptq.py (line ~472, now fixed in latest commit)
Status:Already fixed in commit 4586572d using importlib.util.spec_from_file_location. The new approach is robust and doesn't pollute sys.path.


🟡 Minor: Typo in Data Arguments (Fixed)

File: examples/llm_ptq/hf_ptq.py (line ~478, now fixed in latest commit)
Status:Already fixed in commit 4586572d — corrected to lazy_preprocess=True.


🟡 Minor: Missing Comment on Batch Re-fetching

File: examples/llm_ptq/hf_ptq.py (line ~608, now fixed in latest commit)
Status:Already fixed in commit 4586572d — added clarifying comment explaining why offline speculative decoding models require a real sample as the dummy input for export.


🟡 Minor: Unsupported Flag Combinations Not Validated (Fixed)

File: examples/llm_ptq/hf_ptq.py (lines ~1121–1128, now fixed in commit 04f83da)
Status:Already fixed — validation in parse_args() rejects --specdec_offline_dataset with non-dense sparsity formats, and in __main__ rejects multi-value --calib_size.


3. Architecture & Design Changes

⚠️ This PR contains architecture/design changes that may require design review before approval.

What changed:

  1. New parameter allowed_non_tensor_keys in _process_batch, _forward_loop, create_forward_loop: Allows batch dictionaries to contain non-tensor values for specific keys, relaxing the earlier strict tensor-only requirement.
  2. New offline_specdec_input parameter threaded through export pipeline: export_speculative_decoding()SpeculativeDecodingExporter.export()_export_transformers_checkpoint()requantize_resmooth_fused_llm_layers() → dummy forward pass.
  3. Dataset utility now feature-aware: The generic _process_batch now has implicit knowledge of speculative decoding's base_model_outputs field.

Why it matters:
These changes couple generic utilities (dataset handling, export) to a specific feature (offline speculative decoding). While the allowed_non_tensor_keys parameter is a reasonable abstraction, it's worth confirming:

  • Does this abstraction generalize well if future features need similar customization?
  • Are there alternative approaches (e.g., normalizing custom batch formats before they reach _process_batch) that would be cleaner?

4. Test Coverage

What's covered:

  • tests/examples/speculative_decoding/test_eagle_offline_ptq.py (new): End-to-end integration test covering all three stages (hidden state collection → EAGLE training → PTQ export).
  • tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py (new): Unit tests for:
    • sample_size truncation logic (positive, -1, 0, larger-than-dataset)
    • offline_specdec_input propagation through export
  • tests/unit/torch/utils/test_dataset_utils.py (modified): Added tests for allowed_non_tensor_keys validation.

What's missing:

  • Device movement in calibration: No test verifying that offline calibration batches are correctly moved to the target device before forward passes.
  • Model loading consistency: No test comparing device/attention placement between the offline load path and standard paths.
  • Export forward pass with real input: No test verifying the dummy forward in requantize_resmooth_fused_llm_layers() succeeds with offline_specdec_input on CUDA.
  • End-to-end device flow: The integration test should exercise CUDA (if available) to catch device mismatches.

Recommendation:

Add a test in test_dataset_utils.py that:

  1. Creates a batch with base_model_outputs as a non-tensor (list/dict).
  2. Calls _process_batch() with allowed_non_tensor_keys={"base_model_outputs"}.
  3. Verifies it doesn't error and that other tensor keys are still sliced correctly.

5. Verdict

⚠️ Request changes

The three critical device-placement issues above must be resolved before merge. The fixes are straightforward but essential for runtime correctness:

  1. Move offline calibration batches to the model device (wrap dataloader or modify collator).
  2. Apply device & attention config to the offline model load path.
  3. Move offline_specdec_input to model device before the export forward pass.

After those fixes, the PR is on solid ground. The test coverage is good, the abstraction with allowed_non_tensor_keys is reasonable, and the overall design fits well into the existing PTQ pipeline. A human reviewer should verify:

  • The device-movement fixes are correct and don't break other code paths.
  • The end-to-end integration test actually runs and passes.
  • No regressions in existing PTQ workflows (dense/sparse/VLM paths).

Remember: This is an AI-assisted review. A human maintainer must approve and sign off before merging, regardless of verdict.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be cleaner if we separate the the speculative decoding ptq to examples/speculative_decoding?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends. Our Online PTQ will use hf_ptq so I would prefer we leave all ptq code together

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding to the above — the offline PTQ code in hf_ptq.py is ~50 lines that reuse the existing PTQ infrastructure (model loading, calibration loop, export). Separating it to examples/speculative_decoding/ would require duplicating all of that shared code, and as mentioned, the upcoming online PTQ path will also live in hf_ptq.py. Keeping it together avoids divergence.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@ChenhanYu
Copy link
Copy Markdown
Collaborator

Also, we need to add an Qwen3-8B EAGLE PTQ example to tools/launcher/examples to complete this PR

@yeyu-nvidia
Copy link
Copy Markdown
Contributor Author

Added in the latest push — tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3_ptq.yaml is a 5-step pipeline (query → dump hidden states → train EAGLE3 → PTQ with --specdec_offline_dataset → VLLM benchmark). Also added the wrapper script at tools/launcher/common/eagle3/hf_ptq.sh.

return {k: _to_device(v) for k, v in value.items()}
return value

model(**_to_device(offline_specdec_input))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is offline_specdec_input necessary here? Can we just use dummy inputs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could construct dummy inputs from the model config (hidden_size, etc.), but that would couple the export pipeline to EAGLE's internal data format — if the format changes, the dummy construction breaks silently. Using real offline_specdec_input is more robust since it always matches what the model actually expects. The trade-off isn't worth it IMO.

Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we implement a get_dummy_offline_inputs() for class HFEagleModel and all incoming SD algorithms like dflash? Then we keep the eagle-specific logics locolized, without introducing new arguments all the way from export_quantized() -> export_speculative_decoding() -> exporter -> _export_transformers_checkpoint() -> requantize_resmooth_fused_llm_layers()

Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we simply model(**_to_device(model.get_dummy_inputs())) here, regardless of what algorithm it is

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea — implemented in 0b7b960. Added get_dummy_inputs() to HFEagleModel that constructs the right input format (with base_model_outputs for offline models, plain input_ids otherwise). The export pipeline now just does model(**model.get_dummy_inputs()) when get_dummy_inputs() is available. Removed offline_specdec_input from the entire chain: export_quantizedexport_speculative_decodingexporter.export_export_transformers_checkpointrequantize_resmooth_fused_llm_layers. This is extensible for future SD algorithms like dflash — they just implement their own get_dummy_inputs().

Comment thread examples/llm_ptq/hf_ptq.py Outdated
)
if args.specdec_offline_dataset is not None:
_eagle_utils_path = os.path.join(
os.path.dirname(__file__), "../speculative_decoding/eagle_utils.py"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about moving OfflineSupervisedDataset and EagleOfflineDataCollator to some importable module like modelopt/torch/speculative/eagle/utils.py? I think they are critical enough to be put there

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise this import seems a little hacky to me to put into ptq entrance

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in ac825fb — moved OfflineSupervisedDataset and EagleOfflineDataCollator to modelopt/torch/speculative/eagle/utils.py. The old location in eagle_utils.py re-imports for backward compat.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in ac825fb — the importlib hack is gone. Now imports directly from modelopt.torch.speculative.eagle.utils.

yeyu-nvidia and others added 20 commits April 8, 2026 12:09
- Fix typo devlazy_preprocessice -> lazy_preprocess in offline dataset args
- Replace sys.path.append with importlib.util for eagle_utils import to avoid
  polluting sys.path with a relative path
- Add comment explaining why next(iter(calib_dataloader)) is used as the
  dummy input for offline speculative decoding export
- Refactor _process_batch / _forward_loop / create_forward_loop to accept
  allowed_non_tensor_keys parameter instead of hardcoding 'base_model_outputs'
  in generic utility; pass {"base_model_outputs"} at the offline dataset call site
- Clarify sample_size sentinel: non-positive values use all samples

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- test_dataset_utils: add tests for _process_batch with allowed_non_tensor_keys
  (non-tensor values accepted under allowlist, other keys still validated)
- test_hf_speculative_offline: add tests for sample_size truncation logic in
  eagle_utils (edge cases: positive N, -1, 0, N > len, empty dir) and
  offline_specdec_input propagation through export_speculative_decoding

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Three-stage E2E test: collect hidden states → train offline EAGLE → PTQ
export. Uses minimal data (2 conversations, seq_len=64) to keep CI fast.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Reject upfront when --specdec_offline_dataset is combined with
--sparsity_fmt != dense (offline PTQ only supported in dense path),
and when multiple --calib values are given (offline path expects one).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…specdec PTQ

- hf_ptq.py: Add attn_implementation kwarg and .to(device) for offline model load
  path to match the non-offline code path
- unified_export_hf.py: Move offline_specdec_input tensors (including nested dicts)
  to model device before forward pass in requantize_resmooth
- eagle_utils.py: Add device parameter to EagleOfflineDataCollator so calibration
  batches are moved to GPU; reject sample_size=0 and sample_size<-1 with clear error
- main.py: Add __post_init__ validation for sample_size in DataArguments

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Add a 5-step pipeline YAML (query -> dump hidden states -> train ->
PTQ with --specdec_offline_dataset -> benchmark) and a common shell
script for running hf_ptq.py on EAGLE3 models.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- Move OfflineSupervisedDataset and EagleOfflineDataCollator to
  modelopt/torch/speculative/eagle/utils.py as importable module classes
  (Comments 2, 5)
- Remove device parameter from EagleOfflineDataCollator; move device-transfer
  logic to a _DeviceDataLoader wrapper in make_calib_dataloader to avoid
  interfering with dataloader prefetching (Comment 3)
- Move offline dataset/dataloader creation into make_calib_dataloader for a
  cleaner interface (Comment 4)
- Move export/skip-quantization branch for offline specdec into post_quantize
  as an early-exit, following existing branching patterns (Comment 6)
- Remove hacky importlib-based import in hf_ptq.py (Comment 5)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…emory_mode validation

- Add early return in pre_quantize() for offline specdec (no preview needed)
- Add parse_args() validation rejecting --specdec_offline_dataset + --low_memory_mode

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…y_inputs()

Add get_dummy_inputs() to HFEagleModel that constructs the right input
format (including base_model_outputs for offline models). This localizes
EAGLE-specific logic in the model class and eliminates the
offline_specdec_input parameter that was threaded through
export_quantized -> export_speculative_decoding -> exporter.export ->
_export_transformers_checkpoint -> requantize_resmooth_fused_llm_layers.

The export pipeline now just calls model.get_dummy_inputs() when
available, making it extensible for future SD algorithms without
additional parameter threading.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Addresses reviewer feedback to use weights_only=True for security
best practices when loading offline data files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- Revert unnecessary formatting change in unified_export_hf.py
- Add dataset format documentation to OfflineSupervisedDataset docstring

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…odels

The offline PTQ path previously loaded the base model onto a single GPU
with .to(device), which fails for models that exceed single-GPU memory
(e.g., 70B+ at ~140GB in BF16). Switch to device_map="auto" to match
the non-offline get_model() path, allowing HF accelerate to distribute
the model across available GPUs. Also add torch_dtype="auto" to preserve
the checkpoint's native dtype.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Remove the separate model loading path for --specdec_offline_dataset and
reuse get_model() which already handles device_map="auto", torch_dtype,
trust_remote_code, and attn_implementation. The original motivation for a
separate path was a single-device constraint, but that constraint doesn't
exist — HF accelerate dispatch hooks handle multi-device correctly.

Since --specdec_offline_dataset and --low_memory_mode are mutually
exclusive (validated at argparse level), the two conditions merge cleanly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- Add AutoConfig.from_pretrained() call before load_vlm_or_llm to define
  model_config used for offline training (main.py:214 NameError)
- Add assert for calib_dataloader not None to satisfy mypy type check
  (hf_ptq.py:816 incompatible type error)
- Remove unused 'patch' import (test_hf_speculative_offline.py)

Signed-off-by: Ye Yu <yey@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
sample_size=0 is rejected by both DataArguments.__post_init__ and
make_eagle_supervised_data_module. The test incorrectly expected it
to use all samples; it should assert the ValueError is raised.

Signed-off-by: Ye Yu <yey@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Move AutoConfig.from_pretrained() into the offline training branch
and use getattr with fallback for models like Kimi-K2.5 that use
non-standard config attributes instead of num_hidden_layers.

Signed-off-by: Ye Yu <yey@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
apply_chat_template with return_tensors="pt" may return a raw tensor
or a BatchEncoding dict depending on the tokenizer. Handle both cases
to fix test_collect_hidden_states failure.

Signed-off-by: Ye Yu <yey@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…aCollator

Cover dataset loading, label shifting, collator truncation/padding,
and multi-sample batching to improve code coverage for the new offline
speculative decoding classes.

Signed-off-by: Ye Yu <yey@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yey@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yey@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
After rebase, launch_train.sh requires --config <yaml> with OmegaConf
dotlist overrides instead of individual --flag arguments. Updated
test_eagle_offline_ptq.py to match the new interface.

Signed-off-by: Ye Yu <yey@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
The huggingface_example.sh wrapper doesn't support --specdec_offline_dataset.
Call hf_ptq.py directly with the correct args instead.

Signed-off-by: Ye Yu <yey@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
yeyu-nvidia and others added 2 commits April 8, 2026 15:07
The dummy hidden state tensors used during export were created as float32
but the model weights are bfloat16, causing a dtype mismatch in the
lm_head forward pass during requantize_resmooth_fused_llm_layers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
When use_aux_hidden_state is True (EAGLE3), the export dummy forward
pass calls eagle_module.fc(aux_hiddens), which was None because
get_dummy_inputs did not provide aux_hidden_states.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia yeyu-nvidia merged commit cccfded into main Apr 8, 2026
45 checks passed
@yeyu-nvidia yeyu-nvidia deleted the yeyu/offline_quant branch April 8, 2026 23:56
Edwardf0t1 pushed a commit that referenced this pull request Apr 9, 2026
## What does this PR do?

**Type of change:**
new feature

**Overview:** 
This PR enables loading in a ModelOpt pretrained offline speculative
decoding model (e.g., EAGLE3) and performs PTQ on it and export.

## Usage
Follow the speculative_decoding examples to train an offline speculative
decoding model first.
Then follow the command below to quantize and export it:

```bash
python hf_ptq.py --pyt_ckpt_path <dir_of_offline_specdec_model> --specdec_offline_dataset <dir_of_dataset>
```

## Testing
<!-- Mention how have you tested your change if applicable. -->

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Offline speculative decoding workflow: support loading a local dataset
for calibration, generation, and export; new CLI option to specify the
offline dataset.

* **Improvements**
* Export and quantization paths now accept and propagate offline
speculative-decoding inputs.
* Offline data loading honors a sample-size limit and enforces safe
batch sizing for calibration.

* **Bug Fixes**
* Better handling of model/config mismatches and varied batch types in
offline flows.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Ye Yu <yeyu@nvidia.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
kinjalpatel27 pushed a commit that referenced this pull request Apr 13, 2026
## What does this PR do?

**Type of change:**
new feature

**Overview:** 
This PR enables loading in a ModelOpt pretrained offline speculative
decoding model (e.g., EAGLE3) and performs PTQ on it and export.

## Usage
Follow the speculative_decoding examples to train an offline speculative
decoding model first.
Then follow the command below to quantize and export it:

```bash
python hf_ptq.py --pyt_ckpt_path <dir_of_offline_specdec_model> --specdec_offline_dataset <dir_of_dataset>
```

## Testing
<!-- Mention how have you tested your change if applicable. -->

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Offline speculative decoding workflow: support loading a local dataset
for calibration, generation, and export; new CLI option to specify the
offline dataset.

* **Improvements**
* Export and quantization paths now accept and propagate offline
speculative-decoding inputs.
* Offline data loading honors a sample-size limit and enforces safe
batch sizing for calibration.

* **Bug Fixes**
* Better handling of model/config mismatches and varied batch types in
offline flows.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Ye Yu <yeyu@nvidia.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants