Skip to content

Commit b794224

Browse files
authored
ESM-2 mfsdp recipe expanded tests (#1101)
Adds additional tests to ESM-2 mfsdp recipe to characterize where convergence issues are occurring, also sets the seed for the THD recipe to hopefully avoid flaky errors seen on nightly <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Added a config toggle to enable/disable meta-device initialization. - Introduced new sharded-training options (overlap reduce/param gather, per-step sync, collective averaging). - Enabled use of Nvidia-hosted ESM2 models. - Refactor - Progress bar now shows a precomputed loss value for consistency. - Updated recipe defaults, including reduced micro-batch sizes for 650M and 3B variants. - Tests - Reorganized suite around convergence checks; added meta-device and eager scenarios, multi-GPU cases, and xfail markers. - Centralized seeding via fixture for reproducibility. - Chores - Removed redundant optimizer settings. - Updated run naming conventions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 4a47964 commit b794224

9 files changed

Lines changed: 122 additions & 45 deletions

File tree

recipes/esm2_native_te_mfsdp/hydra_config/L0_sanity.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,3 @@ wandb_init_args:
1414
# Learning rate scheduler config
1515
lr_scheduler_kwargs:
1616
num_warmup_steps: 0
17-
18-
adamw_kwargs:
19-
lr: 1e-2

recipes/esm2_native_te_mfsdp/hydra_config/L1_3B_ddp.yaml renamed to recipes/esm2_native_te_mfsdp/hydra_config/L1_3B.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ defaults:
22
- defaults
33

44
# Training config
5-
model_name: esm2_t33_650M_UR50D
6-
micro_batch_size: 32
5+
model_name: nvidia/esm2_t36_3B_UR50D
6+
micro_batch_size: 16
77
num_train_steps: 10_000
88

99
# WandB config

recipes/esm2_native_te_mfsdp/hydra_config/L1_650M.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ defaults:
33
- _self_
44

55
# Training config
6-
model_name: esm2_t33_650M_UR50D
7-
micro_batch_size: 16
6+
model_name: nvidia/esm2_t33_650M_UR50D
7+
micro_batch_size: 4
88
num_train_steps: 200
99

1010
# WandB config
1111
wandb_init_args:
12-
name: "esm2_t33_650M_UR50D_nvfsdp"
12+
name: "esm2_t33_650M_UR50D_mfsdp"
1313
project: "bionemo-recipes-pstjohn"
1414
mode: "offline"

recipes/esm2_native_te_mfsdp/hydra_config/defaults.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,19 @@ max_seq_length: 1024
55
data_path: .
66
num_train_steps: ???
77

8+
# TODO: Once BIONEMO-2583 and BIONEMO-2719 are fixed, enable this by default and simplify training scripts to remove the
9+
# meta-device conditional.
10+
use_meta_device: false
11+
812
# WandB config
913
wandb_init_args:
1014
name: ???
1115

12-
# nvFSDP config
16+
# mFSDP config
1317
fully_shard_kwargs:
1418
zero_dp_strategy: "optim_grads_params"
1519
calculate_per_token_loss: false
16-
init_model_with_meta_device: false
20+
init_model_with_meta_device: ${use_meta_device}
1721
check_for_nan_in_grad: true
1822
grad_reduce_in_fp32: false
1923
preserve_fp32_weights: true

recipes/esm2_native_te_mfsdp/test_train.py

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
from train_mfsdp import main as main_mfsdp
2929

3030

31-
random.seed(42)
32-
torch.manual_seed(42)
33-
if torch.cuda.is_available():
34-
torch.cuda.manual_seed_all(42)
31+
@pytest.fixture(autouse=True)
32+
def set_seed():
33+
random.seed(42)
34+
torch.manual_seed(42)
35+
if torch.cuda.is_available():
36+
torch.cuda.manual_seed_all(42)
3537

3638

3739
requires_multi_gpu = pytest.mark.skipif(
@@ -85,7 +87,7 @@ def mock_distributed_config(monkeypatch):
8587
_mesh_resources.mesh_dim_group_options.clear()
8688

8789

88-
def test_main_invocation_mfsdp(mock_distributed_config, tmp_path):
90+
def test_sanity_convergence_mfsdp(mock_distributed_config, tmp_path):
8991
"""Test that the main function can be invoked with the correct arguments."""
9092

9193
# Run the training script with Hydra configuration overrides
@@ -96,8 +98,8 @@ def test_main_invocation_mfsdp(mock_distributed_config, tmp_path):
9698
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
9799

98100

99-
@pytest.mark.xfail(reason="MFSDP meta-device init seems to be failing with this model (BIONEMO-2583)")
100-
def test_main_invocation_mfsdp_meta_device(mock_distributed_config, tmp_path):
101+
@pytest.mark.xfail(reason="MFSDP meta-device init seems to be failing with both TE and eager models (BIONEMO-2583)")
102+
def test_sanity_convergence_mfsdp_meta_device(mock_distributed_config, tmp_path):
101103
"""Test that the main function can be invoked with the correct arguments."""
102104

103105
# Run the training script with Hydra configuration overrides
@@ -106,15 +108,34 @@ def test_main_invocation_mfsdp_meta_device(mock_distributed_config, tmp_path):
106108
config_name="L0_sanity",
107109
overrides=[
108110
f"+wandb_init_args.dir={tmp_path}",
109-
"fully_shard_kwargs.init_model_with_meta_device=true",
111+
"use_meta_device=true",
110112
],
111113
)
112114

113115
final_loss = main_mfsdp(sanity_config)
114116
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
115117

116118

117-
def test_main_invocation_ddp(mock_distributed_config, tmp_path):
119+
@pytest.mark.xfail(reason="MFSDP meta-device init seems to be failing with both TE and eager models (BIONEMO-2583)")
120+
def test_sanity_convergence_mfsdp_eager_meta_device(mock_distributed_config, tmp_path):
121+
"""Test that the main function can be invoked with the correct arguments."""
122+
123+
# Run the training script with Hydra configuration overrides
124+
with initialize_config_dir(config_dir=str(recipe_dir / "hydra_config"), version_base="1.2"):
125+
sanity_config = compose(
126+
config_name="L0_sanity",
127+
overrides=[
128+
f"+wandb_init_args.dir={tmp_path}",
129+
"model_name=facebook/esm2_t6_8M_UR50D",
130+
"use_meta_device=true",
131+
],
132+
)
133+
134+
final_loss = main_mfsdp(sanity_config)
135+
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
136+
137+
138+
def test_sanity_convergence_ddp(mock_distributed_config, tmp_path):
118139
"""Test that the main function can be invoked wrapping the model in DDP."""
119140

120141
# Run the training script with Hydra configuration overrides
@@ -125,18 +146,41 @@ def test_main_invocation_ddp(mock_distributed_config, tmp_path):
125146
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
126147

127148

128-
def test_main_invocation_fsdp2(mock_distributed_config, tmp_path):
149+
def test_sanity_convergence_fsdp2(mock_distributed_config, tmp_path):
129150
"""Test that the main function can be invoked wrapping the model in FSDP2."""
130151

131152
# Run the training script with Hydra configuration overrides
132153
with initialize_config_dir(config_dir=str(recipe_dir / "hydra_config"), version_base="1.2"):
133-
sanity_config = compose(config_name="L0_sanity", overrides=[f"+wandb_init_args.dir={tmp_path}"])
154+
sanity_config = compose(
155+
config_name="L0_sanity",
156+
overrides=[
157+
f"+wandb_init_args.dir={tmp_path}",
158+
],
159+
)
134160

135161
final_loss = main_fsdp2(sanity_config)
136162
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
137163

138164

139-
def test_main_invocation_mfsdp_eager(mock_distributed_config, tmp_path):
165+
@pytest.mark.xfail(reason="FSDP2 meta-device init seems doesn't have the same convergence (BIONEMO-2719)")
166+
def test_sanity_convergence_fsdp2_meta_device(mock_distributed_config, tmp_path):
167+
"""Test that the main function can be invoked wrapping the model in FSDP2."""
168+
169+
# Run the training script with Hydra configuration overrides
170+
with initialize_config_dir(config_dir=str(recipe_dir / "hydra_config"), version_base="1.2"):
171+
sanity_config = compose(
172+
config_name="L0_sanity",
173+
overrides=[
174+
f"+wandb_init_args.dir={tmp_path}",
175+
"use_meta_device=true",
176+
],
177+
)
178+
179+
final_loss = main_fsdp2(sanity_config)
180+
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
181+
182+
183+
def test_sanity_convergence_mfsdp_eager(mock_distributed_config, tmp_path):
140184
"""Test that the main function can be invoked with the correct arguments."""
141185

142186
# Run the training script with Hydra configuration overrides
@@ -150,7 +194,7 @@ def test_main_invocation_mfsdp_eager(mock_distributed_config, tmp_path):
150194
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
151195

152196

153-
def test_main_invocation_ddp_eager(mock_distributed_config, tmp_path):
197+
def test_sanity_convergence_ddp_eager(mock_distributed_config, tmp_path):
154198
"""Test that the main function can be invoked wrapping the model in DDP."""
155199

156200
# Run the training script with Hydra configuration overrides
@@ -164,7 +208,7 @@ def test_main_invocation_ddp_eager(mock_distributed_config, tmp_path):
164208
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
165209

166210

167-
def test_main_invocation_fsdp2_eager(mock_distributed_config, tmp_path):
211+
def test_sanity_convergence_fsdp2_eager(mock_distributed_config, tmp_path):
168212
"""Test that the main function can be invoked wrapping the model in FSDP2."""
169213

170214
# Run the training script with Hydra configuration overrides
@@ -178,6 +222,28 @@ def test_main_invocation_fsdp2_eager(mock_distributed_config, tmp_path):
178222
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
179223

180224

225+
@pytest.mark.xfail(reason="This passes on my local 5090 but fails on CI (L4) (BIONEMO-2719)")
226+
def test_sanity_convergence_fsdp2_eager_meta_device(mock_distributed_config, tmp_path):
227+
"""Test that the main function can be invoked wrapping the model in FSDP2 and using meta-device init."""
228+
229+
# Run the training script with Hydra configuration overrides
230+
with initialize_config_dir(config_dir=str(recipe_dir / "hydra_config"), version_base="1.2"):
231+
sanity_config = compose(
232+
config_name="L0_sanity",
233+
overrides=[
234+
f"+wandb_init_args.dir={tmp_path}",
235+
"model_name=facebook/esm2_t6_8M_UR50D",
236+
"use_meta_device=true",
237+
],
238+
)
239+
240+
final_loss = main_fsdp2(sanity_config)
241+
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
242+
243+
244+
# These tests don't check convergence, they just check that the training script runs successfully on multiple GPUs.
245+
246+
181247
@requires_multi_gpu
182248
def test_multi_gpu_train_te_ddp(tmp_path):
183249
# Run 'accelerate launch train.py' as a subprocess
@@ -197,7 +263,7 @@ def test_multi_gpu_train_te_ddp(tmp_path):
197263

198264

199265
@requires_multi_gpu
200-
def test_multi_gpu_train_te_mfsdp_no_meta_device(tmp_path):
266+
def test_multi_gpu_train_te_mfsdp(tmp_path):
201267
# Run 'accelerate launch train.py' as a subprocess
202268
run_train_cmd(
203269
[
@@ -209,14 +275,13 @@ def test_multi_gpu_train_te_mfsdp_no_meta_device(tmp_path):
209275
"train_mfsdp.py",
210276
"--config-name",
211277
"L0_sanity",
212-
"fully_shard_kwargs.init_model_with_meta_device=false",
213278
"num_train_steps=4",
214279
]
215280
)
216281

217282

218283
@requires_multi_gpu
219-
def test_multi_gpu_train_eager_mfsdp(tmp_path):
284+
def test_multi_gpu_train_te_fsdp2(tmp_path):
220285
# Run 'accelerate launch train.py' as a subprocess
221286
run_train_cmd(
222287
[
@@ -225,17 +290,16 @@ def test_multi_gpu_train_eager_mfsdp(tmp_path):
225290
"2",
226291
"--master_port",
227292
f"{random.randint(20000, 40000)}",
228-
"train_mfsdp.py",
293+
"train_fsdp2.py",
229294
"--config-name",
230295
"L0_sanity",
231-
"model_name=facebook/esm2_t6_8M_UR50D",
232296
"num_train_steps=4",
233297
]
234298
)
235299

236300

237301
@requires_multi_gpu
238-
def test_multi_gpu_train_te_fsdp2(tmp_path):
302+
def test_multi_gpu_train_eager_fsdp2_meta_device(tmp_path):
239303
# Run 'accelerate launch train.py' as a subprocess
240304
run_train_cmd(
241305
[
@@ -247,6 +311,8 @@ def test_multi_gpu_train_te_fsdp2(tmp_path):
247311
"train_fsdp2.py",
248312
"--config-name",
249313
"L0_sanity",
314+
"model_name=facebook/esm2_t6_8M_UR50D",
315+
"use_meta_device=true",
250316
"num_train_steps=4",
251317
]
252318
)

recipes/esm2_native_te_mfsdp/train_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def main(args: DictConfig) -> float | None:
170170
)
171171

