Skip to content

Commit 5d207e7

Browse files
kashifyiyixuxudg845
authored
[Discrete Diffusion] Add LLaDA2 pipeline (#13226)
* feat: add LLaDA2 and BlockRefinement pipelines for discrete text diffusion Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation: - BlockRefinementPipeline: block-wise iterative refinement with confidence-based token commitment, supporting editing threshold for LLaDA2.1 models - LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults - DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p, temperature) and prompt/prefix helpers - compute_confidence_aware_loss: CAP-style training loss - Examples: sampling scripts for LLaDA2 and block refinement, training scripts with Qwen causal LM - Docs and tests included * feat: add BlockRefinementScheduler for commit-by-confidence scheduling Extract the confidence-based token commit logic from BlockRefinementPipeline into a dedicated BlockRefinementScheduler, following diffusers conventions. The scheduler owns: - Transfer schedule computation (get_num_transfer_tokens) - Timestep management (set_timesteps) - Step logic: confidence-based mask-filling and optional token editing The pipeline now delegates scheduling to self.scheduler.step() and accepts a scheduler parameter in __init__. * test: add unit tests for BlockRefinementScheduler 12 tests covering set_timesteps, get_num_transfer_tokens, step logic (confidence-based commits, threshold behavior, editing, prompt masking, batched inputs, tuple output). * docs: add toctree entries and standalone scheduler doc page - Add BlockRefinement and LLaDA2 to docs sidebar navigation - Add BlockRefinementScheduler to schedulers sidebar navigation - Move scheduler autodoc to its own page under api/schedulers/ * feat: add --revision flag and fix dtype deprecation in sample_llada2.py - Add --revision argument for loading model revisions from the Hub - Replace deprecated torch_dtype with dtype for transformers 5.x compat * fix: use 1/0 attention mask instead of 0/-inf for LLaDA2 compat LLaDA2 models expect a boolean-style (1/0) attention mask, not an additive (0/-inf) mask. The model internally converts to additive, so passing 0/-inf caused double-masking and gibberish output. * refactor: consolidate training scripts into single train_block_refinement.py - Remove toy train_block_refinement_cap.py (self-contained demo with tiny model) - Rename train_block_refinement_qwen_cap.py to train_block_refinement.py (already works with any causal LM via AutoModelForCausalLM) - Fix torch_dtype deprecation and update README with correct script names * fix formatting * docs: improve LLaDA2 and BlockRefinement documentation - Add usage examples with real model IDs and working code - Add recommended parameters table for LLaDA2.1 quality/speed modes - Note that editing is LLaDA2.1-only (not for LLaDA2.0 models) - Remove misleading config defaults section from BlockRefinement docs * feat: set LLaDA2Pipeline defaults to recommended model parameters - threshold: 0.95 -> 0.7 (quality mode) - max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0) - eos_early_stop: False -> True (stop at EOS token) block_length=32, steps=32, temperature=0.0 were already correct. editing_threshold remains None (users enable for LLaDA2.1 models). * feat: default editing_threshold=0.5 for LLaDA2.1 quality mode LLaDA2.1 is the current generation. Users with LLaDA2.0 models can disable editing by passing editing_threshold=None. * fix: align sampling utilities with official LLaDA2 implementation - top_p filtering: add shift-right to preserve at least one token above threshold (matches official code line 1210) - temperature ordering: apply scaling before top-k/top-p filtering so filtering operates on scaled logits (matches official code lines 1232-1235) - greedy branch: return argmax directly when temperature=0 without filtering (matches official code lines 1226-1230) * refactor: remove duplicate prompt encoding, reuse mixin's _prepare_input_ids LLaDA2Pipeline._prepare_prompt_ids was a near-copy of DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate and call the mixin method directly. Also simplify _extract_input_ids since we always pass return_dict=True. * formatting * fix: replace deprecated torch_dtype with dtype in examples and docstrings - Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID - Fix sample_block_refinement.py to use dtype= * remove BlockRefinementPipeline * cleanup * fix readme * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * removed DiscreteDiffusionPipelineMixin * add support for 2d masks for flash attn * Update src/diffusers/training_utils.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/training_utils.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix issues from review * added tests * formatting * add check_eos_finished to scheduler * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_block_refinement.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_block_refinement.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix renaming issues and types * remove duplicate check * Update docs/source/en/api/pipelines/llada2.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent e358ddc commit 5d207e7

File tree

20 files changed

+2663
-1
lines changed

20 files changed

