Skip to content

Commit 5f8d004

Browse files
ChenhanYuclaude
andcommitted
add: DFlash block diffusion speculative decoding
DFlash (Block Diffusion for Flash Speculative Decoding) predicts an entire block of tokens in a single forward pass using masked parallel prediction with KV injection from the target model's hidden states. Key features: - Feature fusion (multi-layer hidden states -> FC + RMSNorm) - KV injection (fused features as K/V in every draft layer with QK-norm) - Random anchor sampling with bidirectional intra-block attention - Logit distillation with exponential loss decay (gamma weighting) - Multi-node DDP training with checkpoint resume - Export to z-lab compatible HF format - Online validation (context-dependent ground truth) Training recipe: modelopt_recipes/general/speculative_decoding/dflash.yaml Results: examples/speculative_decoding/doc/dflash_results.md Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4a70040 commit 5f8d004

File tree

25 files changed

+3234
-47
lines changed

25 files changed

+3234
-47
lines changed

doc/results/dflash_results.html

Whitespace-only changes.

examples/speculative_decoding/README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,41 @@ More models coming soon!
350350
- 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html)
351351
- 🐛 [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md)
352352
-[File a Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md)
353+
354+
## DFlash (Block Diffusion for Speculative Decoding)
355+
356+
DFlash is a parallel speculative decoding method based on [Block Diffusion](https://arxiv.org/abs/2602.06036).
357+
Unlike autoregressive draft models (EAGLE3), DFlash predicts an entire block of tokens in a single forward pass
358+
using masked parallel prediction with KV injection from the target model's hidden states.
359+
360+
### Quick Start
361+
362+
```bash
363+
./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/dflash.yaml \
364+
model.model_name_or_path=/path/to/Qwen3-8B \
365+
data.data_path=/path/to/train.jsonl \
366+
training.output_dir=/path/to/output
367+
```
368+
369+
### Key Configuration (dflash.yaml)
370+
371+
| Field | Default | Description |
372+
|-------|---------|-------------|
373+
| `dflash.dflash_block_size` | 8 | Block size for parallel prediction |
374+
| `dflash.dflash_num_anchors` | 512 | Number of anchor positions per sample |
375+
| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables) |
376+
| `dflash.dflash_self_logit_distillation` | true | Use logit distillation from target |
377+
| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers |
378+
| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions |
379+
380+
### Export
381+
382+
```bash
383+
python scripts/export_hf_checkpoint.py \
384+
--model_path /path/to/training/output \
385+
--export_path /path/to/exported/model
386+
```
387+
388+
### Results
389+
390+
See [doc/dflash_results.md](doc/dflash_results.md) for benchmark results on Qwen3-8B.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# DFlash Block Diffusion — ModelOpt Training Results
2+
3+
Qwen3-8B target model, trained on nvidia/Nemotron-Post-Training-Dataset-v2 (2M samples)
4+
5+
## Key Metrics
6+
7+
| Benchmark | Acceptance Rate |
8+
|-----------|----------------|
9+
| **gsm8k** | **5.19** |
10+
| **MT-Bench** | **4.36** |
11+
12+
> Online validation, block_size=8, osl=512
13+
14+
## Training Configuration
15+
16+
| Parameter | Value |
17+
|-----------|-------|
18+
| Target Model | Qwen3-8B |
19+
| Draft Layers | 5 |
20+
| Block Size | 8 |
21+
| Sequence Length | 4096 |
22+
| Anchors per Sample | 512 |
23+
| Loss | KD (logit distillation) + exponential decay (gamma=4) |
24+
| Learning Rate | 6e-4 (linear decay) |
25+
| Epochs | 10 |
26+
| GPUs | 64 (8 nodes x 8 H100) |
27+
| Total Steps | 306,620 |
28+
| Final Loss | 1.129 |
29+
| Final Per-Token Acc | 67.0% |
30+
31+
## MT-Bench Per-Category AR (Online Validation)
32+
33+
80 prompts, block_size=8, osl=512, steps=7
34+
35+
| Category | 80K | 150K | 306K (final) |
36+
|----------|-----|------|-------------|
37+
| math | 5.44 | 5.54 | **5.52** |
38+
| extraction | 4.81 | 4.82 | **4.88** |
39+
| coding | 4.40 | 4.53 | **4.60** |
40+
| reasoning | 4.34 | 4.41 | **4.44** |
41+
| stem | 4.05 | 4.15 | **4.17** |
42+
| writing | 3.76 | 3.79 | **3.84** |
43+
| roleplay | 3.58 | 3.73 | **3.78** |
44+
| humanities | 3.55 | 3.62 | **3.65** |
45+
| **ALL** | **4.24** | **4.32** | **4.36** |
46+
47+
## Comparison with z-lab/Qwen3-8B-DFlash-b16
48+
49+
### ModelOpt Eval (online validation, osl=512)
50+
51+
| Dataset | z-lab | ModelOpt (306K) | Diff |
52+
|---------|-------|-----------------|------|
53+
| gsm8k | 4.10 | **5.19** | **+1.09** |
54+
| MT-Bench | 3.58 | **4.36** | **+0.78** |
55+
56+
### z-lab Official Eval (dflash.benchmark, osl=512)
57+
58+
| Dataset | z-lab | ModelOpt (306K) | Diff |
59+
|---------|-------|-----------------|------|
60+
| gsm8k | **5.00** | 4.08 | -0.92 |
61+
| MT-Bench | **3.28** | 2.99 | -0.29 |
62+
63+
> z-lab model trained with block_size=16. ModelOpt trained with block_size=8.
64+
65+
## Evaluation Method Impact (gsm8k)
66+
67+
| Eval Method | z-lab checkpoint | ModelOpt (306K) |
68+
|-------------|-----------------|-----------------|
69+
| Fixed GT (ModelOpt eval) | 2.95 | 4.23 |
70+
| Online GT (ModelOpt eval) | 4.10 | **5.19** |
71+
| z-lab official eval | **5.00** | 4.08 |
72+
73+
- **Fixed GT**: pre-compute greedy ground truth, check draft against it.
74+
- **Online GT**: recompute ground truth after each accepted draft (context-dependent).
75+
- **z-lab official**: actual speculative decoding with draft KV cache.
76+
77+
## Key Findings
78+
79+
| Finding | Evidence |
80+
|---------|----------|
81+
| Loss decay boosts AR | +0.12 AR at 55K steps (gamma=7, bs16); consistent across all checkpoints |
82+
| Longer sequences help | seq=4096 vs 512: +0.49 AR on AA-Synthetic at same checkpoint |
83+
| Online validation essential | Fixed GT underestimates by ~1.0 AR; context-dependent GT matches actual spec-decode |
84+
| Forward pass identical to z-lab | Max diff 0.5 (bf16 noise) on same mask_token_id; 6/7 draft tokens match |
85+
| sdpa vs flash_attn: negligible | Overall AR 3.31 vs 3.31; hidden states identical, logits differ <2% |

