Skip to content

Commit 965d885

Browse files
authored
Merge branch 'main' into savitha/og2-readme-metrics-update
2 parents 4877ca2 + 34aad73 commit 965d885

2 files changed

Lines changed: 27 additions & 3 deletions

File tree

bionemo-recipes/models/amplify/src/amplify/amplify_te.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,18 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
206206
# Initialize weights and apply final processing
207207
self.post_init()
208208

209+
def get_input_embeddings(self):
210+
"""Get the input embeddings of the model."""
211+
return self.encoder
212+
213+
def set_input_embeddings(self, value: nn.Embedding):
214+
"""Set the input embeddings of the model.
215+
216+
Args:
217+
value (nn.Embedding): The input embeddings.
218+
"""
219+
self.encoder = value
220+
209221
def forward(
210222
self,
211223
input_ids,
@@ -288,6 +300,18 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
288300
config.hidden_size, config.vocab_size, params_dtype=config.dtype
289301
)
290302

303+
def get_input_embeddings(self):
304+
"""Get the input embeddings of the model."""
305+
return self.amplify.get_input_embeddings()
306+
307+
def set_input_embeddings(self, value: nn.Embedding):
308+
"""Set the input embeddings of the model.
309+
310+
Args:
311+
value (nn.Embedding): The input embeddings.
312+
"""
313+
self.amplify.set_input_embeddings(value)
314+
291315
def forward(
292316
self,
293317
input_ids,

bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat
257257
ref_val = reference_logits_step_10.flatten()[max_idx].item()
258258
reload_val = reloaded_logits_step_5.flatten()[max_idx].item()
259259

260-
# BF16 tolerance: max diff of ~0.013 is normal for BF16 after 10 training steps
261-
# Using atol=0.015 to account for BF16 precision limitations
262-
assert torch.allclose(reference_logits_step_10, reloaded_logits_step_5, rtol=1e-2, atol=1.5e-2), (
260+
# BF16 tolerance: max diff of ~0.017 is normal for BF16 after 10 training steps
261+
# Using atol=0.02 to account for BF16 precision limitations
262+
assert torch.allclose(reference_logits_step_10, reloaded_logits_step_5, rtol=1e-2, atol=2.0e-2), (
263263
f"Logits don't match - max abs diff: {max_diff:.6f}, mean abs diff: {mean_diff:.6f}\n"
264264
f"Max diff at position {max_idx_tuple}: reference={ref_val:.6f}, reloaded={reload_val:.6f}"
265265
)

0 commit comments

Comments
 (0)