Skip to content

Commit d08af30

Browse files
committed
Add I-DLM pipeline + scheduler (Introspective Diffusion Language Models)
Adds `IDLMPipeline` + `IDLMBlockDiffusionScheduler` implementing Introspective Strided Decoding (Yu et al., 2026). Mirrors the SDAR/DFlash conventions: the scheduler owns the pure-math accept/resample logic (min(1, p/(alpha*q)) with max(0, p - alpha*q) resampling on reject) and per-round new-spec sampling; the pipeline orchestrates the block-N ISD loop, cache management via DynamicCache, and chat-template handling. Each round is a single target-model forward over [pending, spec_0, ..., spec_{K-1}, MASK, ..., MASK] # length 2*N - 1 under strict causal attention. Under I-DLM's Dream-style logit shift, `logits[:, i, :]` predicts the token at input position `i+1`, so the same forward both verifies the pending specs (against the anchor p at their now- clean positions) and samples the next batch of specs from the MASK-position anchors. On partial accept, the corrected token seeds a cold-start next round. Also: as part of aligning with the standard transformers v5 cache convention (use_cache=True + past_key_values always stores), switch `SDARPipeline` from the dual `store_kv=True/False` kwarg to a `DynamicCache.crop(prev_seq_len)` snapshot-and-rollback after read-only denoising forwards. Same behavior, no custom model-side flag required. - src/diffusers/pipelines/idlm/ (pipeline + package init) - src/diffusers/schedulers/scheduling_idlm_block_diffusion.py - tests/pipelines/idlm/test_idlm.py (7 unit tests) - examples/discrete_diffusion/sample_idlm.py (inference) - examples/discrete_diffusion/train_idlm.py (reference training loop) - docs/source/en/api/{pipelines,schedulers}/idlm*.md + TOC entries - src/diffusers/{__init__,pipelines/__init__,schedulers/__init__}.py registrations - src/diffusers/pipelines/sdar/pipeline_sdar.py: crop-based retrieve-only End-to-end verified with `yifanyu/I-DLM-8B` (loaded via `refs/pr/2`): prompt: "What is 2+2?" -> "2 + 2 equals 4."
1 parent d67155c commit d08af30

14 files changed

Lines changed: 1492 additions & 25 deletions

