Skip to content

Commit 2360734

Browse files
authored
Update for transformers 5.0 (#1436)
Currently blocked by huggingface/transformers#43510 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added checkpoint loading validation test. * Enhanced distributed testing infrastructure with improved port configuration. * **Improvements** * Removed upper-bound version constraints on the transformer library for broader compatibility. * Improved weight initialization and checkpoint state management for TransformerEngine-optimized models. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 5778321 commit 2360734

28 files changed

Lines changed: 264 additions & 88 deletions

File tree

.devcontainer/recipes/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
accelerate
1+
accelerate @ git+https://github.com/huggingface/accelerate.git # Until huggingface/accelerate#3852 is released.
22
datasets
33
deepspeed
44
hydra-core
@@ -12,7 +12,7 @@ torchdata
1212
torchmetrics
1313
tqdm
1414
transformer_engine
15-
transformers<5.0
15+
transformers
1616
typer
1717
wandb
1818
zstandard

bionemo-recipes/models/amplify/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
"pytest",
1919
"torch==2.6.0a0+ecf3bae40a.nv25.01",
2020
"transformer_engine[pytorch]",
21-
"transformers<5.0",
21+
"transformers<5.0", # TODO(BIO-143): update AMPLIFY to support Transformers v5
2222
"xformers",
2323
]
2424

bionemo-recipes/models/esm2/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
"torch",
2020
"torchao!=0.14.0",
2121
"transformer_engine[pytorch]",
22-
"transformers<5.0",
22+
"transformers",
2323
]
2424

2525

bionemo-recipes/models/esm2/src/esm/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module:
131131
],
132132
)
133133

134-
output_model.tie_weights()
134+
output_model.post_init()
135135

136136
# Note: contact_head parameters are not preserved in TE models
137137
# They are lost during HF -> TE conversion and cannot be recovered

bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
Adapted from `modeling_esm.py` in huggingface/transformers.
2323
"""
2424

25-
from typing import Literal, Optional, Unpack
25+
from typing import ClassVar, Literal, Optional, Unpack
2626

2727
# TODO: put import guard around transformer_engine here, with an informative error message around
2828
# installation and the nvidia docker container.
@@ -256,10 +256,34 @@ def init_empty_weights(self):
256256
# Meta-device init seems to break weight tying, so we re-tie the weights here.
257257
self.tie_weights()
258258

259-
@classmethod
260-
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
261-
"""Override the default get_init_context method to allow for fp8 model initialization."""
262-
return []
259+
def _init_weights(self, module):
260+
"""Initialize module weights.
261+
262+
We only use this method for standard pytorch modules, TE modules handle their own weight initialization through
263+
`init_method` parameters and the `reset_parameters` method.
264+
"""
265+
if module.__module__.startswith("transformer_engine.pytorch"):
266+
# Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will
267+
# assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking
268+
# `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and
269+
# `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the
270+
# weights are not in fp8. We still need to figure out why this raises an error if we're using
271+
# `quantized_model_init`.
272+
if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False):
273+
module.reset_parameters()
274+
return
275+
276+
super()._init_weights(module)
277+
278+
def state_dict(self, *args, **kwargs):
279+
"""Override state_dict to filter out TransformerEngine's _extra_state keys.
280+
281+
TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading.
282+
These are filtered out to ensure checkpoints can be loaded with from_pretrained().
283+
"""
284+
state_dict = super().state_dict(*args, **kwargs)
285+
# Filter out _extra_state keys which are TransformerEngine-specific and not loadable
286+
return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")}
263287

264288

265289
class NVEsmModel(NVEsmPreTrainedModel):
@@ -367,7 +391,7 @@ def forward(
367391
class NVEsmForMaskedLM(NVEsmPreTrainedModel):
368392
"""NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling."""
369393

370-
_tied_weights_keys = ("lm_head.decoder.weight",)
394+
_tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"}
371395

372396
def __init__(self, config: NVEsmConfig):
373397
"""Initialize a NVEsmForMaskedLM.
@@ -386,7 +410,6 @@ def __init__(self, config: NVEsmConfig):
386410
self.esm = NVEsmModel(config, add_pooling_layer=False)
387411
self.lm_head = NVEsmLMHead(config)
388412

389-
self.init_weights()
390413
self.post_init()
391414

392415
def get_output_embeddings(self):
@@ -614,7 +637,6 @@ def __init__(self, config):
614637
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
615638
)
616639

617-
self.init_weights()
618640
self.post_init()
619641

