Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_joyimage"] = [
"JoyImageEditTransformer3DModel",
"JoyImageEditTransformer3DModel"
]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"]
Expand Down
111 changes: 87 additions & 24 deletions src/diffusers/pipelines/llada2/pipeline_llada2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,14 @@ class LLaDA2Pipeline(DiffusionPipeline):
scheduler: BlockRefinementScheduler
tokenizer: Any

_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]
_callback_tensor_inputs = [
"block_x",
"transfer_index",
"editing_transfer_index",
"sampled_tokens",
"sampled_probs",
"active_block",
]

def __init__(
self,
Expand Down Expand Up @@ -99,16 +106,28 @@ def _prepare_input_ids(
use_chat_template: bool,
add_generation_prompt: bool,
chat_template_kwargs: dict[str, Any] | None,
) -> torch.LongTensor:
"""Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
attention_mask: torch.LongTensor | None = None,
) -> tuple[torch.LongTensor, torch.LongTensor]:
"""Convert prompt/messages/input_ids to `(input_ids, attention_mask)` tensors of shape `[batch, seq]`."""
if input_ids is not None:
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
if input_ids.ndim != 2:
raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.")
if input_ids.dtype != torch.long:
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
return input_ids
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
else:
if attention_mask.ndim == 1:
attention_mask = attention_mask.unsqueeze(0)
if attention_mask.shape != input_ids.shape:
raise ValueError(
f"`attention_mask` shape {tuple(attention_mask.shape)} must match `input_ids` shape "
f"{tuple(input_ids.shape)}."
)
attention_mask = attention_mask.to(dtype=torch.long)
return input_ids, attention_mask

if self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
Expand All @@ -129,7 +148,11 @@ def _prepare_input_ids(
return_dict=True,
**chat_template_kwargs,
)
return encoded["input_ids"]
ids = encoded["input_ids"]
mask = encoded.get("attention_mask")
if mask is None:
mask = torch.ones_like(ids, dtype=torch.long)
return ids, mask.to(dtype=torch.long)

if use_chat_template and getattr(self.tokenizer, "chat_template", None):
if isinstance(prompt, list):
Expand All @@ -142,10 +165,18 @@ def _prepare_input_ids(
return_dict=True,
**chat_template_kwargs,
)
return encoded["input_ids"]
ids = encoded["input_ids"]
mask = encoded.get("attention_mask")
if mask is None:
mask = torch.ones_like(ids, dtype=torch.long)
return ids, mask.to(dtype=torch.long)

encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list))
return encoded["input_ids"]
ids = encoded["input_ids"]
mask = encoded.get("attention_mask")
if mask is None:
mask = torch.ones_like(ids, dtype=torch.long)
return ids, mask.to(dtype=torch.long)

def check_inputs(
self,
Expand Down Expand Up @@ -215,6 +246,7 @@ def __call__(
prompt: str | list[str] | None = None,
messages: list[dict[str, str]] | None = None,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.LongTensor | None = None,
use_chat_template: bool = True,
add_generation_prompt: bool = True,
gen_length: int = 2048,
Expand Down Expand Up @@ -252,6 +284,11 @@ def __call__(
when provided. Requires a tokenizer with `apply_chat_template`.
input_ids (`torch.LongTensor`, *optional*):
Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
attention_mask (`torch.LongTensor`, *optional*):
Per-token mask (1 for valid prompt tokens, 0 for padding) matching the shape of `input_ids`. Only used
when `input_ids` is provided. When omitted (and `input_ids` is given), all positions are treated as
valid. When constructing inputs from `prompt` / `messages`, the tokenizer's mask is carried through
automatically.
use_chat_template (`bool`, defaults to `True`):
Whether to wrap the prompt in a chat template.
add_generation_prompt (`bool`, defaults to `True`):
Expand Down Expand Up @@ -299,8 +336,8 @@ def __call__(
Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
timestep: int, callback_kwargs: Dict)`.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`,
`confidence`, `active_block`.
Tensor keys to pass to the callback. Allowed keys: `block_x`, `transfer_index`,
`editing_transfer_index`, `sampled_tokens`, `sampled_probs`, `active_block`.

Examples:
"""
Expand Down Expand Up @@ -328,10 +365,11 @@ def __call__(
)

# 2. Prepare input IDs from prompt/messages/input_ids
prompt_ids = self._prepare_input_ids(
prompt_ids, prompt_attention_mask = self._prepare_input_ids(
prompt=prompt,
messages=messages,
input_ids=input_ids,
attention_mask=attention_mask,
use_chat_template=use_chat_template,
add_generation_prompt=add_generation_prompt,
chat_template_kwargs=None,
Expand All @@ -342,6 +380,7 @@ def __call__(
if prompt_ids.ndim == 1:
prompt_ids = prompt_ids.unsqueeze(0)
prompt_ids = prompt_ids.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
batch_size, prompt_length = prompt_ids.shape

if eos_token_id is None:
Expand All @@ -353,14 +392,18 @@ def __call__(

num_inference_steps = min(num_inference_steps, gen_length // minimal_topk)

self.scheduler.set_timesteps(num_inference_steps, device=device)
self.scheduler.set_timesteps(num_inference_steps, device=device, block_length=block_length)

# 3. Build attention mask and position IDs
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
total_length = num_blocks * block_length

# 2D attention mask (no padding) — the model handles backend-specific conversion internally.
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long)
# 2D attention mask: prompt tokenizer mask + ones over generated positions + zeros over the
# block-aligned tail past `prompt_length + gen_length`. The model handles backend-specific
# conversion internally; this just tells it which positions are real context.
attn_mask = torch.zeros((batch_size, total_length), device=device, dtype=torch.long)
attn_mask[:, :prompt_length] = prompt_attention_mask
attn_mask[:, prompt_length : prompt_length + gen_length] = 1

position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)

Expand All @@ -377,9 +420,8 @@ def __call__(
global_step = 0

# 5. Block-wise refinement loop
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
block_progress_bar_config["position"] = 0
block_progress_bar_config["desc"] = "Blocks"
outer_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
block_progress_bar_config = {**outer_progress_bar_config, "position": 0, "desc": "Blocks"}
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config):
current_window_end = (num_block + 1) * block_length
block_x = x[:, :current_window_end]
Expand All @@ -396,8 +438,13 @@ def __call__(
post_steps = 0
step_idx = 0
should_continue = True
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps")
progress_bar = self.progress_bar(total=num_inference_steps)
inner_progress_bar_config = {
**outer_progress_bar_config,
"position": 1,
"leave": False,
"desc": f"Block {num_block} Inference Steps",
}
progress_bar = tqdm(total=num_inference_steps, **inner_progress_bar_config)

while should_continue:
block_tokens = block_x[:, -block_length:]
Expand Down Expand Up @@ -428,10 +475,19 @@ def __call__(

transfer_index = scheduler_output.transfer_index
editing_transfer_index = scheduler_output.editing_transfer_index
sampled_tokens = scheduler_output.sampled_tokens
sampled_probs = scheduler_output.sampled_probs
active_block = block_tokens == mask_token_id
final_transfer = transfer_index | editing_transfer_index

# Freeze rows that already emitted EOS so further blocks don't extend them.
if eos_early_stop and finished.any():
final_transfer = final_transfer & ~finished[:, None]

if final_transfer.any():
block_x[:, -block_length:] = scheduler_output.prev_sample
block_x[:, -block_length:] = torch.where(
final_transfer, scheduler_output.prev_sample, block_tokens
)

if eos_early_stop and eos_token_id is not None:
finished = self.scheduler.check_eos_finished(
Expand Down Expand Up @@ -474,14 +530,21 @@ def __call__(
# 6. Post-process output
generated = x[:, : prompt_length + gen_length]
sequences = generated[:, prompt_length:]
if eos_token_id is not None and batch_size == 1:
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0]
if len(eos_positions) > 0:
sequences = sequences[:, : int(eos_positions[0].item()) + 1]

# For decode, trim each row at the first EOS so post-EOS positions (which may still hold
# mask tokens or refined content for unfinished blocks) don't leak into the decoded text.
decode_sequences: list[torch.LongTensor] | torch.LongTensor = sequences
if eos_token_id is not None:
decode_sequences = [
seq[: int((seq == eos_token_id).nonzero(as_tuple=True)[0][0]) + 1]
if (seq == eos_token_id).any()
else seq
for seq in sequences
]

texts = None
if output_type == "text" and self.tokenizer is not None:
texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
texts = self.tokenizer.batch_decode(decode_sequences, skip_special_tokens=True)

if not return_dict:
return sequences.to(device=device), texts
Expand Down
16 changes: 13 additions & 3 deletions src/diffusers/schedulers/scheduling_block_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,21 @@ def __init__(
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long)
self._transfer_schedule: torch.LongTensor | None = None

def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
def set_timesteps(
self,
num_inference_steps: int,
device: str | torch.device | None = None,
block_length: int | None = None,
) -> None:
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
if block_length is None:
block_length = self.config.block_length
elif block_length <= 0:
raise ValueError(f"`block_length` must be > 0, got {block_length}.")
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to(
self._transfer_schedule = self.get_num_transfer_tokens(block_length, self.num_inference_steps).to(
device=device if device is not None else "cpu"
)

Expand Down Expand Up @@ -343,7 +352,8 @@ def check_eos_finished(
if len(eos_pos[0]) == 0:
continue
eos_pos = int(eos_pos[0][0].item())
if prompt_length >= eos_pos:
# The first generated token sits at index `prompt_length`; allow EOS there.
if eos_pos < prompt_length:
continue
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item():
finished[b] = True
Expand Down
Loading
Loading