Skip to content

Commit 715ca50

Browse files
committed
[Docs] flesh out DFlash pipeline + scheduler pages
1 parent 7dd7d9c commit 715ca50

2 files changed

Lines changed: 86 additions & 4 deletions

File tree

docs/source/en/api/pipelines/dflash.md

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,79 @@ specific language governing permissions and limitations under the License.
1212

1313
# DFlash
1414

15-
`DFlashPipeline` performs block-diffusion speculative decoding using a diffusion draft model and a target causal LM.
16-
The draft model is conditioned on target hidden features extracted during prefill and verification steps.
15+
[DFlash](https://huggingface.co/collections/z-lab/dflash) is a block-diffusion speculative decoding scheme. A small
16+
diffusion *draft* model proposes a block of tokens conditioned on hidden features extracted from intermediate layers
17+
of a frozen *target* causal LM; the target then verifies the proposed block in a single forward pass and accepts the
18+
longest matching prefix. The draft model is shared with the target's tokenizer, so no calibration is needed.
19+
20+
`DFlashPipeline` ties the two models together: prefill on the target, draft a block, verify against the target's
21+
posterior via [`DFlashTokenDiffusionScheduler`], commit the accepted prefix and the next-token resample, and repeat
22+
until `max_new_tokens` or a stop token. Compatible draft/target pairs include `z-lab/Qwen3-8B-DFlash-b16` with
23+
`Qwen/Qwen3-8B`, and `z-lab/Qwen3.5-4B-DFlash` with `Qwen/Qwen3.5-4B` (the latter is a hybrid-attention target — see
24+
the rollback note below).
25+
26+
## Usage
27+
28+
```py
29+
import torch
30+
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
31+
32+
from diffusers import DFlashPipeline
33+
34+
draft = AutoModel.from_pretrained(
35+
"z-lab/Qwen3.5-4B-DFlash", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
36+
)
37+
target = AutoModelForCausalLM.from_pretrained(
38+
"Qwen/Qwen3.5-4B", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
39+
)
40+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B", trust_remote_code=True)
41+
42+
pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer)
43+
output = pipe(
44+
prompt="What is 2 + 2? Answer in one sentence.",
45+
max_new_tokens=128,
46+
temperature=0.0,
47+
chat_template_kwargs={"enable_thinking": False},
48+
)
49+
print(output.texts[0])
50+
```
51+
52+
`DFlashPipeline` currently runs `batch_size=1` only. Multi-prompt batching requires per-row partial-accept tracking
53+
and is not yet supported.
54+
55+
## Hybrid-attention targets
56+
57+
For target models with linear-attention layers (e.g. Qwen3.5's gated-delta-net), `DynamicCache.crop()` silently
58+
no-ops on those layers, so a partial-accept block would otherwise leak rejected speculative tokens into the
59+
recurrent state. The pipeline detects linear-attention caches via
60+
[`DFlashTokenDiffusionScheduler.cache_has_linear_attention`] and uses a snapshot/restore + accepted-prefix
61+
re-forward pattern to advance both layer types cleanly. This adds one extra target forward per partial-accept
62+
block but is required for correctness.
63+
64+
## Fast path
65+
66+
When the draft model exposes a `spec_generate(...)` method (e.g. `z-lab/Qwen3-8B-DFlash-b16`), the pipeline
67+
delegates to it — that loop is the upstream-canonical implementation and avoids re-running the rollback bookkeeping.
68+
Newer drafts (`z-lab/Qwen3.5-4B-DFlash`) drop `spec_generate`; the pipeline falls back to its explicit verify loop.
69+
70+
## Callbacks
71+
72+
Callbacks run after each block-verify step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
73+
included in `callback_kwargs`. Allowed keys: `block_output_ids` (the drafted block), `draft_logits`,
74+
`accepted_length`, `next_token`, and `output_ids` (the running output buffer). Return `{"output_ids": ...}` from the
75+
callback to replace the buffer.
76+
77+
```py
78+
def on_step_end(pipe, step, timestep, callback_kwargs):
79+
output_ids = callback_kwargs["output_ids"]
80+
return {"output_ids": output_ids}
81+
82+
out = pipe(
83+
prompt="...",
84+
callback_on_step_end=on_step_end,
85+
callback_on_step_end_tensor_inputs=["output_ids"],
86+
)
87+
```
1788

1889
## DFlashPipeline
1990
[[autodoc]] DFlashPipeline

docs/source/en/api/schedulers/dflash_token_diffusion.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,19 @@ specific language governing permissions and limitations under the License.
1212

1313
# DFlashTokenDiffusionScheduler
1414

15-
`DFlashTokenDiffusionScheduler` implements the acceptance and posterior sampling logic used in DFlash-style block
16-
diffusion speculative decoding.
15+
[`DFlashTokenDiffusionScheduler`] implements the verification step for DFlash-style block-diffusion speculative
16+
decoding. It samples a posterior block from the target logits, computes the acceptance length as the longest prefix
17+
where the draft proposal matches the posterior, and exposes the resampled `next_token` for the first rejected
18+
position. Used by [`DFlashPipeline`].
19+
20+
The scheduler also owns three helpers used by the pipeline's verify loop on hybrid-attention targets:
21+
22+
- `cache_has_linear_attention(cache)` — detect whether a `DynamicCache` contains any linear-attention layers.
23+
- `snapshot_cache(cache)` / `restore_cache(cache, snapshot)` — clone and restore the full per-layer state so a
24+
partial-accept block can be rolled back and the target re-advanced on just the accepted prefix.
25+
26+
These exist because `DynamicCache.crop()` silently no-ops on linear-attention layers, which would otherwise let
27+
rejected speculative tokens permanently contaminate the recurrent state.
1728

1829
## DFlashTokenDiffusionScheduler
1930
[[autodoc]] DFlashTokenDiffusionScheduler

0 commit comments

Comments
 (0)