Skip to content

[2/3][Feat]: Offline DFlash training#1295

Draft
h-guo18 wants to merge 5 commits intohaoguo/spec-file-reorgfrom
haoguo/dflash-offline
Draft

[2/3][Feat]: Offline DFlash training#1295
h-guo18 wants to merge 5 commits intohaoguo/spec-file-reorgfrom
haoguo/dflash-offline

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Apr 19, 2026

What does this PR do?

Type of change: new feature

Part 2 of a 3-PR series splitting #1271:

Changes:

  • Add dflash_offline flag to DFlashConfig for training from pre-computed hidden states; deletes base model layers to save memory.
  • Add Pydantic validators on DFlashConfig:
    • _derive_dflash_offline — derive dflash_offline from data_args.offline_data_path in validation context.
    • _resolve_mask_token_id — auto-detect dflash_mask_token_id from tokenizer.mask_token_id.
    • _check_mask_token_id — fail fast if unset after resolution.
  • HFDFlashModel.modify(): select num_orig_hidden_layers when offline; pick _base_model_lm_head device when no base layers present; drop base-model layers module.
  • HFDFlashModel.forward(): add offline branch — consumes precomputed base_model_outputs via DFlashBaseModelOutput.from_offline_dict, and when dflash_self_logit_distillation is enabled with base_model_logits absent, recomputes logits from base_model_hidden_states via _base_model_lm_head.
  • DFlashBaseModelOutput dataclass in modeling_dflash.py (with from_offline_dict classmethod) to unify online/offline output shapes.
  • examples/speculative_decoding/main.py: replace inline mask_token_id auto-detect with DFlashConfig.model_validate(dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}).

Usage

# Training YAML
dflash:
  dflash_offline: true  # auto-derived when data_args.offline_data_path is set
  # dflash_mask_token_id optional — auto-detected from tokenizer

Testing

  • Validated with offline DFlash training script on pre-computed hidden states (re-run after rebase onto updated [1/3][Refactor]: File reorg; deprecate ParallelDraft #1296 tip).
  • New tests:
    • tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py — CPU unit tests for convert path (online keeps base layers, offline deletes them; num_orig_hidden_layers drives target_layer_ids in offline mode) and DFlashConfig._derive_dflash_offline validator.
    • TestDFlashOfflineForwardGPU in tests/gpu/torch/speculative/plugins/test_hf_dflash.py — GPU forward smoke with precomputed base_model_outputs, plus the dflash_self_logit_distillation logit-recompute path.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ — additive dflash_offline flag defaulting to False; validators fall through when context not provided.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅ — see Testing section above.
  • Did you update Changelog?: ❌

TODO (follow-up)

  • Update examples/speculative_decoding/collect_hidden_states/compute_hidden_states_*.py to support DFlash offline data. Current scripts are Eagle-specific — they hardcode the [2, N/2, N-3] aux-layer selection and emit {input_ids, hidden_states, aux_hidden_states}. DFlash offline needs:
    • Aux layer indices driven by build_target_layer_ids(num_orig_hidden_layers, num_draft_layers) (or a configurable list), not the Eagle triplet.
    • base_model_hidden_states key (last-layer hidden) so DFlashBaseModelOutput.from_offline_dict + the dflash_self_logit_distillation recompute path can consume it.
    • Optional base_model_logits dump so offline training can skip the self-distillation logit recomputation when logits are available.

Additional Information

Base branch is #1296 (file reorg). Retarget to main once #1296 merges.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 19, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 19, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 514e36ae-28fe-46bb-a0c4-73ea242b7411

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/dflash-offline

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 19, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1295/

Built to branch gh-pages at 2026-04-19 21:56 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 19, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 75.56%. Comparing base (f488231) to head (f208109).

Additional details and impacted files
@@                    Coverage Diff                     @@
##           haoguo/spec-file-reorg    #1295      +/-   ##
==========================================================
- Coverage                   75.56%   75.56%   -0.01%     
==========================================================
  Files                         466      466              
  Lines                       50238    50232       -6     
==========================================================
- Hits                        37962    37957       -5     
+ Misses                      12276    12275       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 force-pushed the haoguo/dflash-offline branch 2 times, most recently from 9e4eeb0 to f208109 Compare April 19, 2026 21:53
@h-guo18 h-guo18 changed the base branch from main to haoguo/spec-file-reorg April 19, 2026 21:54
@h-guo18 h-guo18 changed the title offline dflash [2/3][Feat]: Offline DFlash training Apr 19, 2026
- Add `dflash_offline` config flag for training from pre-computed hidden states;
  deletes base model layers to save memory.
- Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig`
  Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`.
- Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming
  pre-computed hidden states in the forward path.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 self-assigned this Apr 19, 2026
@h-guo18 h-guo18 force-pushed the haoguo/dflash-offline branch from f208109 to 178b191 Compare April 19, 2026 23:40
h-guo18 added 4 commits April 20, 2026 00:00
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant