Skip to content

Commit 2d3d3c8

Browse files
authored
Add ESM-2 model gradient tests (#1077)
Adds gradient tests to ensure the gradients we get from ddp, fsdp2, and nvfsdp are consistent <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - Chores - Simplified image builds by switching to standard pip for editable installs. - Dependencies - Updated core ML dependencies and enabled/unpinned select components for better compatibility. - Tests - Added comprehensive distributed-training validation tests comparing strategies across single- and multi-GPU setups with automatic GPU gating and detailed checks. - Refactor - Extracted test data generation into a reusable helper to streamline test setup. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent be9c0db commit 2d3d3c8

5 files changed

Lines changed: 269 additions & 8 deletions

File tree

models/amplify/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
88
WORKDIR /workspace/bionemo
99
COPY . .
1010
RUN --mount=type=cache,target=/root/.cache/uv \
11-
uv pip install --system --break-system-packages -e .
11+
PIP_CONSTRAINT= pip install -e .

models/esm2/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
33
WORKDIR /workspace/bionemo
44
COPY . .
55
RUN --mount=type=cache,target=/root/.cache/uv \
6-
uv pip install --system --break-system-packages -e .
6+
PIP_CONSTRAINT= pip install -e .

models/esm2/pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ dependencies = [
1313
"fiddle",
1414
"hydra-core",
1515
"lightning",
16-
"megatron-core",
17-
"nemo_toolkit[lightning]==2.3.1",
16+
"megatron-core@git+https://github.com/NVIDIA/Megatron-LM.git", # Currently at ToT until mfsdp is in a release.
17+
"megatron-fsdp",
18+
"nemo_toolkit[lightning]", # tested with 2.3.1
1819
"omegaconf",
1920
"pytest",
2021
"torch",
21-
# "transformer_engine[pytorch]",
22-
"transformers<4.56", # TODO: fix me, currently failing with a modelopt import from nemo.
22+
"transformer_engine[pytorch]",
23+
"transformers<4.56", # TODO: fix me, currently failing with a modelopt import from nemo.
2324
]
2425

2526

models/esm2/tests/conftest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def tokenizer():
3434
return AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
3535

3636

37-
@pytest.fixture
38-
def input_data(tokenizer):
37+
def get_input_data(tokenizer):
3938
torch.manual_seed(42)
4039

4140
test_proteins = [
@@ -87,3 +86,8 @@ def tokenize_function(examples):
8786

8887
batch = next(iter(dataloader))
8988
return batch
89+
90+
91+
@pytest.fixture
92+
def input_data(tokenizer):
93+
return get_input_data(tokenizer)
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import logging
18+
import os
19+
import subprocess
20+
21+
import pytest
22+
import torch
23+
24+
25+
logger = logging.getLogger(__name__)
26+
logger.setLevel(logging.INFO)
27+
28+
29+
requires_multi_gpu = pytest.mark.skipif(
30+
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
31+
reason="Test requires at least 2 GPUs",
32+
)
33+
34+
35+
@pytest.mark.parametrize(
36+
"strategy",
37+
[
38+
"fsdp2",
39+
"mfsdp",
40+
],
41+
)
42+
@pytest.mark.parametrize("backend", ["te", "eager"])
43+
def test_ddp_vs_fsdp_single_gpu(strategy, backend):
44+
cmd = [
45+
"torchrun",
46+
"--nproc_per_node=1",
47+
os.path.relpath(__file__),
48+
"--strategy",
49+
strategy,
50+
]
51+
if backend == "te":
52+
cmd.append("--test_te")
53+
54+
result = subprocess.run(
55+
cmd,
56+
check=False,
57+
text=True,
58+
stdout=subprocess.PIPE,
59+
stderr=subprocess.PIPE,
60+
timeout=240,
61+
)
62+
if result.returncode != 0:
63+
print(f"STDOUT:\n{result.stdout}")
64+
print(f"STDERR:\n{result.stderr}")
65+
pytest.fail(f"Command failed with exit code {result.returncode}")
66+
67+
68+
@requires_multi_gpu
69+
@pytest.mark.parametrize("strategy", ["fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2726"))])
70+
@pytest.mark.parametrize("backend", ["te", "eager"])
71+
def test_ddp_vs_fsdp_multi_gpu(strategy, backend):
72+
cmd = [
73+
"torchrun",
74+
"--nproc_per_node=2",
75+
os.path.relpath(__file__),
76+
"--strategy",
77+
strategy,
78+
]
79+
if backend == "te":
80+
cmd.append("--test_te")
81+
82+
result = subprocess.run(
83+
cmd,
84+
check=False,
85+
text=True,
86+
stdout=subprocess.PIPE,
87+
stderr=subprocess.PIPE,
88+
timeout=240,
89+
)
90+
if result.returncode != 0:
91+
print(f"STDOUT:\n{result.stdout}")
92+
print(f"STDERR:\n{result.stderr}")
93+
pytest.fail(f"Command failed with exit code {result.returncode}")
94+
95+
96+
if __name__ == "__main__":
97+
import argparse
98+
import enum
99+
from dataclasses import dataclass, field
100+
101+
import torch.distributed as dist
102+
import transformer_engine.pytorch
103+
import transformers
104+
from megatron_fsdp.fully_shard import fully_shard as megatron_fsdp_fully_shard
105+
from torch.distributed.device_mesh import init_device_mesh
106+
from torch.distributed.fsdp import fully_shard
107+
from torch.optim import AdamW
108+
from transformers import AutoModelForMaskedLM, AutoTokenizer
109+
110+
class Strategy(enum.StrEnum):
111+
DDP = "ddp"
112+
FSDP2 = "fsdp2"
113+
MFSDP = "mfsdp"
114+
115+
parser = argparse.ArgumentParser()
116+
parser.add_argument("--test_te", action="store_true", default=False)
117+
parser.add_argument("--strategy", type=Strategy, default=Strategy.FSDP2, choices=[Strategy.FSDP2, Strategy.MFSDP])
118+
args = parser.parse_args()
119+
120+
from conftest import get_input_data
121+
122+
input_data = get_input_data(AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D"))
123+
124+
@dataclass
125+
class DistributedConfig:
126+
"""Class to track distributed ranks."""
127+
128+
rank: int = field(default_factory=dist.get_rank)
129+
local_rank: int = field(default_factory=lambda: int(os.environ["LOCAL_RANK"]))
130+
world_size: int = field(default_factory=dist.get_world_size)
131+
132+
def is_main_process(self) -> bool:
133+
"""This is the global rank 0 process, to be used for wandb logging, etc."""
134+
return self.rank == 0
135+
136+
def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dist_config: DistributedConfig):
137+
device_mesh = init_device_mesh(
138+
"cuda",
139+
mesh_shape=(dist_config.world_size,),
140+
mesh_dim_names=("dp",),
141+
)
142+
143+
device = f"cuda:{dist_config.local_rank}"
144+
145+
if use_te:
146+
model = AutoModelForMaskedLM.from_pretrained(
147+
"nvidia/esm2_t6_8M_UR50D",
148+
torch_dtype=torch.bfloat16,
149+
trust_remote_code=True,
150+
)
151+
transformer_layers = model.esm.encoder.layers
152+
else:
153+
model = AutoModelForMaskedLM.from_pretrained(
154+
"facebook/esm2_t6_8M_UR50D",
155+
torch_dtype=torch.bfloat16,
156+
)
157+
transformer_layers = model.esm.encoder.layer
158+
del model.esm.contact_head # Unused in backwards pass.
159+
160+
if strategy is Strategy.FSDP2:
161+
for layer in transformer_layers:
162+
fully_shard(layer, mesh=device_mesh["dp"])
163+
fully_shard(model, mesh=device_mesh["dp"])
164+
model.to(device)
165+
166+
elif strategy is Strategy.DDP:
167+
model.to(device)
168+
model = torch.nn.parallel.DistributedDataParallel(
169+
model,
170+
device_ids=[dist_config.local_rank],
171+
output_device=dist_config.local_rank,
172+
device_mesh=device_mesh["dp"],
173+
)
174+
175+
optimizer = AdamW(model.parameters())
176+
177+
if strategy is Strategy.MFSDP:
178+
model, optimizer = megatron_fsdp_fully_shard(
179+
module=model,
180+
optimizer=optimizer,
181+
fsdp_unit_modules=[
182+
transformer_engine.pytorch.TransformerLayer,
183+
transformer_engine.pytorch.LayerNorm,
184+
transformer_engine.pytorch.LayerNormLinear,
185+
transformers.models.esm.modeling_esm.EsmLayer,
186+
],
187+
device_mesh=device_mesh,
188+
dp_shard_dim="dp",
189+
tp_dim="tp",
190+
sync_grads_each_step=True,
191+
preserve_fp32_weights=False, # TODO: cory, any idea why this is needed?
192+
)
193+
194+
model.train()
195+
input_data = {k: v.to(device) for k, v in input_data.items()}
196+
197+
optimizer.zero_grad()
198+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
199+
outputs = model(**input_data)
200+
outputs.loss.backward()
201+
202+
# get gradients
203+
if strategy is Strategy.FSDP2:
204+
grads = {name: p.grad.full_tensor() for name, p in model.named_parameters() if p.grad is not None}
205+
206+
elif strategy is Strategy.DDP:
207+
grads = {name: p.grad for name, p in model.module.named_parameters() if p.grad is not None}
208+
209+
elif strategy is Strategy.MFSDP:
210+
# Because of uneven sharding, we need to manually gather the gradients.
211+
sharded_grads = [(name, p.grad) for name, p in model.module.named_parameters()]
212+
grads = {}
213+
for name, grad in sharded_grads:
214+
grad_shards = [None] * device_mesh["dp"].size()
215+
# For FSDP, we are not strided sharding, so gathering across dp_shard_cp is sufficient.
216+
# For HSDP, we need to first gather across dp_shard_cp, then gather across dp_inter,
217+
# not the other way around or you'll get wrong zig-zags.
218+
torch.distributed.all_gather_object(grad_shards, grad, group=device_mesh["dp"].get_group())
219+
all_valid_shards = [shard for shard in grad_shards if shard is not None]
220+
# Megatron-FSDP is always sharded across dim=0.
221+
grads[name] = torch.cat([s.to_local().to(device) for s in all_valid_shards], dim=0)
222+
223+
del model
224+
torch.cuda.empty_cache()
225+
return outputs, grads
226+
227+
dist.init_process_group(backend="nccl")
228+
dist_config = DistributedConfig()
229+
logger.info(f"Distributed config: {dist_config}")
230+
torch.cuda.set_device(dist_config.local_rank)
231+
232+
ddp, ddp_grads = run_forward_backward(
233+
use_te=args.test_te, strategy=Strategy.DDP, input_data=input_data, dist_config=dist_config
234+
)
235+
236+
fsdp, fsdp_grads = run_forward_backward(
237+
use_te=args.test_te, strategy=args.strategy, input_data=input_data, dist_config=dist_config
238+
)
239+
240+
torch.testing.assert_close(fsdp.loss, ddp.loss, msg=lambda x: f"Loss mismatch: {x}")
241+
torch.testing.assert_close(fsdp.logits, ddp.logits, msg=lambda x: f"Logits mismatch: {x}")
242+
243+
shared_grads = set(ddp_grads) & set(fsdp_grads)
244+
missing_grads = set(ddp_grads) ^ set(fsdp_grads)
245+
246+
assert not missing_grads, f"Missing gradients: {missing_grads}"
247+
248+
for name in shared_grads:
249+
ddp_grad = ddp_grads[name]
250+
fsdp_grad = fsdp_grads[name]
251+
torch.testing.assert_close(ddp_grad, fsdp_grad, msg=lambda x: f"Gradient mismatch for {name}: {x}")
252+
253+
# Check that the gradients are different when the last dimension is shuffled
254+
assert not torch.allclose(ddp_grad, torch.roll(fsdp_grad, -1, -1))
255+
256+
dist.destroy_process_group()

0 commit comments

Comments
 (0)