Skip to content

Commit 5ab4ebb

Browse files
authored
Merge branch 'main' into jingyux/diffusion-skip-softmax
2 parents d2d6d83 + c9b1155 commit 5ab4ebb

58 files changed

Lines changed: 5047 additions & 192 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/gpu_tests.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ jobs:
4242
.github/workflows/gpu_tests.yml
4343
modelopt/**
4444
tests/gpu/**
45+
tests/gpu_regression/**
46+
examples/speculative_decoding/**
47+
examples/dataset/**
48+
modelopt_recipes/general/speculative_decoding/**
49+
tools/launcher/**
4550
pyproject.toml
4651
tox.ini
4752
fail_on_initial_diff_error: true
@@ -66,6 +71,9 @@ jobs:
6671
timeout: 45
6772
container_image: pytorch:26.01-py3
6873
# tests/gpu/_extensions/test_onnx_extensions.py fails for newer containers until https://github.com/tbenthompson/cppimport/pull/98
74+
- example: gpu-regression
75+
timeout: 15
76+
container_image: pytorch:26.01-py3
6977
- example: gpu-megatron
7078
timeout: 45
7179
container_image: pytorch:26.01-py3

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ repos:
5757

5858
- repo: local
5959
hooks:
60+
- id: normalize-yaml-ext
61+
name: normalize .yml to .yaml in required places, right now only yaml files in modelopt_recipes
62+
entry: python tools/precommit/normalize_yaml_ext.py
63+
language: system
64+
files: ^modelopt_recipes/.*\.yml$
65+
6066
- id: check-modelopt-recipes
6167
name: validate modelopt recipes
6268
entry: python tools/precommit/check_modelopt_recipes.py

examples/dataset/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,16 @@ python -m modelopt.torch.utils.plugins.megatron_preprocess_data \
219219
--workers 32 \
220220
--reasoning_content inline
221221
```
222+
223+
## Synthetic Test Dataset
224+
225+
`synthetic_conversations_1k.jsonl` is a 1,000-sample dataset in OpenAI messages format
226+
(900 single-turn + 100 two-turn conversations) covering writing, reasoning, math, coding,
227+
STEM, extraction, humanities, and roleplay categories.
228+
229+
This dataset was synthesized by Claude (Anthropic) and is licensed under Apache-2.0.
230+
It is intended for testing and CI regression — not for production training.
231+
232+
```json
233+
{"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
234+
```

examples/dataset/make_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,10 @@ async def main(args: argparse.Namespace) -> None:
522522
)
523523
if "conversation_id" not in entry:
524524
entry["conversation_id"] = id_for_conversation(entry["conversations"])
525-
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
525+
# Output in OpenAI messages format (rename conversations → messages)
526+
output_entry = {k: v for k, v in entry.items() if k != "conversations"}
527+
output_entry["messages"] = entry["conversations"]
528+
f.write(json.dumps(output_entry, ensure_ascii=False) + "\n")
526529

527530

528531
if __name__ == "__main__":

examples/dataset/synthetic_conversations_1k.jsonl

Lines changed: 1000 additions & 0 deletions
Large diffs are not rendered by default.

examples/speculative_decoding/README.md

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,7 @@ To use your own datasets, please preprocess your data into a `.jsonl` file with
217217

218218
```json
219219
{
220-
"conversation_id": <unique id>,
221-
"conversations": [{"role":<user or assistant>, "content":<content>}]
220+
"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
222221
}
223222
```
224223

@@ -350,3 +349,46 @@ More models coming soon!
350349
- 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html)
351350
- 🐛 [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md)
352351
-[File a Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md)
352+
353+
## DFlash (Block Diffusion for Speculative Decoding)
354+
355+
DFlash is a parallel speculative decoding method based on [Block Diffusion](https://arxiv.org/abs/2602.06036).
356+
Unlike autoregressive draft models (EAGLE3), DFlash predicts an entire block of tokens in a single forward pass
357+
using masked parallel prediction with KV injection from the target model's hidden states.
358+
359+
### Quick Start
360+
361+
For a complete end-to-end example (training + evaluation), see the
362+
[launcher example](../../tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml):
363+
364+
```bash
365+
uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes
366+
```
367+
368+
### Key Configuration ([dflash.yaml](../../modelopt_recipes/general/speculative_decoding/dflash.yaml))
369+
370+
| Field | Default | Description |
371+
|-------|---------|-------------|
372+
| `dflash.dflash_block_size` | 8 | Block size for parallel prediction |
373+
| `dflash.dflash_num_anchors` | 512 | Number of anchor positions per sample |
374+
| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables) |
375+
| `dflash.dflash_self_logit_distillation` | true | Use logit distillation from target |
376+
| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers |
377+
| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions |
378+
| `training.answer_only_loss` | false | Mask loss on non-assistant tokens |
379+
380+
Qwen3 sliding window attention is automatically supported — draft layers inherit
381+
`layer_types` and `sliding_window` from the config, matching the target model's
382+
attention pattern.
383+
384+
### Export
385+
386+
```bash
387+
python scripts/export_hf_checkpoint.py \
388+
--model_path /path/to/training/output \
389+
--export_path /path/to/exported/model
390+
```
391+
392+
### Results
393+
394+
See [doc/dflash.md](doc/dflash.md) for design details, benchmark results, and open items.

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ async def submit_generates():
201201
for entry in dataset:
202202
conversation_id = entry.get("conversation_id", entry.get("uuid"))
203203

204-
conversations = entry["conversations"]
204+
conversations = entry.get("messages") or entry["conversations"]
205205
if not conversations or not isinstance(conversations, list):
206206
num_invalid += 1
207207
continue

0 commit comments

Comments
 (0)