Skip to content
Closed
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
3f89ea9
add: DFlash block diffusion speculative decoding
ChenhanYu Mar 27, 2026
190cb3a
fix: rewrite DFlash to match SpecForge reference
ChenhanYu Mar 28, 2026
b7a2a7b
fix: correct mask_token_id and base model forward dispatch
ChenhanYu Mar 29, 2026
a310d96
add: auto-detect mask_token_id for DFlash across model families
ChenhanYu Mar 29, 2026
972dfaa
fix: prevent DDP deadlock during AR validation
ChenhanYu Mar 29, 2026
6c4eb80
fix: avoid DynamicModule dispatch loop in forward/training paths
ChenhanYu Mar 29, 2026
2c42363
fix: revert training/eval to super().forward() matching EAGLE pattern
ChenhanYu Mar 30, 2026
a279960
fix: DDP deadlock when no valid loss positions on a rank
ChenhanYu Mar 30, 2026
cbddc30
add: logit distillation option for DFlash training
ChenhanYu Mar 30, 2026
c53a66a
fix: print training accuracy to console at each log step
ChenhanYu Mar 30, 2026
2eabf57
fix: use response-only loss mask for DFlash training
ChenhanYu Mar 31, 2026
2a16232
fix: apply assistant_masks to labels in LanguageDataCollator
ChenhanYu Mar 31, 2026
e3b9930
fix: robust response-only loss mask via regex assistant span detection
ChenhanYu Mar 31, 2026
07066c2
docs: add DFlash section to speculative decoding README
ChenhanYu Mar 31, 2026
a32de63
fix: resolve DFlash components from base model architecture
ChenhanYu Mar 31, 2026
6a6a9ca
fix: enable response-only loss mask for DFlash training
ChenhanYu Mar 31, 2026
a777849
add: DFlash launcher example for Qwen3-8B
ChenhanYu Apr 1, 2026
2c56aca
fix: inline values in DFlash launcher YAML for --yaml compatibility
ChenhanYu Apr 1, 2026
306fc3e
add: unit tests for DFlash speculative decoding
ChenhanYu Apr 1, 2026
c4a3ecb
fix: add docstrings to DFlash classes for coverage check
ChenhanYu Apr 1, 2026
1c23ced
add: AR validation step to DFlash launcher pipeline
ChenhanYu Apr 1, 2026
38450b0
fix: split DFlash tests into CPU (unit) and GPU tests
ChenhanYu Apr 1, 2026
4c2fc77
fix: correct DFlash attention mask test for reverse-causal pattern
ChenhanYu Apr 1, 2026
bce17cf
fix: remove __init__.py from GPU test dirs to avoid conftest conflict
ChenhanYu Apr 1, 2026
1165272
fix: match dtype in DFlash GPU tests to model dtype
ChenhanYu Apr 1, 2026
273ba32
fix: use Optional types for nullable DFlash arguments
ChenhanYu Apr 1, 2026
9bf9c34
fix: merge AR validation into DFlash training script
ChenhanYu Apr 1, 2026
d19cd3b
fix: align pseudo_speculative_generate with training masks
ChenhanYu Apr 2, 2026
73bb0cc
fix: use standard causal mask within DFlash blocks
ChenhanYu Apr 2, 2026
3fa0d64
fix: increase DDP timeout to 1800s for DFlash training
ChenhanYu Apr 2, 2026
80afde2
fix: revert to SpecForge's reverse-causal mask (j >= i)
ChenhanYu Apr 2, 2026
bfdd582
fix: use continuing position IDs for DFlash inference block
ChenhanYu Apr 2, 2026
fb7acab
fix: remove attention mask at DFlash inference, matching SpecForge
ChenhanYu Apr 2, 2026
290670f
add: standalone DFlash training script with SpecForge data pipeline
ChenhanYu Apr 2, 2026
eb6a0c9
fix: create attention mask in f32 then cast, matching SpecForge
ChenhanYu Apr 2, 2026
2c853c1
fix: use HF attention dispatch in DFlashAttention for SpecForge parity
ChenhanYu Apr 2, 2026
d6adadb
fix: default DFlash attention to sdpa matching SpecForge
ChenhanYu Apr 2, 2026
65df160
fix: initialize DFlash weights with normal_(std=0.02) matching SpecForge
ChenhanYu Apr 2, 2026
4451101
debug: add attn_fn resolution and per-layer comparison prints
ChenhanYu Apr 2, 2026
2726068
feat: update DFlash training to match SpecForge latest (post-PR #473)
ChenhanYu Apr 3, 2026
3516c0b
fix: remove extra unsqueeze in DFlash training attention mask
ChenhanYu Apr 3, 2026
606e31d
fix: create training attention mask in f32 to avoid bf16 overflow
ChenhanYu Apr 3, 2026
e1237f7
fix: add dflash_num_anchors/loss_decay_gamma to launch_train.sh
ChenhanYu Apr 3, 2026
b8e5eb7
feat: add logit distillation to new random-anchor DFlash training
ChenhanYu Apr 3, 2026
b0df28c
fix: add dflash_use_logit_distillation to launch_train.sh
ChenhanYu Apr 3, 2026
4226349
fix: shift teacher logits by -1 for DFlash logit distillation
ChenhanYu Apr 3, 2026
818eb74
fix: mask all tokens when assistant pattern not found
ChenhanYu Apr 3, 2026
49038c7
feat: auto-inject generation tags for reliable answer_only_loss
ChenhanYu Apr 3, 2026
c49f6d9
fix: add <think> wrapper to simplified ChatML template for Qwen3
ChenhanYu Apr 3, 2026
ae2e7bd
fix: remove think wrapper from simplified ChatML template
ChenhanYu Apr 3, 2026
82eedb2
feat: add chatml_think template variant for Qwen3 think injection
ChenhanYu Apr 3, 2026
4ebb9de
docs: document simplified generation templates and limitations
ChenhanYu Apr 3, 2026
4561515
fix: ensure zero-loss path has gradient for DDP sync
ChenhanYu Apr 3, 2026
047ba1d
fix: prefer conversations field when messages lacks assistant turn
ChenhanYu Apr 3, 2026
bdcc0de
cleanup: remove debug prints and regex fallback in DFlash and dataset…
ChenhanYu Apr 3, 2026
e526016
fix: unwrap DDP model for AR validation to avoid deadlock
ChenhanYu Apr 3, 2026
633da55
fix: skip samples without assistant turns instead of crashing
ChenhanYu Apr 3, 2026
43afb06
fix: handle empty batch with dummy assistant turn
ChenhanYu Apr 3, 2026
90e9b4b
fix: AR validation deadlock - eval all ranks, validate on rank 0
ChenhanYu Apr 3, 2026
870db23
feat: add TensorBoard logging for DFlash training
ChenhanYu Apr 3, 2026
0d3f1fa
fix: skip AR validation during DDP training to prevent deadlock
ChenhanYu Apr 3, 2026
9b76b7d
feat: add DFlash export to z-lab compatible HF format
ChenhanYu Apr 4, 2026
1cfd558
fix: checkpoint resume + export-then-validate pipeline
ChenhanYu Apr 4, 2026
6a153bf
feat: auto-detect HEAD_NODE_IP for multi-node DFlash training
ChenhanYu Apr 4, 2026
c0c4330
fix: use explicit bfloat16 and device_map for export loading
ChenhanYu Apr 4, 2026
6684f47
fix: improve HEAD_NODE_IP auto-detection for multi-node
ChenhanYu Apr 4, 2026
dd6e282
fix: multi-method HEAD_NODE_IP detection for multi-node
ChenhanYu Apr 4, 2026
3efd659
fix: force dp_shard_size=1 for DFlash DDP training
ChenhanYu Apr 4, 2026
7d0028e
fix: support both DDP and FSDP for DFlash training
ChenhanYu Apr 4, 2026
496830d
fix: use AutoModelForCausalLM directly for export loading
ChenhanYu Apr 4, 2026
7255969
fix: use export_hf_checkpoint.py script for DFlash export
ChenhanYu Apr 4, 2026
3a8ff9c
fix: load model from output_dir for checkpoint resume
ChenhanYu Apr 4, 2026
ba0132c
add: DFlash results page + online validation for AR evaluation
ChenhanYu Apr 8, 2026
d5a5200
fix: rename dataset to nvidia/Nemotron-Post-Training-Dataset-v2
ChenhanYu Apr 8, 2026
7605414
chg: replace HTML results page with Markdown for GitHub rendering
ChenhanYu Apr 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,96 @@ trainer.save_state()
trainer.save_model("<path to the output directory>")
```

## DFlash: Block Diffusion for Flash Speculative Decoding

DFlash ([arXiv:2602.06036](https://arxiv.org/abs/2602.06036)) is a parallel speculative decoding method that predicts multiple tokens simultaneously using block diffusion. Unlike autoregressive methods (EAGLE, Medusa) that draft one token at a time, DFlash predicts an entire block of tokens in parallel, then iteratively denoises them.

### Architecture

DFlash uses three key mechanisms:

- **Feature Fusion**: Multi-layer hidden states from the target model are projected via a fully-connected layer and RMSNorm to create context features
- **KV Injection**: Context features are injected as K/V in every draft decoder layer, while Q comes from the noise embeddings. QK-Norm (RMSNorm on Q and K before RoPE) stabilizes attention
- **Parallel Drafting**: Within each block of size B, unknown positions use a `mask_token_id` token. Only block-start positions get the real token. The attention mask allows noise tokens to attend to all context tokens from previous blocks, plus causally within the same block

### Training

```bash
./launch_train.sh --model $BASE_MODEL \
--output_dir $OUTPUT_DIR \
--data input_conversations/train.jsonl \
--num_epochs $NUM_EPOCH \
--mode dflash \
--dflash_block_size 16 \
--dflash_num_layers 5
```

Key arguments:

| Flag | Default | Description |
|------|---------|-------------|
| `--mode dflash` | - | Enable DFlash mode |
| `--dflash_block_size` | 16 | Block size for parallel prediction |
| `--dflash_num_layers` | 5 | Number of decoder layers in draft module |
| `--dflash_config` | None | Path to JSON config for custom architecture |
| `--dflash_mask_token_id` | auto | Mask token ID (auto-detected from model) |
| `--dflash_disable_torch_compile` | False | Disable torch.compile |
| `--dflash_use_logit_distillation` | False | Use KD from target model logits instead of hard CE |

### mask_token_id

The `mask_token_id` is critical for DFlash training and inference. It must be consistent between training and deployment. Auto-detection logic:

| Model Family | mask_token_id | Source |
|-------------|---------------|--------|
| Qwen3.5 | 248070 | Built-in `[MASK]` token |
| Qwen3 (8B) | 151643 | `eos_token_id` |
| Llama 3 | 128002 | `reserved_special_token_0` |
| Others | `pad_token_id` | Fallback |

Override with `--dflash_mask_token_id <id>` if auto-detection is incorrect.

### Configuring Draft Model

Similar to EAGLE, provide a JSON config to customize the draft architecture:

```json
{
"num_hidden_layers": 5,
"rms_norm_eps": 1e-6
}
```

Model dimensions (hidden_size, num_attention_heads, etc.) are automatically inherited from the base model.

### Current Status (WIP)

| Feature | Status |
|---------|--------|
| Architecture (Feature Fusion, KV Injection, Parallel Drafting) | Working |
| Online training with HF Trainer | Working |
| Inference / AR validation (`pseudo_speculative_generate`) | Working |
| z-lab checkpoint loading and inference (AR 7-9) | Working |
| Logit distillation option | Working |
| Response-only loss masking | Working |
| DDP training | Working (with `find_unused_parameters=True`) |

**Known gap**: Training with ModelOpt achieves ~35% per-token accuracy (matching SpecForge's ~30%), but acceptance rate (AR) is lower than SpecForge-trained checkpoints (1.15 vs 1.95). Investigation shows the **data pipeline** differs significantly:

- SpecForge uses its own tokenizer template with system prompt and response-only loss mask
- ModelOpt's `LanguageDataCollator` uses `apply_chat_template` with different formatting

Aligning the data pipeline is the next step to close the AR gap.

## Support Matrix

| Model | Medusa | EAGLE1/2 | EAGLE3 |
| :---: | :---: | :---: | :---: |
| LLAMA 2 | ✅ | ✅ | ✅ |
| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ |
| Phi 3 | ✅ | ✅ | ✅ |
| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ |
| Model | Medusa | EAGLE1/2 | EAGLE3 | DFlash |
| :---: | :---: | :---: | :---: | :---: |
| LLAMA 2 | ✅ | ✅ | ✅ | ✅ |
| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ |
| Phi 3 | ✅ | ✅ | ✅ | ✅ |
| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | ✅ |

## Speculation Module Checkpoints

Expand Down
44 changes: 30 additions & 14 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def make_eagle_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
train_len=None,
answer_only_loss=False,
) -> dict:
if data_args.offline_data_path is None:
train_dataset = ShardedDataset("json", data_files=data_args.data_path)
Expand All @@ -148,6 +149,7 @@ def make_eagle_supervised_data_module(
tokenizer=tokenizer,
train_len=train_len,
return_labels=True,
answer_only_loss=answer_only_loss,
)
else:
data_collator = VisionLanguageDataCollator(
Expand Down Expand Up @@ -203,6 +205,12 @@ def on_log(self, args, state, control, **kwargs):
if not hasattr(state, "training_accs") or len(state.training_accs) == 0:
return control
average_acc = np.mean(state.training_accs, axis=0)
# Always print accuracy to console
try:
acc_str = ", ".join(f"{a:.4f}" for a in np.array(average_acc).flatten())
print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]")
except Exception:
print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}")
if self.estimate_ar:
# Calculate mean training AR since last log
# NOTE: This is only an estimate of the real AR.
Expand Down Expand Up @@ -235,23 +243,31 @@ def on_log(self, args, state, control, **kwargs):
return control

def on_step_end(self, args, state, control, **kwargs):
"""Run AR validation periodically, if available."""
"""Run AR validation periodically, if available.

Only runs on rank 0 to avoid DDP deadlock — other ranks skip and
synchronize via barrier.
"""
if self.ar_validate_steps <= 0:
return control
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
print_rank_0("Running AR validation...")
try:
ars = validate_ar(
model=kwargs["model"],
tokenizer=kwargs["processing_class"],
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb and is_master():
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
except Exception:
print_rank_0("AR validation not available.")
if is_master():
print_rank_0("Running AR validation...")
try:
ars = validate_ar(
model=kwargs["model"],
tokenizer=kwargs["processing_class"],
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
except Exception:
print_rank_0("AR validation not available.")
# Barrier to synchronize all ranks after validation
if torch.distributed.is_initialized():
torch.distributed.barrier()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
return control


Expand Down
44 changes: 37 additions & 7 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
FSDP="${1#*=}"
;;
--dflash_block_size*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_BLOCK_SIZE="${1#*=}"
;;
--dflash_num_layers*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_NUM_LAYERS="${1#*=}"
;;
--dflash_config*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_CONFIG="${1#*=}"
;;
--dflash_mask_token_id*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_MASK_TOKEN_ID="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -195,8 +211,20 @@ if [[ "$MODE" == "eagle3" ]]; then
else
SPECULATIVE_ARGS=""
fi
elif [[ "$MODE" == "dflash" ]]; then
DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16}
DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5}
SPECULATIVE_ARGS="--dflash_block_size $DFLASH_BLOCK_SIZE --dflash_num_layers $DFLASH_NUM_LAYERS --dflash_disable_torch_compile"
if [[ -n "$DFLASH_CONFIG" ]]; then
SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_config $DFLASH_CONFIG"
fi
if [[ -n "$DFLASH_MASK_TOKEN_ID" ]]; then
SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_mask_token_id $DFLASH_MASK_TOKEN_ID"
fi
# DFlash uses DDP instead of FSDP
FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 300"
else
echo "Only eagle3 supported for now!"
echo "Unsupported mode: $MODE. Supported: eagle3, dflash"
exit 1
fi

Expand All @@ -218,12 +246,14 @@ else
VLM_ARGS=""
fi

if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
else
#Otherwise, single GPU training
FSDP_ARGS=""
if [[ "$MODE" != "dflash" ]]; then
if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
else
#Otherwise, single GPU training
FSDP_ARGS=""
fi
fi


Expand Down
54 changes: 50 additions & 4 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class TrainingArguments(transformers.TrainingArguments):
)
dataloader_drop_last: bool = field(default=True)
bf16: bool = field(default=True)
mode: Literal["eagle3", "medusa"] = "eagle3"
mode: Literal["eagle3", "medusa", "dflash"] = "eagle3"
estimate_ar: bool = field(
default=False, metadata={"help": "Whether to estimate AR during training for logging."}
)
Expand Down Expand Up @@ -144,6 +144,32 @@ class EagleArguments:
)


@dataclass
class DFlashArguments:
dflash_block_size: int = field(
default=16, metadata={"help": "Block size for DFlash parallel prediction."}
)
dflash_num_layers: int = field(
default=5, metadata={"help": "Number of decoder layers in the DFlash draft module."}
)
dflash_config: str = field(default=None, metadata={"help": "Path to dflash_config.json"})
dflash_disable_torch_compile: bool = field(
default=False,
metadata={"help": "Disable torch.compile on DFlash forward/loss methods."},
)
dflash_mask_token_id: int = field(
default=None,
metadata={"help": "Mask token ID for DFlash. If not set, auto-detected from model."},
)
dflash_use_logit_distillation: bool = field(
default=False,
metadata={
"help": "Use logit distillation (KD from target model) instead of hard CE. "
"Enables training with data not synthesized by the target model."
},
)


def train():
parser = transformers.HfArgumentParser(
(
Expand All @@ -152,9 +178,10 @@ def train():
TrainingArguments,
MedusaArguments,
EagleArguments,
DFlashArguments,
)
)
model_args, data_args, training_args, medusa_args, eagle_args = (
model_args, data_args, training_args, medusa_args, eagle_args, dflash_args = (
parser.parse_args_into_dataclasses()
)
if not data_args.data_path and not data_args.offline_data_path:
Expand Down Expand Up @@ -236,13 +263,32 @@ def train():
)
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache)
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
elif training_args.mode == "dflash":
custom_config = (
json.load(open(dflash_args.dflash_config)) if dflash_args.dflash_config else {}
)
custom_config.setdefault("num_hidden_layers", dflash_args.dflash_num_layers)
if dflash_args.dflash_mask_token_id is not None:
custom_config["mask_token_id"] = dflash_args.dflash_mask_token_id

config = {
"dflash_block_size": dflash_args.dflash_block_size,
"dflash_use_torch_compile": not dflash_args.dflash_disable_torch_compile,
"dflash_self_logit_distillation": dflash_args.dflash_use_logit_distillation,
"dflash_architecture_config": custom_config,
}

mtsp.convert(model, [("dflash", config)])
else:
raise Exception(f"{training_args.mode} is not supported!")

print_rank_0("Loading dataset...")
if training_args.mode == "eagle3":
if training_args.mode in ("eagle3", "dflash"):
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, train_len=training_args.training_seq_len
tokenizer,
data_args,
train_len=training_args.training_seq_len,
answer_only_loss=(training_args.mode == "dflash"),
)

trainer = EagleTrainerWithAccLog(
Expand Down
48 changes: 48 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,54 @@
}


def _get_dflash_default_config():
from .dflash.default_config import default_dflash_config

return default_dflash_config


DFLASH_DEFAULT_CFG = {
"algorithm": "dflash",
"config": {
"dflash_architecture_config": {}, # merged with default at convert time
},
}


class DFlashConfig(ModeloptBaseConfig):
"""DFlash config for block-wise parallel speculative decoding."""

dflash_block_size: int = ModeloptField(
default=16,
description="Block size for parallel prediction. Draft predicts this many tokens per block.",
)

dflash_freeze_base_model: bool = ModeloptField(
default=True, description="Whether to freeze base model during DFlash module training."
)

dflash_self_logit_distillation: bool = ModeloptField(
default=True, description="Whether to use logit distillation from base model."
)

dflash_loss_decay_factor: float = ModeloptField(
default=0.9, description="Decay factor for per-block loss weighting."
)

dflash_report_acc: bool = ModeloptField(
default=True, description="Whether to report eval accuracy."
)

dflash_architecture_config: dict = ModeloptField(
default={}, description="Config for the DFlash draft module architecture."
)

dflash_use_torch_compile: bool = ModeloptField(
default=True,
description="Whether to use torch.compile on DFlash forward/loss methods.",
)


class MedusaConfig(ModeloptBaseConfig):
"""Medusa config."""

Expand Down
Loading
Loading