172172
progress_bar.update(1)
173-
progress_bar.set_postfix({"loss": loss.item()})
173+
progress_bar.set_postfix({"loss": loss_value})
174174

175175
# Clean up distributed training
176176
if dist_config.is_main_process():

recipes/esm2_native_te_mfsdp/train_fsdp2.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
import os
1818
import time
19+
from contextlib import nullcontext
1920
from dataclasses import dataclass, field
2021

2122
import hydra
@@ -85,7 +86,7 @@ def main(args: DictConfig) -> float | None: # noqa: C901
8586
# Create an empty ESM-2 model with a masked language model head.
8687
if "facebook" in args.model_name:
8788
config = AutoConfig.from_pretrained(args.model_name, dtype=torch.bfloat16)
88-
with torch.device("meta"):
89+
with torch.device("meta") if args.use_meta_device else nullcontext():
8990
model = AutoModelForMaskedLM.from_config(config, attn_implementation="flash_attention_2")
9091
del model.esm.contact_head
9192
transformer_stack = model.esm.encoder.layer
@@ -94,7 +95,7 @@ def main(args: DictConfig) -> float | None: # noqa: C901
9495
config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True, dtype=torch.bfloat16)
9596
config.max_seq_length = args.max_seq_length
9697
config.micro_batch_size = args.micro_batch_size
97-
with torch.device("meta"):
98+
with torch.device("meta") if args.use_meta_device else nullcontext():
9899
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True)
99100
transformer_stack = model.esm.encoder.layers
100101

