Skip to content

Commit 7dd7d9c

Browse files
committed
[Discrete Diffusion] Add DFlash pipeline
Adds DFlashPipeline + DFlashTokenDiffusionScheduler for block-diffusion speculative decoding with a draft DFlash model and a target causal LM. Verified against the six bug patterns surfaced in the LLaDA2 review (#13598). DFlash sidesteps most of them by being batch_size=1 only and relying on the causal default for attention; the applicable patterns (#3 callback bindings, #4 EOS at first generated position, #6 inner progress-bar config preservation) are pinned by regression tests. Public surface mirrors the LLaDA2 / SDAR / IDLM conventions: lazy import, dummy objects, scheduler + output dataclass, pipeline + output dataclass, fast tests for both, scheduler doc page, pipeline doc page. Sample/train scripts under examples/discrete_diffusion/.
1 parent a851ce1 commit 7dd7d9c

16 files changed

Lines changed: 2227 additions & 0 deletions

File tree

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@
648648
title: Z-Image
649649
title: Image
650650
- sections:
651+
- local: api/pipelines/dflash
652+
title: DFlash
651653
- local: api/pipelines/llada2
652654
title: LLaDA2
653655
title: Text
@@ -711,6 +713,8 @@
711713
title: DDPMScheduler
712714
- local: api/schedulers/deis
713715
title: DEISMultistepScheduler
716+
- local: api/schedulers/dflash_token_diffusion
717+
title: DFlashTokenDiffusionScheduler
714718
- local: api/schedulers/multistep_dpm_solver_inverse
715719
title: DPMSolverMultistepInverse
716720
- local: api/schedulers/multistep_dpm_solver
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# DFlash
14+
15+
`DFlashPipeline` performs block-diffusion speculative decoding using a diffusion draft model and a target causal LM.
16+
The draft model is conditioned on target hidden features extracted during prefill and verification steps.
17+
18+
## DFlashPipeline
19+
[[autodoc]] DFlashPipeline
20+
- all
21+
- __call__
22+
23+
## DFlashPipelineOutput
24+
[[autodoc]] pipelines.DFlashPipelineOutput
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# DFlashTokenDiffusionScheduler
14+
15+
`DFlashTokenDiffusionScheduler` implements the acceptance and posterior sampling logic used in DFlash-style block
16+
diffusion speculative decoding.
17+
18+
## DFlashTokenDiffusionScheduler
19+
[[autodoc]] DFlashTokenDiffusionScheduler
20+
21+
## DFlashTokenDiffusionSchedulerOutput
22+
[[autodoc]] schedulers.scheduling_dflash_token_diffusion.DFlashTokenDiffusionSchedulerOutput
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#!/usr/bin/env python
2+
# Copyright 2025 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Sample script for DFlash speculative decoding.
18+
19+
Example:
20+
python sample_dflash.py \
21+
--draft_model_id z-lab/Qwen3-8B-DFlash-b16 \
22+
--target_model_id Qwen/Qwen3-8B \
23+
--prompt "How many positive whole-number divisors does 196 have?" \
24+
--max_new_tokens 256
25+
"""
26+
27+
import argparse
28+
29+
import torch
30+
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
31+
32+
from diffusers import DFlashPipeline
33+
34+
35+
def main():
36+
parser = argparse.ArgumentParser(description="Run DFlash speculative decoding.")
37+
parser.add_argument(
38+
"--draft_model_id",
39+
type=str,
40+
default="z-lab/Qwen3-8B-DFlash-b16",
41+
help="Draft model ID or local path.",
42+
)
43+
parser.add_argument(
44+
"--target_model_id",
45+
type=str,
46+
default="Qwen/Qwen3-8B",
47+
help="Target model ID or local path.",
48+
)
49+
parser.add_argument(
50+
"--prompt",
51+
type=str,
52+
default="How many positive whole-number divisors does 196 have?",
53+
help="Prompt text to generate from.",
54+
)
55+
parser.add_argument(
56+
"--max_new_tokens",
57+
type=int,
58+
default=2048,
59+
help="Maximum number of new tokens to generate.",
60+
)
61+
parser.add_argument(
62+
"--temperature",
63+
type=float,
64+
default=0.0,
65+
help="Sampling temperature.",
66+
)
67+
parser.add_argument(
68+
"--use_chat_template",
69+
action="store_true",
70+
help="Use the tokenizer chat template for the prompt.",
71+
)
72+
parser.add_argument(
73+
"--add_generation_prompt",
74+
action="store_true",
75+
help="Add the generation prompt when using the chat template.",
76+
)
77+
parser.add_argument(
78+
"--enable_thinking",
79+
action="store_true",
80+
help="Enable chat-template thinking mode if supported by the tokenizer.",
81+
)
82+
parser.add_argument(
83+
"--mask_token",
84+
type=str,
85+
default="<|MASK|>",
86+
help="Mask token to add if the tokenizer does not define one.",
87+
)
88+
parser.add_argument(
89+
"--device",
90+
type=str,
91+
default="cuda" if torch.cuda.is_available() else "cpu",
92+
help="Device to run inference on.",
93+
)
94+
parser.add_argument(
95+
"--dtype",
96+
type=str,
97+
default="auto",
98+
choices=["auto", "float32", "float16", "bfloat16"],
99+
help="Model dtype.",
100+
)
101+
102+
args = parser.parse_args()
103+
104+
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
105+
torch_dtype = dtype_map.get(args.dtype)
106+
107+
print(f"Loading draft model: {args.draft_model_id}")
108+
print(f"Loading target model: {args.target_model_id}")
109+
dtype_arg = torch_dtype if torch_dtype is not None else "auto"
110+
# Draft model is a custom DFlashDraftModel; use AutoModel so trust_remote_code routes to the class in `auto_map`.
111+
draft_model = AutoModel.from_pretrained(
112+
args.draft_model_id,
113+
trust_remote_code=True,
114+
dtype=dtype_arg,
115+
device_map=args.device,
116+
)
117+
target_model = AutoModelForCausalLM.from_pretrained(
118+
args.target_model_id,
119+
dtype=dtype_arg,
120+
device_map=args.device,
121+
)
122+
tokenizer = AutoTokenizer.from_pretrained(args.target_model_id)
123+
if tokenizer.mask_token is None:
124+
tokenizer.add_special_tokens({"mask_token": args.mask_token})
125+
pipe = DFlashPipeline(draft_model=draft_model, target_model=target_model, tokenizer=tokenizer)
126+
127+
chat_kwargs = {"enable_thinking": args.enable_thinking}
128+
129+
print(f"\nPrompt: {args.prompt}")
130+
output = pipe(
131+
prompt=args.prompt,
132+
max_new_tokens=args.max_new_tokens,
133+
temperature=args.temperature,
134+
use_chat_template=args.use_chat_template,
135+
add_generation_prompt=args.add_generation_prompt,
136+
chat_template_kwargs=chat_kwargs,
137+
)
138+
139+
print("\nGenerated text:")
140+
print(output.texts[0])
141+
print(f"\nGenerated {output.sequences.shape[1]} tokens")
142+
143+
144+
if __name__ == "__main__":
145+
main()

0 commit comments

Comments
 (0)