@@ -284,6 +284,34 @@ def get_num_steps_confidence(self, xt: Tensor, num_tokens_unmask: int = 1):
284284 else :
285285 return int (max (math .ceil (nsteps // num_tokens_unmask ), 1 ))
286286
287+ def step_auto_regressive (
288+ self ,
289+ logits : Tensor ,
290+ xt : Tensor ,
291+ logit_temperature : float = 1.0 ,
292+ ):
293+ """Auto-regressive sampling from MDLM.
294+
295+ This method samples from the predicted logits and replaces the next token in the sequence.
296+ The next token is chosen based on the predicted logits and the current state of the sequence.
297+ """
298+ xt = xt .clone ()
299+ log_p_x0 = self ._subs_parameterization (logits , xt )
300+ # sample the code from the softmax prediction
301+ probs = torch .softmax (log_p_x0 / logit_temperature , dim = - 1 )
302+ preds = torch .distributions .Categorical (probs = probs ).sample ()
303+
304+ # do not predict on already predicted tokens
305+ mask = xt == self .mask_index
306+
307+ next_idx = torch .where (mask )[1 ][0 ]
308+ to_replace = torch .zeros_like (xt )
309+ to_replace [:, next_idx ] = 1
310+ to_replace = (mask .float () * to_replace .float ()).bool ()
311+
312+ xt [to_replace ] = preds [to_replace ]
313+ return xt
314+
287315 def step_confidence (
288316 self ,
289317 logits : Tensor ,
@@ -344,10 +372,87 @@ def step_confidence(
344372 # choose the predicted token with the highest confidence
345373 confidence_threshold , idx_mask = torch .topk (confidence , k = num_tokens_unmask , dim = - 1 )
346374 confidence_threshold = confidence_threshold [:, - 1 ].unsqueeze (- 1 )
375+ # Rather than take all tokens with confidence above the threshold, we use the topk indices to replace the chosen tokens
376+ # replace the chosen tokens
377+ to_replace = torch .zeros_like (confidence )
378+ to_replace .scatter_ (1 , idx_mask , 1 )
379+ to_replace = to_replace .bool () & mask .bool ()
380+ # to_replace = confidence >= confidence_threshold
381+ # to_replace = (mask.float() * to_replace.float()).bool()
382+ xt [to_replace ] = preds [to_replace ]
383+ return xt
384+
385+ def step_confidence_margin (
386+ self ,
387+ logits : Tensor ,
388+ xt : Tensor ,
389+ curr_step : int ,
390+ num_steps : int ,
391+ logit_temperature : float = 1.0 ,
392+ randomness : float = 1.0 ,
393+ confidence_temperature : float = 1.0 ,
394+ num_tokens_unmask : int = 1 ,
395+ ) -> Tensor :
396+ """Kim et al., Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions, ICML, 2025.
347397
398+ This method is similar to step_confidence, but it uses the margin of the confidence scores to replace the tokens.
399+ The margin is the difference between the confidence score and the next highest confidence score.
400+ The tokens with the highest margin are replaced.
401+
402+ Args:
403+ logits: Predicted logits
404+ xt: Input sequence
405+ curr_step: Current step
406+ num_steps: Total number of steps
407+ logit_temperature: Temperature for softmax over logits
408+ randomness: Scale for Gumbel noise
409+ confidence_temperature: Temperature for Gumbel confidence
410+ num_tokens_unmask: number of tokens to unmask each step
411+
412+ Returns:
413+ Updated input sequence xt unmasking num_tokens_unmask token each step.
414+ """
415+ if xt .ndim > 3 :
416+ raise NotImplementedError (
417+ "step_confidence is implemented for Batch x Sequence x State Space shaped tensors."
418+ )
419+ if curr_step < 0 or num_steps < 1 or num_tokens_unmask < 1 :
420+ raise ValueError ("Invalid input values for curr_step, num_steps, or num_tokens_unmask." )
421+ xt = xt .clone ()
422+ log_p_x0 = self ._subs_parameterization (logits , xt )
423+ # sample the code from the softmax prediction
424+ probs = torch .softmax (log_p_x0 / logit_temperature , dim = - 1 )
425+ preds = torch .stack ([torch .multinomial (prob , num_samples = 2 , replacement = False ) for prob in probs ])
426+ confidence_first = probs .gather (- 1 , preds [:, :, 0 ].unsqueeze (- 1 )).squeeze (- 1 )
427+ confidence_second = probs .gather (- 1 , preds [:, :, 1 ].unsqueeze (- 1 )).squeeze (- 1 )
428+ confidence = confidence_first - confidence_second
429+ preds = preds [:, :, 0 ]
430+ # add Gumbel noise decreasing over the sampling process
431+ ratio = curr_step / (num_steps - 1 )
432+ # Using manual definition of 0,1 Gumbel to pass in generator, manually specifying the device is faster than transfer
433+ gumbel_sample = - torch .log (
434+ - torch .log (torch .rand (xt .shape , device = logits .device , generator = self .rng_generator ))
435+ )
436+ # gumbel_sample = self.gumbel_dist.sample(xt.shape).to(logits.device)
437+ gumbel_noise = gumbel_sample * randomness * (1 - ratio ) # type: ignore
438+ confidence = (
439+ (torch .log (confidence ) + gumbel_noise ) / confidence_temperature
440+ ) # stems from tau of https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
441+
442+ # do not predict on already predicted tokens
443+ mask = xt == self .mask_index
444+ confidence [~ mask ] = - torch .inf
445+
446+ # choose the predicted token with the highest confidence
447+ confidence_threshold , idx_mask = torch .topk (confidence , k = num_tokens_unmask , dim = - 1 )
448+ confidence_threshold = confidence_threshold [:, - 1 ].unsqueeze (- 1 )
449+ # Rather than take all tokens with confidence above the threshold, we use the topk indices to replace the chosen tokens
348450 # replace the chosen tokens
349- to_replace = confidence >= confidence_threshold
350- to_replace = (mask .float () * to_replace .float ()).bool ()
451+ to_replace = torch .zeros_like (confidence )
452+ to_replace .scatter_ (1 , idx_mask , 1 )
453+ to_replace = to_replace .bool () & mask .bool ()
454+ # to_replace = confidence >= confidence_threshold
455+ # to_replace = (mask.float() * to_replace.float()).bool()
351456 xt [to_replace ] = preds [to_replace ]
352457 return xt
353458
0 commit comments