Skip to content

Commit f08bbee

Browse files
committed
change where dtype is found in checkpoint export
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 1a2a4ed commit f08bbee

2 files changed

Lines changed: 9 additions & 1 deletion

File tree

sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def apply(self, output_path: Path) -> Path:
134134
)
135135
source, _ = self.nemo_load(self, trainer=trainer, cpu=cpu)
136136

137-
dtype = torch.bfloat16 if source.config.bf16 else torch.float32
137+
dtype = source.dtype
138138

139139
# Not sure why we need to do this, for some reason lm_head stays as fp32
140140
source.module.lm_head.to(dtype)

sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ def test_nemo2_conversion_equivalent_650m(tmp_path):
104104
assert_esm2_equivalence(tmp_path / "nemo_checkpoint", model_tag, atol=1e-4, rtol=1e-4)
105105

106106

107+
@pytest.mark.slow
108+
def test_nemo2_export_equivalent_650m(tmp_path):
109+
ckpt_path = load("esm2/nv_650m:2.1")
110+
output_path = io.export_ckpt(ckpt_path, "hf", tmp_path / "hf_checkpoint")
111+
with megatron_parallel_state_utils.distributed_model_parallel_state():
112+
assert_esm2_equivalence(ckpt_path, output_path, precision="bf16")
113+
114+
107115
def test_cli_nemo2_conversion_equivalent_8m(tmp_path):
108116
"""Test that the CLI conversion functions maintain model equivalence."""
109117
model_tag = "facebook/esm2_t6_8M_UR50D"

0 commit comments

Comments
 (0)