620642
def forward(

bionemo-recipes/models/esm2/tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import importlib
1717
import os
18+
import socket
1819

1920
import pytest
2021
import transformer_engine.pytorch
@@ -32,6 +33,16 @@
3233
os.environ["TRITON_LIBCUDA_PATH"] = "/usr/local/cuda/lib64"
3334

3435

36+
@pytest.fixture
37+
def unused_tcp_port():
38+
"""Find and return an unused TCP port for torchrun rendezvous."""
39+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
40+
s.bind(("", 0))
41+
s.listen(1)
42+
port = s.getsockname()[1]
43+
return port
44+
45+
3546
@pytest.fixture(autouse=True)
3647
def use_te_debug(monkeypatch):
3748
monkeypatch.setenv("NVTE_DEBUG", "1")

bionemo-recipes/models/esm2/tests/test_convert.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ def test_convert_te_to_hf_roundtrip():
3838
torch.testing.assert_close(original_state_dict[key], converted_state_dict[key], atol=1e-5, rtol=1e-5)
3939

4040

41+
def test_load_from_converted_checkpoint(te_model_checkpoint):
42+
from esm.modeling_esm_te import NVEsmForMaskedLM
43+
44+
NVEsmForMaskedLM.from_pretrained(te_model_checkpoint)
45+
46+
4147
def test_qkv_unpacking():
4248
"""Test that QKV unpacking works correctly."""
4349
from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf

bionemo-recipes/models/esm2/tests/test_distributed_fp8.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@ def requires_fp8(func):
3838
"strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))]
3939
)
4040
@requires_fp8
41-
def test_single_process_attaches_correct_fp8_recipe(strategy):
41+
def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port):
4242
cmd = [
4343
"torchrun",
4444
"--nproc_per_node=1",
45+
"--rdzv-backend=c10d",
46+
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
4547
os.path.relpath(__file__),
4648
"--strategy",
4749
strategy,
@@ -66,10 +68,12 @@ def test_single_process_attaches_correct_fp8_recipe(strategy):
6668
)
6769
@requires_fp8
6870
@requires_multi_gpu
69-
def test_multi_process_fp8_recipes_are_synced(strategy):
71+
def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port):
7072
cmd = [
7173
"torchrun",
7274
"--nproc_per_node=2",
75+
"--rdzv-backend=c10d",
76+
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
7377
os.path.relpath(__file__),
7478
"--strategy",
7579
strategy,
@@ -207,11 +211,14 @@ def is_main_process(self) -> bool:
207211

208212
outputs.loss.backward()
209213

210-
fp8_extra_states = {key: val for key, val in model.state_dict().items() if key.endswith("_extra_state")}
211-
212-
# For some reason, this one doesn't get an fp8 recipe? It's the only te.LayerNorm.
213-
key = filter(lambda x: x.endswith("encoder.emb_layer_norm_after._extra_state"), fp8_extra_states.keys())
214-
fp8_extra_states.pop(next(key))
214+
# Access FP8 extra states directly from modules instead of state_dict()
215+
# since state_dict() now filters them out for HuggingFace compatibility
216+
fp8_extra_states = {}
217+
for name, module in model.named_modules():
218+
if hasattr(module, "_extra_state") and callable(module._extra_state):
219+
extra_state = module._extra_state()
220+
if extra_state is not None and len(extra_state) > 0:
221+
fp8_extra_states[f"{name}._extra_state"] = extra_state
215222

216223
# lm_head.dense and lm_head.decoder are BF16, not FP8, so exclude them from FP8 checks
217224
fp8_extra_states = {key: val for key, val in fp8_extra_states.items() if "lm_head." not in key}

bionemo-recipes/models/esm2/tests/test_distributed_strategies.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@
4141
],
4242
)
4343
@pytest.mark.parametrize("backend", ["te", "eager"])
44-
def test_ddp_vs_fsdp_single_gpu(strategy, backend):
44+
def test_ddp_vs_fsdp_single_gpu(strategy, backend, unused_tcp_port):
4545
cmd = [
4646
"torchrun",
4747
"--nproc_per_node=1",
48+
"--rdzv-backend=c10d",
49+
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
4850
os.path.relpath(__file__),
4951
"--strategy",
5052
strategy,
@@ -69,10 +71,12 @@ def test_ddp_vs_fsdp_single_gpu(strategy, backend):
6971
@requires_multi_gpu
7072
@pytest.mark.parametrize("strategy", ["fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2726"))])
7173
@pytest.mark.parametrize("backend", ["te", "eager"])
72-
def test_ddp_vs_fsdp_multi_gpu(strategy, backend):
74+
def test_ddp_vs_fsdp_multi_gpu(strategy, backend, unused_tcp_port):
7375
cmd = [
7476
"torchrun",
7577
"--nproc_per_node=2",
78+
"--rdzv-backend=c10d",
79+
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
7680
os.path.relpath(__file__),
7781
"--strategy",
7882
strategy,
@@ -160,20 +164,28 @@ def is_main_process(self) -> bool:
160164
return self.rank == 0
161165

162166
def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dist_config: DistributedConfig):
167+
# Set seed for reproducible model initialization across strategies
168+
torch.manual_seed(42)
169+
torch.cuda.manual_seed_all(42)
170+
163171
device_mesh = init_device_mesh(
164172
"cuda",
165-
mesh_shape=(dist_config.world_size,),
166-
mesh_dim_names=("dp",),
173+
mesh_shape=(dist_config.world_size, 1),
174+
mesh_dim_names=("dp", "tp"), # mfsdp requires us to give a tp mesh dimension.
167175
)
168176

