@@ -12,8 +12,79 @@ specific language governing permissions and limitations under the License.
1212
1313# DFlash
1414
15- ` DFlashPipeline ` performs block-diffusion speculative decoding using a diffusion draft model and a target causal LM.
16- The draft model is conditioned on target hidden features extracted during prefill and verification steps.
15+ [ DFlash] ( https://huggingface.co/collections/z-lab/dflash ) is a block-diffusion speculative decoding scheme. A small
16+ diffusion * draft* model proposes a block of tokens conditioned on hidden features extracted from intermediate layers
17+ of a frozen * target* causal LM; the target then verifies the proposed block in a single forward pass and accepts the
18+ longest matching prefix. The draft model is shared with the target's tokenizer, so no calibration is needed.
19+
20+ ` DFlashPipeline ` ties the two models together: prefill on the target, draft a block, verify against the target's
21+ posterior via [ ` DFlashTokenDiffusionScheduler ` ] , commit the accepted prefix and the next-token resample, and repeat
22+ until ` max_new_tokens ` or a stop token. Compatible draft/target pairs include ` z-lab/Qwen3-8B-DFlash-b16 ` with
23+ ` Qwen/Qwen3-8B ` , and ` z-lab/Qwen3.5-4B-DFlash ` with ` Qwen/Qwen3.5-4B ` (the latter is a hybrid-attention target — see
24+ the rollback note below).
25+
26+ ## Usage
27+
28+ ``` py
29+ import torch
30+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
31+
32+ from diffusers import DFlashPipeline
33+
34+ draft = AutoModel.from_pretrained(
35+ " z-lab/Qwen3.5-4B-DFlash" , trust_remote_code = True , dtype = torch.bfloat16, device_map = " auto"
36+ )
37+ target = AutoModelForCausalLM.from_pretrained(
38+ " Qwen/Qwen3.5-4B" , trust_remote_code = True , dtype = torch.bfloat16, device_map = " auto"
39+ )
40+ tokenizer = AutoTokenizer.from_pretrained(" Qwen/Qwen3.5-4B" , trust_remote_code = True )
41+
42+ pipe = DFlashPipeline(draft_model = draft, target_model = target, tokenizer = tokenizer)
43+ output = pipe(
44+ prompt = " What is 2 + 2? Answer in one sentence." ,
45+ max_new_tokens = 128 ,
46+ temperature = 0.0 ,
47+ chat_template_kwargs = {" enable_thinking" : False },
48+ )
49+ print (output.texts[0 ])
50+ ```
51+
52+ ` DFlashPipeline ` currently runs ` batch_size=1 ` only. Multi-prompt batching requires per-row partial-accept tracking
53+ and is not yet supported.
54+
55+ ## Hybrid-attention targets
56+
57+ For target models with linear-attention layers (e.g. Qwen3.5's gated-delta-net), ` DynamicCache.crop() ` silently
58+ no-ops on those layers, so a partial-accept block would otherwise leak rejected speculative tokens into the
59+ recurrent state. The pipeline detects linear-attention caches via
60+ [ ` DFlashTokenDiffusionScheduler.cache_has_linear_attention ` ] and uses a snapshot/restore + accepted-prefix
61+ re-forward pattern to advance both layer types cleanly. This adds one extra target forward per partial-accept
62+ block but is required for correctness.
63+
64+ ## Fast path
65+
66+ When the draft model exposes a ` spec_generate(...) ` method (e.g. ` z-lab/Qwen3-8B-DFlash-b16 ` ), the pipeline
67+ delegates to it — that loop is the upstream-canonical implementation and avoids re-running the rollback bookkeeping.
68+ Newer drafts (` z-lab/Qwen3.5-4B-DFlash ` ) drop ` spec_generate ` ; the pipeline falls back to its explicit verify loop.
69+
70+ ## Callbacks
71+
72+ Callbacks run after each block-verify step. Pass ` callback_on_step_end_tensor_inputs ` to select which tensors are
73+ included in ` callback_kwargs ` . Allowed keys: ` block_output_ids ` (the drafted block), ` draft_logits ` ,
74+ ` accepted_length ` , ` next_token ` , and ` output_ids ` (the running output buffer). Return ` {"output_ids": ...} ` from the
75+ callback to replace the buffer.
76+
77+ ``` py
78+ def on_step_end (pipe , step , timestep , callback_kwargs ):
79+ output_ids = callback_kwargs[" output_ids" ]
80+ return {" output_ids" : output_ids}
81+
82+ out = pipe(
83+ prompt = " ..." ,
84+ callback_on_step_end = on_step_end,
85+ callback_on_step_end_tensor_inputs = [" output_ids" ],
86+ )
87+ ```
1788
1889## DFlashPipeline
1990[[ autodoc]] DFlashPipeline
0 commit comments