+2663
-1
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,8 @@
580580
title: Latent Diffusion
581581
- local: api/pipelines/ledits_pp
582582
title: LEDITS++
583+
- local: api/pipelines/llada2
584+
title: LLaDA2
583585
- local: api/pipelines/longcat_image
584586
title: LongCat-Image
585587
- local: api/pipelines/lumina2
@@ -718,6 +720,8 @@
718720
- sections:
719721
- local: api/schedulers/overview
720722
title: Overview
723+
- local: api/schedulers/block_refinement
724+
title: BlockRefinementScheduler
721725
- local: api/schedulers/cm_stochastic_iterative
722726
title: CMStochasticIterativeScheduler
723727
- local: api/schedulers/ddim_cogvideox
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# LLaDA2
14+
15+
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) is a family of discrete diffusion language models
16+
that generate text through block-wise iterative refinement. Instead of autoregressive token-by-token generation,
17+
LLaDA2 starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement
18+
steps.
19+
20+
## Usage
21+
22+
```py
23+
import torch
24+
from transformers import AutoModelForCausalLM, AutoTokenizer
25+
26+
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
27+
28+
model_id = "inclusionAI/LLaDA2.1-mini"
29+
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto")
30+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
31+
scheduler = BlockRefinementScheduler()
32+
33+
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
34+
output = pipe(
35+
prompt="Write a short poem about the ocean.",
36+
gen_length=256,
37+
block_length=32,
38+
num_inference_steps=32,
39+
threshold=0.7,
40+
editing_threshold=0.5,
41+
max_post_steps=16,
42+
temperature=0.0,
43+
)
44+
print(output.texts[0])
45+
```
46+
47+
## Callbacks
48+
49+
Callbacks run after each refinement step and can inspect or modify the current tokens.
50+
51+
```py
52+
def on_step_end(pipe, step, timestep, callback_kwargs):
53+
cur_x = callback_kwargs["cur_x"]
54+
# Inspect or modify `cur_x` here.
55+
return {"cur_x": cur_x}
56+
57+
out = pipe(
58+
prompt="Write a short poem.",
59+
callback_on_step_end=on_step_end,
60+
callback_on_step_end_tensor_inputs=["cur_x"],
61+
)
62+
```
63+
64+
## Recommended parameters
65+
66+
LLaDA2.1 models support two modes:
67+
68+
| Mode | `threshold` | `editing_threshold` | `max_post_steps` |
69+
|------|-------------|---------------------|------------------|
70+
| Quality | 0.7 | 0.5 | 16 |
71+
| Speed | 0.5 | 0.0 | 16 |
72+
73+
For LLaDA2.0 models, disable editing by passing `editing_threshold=None`.
74+
75+
For all models: `block_length=32`, `temperature=0.0`, `steps=32`.
76+
77+
## LLaDA2Pipeline
78+
[[autodoc]] LLaDA2Pipeline
79+
- all
80+
- __call__
81+
82+
## LLaDA2PipelineOutput
83+
[[autodoc]] pipelines.LLaDA2PipelineOutput

docs/source/en/api/pipelines/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
6363
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
6464
| [Latte](latte) | text2image |
6565
| [LEDITS++](ledits_pp) | image editing |
66+
| [LLaDA2](llada2) | text2text |
6667
| [Lumina-T2X](lumina) | text2image |
6768
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
6869
| [MultiDiffusion](panorama) | text2image |
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# BlockRefinementScheduler
14+
15+
The `BlockRefinementScheduler` manages block-wise iterative refinement for discrete token diffusion. At each step it
16+
commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different
17+
token with high confidence.
18+
19+
This scheduler is used by [`LLaDA2Pipeline`].
20+
21+
## BlockRefinementScheduler
22+
[[autodoc]] BlockRefinementScheduler
23+
24+
## BlockRefinementSchedulerOutput
25+
[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Discrete Token Diffusion (Experimental)
2+
3+
This folder contains **training and sampling examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions.
4+
5+
## LLaDA2
6+
7+
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) generates text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, it starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement steps.
8+
9+
### Train
10+
11+
The training script uses confidence-aware loss and works with any causal LM from the Hub (e.g. Qwen, Llama, Mistral):
12+
13+
```bash
14+
accelerate launch examples/discrete_diffusion/train_llada2.py \
15+
--model_name_or_path Qwen/Qwen2.5-0.5B \
16+
--dataset_name wikitext \
17+
--dataset_config_name wikitext-2-raw-v1 \
18+
--text_column text \
19+
--output_dir llada2-output \
20+
--max_train_steps 1000 \
21+
--prompt_length 32 \
22+
--block_length 32 \
23+
--lambda_conf 2.0 \
24+
--conf_temperature 0.5
25+
```
26+
27+
If you don't want to download a dataset, you can use random-token data:
28+
29+
```bash
30+
accelerate launch examples/discrete_diffusion/train_llada2.py \
31+
--model_name_or_path Qwen/Qwen2.5-0.5B \
32+
--output_dir llada2-output \
33+
--use_dummy_data \
34+
--num_dummy_samples 2048
35+
```
36+
37+
### Sample
38+
39+
```bash
40+
python examples/discrete_diffusion/sample_llada2.py \
41+
--model_id inclusionAI/LLaDA2.1-mini \
42+
--prompt "Write a short poem about the ocean." \
43+
--gen_length 256 \
44+
--num_inference_steps 32 \
45+
--threshold 0.7 \
46+
--editing_threshold 0.5 \
47+
--max_post_steps 16 \
48+
--use_chat_template \
49+
--add_generation_prompt
50+
```

0 commit comments

Comments
 (0)