@@ -111,10 +112,11 @@ def main(args: DictConfig) -> float | None: # noqa: C901
111112
optimizer = AdamW(model.parameters(), **args.adamw_kwargs)
112113
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
113114

114-
model.to_empty(device=device)
115-
for module in model.modules():
116-
if hasattr(module, "reset_parameters"):
117-
module.reset_parameters()
115+
if args.use_meta_device:
116+
model.to_empty(device=device)
117+
for module in model.modules():
118+
if hasattr(module, "reset_parameters"):
119+
module.reset_parameters()
118120

119121
# Training loop.
120122
model.train()

recipes/esm2_native_te_mfsdp/train_mfsdp.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
import os
1818
import time
19+
from contextlib import nullcontext
1920
from dataclasses import dataclass, field
2021

2122
import hydra
@@ -86,15 +87,19 @@ def main(args: DictConfig) -> float | None:
8687
config = AutoConfig.from_pretrained(args.model_name, dtype=torch.bfloat16)
8788
from transformers.models.esm.modeling_esm import EsmForMaskedLM # noqa: F401
8889

89-
with torch.device("meta" if args.fully_shard_kwargs.get("init_model_with_meta_device", True) else device):
90+
with (
91+
torch.device("meta") if args.fully_shard_kwargs.get("init_model_with_meta_device", True) else nullcontext()
92+
):
9093
model = AutoModelForMaskedLM.from_config(config, attn_implementation="flash_attention_2")
91-
del model.esm.contact_head
94+
del model.esm.contact_head
9295

9396
else:
9497
config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True, dtype=torch.bfloat16)
9598
config.max_seq_length = args.max_seq_length
9699
config.micro_batch_size = args.micro_batch_size
97-
with torch.device("meta" if args.fully_shard_kwargs.get("init_model_with_meta_device", True) else device):
100+
with (
101+
torch.device("meta") if args.fully_shard_kwargs.get("init_model_with_meta_device", True) else nullcontext()
102+
):
98103
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True)
99104

100105
# Log model and number of parameters on main process.
@@ -188,7 +193,7 @@ def main(args: DictConfig) -> float | None:
188193
)
189194

190195
progress_bar.update(1)
191-
progress_bar.set_postfix({"loss": loss.item()})
196+
progress_bar.set_postfix({"loss": loss_value})
192197

193198
# Clean up distributed training
194199
if dist_config.is_main_process():

0 commit comments

Comments
 (0)