Skip to content

Commit 2c1405b

Browse files
committed
[LLaDA2] address review findings from #13598
Fixes the six in-scope issues raised in the llada2 model/pipeline review: 1. Carry tokenizer `attention_mask` through `_prepare_input_ids` and add an `attention_mask` arg to `__call__` for pre-tokenized inputs. The runtime mask now reflects prompt padding and zeros out the block-aligned tail past `prompt_length + gen_length` instead of treating those positions as valid context. 2. Thread the per-call `block_length` into `BlockRefinementScheduler.set_timesteps` so the transfer schedule matches the requested block size (previously the scheduler only read its constructor default). 3. Drop `x0`/`x0_p`/`confidence` from `_callback_tensor_inputs` (never bound locals) and bind `sampled_tokens`, `sampled_probs`, `editing_transfer_index`, `active_block` so all advertised callback keys resolve. 4. Allow EOS exactly at index `prompt_length` (the first generated position) to mark a row finished. 5. Freeze rows that have already emitted EOS so subsequent block refinement doesn't extend them, and trim per-row at decode (previously gated on batch_size==1) so post-EOS positions don't leak into decoded text. 6. Stop calling `self.set_progress_bar_config(...)` from inside `__call__`; build a local config dict for the inner block bar so user-supplied flags (in particular `disable=True`) survive the call. Adds regression tests pinning each of the six fixes.
1 parent a851ce1 commit 2c1405b

3 files changed

Lines changed: 278 additions & 31 deletions

File tree

src/diffusers/pipelines/llada2/pipeline_llada2.py

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,14 @@ class LLaDA2Pipeline(DiffusionPipeline):
7171
scheduler: BlockRefinementScheduler
7272
tokenizer: Any
7373

74-
_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]
74+
_callback_tensor_inputs = [
75+
"block_x",
76+
"transfer_index",
77+
"editing_transfer_index",
78+
"sampled_tokens",
79+
"sampled_probs",
80+
"active_block",
81+
]
7582

