@@ -71,7 +71,14 @@ class LLaDA2Pipeline(DiffusionPipeline):
7171 scheduler : BlockRefinementScheduler
7272 tokenizer : Any
7373
74- _callback_tensor_inputs = ["block_x" , "x0" , "x0_p" , "transfer_index" , "confidence" , "active_block" ]
74+ _callback_tensor_inputs = [
75+ "block_x" ,
76+ "transfer_index" ,
77+ "editing_transfer_index" ,
78+ "sampled_tokens" ,
79+ "sampled_probs" ,
80+ "active_block" ,
81+ ]
7582
7683 def __init__ (
7784 self ,
@@ -99,16 +106,28 @@ def _prepare_input_ids(
99106 use_chat_template : bool ,
100107 add_generation_prompt : bool ,
101108 chat_template_kwargs : dict [str , Any ] | None ,
102- ) -> torch .LongTensor :
103- """Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
109+ attention_mask : torch .LongTensor | None = None ,
110+ ) -> tuple [torch .LongTensor , torch .LongTensor ]:
111+ """Convert prompt/messages/input_ids to `(input_ids, attention_mask)` tensors of shape `[batch, seq]`."""
104112 if input_ids is not None :
105113 if input_ids .ndim == 1 :
106114 input_ids = input_ids .unsqueeze (0 )
107115 if input_ids .ndim != 2 :
108116 raise ValueError (f"`input_ids` must be 2D, got shape { tuple (input_ids .shape )} ." )
109117 if input_ids .dtype != torch .long :
110118 raise ValueError (f"`input_ids` must be int64 token IDs, got dtype={ input_ids .dtype } ." )
111- return input_ids
119+ if attention_mask is None :
120+ attention_mask = torch .ones_like (input_ids , dtype = torch .long )
121+ else :
122+ if attention_mask .ndim == 1 :
123+ attention_mask = attention_mask .unsqueeze (0 )
124+ if attention_mask .shape != input_ids .shape :
125+ raise ValueError (
126+ f"`attention_mask` shape { tuple (attention_mask .shape )} must match `input_ids` shape "
127+ f"{ tuple (input_ids .shape )} ."
128+ )
129+ attention_mask = attention_mask .to (dtype = torch .long )
130+ return input_ids , attention_mask
112131
113132 if self .tokenizer is None :
114133 raise ValueError ("Tokenizer is required when `input_ids` is not provided." )
@@ -129,7 +148,11 @@ def _prepare_input_ids(
129148 return_dict = True ,
130149 ** chat_template_kwargs ,
131150 )
132- return encoded ["input_ids" ]
151+ ids = encoded ["input_ids" ]
152+ mask = encoded .get ("attention_mask" )
153+ if mask is None :
154+ mask = torch .ones_like (ids , dtype = torch .long )
155+ return ids , mask .to (dtype = torch .long )
133156
134157 if use_chat_template and getattr (self .tokenizer , "chat_template" , None ):
135158 if isinstance (prompt , list ):
@@ -142,10 +165,18 @@ def _prepare_input_ids(
142165 return_dict = True ,
143166 ** chat_template_kwargs ,
144167 )
145- return encoded ["input_ids" ]
168+ ids = encoded ["input_ids" ]
169+ mask = encoded .get ("attention_mask" )
170+ if mask is None :
171+ mask = torch .ones_like (ids , dtype = torch .long )
172+ return ids , mask .to (dtype = torch .long )
146173
147174 encoded = self .tokenizer (prompt , return_tensors = "pt" , padding = isinstance (prompt , list ))
148- return encoded ["input_ids" ]
175+ ids = encoded ["input_ids" ]
176+ mask = encoded .get ("attention_mask" )
177+ if mask is None :
178+ mask = torch .ones_like (ids , dtype = torch .long )
179+ return ids , mask .to (dtype = torch .long )
149180
150181 def check_inputs (
151182 self ,
@@ -215,10 +246,11 @@ def __call__(
215246 prompt : str | list [str ] | None = None ,
216247 messages : list [dict [str , str ]] | None = None ,
217248 input_ids : torch .LongTensor | None = None ,
249+ attention_mask : torch .LongTensor | None = None ,
218250 use_chat_template : bool = True ,
219251 add_generation_prompt : bool = True ,
220252 gen_length : int = 2048 ,
221- block_length : int = 32 ,
253+ block_length : int | None = None ,
222254 num_inference_steps : int = 32 ,
223255 temperature : float = 0.0 ,
224256 top_p : float | None = None ,
@@ -252,14 +284,19 @@ def __call__(
252284 when provided. Requires a tokenizer with `apply_chat_template`.
253285 input_ids (`torch.LongTensor`, *optional*):
254286 Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
287+ attention_mask (`torch.LongTensor`, *optional*):
288+ Per-token mask (1 for valid prompt tokens, 0 for padding) matching the shape of `input_ids`. Only used
289+ when `input_ids` is provided. When omitted (and `input_ids` is given), all positions are treated as
290+ valid. When constructing inputs from `prompt` / `messages`, the tokenizer's mask is carried through
291+ automatically.
255292 use_chat_template (`bool`, defaults to `True`):
256293 Whether to wrap the prompt in a chat template.
257294 add_generation_prompt (`bool`, defaults to `True`):
258295 Whether to add the generation prompt when using chat templates.
259296 gen_length (`int`):
260297 Number of tokens to generate.
261- block_length (`int`):
262- Block size for refinement.
298+ block_length (`int`, *optional* ):
299+ Block size for refinement. If not provided, the scheduler's configured `block_length` is used.
263300 num_inference_steps (`int`):
264301 Number of refinement steps per block.
265302 temperature (`float`):
@@ -299,8 +336,8 @@ def __call__(
299336 Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
300337 timestep: int, callback_kwargs: Dict)`.
301338 callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
302- Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, ` transfer_index`,
303- `confidence `, `active_block`.
339+ Tensor keys to pass to the callback. Allowed keys: `block_x`, `transfer_index`,
340+ `editing_transfer_index`, `sampled_tokens`, `sampled_probs `, `active_block`.
304341
305342 Examples:
306343 """
@@ -312,6 +349,9 @@ def __call__(
312349 if callback_on_step_end_tensor_inputs is None :
313350 callback_on_step_end_tensor_inputs = ["block_x" ]
314351
352+ if block_length is None :
353+ block_length = self .scheduler .config .block_length
354+
315355 self .check_inputs (
316356 prompt = prompt ,
317357 messages = messages ,
@@ -328,10 +368,11 @@ def __call__(
328368 )
329369
330370 # 2. Prepare input IDs from prompt/messages/input_ids
331- prompt_ids = self ._prepare_input_ids (
371+ prompt_ids , prompt_attention_mask = self ._prepare_input_ids (
332372 prompt = prompt ,
333373 messages = messages ,
334374 input_ids = input_ids ,
375+ attention_mask = attention_mask ,
335376 use_chat_template = use_chat_template ,
336377 add_generation_prompt = add_generation_prompt ,
337378 chat_template_kwargs = None ,
@@ -342,6 +383,7 @@ def __call__(
342383 if prompt_ids .ndim == 1 :
343384 prompt_ids = prompt_ids .unsqueeze (0 )
344385 prompt_ids = prompt_ids .to (device = device )
386+ prompt_attention_mask = prompt_attention_mask .to (device = device )
345387 batch_size , prompt_length = prompt_ids .shape
346388
347389 if eos_token_id is None :
@@ -353,14 +395,18 @@ def __call__(
353395
354396 num_inference_steps = min (num_inference_steps , gen_length // minimal_topk )
355397
356- self .scheduler .set_timesteps (num_inference_steps , device = device )
398+ self .scheduler .set_timesteps (num_inference_steps , device = device , block_length = block_length )
357399
358400 # 3. Build attention mask and position IDs
359401 num_blocks = (prompt_length + gen_length + block_length - 1 ) // block_length
360402 total_length = num_blocks * block_length
361403
362- # 2D attention mask (no padding) — the model handles backend-specific conversion internally.
363- attn_mask = torch .ones ((batch_size , total_length ), device = device , dtype = torch .long )
404+ # 2D attention mask: prompt tokenizer mask + ones over generated positions + zeros over the
405+ # block-aligned tail past `prompt_length + gen_length`. The model handles backend-specific
406+ # conversion internally; this just tells it which positions are real context.
407+ attn_mask = torch .zeros ((batch_size , total_length ), device = device , dtype = torch .long )
408+ attn_mask [:, :prompt_length ] = prompt_attention_mask
409+ attn_mask [:, prompt_length : prompt_length + gen_length ] = 1
364410
365411 position_ids = torch .arange (total_length , device = device , dtype = torch .long ).unsqueeze (0 ).expand (batch_size , - 1 )
366412
@@ -377,9 +423,8 @@ def __call__(
377423 global_step = 0
378424
379425 # 5. Block-wise refinement loop
380- block_progress_bar_config = getattr (self , "_progress_bar_config" , {}).copy ()
381- block_progress_bar_config ["position" ] = 0
382- block_progress_bar_config ["desc" ] = "Blocks"
426+ outer_progress_bar_config = getattr (self , "_progress_bar_config" , {}).copy ()
427+ block_progress_bar_config = {** outer_progress_bar_config , "position" : 0 , "desc" : "Blocks" }
383428 for num_block in tqdm (range (prefill_blocks , num_blocks ), ** block_progress_bar_config ):
384429 current_window_end = (num_block + 1 ) * block_length
385430 block_x = x [:, :current_window_end ]
@@ -396,8 +441,13 @@ def __call__(
396441 post_steps = 0
397442 step_idx = 0
398443 should_continue = True
399- self .set_progress_bar_config (position = 1 , leave = False , desc = f"Block { num_block } Inference Steps" )
400- progress_bar = self .progress_bar (total = num_inference_steps )
444+ inner_progress_bar_config = {
445+ ** outer_progress_bar_config ,
446+ "position" : 1 ,
447+ "leave" : False ,
448+ "desc" : f"Block { num_block } Inference Steps" ,
449+ }
450+ progress_bar = tqdm (total = num_inference_steps , ** inner_progress_bar_config )
401451
402452 while should_continue :
403453 block_tokens = block_x [:, - block_length :]
@@ -428,10 +478,19 @@ def __call__(
428478
429479 transfer_index = scheduler_output .transfer_index
430480 editing_transfer_index = scheduler_output .editing_transfer_index
481+ sampled_tokens = scheduler_output .sampled_tokens
482+ sampled_probs = scheduler_output .sampled_probs
483+ active_block = block_tokens == mask_token_id
431484 final_transfer = transfer_index | editing_transfer_index
432485
486+ # Freeze rows that already emitted EOS so further blocks don't extend them.
487+ if eos_early_stop and finished .any ():
488+ final_transfer = final_transfer & ~ finished [:, None ]
489+
433490 if final_transfer .any ():
434- block_x [:, - block_length :] = scheduler_output .prev_sample
491+ block_x [:, - block_length :] = torch .where (
492+ final_transfer , scheduler_output .prev_sample , block_tokens
493+ )
435494
436495 if eos_early_stop and eos_token_id is not None :
437496 finished = self .scheduler .check_eos_finished (
@@ -474,14 +533,21 @@ def __call__(
474533 # 6. Post-process output
475534 generated = x [:, : prompt_length + gen_length ]
476535 sequences = generated [:, prompt_length :]
477- if eos_token_id is not None and batch_size == 1 :
478- eos_positions = (sequences [0 ] == eos_token_id ).nonzero (as_tuple = True )[0 ]
479- if len (eos_positions ) > 0 :
480- sequences = sequences [:, : int (eos_positions [0 ].item ()) + 1 ]
536+
537+ # For decode, trim each row at the first EOS so post-EOS positions (which may still hold
538+ # mask tokens or refined content for unfinished blocks) don't leak into the decoded text.
539+ decode_sequences : list [torch .LongTensor ] | torch .LongTensor = sequences
540+ if eos_token_id is not None :
541+ decode_sequences = [
542+ seq [: int ((seq == eos_token_id ).nonzero (as_tuple = True )[0 ][0 ]) + 1 ]
543+ if (seq == eos_token_id ).any ()
544+ else seq
545+ for seq in sequences
546+ ]
481547
482548 texts = None
483549 if output_type == "text" and self .tokenizer is not None :
484- texts = self .tokenizer .batch_decode (sequences , skip_special_tokens = True )
550+ texts = self .tokenizer .batch_decode (decode_sequences , skip_special_tokens = True )
485551
486552 if not return_dict :
487553 return sequences .to (device = device ), texts
0 commit comments