@@ -71,7 +71,14 @@ class LLaDA2Pipeline(DiffusionPipeline):
7171 scheduler : BlockRefinementScheduler
7272 tokenizer : Any
7373
74- _callback_tensor_inputs = ["block_x" , "x0" , "x0_p" , "transfer_index" , "confidence" , "active_block" ]
74+ _callback_tensor_inputs = [
75+ "block_x" ,
76+ "transfer_index" ,
77+ "editing_transfer_index" ,
78+ "sampled_tokens" ,
79+ "sampled_probs" ,
80+ "active_block" ,
81+ ]
7582
7683 def __init__ (
7784 self ,
@@ -99,16 +106,28 @@ def _prepare_input_ids(
99106 use_chat_template : bool ,
100107 add_generation_prompt : bool ,
101108 chat_template_kwargs : dict [str , Any ] | None ,
102- ) -> torch .LongTensor :
103- """Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
109+ attention_mask : torch .LongTensor | None = None ,
110+ ) -> tuple [torch .LongTensor , torch .LongTensor ]:
111+ """Convert prompt/messages/input_ids to `(input_ids, attention_mask)` tensors of shape `[batch, seq]`."""
104112 if input_ids is not None :
105113 if input_ids .ndim == 1 :
106114 input_ids = input_ids .unsqueeze (0 )
107115 if input_ids .ndim != 2 :
108116 raise ValueError (f"`input_ids` must be 2D, got shape { tuple (input_ids .shape )} ." )
109117 if input_ids .dtype != torch .long :
110118 raise ValueError (f"`input_ids` must be int64 token IDs, got dtype={ input_ids .dtype } ." )
111- return input_ids
119+ if attention_mask is None :
120+ attention_mask = torch .ones_like (input_ids , dtype = torch .long )
121+ else :
122+ if attention_mask .ndim == 1 :
123+ attention_mask = attention_mask .unsqueeze (0 )
124+ if attention_mask .shape != input_ids .shape :
125+ raise ValueError (
126+ f"`attention_mask` shape { tuple (attention_mask .shape )} must match `input_ids` shape "
127+ f"{ tuple (input_ids .shape )} ."
128+ )
129+ attention_mask = attention_mask .to (dtype = torch .long )
130+ return input_ids , attention_mask
112131
113132 if self .tokenizer is None :
114133 raise ValueError ("Tokenizer is required when `input_ids` is not provided." )
@@ -129,7 +148,11 @@ def _prepare_input_ids(
129148 return_dict = True ,
130149 ** chat_template_kwargs ,
131150 )
132- return encoded ["input_ids" ]
151+ ids = encoded ["input_ids" ]
152+ mask = encoded .get ("attention_mask" )
153+ if mask is None :
154+ mask = torch .ones_like (ids , dtype = torch .long )
155+ return ids , mask .to (dtype = torch .long )
133156
134157 if use_chat_template and getattr (self .tokenizer , "chat_template" , None ):
135158 if isinstance (prompt , list ):
@@ -142,10 +165,18 @@ def _prepare_input_ids(
142165 return_dict = True ,
143166 ** chat_template_kwargs ,
144167 )
145- return encoded ["input_ids" ]
168+ ids = encoded ["input_ids" ]
169+ mask = encoded .get ("attention_mask" )
170+ if mask is None :
171+ mask = torch .ones_like (ids , dtype = torch .long )
172+ return ids , mask .to (dtype = torch .long )
146173
147174 encoded = self .tokenizer (prompt , return_tensors = "pt" , padding = isinstance (prompt , list ))
148- return encoded ["input_ids" ]
175+ ids = encoded ["input_ids" ]
176+ mask = encoded .get ("attention_mask" )
177+ if mask is None :
178+ mask = torch .ones_like (ids , dtype = torch .long )
179+ return ids , mask .to (dtype = torch .long )
149180
150181 def check_inputs (
151182 self ,
@@ -215,6 +246,7 @@ def __call__(
215246 prompt : str | list [str ] | None = None ,
216247 messages : list [dict [str , str ]] | None = None ,
217248 input_ids : torch .LongTensor | None = None ,
249+ attention_mask : torch .LongTensor | None = None ,
218250 use_chat_template : bool = True ,
219251 add_generation_prompt : bool = True ,
220252 gen_length : int = 2048 ,
@@ -252,6 +284,11 @@ def __call__(
252284 when provided. Requires a tokenizer with `apply_chat_template`.
253285 input_ids (`torch.LongTensor`, *optional*):
254286 Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
287+ attention_mask (`torch.LongTensor`, *optional*):
288+ Per-token mask (1 for valid prompt tokens, 0 for padding) matching the shape of `input_ids`. Only used
289+ when `input_ids` is provided. When omitted (and `input_ids` is given), all positions are treated as
290+ valid. When constructing inputs from `prompt` / `messages`, the tokenizer's mask is carried through
291+ automatically.
255292 use_chat_template (`bool`, defaults to `True`):
256293 Whether to wrap the prompt in a chat template.
257294 add_generation_prompt (`bool`, defaults to `True`):
@@ -299,8 +336,8 @@ def __call__(
299336 Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
300337 timestep: int, callback_kwargs: Dict)`.
301338 callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
302- Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, ` transfer_index`,
303- `confidence `, `active_block`.
339+ Tensor keys to pass to the callback. Allowed keys: `block_x`, `transfer_index`,
340+ `editing_transfer_index`, `sampled_tokens`, `sampled_probs `, `active_block`.
304341
305342 Examples:
306343 """
@@ -328,10 +365,11 @@ def __call__(
328365 )
329366
330367 # 2. Prepare input IDs from prompt/messages/input_ids
331- prompt_ids = self ._prepare_input_ids (
368+ prompt_ids , prompt_attention_mask = self ._prepare_input_ids (
332369 prompt = prompt ,
333370 messages = messages ,
334371 input_ids = input_ids ,
372+ attention_mask = attention_mask ,
335373 use_chat_template = use_chat_template ,
336374 add_generation_prompt = add_generation_prompt ,
337375 chat_template_kwargs = None ,
@@ -342,6 +380,7 @@ def __call__(
342380 if prompt_ids .ndim == 1 :
343381 prompt_ids = prompt_ids .unsqueeze (0 )
344382 prompt_ids = prompt_ids .to (device = device )
383+ prompt_attention_mask = prompt_attention_mask .to (device = device )
345384 batch_size , prompt_length = prompt_ids .shape
346385
347386 if eos_token_id is None :
@@ -353,14 +392,18 @@ def __call__(
353392
354393 num_inference_steps = min (num_inference_steps , gen_length // minimal_topk )
355394
356- self .scheduler .set_timesteps (num_inference_steps , device = device )
395+ self .scheduler .set_timesteps (num_inference_steps , device = device , block_length = block_length )
357396
358397 # 3. Build attention mask and position IDs
359398 num_blocks = (prompt_length + gen_length + block_length - 1 ) // block_length
360399 total_length = num_blocks * block_length
361400
362- # 2D attention mask (no padding) — the model handles backend-specific conversion internally.
363- attn_mask = torch .ones ((batch_size , total_length ), device = device , dtype = torch .long )
401+ # 2D attention mask: prompt tokenizer mask + ones over generated positions + zeros over the
402+ # block-aligned tail past `prompt_length + gen_length`. The model handles backend-specific
403+ # conversion internally; this just tells it which positions are real context.
404+ attn_mask = torch .zeros ((batch_size , total_length ), device = device , dtype = torch .long )
405+ attn_mask [:, :prompt_length ] = prompt_attention_mask
406+ attn_mask [:, prompt_length : prompt_length + gen_length ] = 1
364407
365408 position_ids = torch .arange (total_length , device = device , dtype = torch .long ).unsqueeze (0 ).expand (batch_size , - 1 )
366409
@@ -377,9 +420,8 @@ def __call__(
377420 global_step = 0
378421
379422 # 5. Block-wise refinement loop
380- block_progress_bar_config = getattr (self , "_progress_bar_config" , {}).copy ()
381- block_progress_bar_config ["position" ] = 0
382- block_progress_bar_config ["desc" ] = "Blocks"
423+ outer_progress_bar_config = getattr (self , "_progress_bar_config" , {}).copy ()
424+ block_progress_bar_config = {** outer_progress_bar_config , "position" : 0 , "desc" : "Blocks" }
383425 for num_block in tqdm (range (prefill_blocks , num_blocks ), ** block_progress_bar_config ):
384426 current_window_end = (num_block + 1 ) * block_length
385427 block_x = x [:, :current_window_end ]
@@ -396,8 +438,13 @@ def __call__(
396438 post_steps = 0
397439 step_idx = 0
398440 should_continue = True
399- self .set_progress_bar_config (position = 1 , leave = False , desc = f"Block { num_block } Inference Steps" )
400- progress_bar = self .progress_bar (total = num_inference_steps )
441+ inner_progress_bar_config = {
442+ ** outer_progress_bar_config ,
443+ "position" : 1 ,
444+ "leave" : False ,
445+ "desc" : f"Block { num_block } Inference Steps" ,
446+ }
447+ progress_bar = tqdm (total = num_inference_steps , ** inner_progress_bar_config )
401448
402449 while should_continue :
403450 block_tokens = block_x [:, - block_length :]
@@ -428,10 +475,19 @@ def __call__(
428475
429476 transfer_index = scheduler_output .transfer_index
430477 editing_transfer_index = scheduler_output .editing_transfer_index
478+ sampled_tokens = scheduler_output .sampled_tokens
479+ sampled_probs = scheduler_output .sampled_probs
480+ active_block = block_tokens == mask_token_id
431481 final_transfer = transfer_index | editing_transfer_index
432482
483+ # Freeze rows that already emitted EOS so further blocks don't extend them.
484+ if eos_early_stop and finished .any ():
485+ final_transfer = final_transfer & ~ finished [:, None ]
486+
433487 if final_transfer .any ():
434- block_x [:, - block_length :] = scheduler_output .prev_sample
488+ block_x [:, - block_length :] = torch .where (
489+ final_transfer , scheduler_output .prev_sample , block_tokens
490+ )
435491
436492 if eos_early_stop and eos_token_id is not None :
437493 finished = self .scheduler .check_eos_finished (
@@ -474,14 +530,21 @@ def __call__(
474530 # 6. Post-process output
475531 generated = x [:, : prompt_length + gen_length ]
476532 sequences = generated [:, prompt_length :]
477- if eos_token_id is not None and batch_size == 1 :
478- eos_positions = (sequences [0 ] == eos_token_id ).nonzero (as_tuple = True )[0 ]
479- if len (eos_positions ) > 0 :
480- sequences = sequences [:, : int (eos_positions [0 ].item ()) + 1 ]
533+
534+ # For decode, trim each row at the first EOS so post-EOS positions (which may still hold
535+ # mask tokens or refined content for unfinished blocks) don't leak into the decoded text.
536+ decode_sequences : list [torch .LongTensor ] | torch .LongTensor = sequences
537+ if eos_token_id is not None :
538+ decode_sequences = [
539+ seq [: int ((seq == eos_token_id ).nonzero (as_tuple = True )[0 ][0 ]) + 1 ]
540+ if (seq == eos_token_id ).any ()
541+ else seq
542+ for seq in sequences
543+ ]
481544
482545 texts = None
483546 if output_type == "text" and self .tokenizer is not None :
484- texts = self .tokenizer .batch_decode (sequences , skip_special_tokens = True )
547+ texts = self .tokenizer .batch_decode (decode_sequences , skip_special_tokens = True )
485548
486549 if not return_dict :
487550 return sequences .to (device = device ), texts
0 commit comments