Skip to content

Commit 77fdb9a

Browse files
fix confidence MDLM sampling and add ar and margin (#1029)
### Description Fixing issue in MDLM confidence sampling and added two new step functions. Bumping version to 0.0.2.2 ### Type of changes <!-- Mark the relevant option with an [x] --> - [ x] Bug fix (non-breaking change which fixes an issue) - [ x] New feature (non-breaking change which adds functionality) ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ x] I have tested these changes locally - [ x] I have updated the documentation accordingly - [ x] I have added/updated tests as needed - [ x] All existing tests pass successfully
1 parent 2956b15 commit 77fdb9a

2 files changed

Lines changed: 108 additions & 3 deletions

File tree

  • sub-packages/bionemo-moco

sub-packages/bionemo-moco/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.0.2.1
1+
0.0.2.2

sub-packages/bionemo-moco/src/bionemo/moco/interpolants/continuous_time/discrete/mdlm.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,34 @@ def get_num_steps_confidence(self, xt: Tensor, num_tokens_unmask: int = 1):
284284
else:
285285
return int(max(math.ceil(nsteps // num_tokens_unmask), 1))
286286

287+
def step_auto_regressive(
288+
self,
289+
logits: Tensor,
290+
xt: Tensor,
291+
logit_temperature: float = 1.0,
292+
):
293+
"""Auto-regressive sampling from MDLM.
294+
295+
This method samples from the predicted logits and replaces the next token in the sequence.
296+
The next token is chosen based on the predicted logits and the current state of the sequence.
297+
"""
298+
xt = xt.clone()
299+
log_p_x0 = self._subs_parameterization(logits, xt)
300+
# sample the code from the softmax prediction
301+
probs = torch.softmax(log_p_x0 / logit_temperature, dim=-1)
302+
preds = torch.distributions.Categorical(probs=probs).sample()
303+
304+
# do not predict on already predicted tokens
305+
mask = xt == self.mask_index
306+
307+
next_idx = torch.where(mask)[1][0]
308+
to_replace = torch.zeros_like(xt)
309+
to_replace[:, next_idx] = 1
310+
to_replace = (mask.float() * to_replace.float()).bool()
311+
312+
xt[to_replace] = preds[to_replace]
313+
return xt
314+
287315
def step_confidence(
288316
self,
289317
logits: Tensor,
@@ -344,10 +372,87 @@ def step_confidence(
344372
# choose the predicted token with the highest confidence
345373
confidence_threshold, idx_mask = torch.topk(confidence, k=num_tokens_unmask, dim=-1)
346374
confidence_threshold = confidence_threshold[:, -1].unsqueeze(-1)
375+
# Rather than take all tokens with confidence above the threshold, we use the topk indices to replace the chosen tokens
376+
# replace the chosen tokens
377+
to_replace = torch.zeros_like(confidence)
378+
to_replace.scatter_(1, idx_mask, 1)
379+
to_replace = to_replace.bool() & mask.bool()
380+
# to_replace = confidence >= confidence_threshold
381+
# to_replace = (mask.float() * to_replace.float()).bool()
382+
xt[to_replace] = preds[to_replace]
383+
return xt
384+
385+
def step_confidence_margin(
386+
self,
387+
logits: Tensor,
388+
xt: Tensor,
389+
curr_step: int,
390+
num_steps: int,
391+
logit_temperature: float = 1.0,
392+
randomness: float = 1.0,
393+
confidence_temperature: float = 1.0,
394+
num_tokens_unmask: int = 1,
395+
) -> Tensor:
396+
"""Kim et al., Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions, ICML, 2025.
347397
398+
This method is similar to step_confidence, but it uses the margin of the confidence scores to replace the tokens.
399+
The margin is the difference between the confidence score and the next highest confidence score.
400+
The tokens with the highest margin are replaced.
401+
402+
Args:
403+
logits: Predicted logits
404+
xt: Input sequence
405+
curr_step: Current step
406+
num_steps: Total number of steps
407+
logit_temperature: Temperature for softmax over logits
408+
randomness: Scale for Gumbel noise
409+
confidence_temperature: Temperature for Gumbel confidence
410+
num_tokens_unmask: number of tokens to unmask each step
411+
412+
Returns:
413+
Updated input sequence xt unmasking num_tokens_unmask token each step.
414+
"""
415+
if xt.ndim > 3:
416+
raise NotImplementedError(
417+
"step_confidence is implemented for Batch x Sequence x State Space shaped tensors."
418+
)
419+
if curr_step < 0 or num_steps < 1 or num_tokens_unmask < 1:
420+
raise ValueError("Invalid input values for curr_step, num_steps, or num_tokens_unmask.")
421+
xt = xt.clone()
422+
log_p_x0 = self._subs_parameterization(logits, xt)
423+
# sample the code from the softmax prediction
424+
probs = torch.softmax(log_p_x0 / logit_temperature, dim=-1)
425+
preds = torch.stack([torch.multinomial(prob, num_samples=2, replacement=False) for prob in probs])
426+
confidence_first = probs.gather(-1, preds[:, :, 0].unsqueeze(-1)).squeeze(-1)
427+
confidence_second = probs.gather(-1, preds[:, :, 1].unsqueeze(-1)).squeeze(-1)
428+
confidence = confidence_first - confidence_second
429+
preds = preds[:, :, 0]
430+
# add Gumbel noise decreasing over the sampling process
431+
ratio = curr_step / (num_steps - 1)
432+
# Using manual definition of 0,1 Gumbel to pass in generator, manually specifying the device is faster than transfer
433+
gumbel_sample = -torch.log(
434+
-torch.log(torch.rand(xt.shape, device=logits.device, generator=self.rng_generator))
435+
)
436+
# gumbel_sample = self.gumbel_dist.sample(xt.shape).to(logits.device)
437+
gumbel_noise = gumbel_sample * randomness * (1 - ratio) # type: ignore
438+
confidence = (
439+
(torch.log(confidence) + gumbel_noise) / confidence_temperature
440+
) # stems from tau of https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
441+
442+
# do not predict on already predicted tokens
443+
mask = xt == self.mask_index
444+
confidence[~mask] = -torch.inf
445+
446+
# choose the predicted token with the highest confidence
447+
confidence_threshold, idx_mask = torch.topk(confidence, k=num_tokens_unmask, dim=-1)
448+
confidence_threshold = confidence_threshold[:, -1].unsqueeze(-1)
449+
# Rather than take all tokens with confidence above the threshold, we use the topk indices to replace the chosen tokens
348450
# replace the chosen tokens
349-
to_replace = confidence >= confidence_threshold
350-
to_replace = (mask.float() * to_replace.float()).bool()
451+
to_replace = torch.zeros_like(confidence)
452+
to_replace.scatter_(1, idx_mask, 1)
453+
to_replace = to_replace.bool() & mask.bool()
454+
# to_replace = confidence >= confidence_threshold
455+
# to_replace = (mask.float() * to_replace.float()).bool()
351456
xt[to_replace] = preds[to_replace]
352457
return xt
353458

0 commit comments

Comments
 (0)