7683
def __init__(
7784
self,
@@ -99,16 +106,28 @@ def _prepare_input_ids(
99106
use_chat_template: bool,
100107
add_generation_prompt: bool,
101108
chat_template_kwargs: dict[str, Any] | None,
102-
) -> torch.LongTensor:
103-
"""Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
109+
attention_mask: torch.LongTensor | None = None,
110+
) -> tuple[torch.LongTensor, torch.LongTensor]:
111+
"""Convert prompt/messages/input_ids to `(input_ids, attention_mask)` tensors of shape `[batch, seq]`."""
104112
if input_ids is not None:
105113
if input_ids.ndim == 1:
106114
input_ids = input_ids.unsqueeze(0)
107115
if input_ids.ndim != 2:
108116
raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.")
109117
if input_ids.dtype != torch.long:
110118
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
111-
return input_ids
119+
if attention_mask is None:
120+
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
121+
else:
122+
if attention_mask.ndim == 1:
123+
attention_mask = attention_mask.unsqueeze(0)
124+
if attention_mask.shape != input_ids.shape:
125+
raise ValueError(
126+
f"`attention_mask` shape {tuple(attention_mask.shape)} must match `input_ids` shape "
127+
f"{tuple(input_ids.shape)}."
128+
)
129+
attention_mask = attention_mask.to(dtype=torch.long)
130+
return input_ids, attention_mask
112131

113132
if self.tokenizer is None:
114133
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
@@ -129,7 +148,11 @@ def _prepare_input_ids(
129148
return_dict=True,
130149
**chat_template_kwargs,
131150
)
132-
return encoded["input_ids"]
151+
ids = encoded["input_ids"]
152+
mask = encoded.get("attention_mask")
153+
if mask is None:
154+
mask = torch.ones_like(ids, dtype=torch.long)
155+
return ids, mask.to(dtype=torch.long)
133156

134157
if use_chat_template and getattr(self.tokenizer, "chat_template", None):
135158
if isinstance(prompt, list):
@@ -142,10 +165,18 @@ def _prepare_input_ids(
142165
return_dict=True,
143166
**chat_template_kwargs,
144167
)
145-
return encoded["input_ids"]
168+
ids = encoded["input_ids"]
169+
mask = encoded.get("attention_mask")
170+
if mask is None:
171+
mask = torch.ones_like(ids, dtype=torch.long)
172+
return ids, mask.to(dtype=torch.long)
146173

147174
encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list))
148-
return encoded["input_ids"]
175+
ids = encoded["input_ids"]
176+
mask = encoded.get("attention_mask")
177+
if mask is None:
178+
mask = torch.ones_like(ids, dtype=torch.long)
179+
return ids, mask.to(dtype=torch.long)
149180

150181
def check_inputs(
151182
self,
@@ -215,6 +246,7 @@ def __call__(
215246
prompt: str | list[str] | None = None,
216247
messages: list[dict[str, str]] | None = None,
217248
input_ids: torch.LongTensor | None = None,
249+
attention_mask: torch.LongTensor | None = None,
218250
use_chat_template: bool = True,
219251
add_generation_prompt: bool = True,
220252
gen_length: int = 2048,
@@ -252,6 +284,11 @@ def __call__(
252284
when provided. Requires a tokenizer with `apply_chat_template`.
253285
input_ids (`torch.LongTensor`, *optional*):
254286
Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
287+
attention_mask (`torch.LongTensor`, *optional*):
288+
Per-token mask (1 for valid prompt tokens, 0 for padding) matching the shape of `input_ids`. Only used
289+
when `input_ids` is provided. When omitted (and `input_ids` is given), all positions are treated as
290+
valid. When constructing inputs from `prompt` / `messages`, the tokenizer's mask is carried through
291+
automatically.
255292
use_chat_template (`bool`, defaults to `True`):
256293
Whether to wrap the prompt in a chat template.
257294
add_generation_prompt (`bool`, defaults to `True`):
@@ -299,8 +336,8 @@ def __call__(
299336
Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
300337
timestep: int, callback_kwargs: Dict)`.
301338
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
302-
Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`,
303-
`confidence`, `active_block`.
339+
Tensor keys to pass to the callback. Allowed keys: `block_x`, `transfer_index`,
340+
`editing_transfer_index`, `sampled_tokens`, `sampled_probs`, `active_block`.
304341
305342
Examples:
306343
"""
@@ -328,10 +365,11 @@ def __call__(
328365
)
329366

330367
# 2. Prepare input IDs from prompt/messages/input_ids
331-
prompt_ids = self._prepare_input_ids(
368+
prompt_ids, prompt_attention_mask = self._prepare_input_ids(
332369
prompt=prompt,
333370
messages=messages,
334371
input_ids=input_ids,
372+
attention_mask=attention_mask,
335373
use_chat_template=use_chat_template,
336374
add_generation_prompt=add_generation_prompt,
337375
chat_template_kwargs=None,
@@ -342,6 +380,7 @@ def __call__(
342380
if prompt_ids.ndim == 1:
343381
prompt_ids = prompt_ids.unsqueeze(0)
344382
prompt_ids = prompt_ids.to(device=device)
383+
prompt_attention_mask = prompt_attention_mask.to(device=device)
345384
batch_size, prompt_length = prompt_ids.shape
346385

347386
if eos_token_id is None:
@@ -353,14 +392,18 @@ def __call__(
353392

354393
num_inference_steps = min(num_inference_steps, gen_length // minimal_topk)
355394

356-
self.scheduler.set_timesteps(num_inference_steps, device=device)
395+
self.scheduler.set_timesteps(num_inference_steps, device=device, block_length=block_length)
357396

358397
# 3. Build attention mask and position IDs
359398
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
360399
total_length = num_blocks * block_length
361400

362-
# 2D attention mask (no padding) — the model handles backend-specific conversion internally.
363-
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long)
401+
# 2D attention mask: prompt tokenizer mask + ones over generated positions + zeros over the
402+
# block-aligned tail past `prompt_length + gen_length`. The model handles backend-specific
403+
# conversion internally; this just tells it which positions are real context.
404+
attn_mask = torch.zeros((batch_size, total_length), device=device, dtype=torch.long)
405+
attn_mask[:, :prompt_length] = prompt_attention_mask
406+
attn_mask[:, prompt_length : prompt_length + gen_length] = 1
364407

365408
position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
366409

@@ -377,9 +420,8 @@ def __call__(
377420
global_step = 0
378421

379422
# 5. Block-wise refinement loop
380-
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
381-
block_progress_bar_config["position"] = 0
382-
block_progress_bar_config["desc"] = "Blocks"
423+
outer_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
424+
block_progress_bar_config = {**outer_progress_bar_config, "position": 0, "desc": "Blocks"}
383425
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config):
384426
current_window_end = (num_block + 1) * block_length
385427
block_x = x[:, :current_window_end]
@@ -396,8 +438,13 @@ def __call__(
396438
post_steps = 0
397439
step_idx = 0
398440
should_continue = True
399-
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps")
400-
progress_bar = self.progress_bar(total=num_inference_steps)
441+
inner_progress_bar_config = {
442+
**outer_progress_bar_config,
443+
"position": 1,
444+
"leave": False,
445+
"desc": f"Block {num_block} Inference Steps",
446+
}
447+
progress_bar = tqdm(total=num_inference_steps, **inner_progress_bar_config)
401448

402449
while should_continue:
403450
block_tokens = block_x[:, -block_length:]
@@ -428,10 +475,19 @@ def __call__(
428475

429476
transfer_index = scheduler_output.transfer_index
430477
editing_transfer_index = scheduler_output.editing_transfer_index
478+
sampled_tokens = scheduler_output.sampled_tokens
479+
sampled_probs = scheduler_output.sampled_probs
480+
active_block = block_tokens == mask_token_id
431481
final_transfer = transfer_index | editing_transfer_index
432482

483+
# Freeze rows that already emitted EOS so further blocks don't extend them.
484+
if eos_early_stop and finished.any():
485+
final_transfer = final_transfer & ~finished[:, None]
486+
433487
if final_transfer.any():
434-
block_x[:, -block_length:] = scheduler_output.prev_sample
488+
block_x[:, -block_length:] = torch.where(
489+
final_transfer, scheduler_output.prev_sample, block_tokens
490+
)
435491

436492
if eos_early_stop and eos_token_id is not None:
437493
finished = self.scheduler.check_eos_finished(
@@ -474,14 +530,21 @@ def __call__(
474530
# 6. Post-process output
475531
generated = x[:, : prompt_length + gen_length]
476532
sequences = generated[:, prompt_length:]
477-
if eos_token_id is not None and batch_size == 1:
478-
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0]
479-
if len(eos_positions) > 0:
480-
sequences = sequences[:, : int(eos_positions[0].item()) + 1]
533+
534+
# For decode, trim each row at the first EOS so post-EOS positions (which may still hold
535+
# mask tokens or refined content for unfinished blocks) don't leak into the decoded text.
536+
decode_sequences: list[torch.LongTensor] | torch.LongTensor = sequences
537+
if eos_token_id is not None:
538+
decode_sequences = [
539+
seq[: int((seq == eos_token_id).nonzero(as_tuple=True)[0][0]) + 1]
540+
if (seq == eos_token_id).any()
541+
else seq
542+
for seq in sequences
543+
]
481544

482545
texts = None
483546
if output_type == "text" and self.tokenizer is not None:
484-
texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
547+
texts = self.tokenizer.batch_decode(decode_sequences, skip_special_tokens=True)
485548

486549
if not return_dict:
487550
return sequences.to(device=device), texts

src/diffusers/schedulers/scheduling_block_refinement.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,21 @@ def __init__(
7575
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long)
7676
self._transfer_schedule: torch.LongTensor | None = None
7777

78-
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
78+
def set_timesteps(
79+
self,
80+
num_inference_steps: int,
81+
device: str | torch.device | None = None,
82+
block_length: int | None = None,
83+
) -> None:
7984
if num_inference_steps <= 0:
8085
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
86+
if block_length is None:
87+
block_length = self.config.block_length
88+
elif block_length <= 0:
89+
raise ValueError(f"`block_length` must be > 0, got {block_length}.")
8190
self.num_inference_steps = num_inference_steps
8291
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
83-
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to(
92+
self._transfer_schedule = self.get_num_transfer_tokens(block_length, self.num_inference_steps).to(
8493
device=device if device is not None else "cpu"
8594
)
8695

@@ -343,7 +352,8 @@ def check_eos_finished(
343352
if len(eos_pos[0]) == 0:
344353
continue
345354
eos_pos = int(eos_pos[0][0].item())
346-
if prompt_length >= eos_pos:
355+
# The first generated token sits at index `prompt_length`; allow EOS there.
356+
if eos_pos < prompt_length:
347357
continue
348358
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item():
349359
finished[b] = True

0 commit comments

Comments
 (0)