examples/speculative_decoding/eagle_utils.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def make_eagle_supervised_data_module(
141141
tokenizer: transformers.PreTrainedTokenizer,
142142
data_args,
143143
train_len=None,
144+
answer_only_loss=False,
144145
) -> dict:
145146
if data_args.offline_data_path is None:
146147
train_dataset = ShardedDataset("json", data_files=data_args.data_path)
@@ -150,6 +151,7 @@ def make_eagle_supervised_data_module(
150151
tokenizer=tokenizer,
151152
train_len=train_len,
152153
return_labels=True,
154+
answer_only_loss=answer_only_loss,
153155
)
154156
else:
155157
data_collator = VisionLanguageDataCollator(
@@ -205,6 +207,12 @@ def on_log(self, args, state, control, **kwargs):
205207
if not hasattr(state, "training_accs") or len(state.training_accs) == 0:
206208
return control
207209
average_acc = np.mean(state.training_accs, axis=0)
210+
# Always print accuracy to console
211+
try:
212+
acc_str = ", ".join(f"{a:.4f}" for a in np.array(average_acc).flatten())
213+
print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]")
214+
except Exception:
215+
print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}")
208216
if self.estimate_ar:
209217
# Calculate mean training AR since last log
210218
# NOTE: This is only an estimate of the real AR.
@@ -219,41 +227,64 @@ def on_log(self, args, state, control, **kwargs):
219227
est_ar += acc_cumprod
220228
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")
221229

230+
# Log accuracy to HF Trainer's logs dict (picked up by TensorBoard)
231+
logs = kwargs.get("logs") or {}
232+
for i, draft_acc in enumerate(average_acc):
233+
for j, step_acc in enumerate(draft_acc):
234+
logs[f"train_acc/parallel_{i}_step_{j}"] = float(step_acc)
235+
if self.estimate_ar:
236+
logs["estimated_training_ar"] = est_ar
237+
222238
# log to wandb
223-
if wandb and is_master():
224-
logs = kwargs.get("logs") or {}
239+
if hasattr(wandb, "init") and is_master():
225240
if logs:
226241
wandb.log({k: v for k, v in logs.items() if v is not None}, step=state.global_step)
227-
for i, draft_acc in enumerate(average_acc):
228-
for j, step_acc in enumerate(draft_acc):
229-
wandb.log(
230-
{f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step
231-
)
232-
if self.estimate_ar:
233-
wandb.log({"estimated_training_ar": est_ar}, step=state.global_step)
234242

235243
# reset training_accs
236244
state.training_accs = []
237245
return control
238246

239247
def on_step_end(self, args, state, control, **kwargs):
240-
"""Run AR validation periodically, if available."""
248+
"""Run AR validation periodically (single-GPU only).
249+
250+
AR validation with DDP is not supported because pseudo_speculative_generate
251+
runs only on rank 0 while other ranks deadlock waiting for collective ops.
252+
When world_size > 1, AR validation is skipped with a one-time warning.
253+
Use post-training AR validation instead (online_training.sh runs it after training).
254+
"""
241255
if self.ar_validate_steps <= 0:
242256
return control
243257
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
258+
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
259+
if not hasattr(self, "_ar_ddp_warned"):
260+
self._ar_ddp_warned = True
261+
print_rank_0(
262+
"=== WARNING === AR validation during training is not supported with "
263+
"DDP (world_size > 1). Skipping. Use post-training AR validation."
264+
)
265+
return control
266+
267+
model = kwargs["model"]
268+
raw_model = model.module if hasattr(model, "module") else model
269+
was_training = raw_model.training
270+
raw_model.eval()
244271
print_rank_0("Running AR validation...")
245272
try:
246-
ars = validate_ar(
247-
model=kwargs["model"],
248-
tokenizer=kwargs["processing_class"],
249-
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
250-
device=kwargs["model"].device,
251-
)
273+
with torch.no_grad():
274+
ars = validate_ar(
275+
model=raw_model,
276+
tokenizer=kwargs["processing_class"],
277+
ds=load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"],
278+
device=next(raw_model.parameters()).device,
279+
num_samples=8,
280+
)
252281
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
253-
if wandb and is_master():
282+
if wandb:
254283
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
255-
except Exception:
256-
print_rank_0("AR validation not available.")
284+
except Exception as e:
285+
print_rank_0(f"AR validation failed: {e}")
286+
if was_training:
287+
raw_model.train()
257288
return control
258289

259290

0 commit comments

Comments
 (0)