File tree

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,8 @@
560560
title: HunyuanImage2.1
561561
- local: api/pipelines/hybrid_token_diffusion
562562
title: Hybrid Token Diffusion
563+
- local: api/pipelines/idlm
564+
title: I-DLM
563565
- local: api/pipelines/pix2pix
564566
title: InstructPix2Pix
565567
- local: api/pipelines/kandinsky
@@ -749,6 +751,8 @@
749751
title: HeunDiscreteScheduler
750752
- local: api/schedulers/hybrid_token_diffusion
751753
title: HybridTokenDiffusionScheduler
754+
- local: api/schedulers/idlm_block_diffusion
755+
title: IDLMBlockDiffusionScheduler
752756
- local: api/schedulers/ipndm
753757
title: IPNDMScheduler
754758
- local: api/schedulers/stochastic_karras_ve
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# I-DLM
14+
15+
[Introspective Diffusion Language Models (I-DLM)](https://arxiv.org/abs/2604.11035) are diffusion LLMs that recover the AR self-consistency property (the "introspective acceptance rate" of ~0.98) via strict causal attention, Dream-style logit shift, and all-masked training. At inference time, *Introspective Strided Decoding* (ISD) runs a single forward per round that both **verifies** previously-proposed speculative tokens (against the anchor distribution `p` at now-visible clean positions) and **generates** the next batch of specs (from the MASK-position proposal distribution `q`). Acceptance via `min(1, p(x) / (alpha * q(x)))` guarantees the output matches the base AR distribution.
16+
17+
Published I-DLM checkpoints (e.g. [`yifanyu/I-DLM-8B`](https://huggingface.co/yifanyu/I-DLM-8B)) are finetuned from standard Qwen3 weights and load via `AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True)`.
18+
19+
```python
20+
import torch
21+
from transformers import AutoModelForCausalLM, AutoTokenizer
22+
from diffusers import IDLMPipeline, IDLMBlockDiffusionScheduler
23+
24+
model = AutoModelForCausalLM.from_pretrained("yifanyu/I-DLM-8B", trust_remote_code=True, dtype=torch.bfloat16)
25+
tokenizer = AutoTokenizer.from_pretrained("yifanyu/I-DLM-8B", trust_remote_code=True)
26+
27+
scheduler = IDLMBlockDiffusionScheduler(gen_block_size=4)
28+
pipe = IDLMPipeline(model=model, tokenizer=tokenizer, scheduler=scheduler)
29+
out = pipe(prompt="Prove that sqrt(2) is irrational.", max_new_tokens=256, use_chat_template=True)
30+
print(out.texts[0])
31+
```
32+
33+
## IDLMPipeline
34+
[[autodoc]] IDLMPipeline
35+
- all
36+
- __call__
37+
38+
## IDLMPipelineOutput
39+
[[autodoc]] pipelines.IDLMPipelineOutput
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# IDLMBlockDiffusionScheduler
14+
15+
`IDLMBlockDiffusionScheduler` implements the block-N Introspective Strided Decoding step for I-DLM: speculative verification via `min(1, p/(alpha*q))` with `max(0, p - alpha*q)` resampling on reject, plus sampling of the next batch of speculative tokens from the MASK-position anchor logits. It is stateless and pure-math — the pipeline owns model and cache I/O.
16+
17+
## IDLMBlockDiffusionScheduler
18+
[[autodoc]] IDLMBlockDiffusionScheduler
19+
20+
## IDLMBlockDiffusionSchedulerOutput
21+
[[autodoc]] schedulers.scheduling_idlm_block_diffusion.IDLMBlockDiffusionSchedulerOutput
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#!/usr/bin/env python
2+
# Copyright 2025 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Sample script for I-DLM (Introspective Diffusion Language Model) block-N decoding.
18+
19+
Example:
20+
python sample_idlm.py \
21+
--model_id yifanyu/I-DLM-8B \
22+
--prompt "Prove that sqrt(2) is irrational." \
23+
--gen_block_size 4 \
24+
--max_new_tokens 256
25+
"""
26+
27+
import argparse
28+
29+
import torch
30+
from transformers import AutoModelForCausalLM, AutoTokenizer
31+
32+
from diffusers import IDLMBlockDiffusionScheduler, IDLMPipeline
33+
34+
35+
def main():
36+
parser = argparse.ArgumentParser(description="Run I-DLM introspective strided decoding.")
37+
parser.add_argument("--model_id", type=str, default="yifanyu/I-DLM-8B", help="Model ID or local path.")
38+
parser.add_argument(
39+
"--prompt",
40+
type=str,
41+
default="Prove that sqrt(2) is irrational.",
42+
help="Prompt text to generate from.",
43+
)
44+
parser.add_argument("--max_new_tokens", type=int, default=256)
45+
parser.add_argument(
46+
"--gen_block_size",
47+
type=int,
48+
default=4,
49+
help="Block size N: each ISD round commits up to N tokens. `block_size = 2*N - 1`.",
50+
)
51+
parser.add_argument("--temperature", type=float, default=1.0)
52+
parser.add_argument("--top_k", type=int, default=50)
53+
parser.add_argument("--top_p", type=float, default=0.95)
54+
parser.add_argument(
55+
"--verify_alpha",
56+
type=float,
57+
default=1.0,
58+
help="Leniency in the min(1, p/(alpha*q)) accept criterion. 1.0 = standard verify.",
59+
)
60+
parser.add_argument("--mask_token_id", type=int, default=None)
61+
parser.add_argument("--use_chat_template", action="store_true")
62+
parser.add_argument("--add_generation_prompt", action="store_true")
63+
parser.add_argument(
64+
"--enable_thinking",
65+
action="store_true",
66+
help="Enable <think>...</think> reasoning in the chat template. I-DLM is a Qwen3-derivative; "
67+
"thinking is ON by default in the Qwen3 template. Pass this flag to keep it on.",
68+
)
69+
parser.add_argument(
70+
"--device",
71+
type=str,
72+
default="cuda" if torch.cuda.is_available() else "cpu",
73+
)
74+
parser.add_argument(
75+
"--dtype",
76+
type=str,
77+
default="auto",
78+
choices=["auto", "float32", "float16", "bfloat16"],
79+
)
80+
parser.add_argument("--revision", type=str, default=None)
81+
parser.add_argument("--seed", type=int, default=None)
82+
83+
args = parser.parse_args()
84+
85+
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
86+
torch_dtype = dtype_map.get(args.dtype)
87+
88+
print(f"Loading model: {args.model_id}")
89+
model = AutoModelForCausalLM.from_pretrained(
90+
args.model_id,
91+
trust_remote_code=True,
92+
dtype=torch_dtype if torch_dtype is not None else "auto",
93+
device_map=args.device,
94+
revision=args.revision,
95+
)
96+
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision)
97+
if tokenizer.mask_token is None:
98+
tokenizer.add_special_tokens({"mask_token": "<|MASK|>"})
99+
100+
scheduler = IDLMBlockDiffusionScheduler(
101+
gen_block_size=args.gen_block_size,
102+
temperature=args.temperature,
103+
top_k=args.top_k,
104+
top_p=args.top_p,
105+
verify_alpha=args.verify_alpha,
106+
)
107+
pipe = IDLMPipeline(model=model, tokenizer=tokenizer, scheduler=scheduler)
108+
109+
generator = None
110+
if args.seed is not None:
111+
generator = torch.Generator(device=args.device).manual_seed(args.seed)
112+
113+
print(f"\nPrompt: {args.prompt}")
114+
chat_template_kwargs = {"enable_thinking": bool(args.enable_thinking)}
115+
output = pipe(
116+
prompt=args.prompt,
117+
max_new_tokens=args.max_new_tokens,
118+
gen_block_size=args.gen_block_size,
119+
temperature=args.temperature,
120+
top_k=args.top_k,
121+
top_p=args.top_p,
122+
verify_alpha=args.verify_alpha,
123+
mask_token_id=args.mask_token_id,
124+
use_chat_template=args.use_chat_template,
125+
add_generation_prompt=args.add_generation_prompt,
126+
chat_template_kwargs=chat_template_kwargs,
127+
generator=generator,
128+
)
129+
130+
print("\nGenerated text:")
131+
print(
132+
output.texts[0]
133+
if output.texts is not None
134+
else tokenizer.decode(output.sequences[0], skip_special_tokens=True)
135+
)
136+
print(f"\nGenerated {output.sequences.shape[1]} tokens")
137+
138+
139+
if __name__ == "__main__":
140+
main()

0 commit comments

Comments
 (0)