Skip to content

Commit 3073626

Browse files
committed
add esm2 gradient tests
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent a29272f commit 3073626

3 files changed

Lines changed: 255 additions & 5 deletions

File tree

models/esm2/pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ dependencies = [
1313
"fiddle",
1414
"hydra-core",
1515
"lightning",
16-
"megatron-core",
17-
"nemo_toolkit[lightning]==2.3.1",
16+
"megatron-core@git+https://github.com/NVIDIA/Megatron-LM.git", # Currently at ToT until mfsdp is in a release.
17+
"megatron-fsdp",
18+
"nemo_toolkit[lightning]", # tested with 2.3.1
1819
"omegaconf",
1920
"pytest",
2021
"torch",
21-
# "transformer_engine[pytorch]",
22+
"transformer_engine[pytorch]",
2223
"transformers",
2324
]
2425

models/esm2/tests/conftest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def tokenizer():
3434
return AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
3535

3636

37-
@pytest.fixture
38-
def input_data(tokenizer):
37+
def get_input_data(tokenizer):
3938
torch.manual_seed(42)
4039

4140
test_proteins = [
@@ -87,3 +86,8 @@ def tokenize_function(examples):
8786

8887
batch = next(iter(dataloader))
8988
return batch
89+
90+
91+
@pytest.fixture
92+
def input_data(tokenizer):
93+
return get_input_data(tokenizer)
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import logging
18+
import os
19+
import subprocess
20+
21+
import pytest
22+
import torch
23+
24+
25+
logger = logging.getLogger(__name__)
26+
logger.setLevel(logging.INFO)
27+
28+
29+
requires_multi_gpu = pytest.mark.skipif(
30+
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
31+
reason="Test requires at least 2 GPUs",
32+
)
33+
34+
35+
@pytest.mark.parametrize(
36+
"strategy",
37+
[
38+
"fsdp2",
39+
"mfsdp",
40+
],
41+
)
42+
@pytest.mark.parametrize("backend", ["te", "eager"])
43+
def test_ddp_vs_fsdp2_single_gpu(strategy, backend):
44+
cmd = [
45+
"torchrun",
46+
"--nproc_per_node=1",
47+
os.path.relpath(__file__),
48+
"--strategy",
49+
strategy,
50+
]
51+
if backend == "te":
52+
cmd.append("--test_te")
53+
54+
result = subprocess.run(
55+
cmd,
56+
check=False,
57+
text=True,
58+
stdout=subprocess.PIPE,
59+
stderr=subprocess.PIPE,
60+
timeout=240,
61+
)
62+
if result.returncode != 0:
63+
print(f"STDOUT:\n{result.stdout}")
64+
print(f"STDERR:\n{result.stderr}")
65+
pytest.fail(f"Command failed with exit code {result.returncode}")
66+
67+
68+
@requires_multi_gpu
69+
@pytest.mark.parametrize("strategy", ["fsdp2", "mfsdp"])
70+
@pytest.mark.parametrize("backend", ["te", "eager"])
71+
def test_ddp_vs_fsdp2_multi_gpu(strategy, backend):
72+
cmd = [
73+
"torchrun",
74+
"--nproc_per_node=2",
75+
os.path.relpath(__file__),
76+
"--strategy",
77+
strategy,
78+
]
79+
if backend == "te":
80+
cmd.append("--test_te")
81+
82+
result = subprocess.run(
83+
cmd,
84+
check=False,
85+
text=True,
86+
stdout=subprocess.PIPE,
87+
stderr=subprocess.PIPE,
88+
timeout=240,
89+
)
90+
if result.returncode != 0:
91+
print(f"STDOUT:\n{result.stdout}")
92+
print(f"STDERR:\n{result.stderr}")
93+
pytest.fail(f"Command failed with exit code {result.returncode}")
94+
95+
96+
if __name__ == "__main__":
97+
import argparse
98+
import enum
99+
from dataclasses import dataclass, field
100+
101+
import torch.distributed as dist
102+
import transformer_engine.pytorch
103+
import transformers
104+
from megatron_fsdp.fully_shard import fully_shard as megatron_fsdp_fully_shard
105+
from torch.distributed.device_mesh import init_device_mesh
106+
from torch.distributed.fsdp import fully_shard
107+
from torch.optim import AdamW
108+
from transformers import AutoModelForMaskedLM, AutoTokenizer
109+
110+
class Strategy(enum.StrEnum):
111+
DDP = "ddp"
112+
FSDP2 = "fsdp2"
113+
MFSDP = "mfsdp"
114+
115+
parser = argparse.ArgumentParser()
116+
parser.add_argument("--test_te", action="store_true", default=False)
117+
parser.add_argument("--strategy", type=Strategy, default=Strategy.FSDP2, choices=[Strategy.FSDP2, Strategy.MFSDP])
118+
args = parser.parse_args()
119+
120+
from conftest import get_input_data
121+
122+
input_data = get_input_data(AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D"))
123+
124+
@dataclass
125+
class DistributedConfig:
126+
"""Class to track distributed ranks."""
127+
128+
rank: int = field(default_factory=dist.get_rank)
129+
local_rank: int = field(default_factory=lambda: int(os.environ["LOCAL_RANK"]))
130+
world_size: int = field(default_factory=dist.get_world_size)
131+
132+
def is_main_process(self) -> bool:
133+
"""This is the global rank 0 process, to be used for wandb logging, etc."""
134+
return self.rank == 0
135+
136+
def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dist_config: DistributedConfig):
137+
device_mesh = init_device_mesh(
138+
"cuda",
139+
mesh_shape=(dist_config.world_size,),
140+
mesh_dim_names=("dp",),
141+
)
142+
143+
device = f"cuda:{dist_config.local_rank}"
144+
145+
if use_te:
146+
model = AutoModelForMaskedLM.from_pretrained(
147+
"nvidia/esm2_t6_8M_UR50D",
148+
torch_dtype=torch.bfloat16,
149+
trust_remote_code=True,
150+
)
151+
transformer_layers = model.esm.encoder.layers
152+
else:
153+
model = AutoModelForMaskedLM.from_pretrained(
154+
"facebook/esm2_t6_8M_UR50D",
155+
torch_dtype=torch.bfloat16,
156+
)
157+
transformer_layers = model.esm.encoder.layer
158+
del model.esm.contact_head # Unused in backwards pass.
159+
160+
if strategy is Strategy.FSDP2:
161+
for layer in transformer_layers:
162+
fully_shard(layer, mesh=device_mesh["dp"])
163+
fully_shard(model, mesh=device_mesh["dp"])
164+
model.to(device)
165+
166+
elif strategy is Strategy.DDP:
167+
model.to(device)
168+
model = torch.nn.parallel.DistributedDataParallel(
169+
model,
170+
device_ids=[dist_config.local_rank],
171+
output_device=dist_config.local_rank,
172+
device_mesh=device_mesh["dp"],
173+
)
174+
175+
optimizer = AdamW(model.parameters())
176+
177+
if strategy is Strategy.MFSDP:
178+
model, optimizer = megatron_fsdp_fully_shard(
179+
module=model,
180+
optimizer=optimizer,
181+
fsdp_unit_modules=[
182+
transformer_engine.pytorch.TransformerLayer,
183+
transformer_engine.pytorch.LayerNorm,
184+
transformer_engine.pytorch.LayerNormLinear,
185+
transformers.models.esm.modeling_esm.EsmLayer,
186+
],
187+
device_mesh=device_mesh,
188+
dp_shard_dim="dp",
189+
tp_dim="tp",
190+
sync_grads_each_step=True,
191+
preserve_fp32_weights=False, # TODO: cory, any idea why this is needed?
192+
)
193+
194+
model.train()
195+
input_data = {k: v.to(device) for k, v in input_data.items()}
196+
197+
optimizer.zero_grad()
198+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
199+
outputs = model(**input_data)
200+
outputs.loss.backward()
201+
202+
# get gradients
203+
if strategy is Strategy.FSDP2:
204+
grads = {name: p.grad.full_tensor() for name, p in model.named_parameters() if p.grad is not None}
205+
206+
elif strategy is Strategy.DDP:
207+
grads = {name: p.grad for name, p in model.module.named_parameters() if p.grad is not None}
208+
209+
elif strategy is Strategy.MFSDP:
210+
grads = {name: p.grad.full_tensor() for name, p in model.module.named_parameters() if p.grad is not None}
211+
212+
del model
213+
torch.cuda.empty_cache()
214+
return outputs, grads
215+
216+
dist.init_process_group(backend="nccl")
217+
dist_config = DistributedConfig()
218+
logger.info(f"Distributed config: {dist_config}")
219+
torch.cuda.set_device(dist_config.local_rank)
220+
221+
ddp, ddp_grads = run_forward_backward(
222+
use_te=args.test_te, strategy=Strategy.DDP, input_data=input_data, dist_config=dist_config
223+
)
224+
225+
fsdp, fsdp_grads = run_forward_backward(
226+
use_te=args.test_te, strategy=args.strategy, input_data=input_data, dist_config=dist_config
227+
)
228+
229+
torch.testing.assert_close(fsdp.loss, ddp.loss, msg=lambda x: f"Loss mismatch: {x}")
230+
torch.testing.assert_close(fsdp.logits, ddp.logits, msg=lambda x: f"Logits mismatch: {x}")
231+
232+
shared_grads = set(ddp_grads) & set(fsdp_grads)
233+
missing_grads = set(ddp_grads) ^ set(fsdp_grads)
234+
235+
assert not missing_grads, f"Missing gradients: {missing_grads}"
236+
237+
for name in shared_grads:
238+
ddp_grad = ddp_grads[name]
239+
fsdp_grad = fsdp_grads[name]
240+
torch.testing.assert_close(ddp_grad, fsdp_grad, msg=lambda x: f"Gradient mismatch for {name}: {x}")
241+
242+
# Check that the gradients are different when the last dimension is shuffled
243+
assert not torch.allclose(ddp_grad, torch.roll(fsdp_grad, -1, -1))
244+
245+
dist.destroy_process_group()

0 commit comments

Comments
 (0)