2828from train_mfsdp import main as main_mfsdp
2929
3030
31- random .seed (42 )
32- torch .manual_seed (42 )
33- if torch .cuda .is_available ():
34- torch .cuda .manual_seed_all (42 )
31+ @pytest .fixture (autouse = True )
32+ def set_seed ():
33+ random .seed (42 )
34+ torch .manual_seed (42 )
35+ if torch .cuda .is_available ():
36+ torch .cuda .manual_seed_all (42 )
3537
3638
3739requires_multi_gpu = pytest .mark .skipif (
@@ -85,7 +87,7 @@ def mock_distributed_config(monkeypatch):
8587 _mesh_resources .mesh_dim_group_options .clear ()
8688
8789
88- def test_main_invocation_mfsdp (mock_distributed_config , tmp_path ):
90+ def test_sanity_convergence_mfsdp (mock_distributed_config , tmp_path ):
8991 """Test that the main function can be invoked with the correct arguments."""
9092
9193 # Run the training script with Hydra configuration overrides
@@ -96,8 +98,8 @@ def test_main_invocation_mfsdp(mock_distributed_config, tmp_path):
9698 assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
9799
98100
99- @pytest .mark .xfail (reason = "MFSDP meta-device init seems to be failing with this model (BIONEMO-2583)" )
100- def test_main_invocation_mfsdp_meta_device (mock_distributed_config , tmp_path ):
101+ @pytest .mark .xfail (reason = "MFSDP meta-device init seems to be failing with both TE and eager models (BIONEMO-2583)" )
102+ def test_sanity_convergence_mfsdp_meta_device (mock_distributed_config , tmp_path ):
101103 """Test that the main function can be invoked with the correct arguments."""
102104
103105 # Run the training script with Hydra configuration overrides
@@ -106,15 +108,34 @@ def test_main_invocation_mfsdp_meta_device(mock_distributed_config, tmp_path):
106108 config_name = "L0_sanity" ,
107109 overrides = [
108110 f"+wandb_init_args.dir={ tmp_path } " ,
109- "fully_shard_kwargs.init_model_with_meta_device =true" ,
111+ "use_meta_device =true" ,
110112 ],
111113 )
112114
113115 final_loss = main_mfsdp (sanity_config )
114116 assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
115117
116118
117- def test_main_invocation_ddp (mock_distributed_config , tmp_path ):
119+ @pytest .mark .xfail (reason = "MFSDP meta-device init seems to be failing with both TE and eager models (BIONEMO-2583)" )
120+ def test_sanity_convergence_mfsdp_eager_meta_device (mock_distributed_config , tmp_path ):
121+ """Test that the main function can be invoked with the correct arguments."""
122+
123+ # Run the training script with Hydra configuration overrides
124+ with initialize_config_dir (config_dir = str (recipe_dir / "hydra_config" ), version_base = "1.2" ):
125+ sanity_config = compose (
126+ config_name = "L0_sanity" ,
127+ overrides = [
128+ f"+wandb_init_args.dir={ tmp_path } " ,
129+ "model_name=facebook/esm2_t6_8M_UR50D" ,
130+ "use_meta_device=true" ,
131+ ],
132+ )
133+
134+ final_loss = main_mfsdp (sanity_config )
135+ assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
136+
137+
138+ def test_sanity_convergence_ddp (mock_distributed_config , tmp_path ):
118139 """Test that the main function can be invoked wrapping the model in DDP."""
119140
120141 # Run the training script with Hydra configuration overrides
@@ -125,18 +146,41 @@ def test_main_invocation_ddp(mock_distributed_config, tmp_path):
125146 assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
126147
127148
128- def test_main_invocation_fsdp2 (mock_distributed_config , tmp_path ):
149+ def test_sanity_convergence_fsdp2 (mock_distributed_config , tmp_path ):
129150 """Test that the main function can be invoked wrapping the model in FSDP2."""
130151
131152 # Run the training script with Hydra configuration overrides
132153 with initialize_config_dir (config_dir = str (recipe_dir / "hydra_config" ), version_base = "1.2" ):
133- sanity_config = compose (config_name = "L0_sanity" , overrides = [f"+wandb_init_args.dir={ tmp_path } " ])
154+ sanity_config = compose (
155+ config_name = "L0_sanity" ,
156+ overrides = [
157+ f"+wandb_init_args.dir={ tmp_path } " ,
158+ ],
159+ )
134160
135161 final_loss = main_fsdp2 (sanity_config )
136162 assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
137163
138164
139- def test_main_invocation_mfsdp_eager (mock_distributed_config , tmp_path ):
165+ @pytest .mark .xfail (reason = "FSDP2 meta-device init seems doesn't have the same convergence (BIONEMO-2719)" )
166+ def test_sanity_convergence_fsdp2_meta_device (mock_distributed_config , tmp_path ):
167+ """Test that the main function can be invoked wrapping the model in FSDP2."""
168+
169+ # Run the training script with Hydra configuration overrides
170+ with initialize_config_dir (config_dir = str (recipe_dir / "hydra_config" ), version_base = "1.2" ):
171+ sanity_config = compose (
172+ config_name = "L0_sanity" ,
173+ overrides = [
174+ f"+wandb_init_args.dir={ tmp_path } " ,
175+ "use_meta_device=true" ,
176+ ],
177+ )
178+
179+ final_loss = main_fsdp2 (sanity_config )
180+ assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
181+
182+
183+ def test_sanity_convergence_mfsdp_eager (mock_distributed_config , tmp_path ):
140184 """Test that the main function can be invoked with the correct arguments."""
141185
142186 # Run the training script with Hydra configuration overrides
@@ -150,7 +194,7 @@ def test_main_invocation_mfsdp_eager(mock_distributed_config, tmp_path):
150194 assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
151195
152196
153- def test_main_invocation_ddp_eager (mock_distributed_config , tmp_path ):
197+ def test_sanity_convergence_ddp_eager (mock_distributed_config , tmp_path ):
154198 """Test that the main function can be invoked wrapping the model in DDP."""
155199
156200 # Run the training script with Hydra configuration overrides
@@ -164,7 +208,7 @@ def test_main_invocation_ddp_eager(mock_distributed_config, tmp_path):
164208 assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
165209
166210
167- def test_main_invocation_fsdp2_eager (mock_distributed_config , tmp_path ):
211+ def test_sanity_convergence_fsdp2_eager (mock_distributed_config , tmp_path ):
168212 """Test that the main function can be invoked wrapping the model in FSDP2."""
169213
170214 # Run the training script with Hydra configuration overrides
@@ -178,6 +222,28 @@ def test_main_invocation_fsdp2_eager(mock_distributed_config, tmp_path):
178222 assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
179223
180224
225+ @pytest .mark .xfail (reason = "This passes on my local 5090 but fails on CI (L4) (BIONEMO-2719)" )
226+ def test_sanity_convergence_fsdp2_eager_meta_device (mock_distributed_config , tmp_path ):
227+ """Test that the main function can be invoked wrapping the model in FSDP2 and using meta-device init."""
228+
229+ # Run the training script with Hydra configuration overrides
230+ with initialize_config_dir (config_dir = str (recipe_dir / "hydra_config" ), version_base = "1.2" ):
231+ sanity_config = compose (
232+ config_name = "L0_sanity" ,
233+ overrides = [
234+ f"+wandb_init_args.dir={ tmp_path } " ,
235+ "model_name=facebook/esm2_t6_8M_UR50D" ,
236+ "use_meta_device=true" ,
237+ ],
238+ )
239+
240+ final_loss = main_fsdp2 (sanity_config )
241+ assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
242+
243+
244+ # These tests don't check convergence, they just check that the training script runs successfully on multiple GPUs.
245+
246+
181247@requires_multi_gpu
182248def test_multi_gpu_train_te_ddp (tmp_path ):
183249 # Run 'accelerate launch train.py' as a subprocess
@@ -197,7 +263,7 @@ def test_multi_gpu_train_te_ddp(tmp_path):
197263
198264
199265@requires_multi_gpu
200- def test_multi_gpu_train_te_mfsdp_no_meta_device (tmp_path ):
266+ def test_multi_gpu_train_te_mfsdp (tmp_path ):
201267 # Run 'accelerate launch train.py' as a subprocess
202268 run_train_cmd (
203269 [
@@ -209,14 +275,13 @@ def test_multi_gpu_train_te_mfsdp_no_meta_device(tmp_path):
209275 "train_mfsdp.py" ,
210276 "--config-name" ,
211277 "L0_sanity" ,
212- "fully_shard_kwargs.init_model_with_meta_device=false" ,
213278 "num_train_steps=4" ,
214279 ]
215280 )
216281
217282
218283@requires_multi_gpu
219- def test_multi_gpu_train_eager_mfsdp (tmp_path ):
284+ def test_multi_gpu_train_te_fsdp2 (tmp_path ):
220285 # Run 'accelerate launch train.py' as a subprocess
221286 run_train_cmd (
222287 [
@@ -225,17 +290,16 @@ def test_multi_gpu_train_eager_mfsdp(tmp_path):
225290 "2" ,
226291 "--master_port" ,
227292 f"{ random .randint (20000 , 40000 )} " ,
228- "train_mfsdp .py" ,
293+ "train_fsdp2 .py" ,
229294 "--config-name" ,
230295 "L0_sanity" ,
231- "model_name=facebook/esm2_t6_8M_UR50D" ,
232296 "num_train_steps=4" ,
233297 ]
234298 )
235299
236300
237301@requires_multi_gpu
238- def test_multi_gpu_train_te_fsdp2 (tmp_path ):
302+ def test_multi_gpu_train_eager_fsdp2_meta_device (tmp_path ):
239303 # Run 'accelerate launch train.py' as a subprocess
240304 run_train_cmd (
241305 [
@@ -247,6 +311,8 @@ def test_multi_gpu_train_te_fsdp2(tmp_path):
247311 "train_fsdp2.py" ,
248312 "--config-name" ,
249313 "L0_sanity" ,
314+ "model_name=facebook/esm2_t6_8M_UR50D" ,
315+ "use_meta_device=true" ,
250316 "num_train_steps=4" ,
251317 ]
252318 )
0 commit comments