@@ -364,12 +364,8 @@ def test_sanity_convergence_fsdp2_fp8_and_model_init(tmp_path, recipe_path):
364364 assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
365365
366366
367- def test_sanity_convergence_fsdp2_thd (tmp_path , monkeypatch , recipe_path ):
367+ def test_sanity_convergence_fsdp2_thd (tmp_path , recipe_path ):
368368 """For FSDP2, we check that the script can run successfully with FP8 and check convergence."""
369- if torch .cuda .get_device_capability () == (12 , 0 ):
370- # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
371- # but it's missing this THD implementation.
372- monkeypatch .setenv ("NVTE_FUSED_ATTN" , "0" )
373369
374370 with initialize_config_dir (config_dir = str (recipe_path / "hydra_config" ), version_base = "1.2" ):
375371 sanity_config = compose (
@@ -386,12 +382,8 @@ def test_sanity_convergence_fsdp2_thd(tmp_path, monkeypatch, recipe_path):
386382
387383
388384@requires_fp8
389- def test_sanity_convergence_fsdp2_thd_fp8 (tmp_path , monkeypatch , recipe_path ):
385+ def test_sanity_convergence_fsdp2_thd_fp8 (tmp_path , recipe_path ):
390386 """For FSDP2, we check that the script can run successfully with THD + FP8 and check convergence."""
391- if torch .cuda .get_device_capability () == (12 , 0 ):
392- # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
393- # but it's missing this THD implementation.
394- monkeypatch .setenv ("NVTE_FUSED_ATTN" , "0" )
395387
396388 with initialize_config_dir (config_dir = str (recipe_path / "hydra_config" ), version_base = "1.2" ):
397389 sanity_config = compose (
@@ -408,11 +400,7 @@ def test_sanity_convergence_fsdp2_thd_fp8(tmp_path, monkeypatch, recipe_path):
408400 assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
409401
410402
411- def test_sanity_ddp_thd (tmp_path , monkeypatch , recipe_path ):
412- if torch .cuda .get_device_capability () == (12 , 0 ):
413- # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
414- # but it's missing this THD implementation.
415- monkeypatch .setenv ("NVTE_FUSED_ATTN" , "0" )
403+ def test_sanity_ddp_thd (tmp_path , recipe_path ):
416404
417405 # For DDP, we only check that the script can run successfully with THD, not convergence.
418406 with initialize_config_dir (config_dir = str (recipe_path / "hydra_config" ), version_base = "1.2" ):
@@ -429,11 +417,7 @@ def test_sanity_ddp_thd(tmp_path, monkeypatch, recipe_path):
429417 main_ddp (sanity_config )
430418
431419
432- def test_sanity_mfsdp_thd (tmp_path , monkeypatch , recipe_path ):
433- if torch .cuda .get_device_capability () == (12 , 0 ):
434- # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
435- # but it's missing this THD implementation.
436- monkeypatch .setenv ("NVTE_FUSED_ATTN" , "0" )
420+ def test_sanity_mfsdp_thd (tmp_path , recipe_path ):
437421
438422 # For MFSDP, we only check that the script can run successfully with THD, not convergence.
439423 with initialize_config_dir (config_dir = str (recipe_path / "hydra_config" ), version_base = "1.2" ):
@@ -451,11 +435,7 @@ def test_sanity_mfsdp_thd(tmp_path, monkeypatch, recipe_path):
451435
452436
453437@requires_fp8
454- def test_sanity_ddp_thd_fp8 (tmp_path , monkeypatch , recipe_path ):
455- if torch .cuda .get_device_capability () == (12 , 0 ):
456- # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
457- # but it's missing this THD implementation.
458- monkeypatch .setenv ("NVTE_FUSED_ATTN" , "0" )
438+ def test_sanity_ddp_thd_fp8 (tmp_path , recipe_path ):
459439
460440 # For DDP, we only check that the script can run successfully with THD, not convergence.
461441 with initialize_config_dir (config_dir = str (recipe_path / "hydra_config" ), version_base = "1.2" ):
@@ -474,11 +454,7 @@ def test_sanity_ddp_thd_fp8(tmp_path, monkeypatch, recipe_path):
474454
475455
476456@requires_fp8
477- def test_sanity_mfsdp_thd_fp8 (tmp_path , monkeypatch , recipe_path ):
478- if torch .cuda .get_device_capability () == (12 , 0 ):
479- # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
480- # but it's missing this THD implementation.
481- monkeypatch .setenv ("NVTE_FUSED_ATTN" , "0" )
457+ def test_sanity_mfsdp_thd_fp8 (tmp_path , recipe_path ):
482458
483459 # For MFSDP, we only check that the script can run successfully with THD, not convergence.
484460 with initialize_config_dir (config_dir = str (recipe_path / "hydra_config" ), version_base = "1.2" ):
@@ -571,11 +547,7 @@ def test_sanity_convergence_fsdp2_huggingface_model(tmp_path, recipe_path):
571547 assert final_loss < 3.0 , f"Final loss { final_loss } is too high"
572548
573549
574- def test_sanity_ddp_thd_token_packing (tmp_path , monkeypatch , recipe_path ):
575- if torch .cuda .get_device_capability () == (12 , 0 ):
576- # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
577- # but it's missing this THD implementation.
578- monkeypatch .setenv ("NVTE_FUSED_ATTN" , "0" )
550+ def test_sanity_ddp_thd_token_packing (tmp_path , recipe_path ):
579551
580552 # For DDP, we only check that the script can run successfully with THD, not convergence.
581553 with initialize_config_dir (config_dir = str (recipe_path / "hydra_config" ), version_base = "1.2" ):
@@ -592,11 +564,7 @@ def test_sanity_ddp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
592564 main_ddp (sanity_config )
593565
594566
595- def test_sanity_mfsdp_thd_token_packing (tmp_path , monkeypatch , recipe_path ):
596- if torch .cuda .get_device_capability () == (12 , 0 ):
597- # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
598- # but it's missing this THD implementation.
599- monkeypatch .setenv ("NVTE_FUSED_ATTN" , "0" )
567+ def test_sanity_mfsdp_thd_token_packing (tmp_path , recipe_path ):
600568
601569 with initialize_config_dir (config_dir = str (recipe_path / "hydra_config" ), version_base = "1.2" ):
602570 sanity_config = compose (
@@ -612,11 +580,7 @@ def test_sanity_mfsdp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
612580 main_mfsdp (sanity_config )
613581
614582
615- def test_sanity_fsdp2_thd_token_packing (tmp_path , monkeypatch , recipe_path ):
616- if torch .cuda .get_device_capability () == (12 , 0 ):
617- # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
618- # but it's missing this THD implementation.
619- monkeypatch .setenv ("NVTE_FUSED_ATTN" , "0" )
583+ def test_sanity_fsdp2_thd_token_packing (tmp_path , recipe_path ):
620584
621585 with initialize_config_dir (config_dir = str (recipe_path / "hydra_config" ), version_base = "1.2" ):
622586 sanity_config = compose (
0 commit comments