169177
device = f"cuda:{dist_config.local_rank}"
170178

171179
if use_te:
172-
model = AutoModelForMaskedLM.from_pretrained(
173-
"nvidia/esm2_t6_8M_UR50D",
180+
# Import local model classes to avoid using outdated code from HF Hub
181+
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
182+
183+
config = NVEsmConfig.from_pretrained(
184+
"facebook/esm2_t6_8M_UR50D",
174185
dtype=torch.bfloat16,
175-
trust_remote_code=True,
186+
revision="c731040f",
176187
)
188+
model = NVEsmForMaskedLM(config)
177189
transformer_layers = model.esm.encoder.layers
178190
else:
179191
model = AutoModelForMaskedLM.from_pretrained(

bionemo-recipes/models/esm2/tests/test_meta_device_init.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@
4444
)
4545

4646

47-
def msg(x):
48-
return f"Mismatch in module {name}: {x}"
49-
50-
5147
def verify_model_parameters_initialized_correctly(
5248
model: NVEsmForMaskedLM, atol=1e-3, rtol=1e-4, should_be_fp8: bool = False
5349
):
@@ -57,6 +53,10 @@ def verify_model_parameters_initialized_correctly(
5753
assert str(parameter.device).startswith("cuda"), f"Parameter {name} is not on the cuda device"
5854

5955
for name, module in model.named_modules():
56+
57+
def msg(x):
58+
return f"Mismatch in module {name}: {x}"
59+
6060
if isinstance(module, torch.nn.Embedding):
6161
torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg)
6262
torch.testing.assert_close(
@@ -118,8 +118,12 @@ def verify_model_parameters_initialized_correctly(
118118
torch.testing.assert_close(module.inv_freq, expected_inv_freq, msg=msg)
119119

120120

121-
def verify_pretrained_model_sanity(model: NVEsmForTokenClassification, atol=1e-3, rtol=1e-4):
121+
def verify_pretrained_model_sanity(model: NVEsmForTokenClassification, atol=1e-2, rtol=1e-3):
122122
for name, p in model.named_parameters():
123+
124+
def msg(x):
125+
return f"Mismatch in parameter {name}: {x}"
126+
123127
assert p.numel() > 0, f"{name} is empty"
124128
assert torch.isfinite(p).all(), f"{name} has NaN/Inf"
125129

@@ -187,14 +191,12 @@ def test_meta_fp8_init(fp8_recipe):
187191

188192

189193
def test_model_for_token_classification_init(te_model_checkpoint):
190-
config = NVEsmConfig.from_pretrained(te_model_checkpoint, trust_remote_code=True)
191-
192194
set_seed(42)
193-
model = NVEsmForTokenClassification.from_pretrained(
194-
te_model_checkpoint, config=config, dtype=torch.bfloat16, trust_remote_code=True
195-
)
196-
model.to("cuda")
197195

196+
config = NVEsmConfig.from_pretrained(te_model_checkpoint)
197+
model = NVEsmForTokenClassification.from_pretrained(te_model_checkpoint, config=config, dtype=torch.bfloat16)
198+
# model.classifier.reset_parameters()
199+
model.to("cuda")
198200
verify_pretrained_model_sanity(model)
199201

200202

0 commit comments

Comments
 (0)