Skip to content

Commit 040eed7

Browse files
committed
Remove BIONEMO-2840 sm120 fused attention workarounds
The THD implementation for fused attention on sm120 (Blackwell) is now available in Transformer Engine, so these workarounds are no longer needed. Removes: - pytest.xfail guards for sm120 in test_modeling_common.py (6 files) - monkeypatch.setenv("NVTE_FUSED_ATTN", "0") blocks in esm2 recipe tests - Unused monkeypatch parameters and torch import where applicable
1 parent f0d4bfd commit 040eed7

8 files changed

Lines changed: 10 additions & 64 deletions

File tree

bionemo-recipes/models/codonfm/tests/common/test_modeling_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,8 +724,6 @@ def test_golden_values_thd(self, te_attn_backend):
724724

725725
if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
726726
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
727-
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
728-
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")
729727

730728
input_data_bshd = self.get_test_input_data(format="bshd")
731729
input_data_thd = self.get_test_input_data(format="thd")

bionemo-recipes/models/codonfm/tests/test_modeling_codonfm_te.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,6 @@ def test_golden_values_thd(self, te_attn_backend):
328328

329329
if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
330330
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
331-
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
332-
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")
333331

334332
golden_dir = Path(__file__).parent
335333
golden_sd_path = golden_dir / "golden_state_dict.safetensors"

bionemo-recipes/models/esm2/tests/common/test_modeling_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,8 +718,6 @@ def test_golden_values_thd(self, te_attn_backend):
718718

719719
if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
720720
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
721-
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
722-
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")
723721

724722
input_data_bshd = self.get_test_input_data(format="bshd")
725723
input_data_thd = self.get_test_input_data(format="thd")

bionemo-recipes/models/llama3/tests/common/test_modeling_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,8 +724,6 @@ def test_golden_values_thd(self, te_attn_backend):
724724

725725
if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
726726
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
727-
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
728-
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")
729727

730728
input_data_bshd = self.get_test_input_data(format="bshd")
731729
input_data_thd = self.get_test_input_data(format="thd")

bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,8 +724,6 @@ def test_golden_values_thd(self, te_attn_backend):
724724

725725
if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
726726
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
727-
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
728-
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")
729727

730728
input_data_bshd = self.get_test_input_data(format="bshd")
731729
input_data_thd = self.get_test_input_data(format="thd")

bionemo-recipes/models/qwen/tests/common/test_modeling_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,8 +724,6 @@ def test_golden_values_thd(self, te_attn_backend):
724724

725725
if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
726726
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
727-
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
728-
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")
729727

730728
input_data_bshd = self.get_test_input_data(format="bshd")
731729
input_data_thd = self.get_test_input_data(format="thd")

bionemo-recipes/recipes/esm2_native_te/tests/test_train.py

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,8 @@ def test_sanity_convergence_fsdp2_fp8_and_model_init(tmp_path, recipe_path):
364364
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
365365

366366

367-
def test_sanity_convergence_fsdp2_thd(tmp_path, monkeypatch, recipe_path):
367+
def test_sanity_convergence_fsdp2_thd(tmp_path, recipe_path):
368368
"""For FSDP2, we check that the script can run successfully with FP8 and check convergence."""
369-
if torch.cuda.get_device_capability() == (12, 0):
370-
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
371-
# but it's missing this THD implementation.
372-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
373369

374370
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
375371
sanity_config = compose(
@@ -386,12 +382,8 @@ def test_sanity_convergence_fsdp2_thd(tmp_path, monkeypatch, recipe_path):
386382

387383

388384
@requires_fp8
389-
def test_sanity_convergence_fsdp2_thd_fp8(tmp_path, monkeypatch, recipe_path):
385+
def test_sanity_convergence_fsdp2_thd_fp8(tmp_path, recipe_path):
390386
"""For FSDP2, we check that the script can run successfully with THD + FP8 and check convergence."""
391-
if torch.cuda.get_device_capability() == (12, 0):
392-
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
393-
# but it's missing this THD implementation.
394-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
395387

396388
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
397389
sanity_config = compose(
@@ -408,11 +400,7 @@ def test_sanity_convergence_fsdp2_thd_fp8(tmp_path, monkeypatch, recipe_path):
408400
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
409401

410402

411-
def test_sanity_ddp_thd(tmp_path, monkeypatch, recipe_path):
412-
if torch.cuda.get_device_capability() == (12, 0):
413-
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
414-
# but it's missing this THD implementation.
415-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
403+
def test_sanity_ddp_thd(tmp_path, recipe_path):
416404

417405
# For DDP, we only check that the script can run successfully with THD, not convergence.
418406
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
@@ -429,11 +417,7 @@ def test_sanity_ddp_thd(tmp_path, monkeypatch, recipe_path):
429417
main_ddp(sanity_config)
430418

431419

432-
def test_sanity_mfsdp_thd(tmp_path, monkeypatch, recipe_path):
433-
if torch.cuda.get_device_capability() == (12, 0):
434-
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
435-
# but it's missing this THD implementation.
436-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
420+
def test_sanity_mfsdp_thd(tmp_path, recipe_path):
437421

438422
# For MFSDP, we only check that the script can run successfully with THD, not convergence.
439423
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
@@ -451,11 +435,7 @@ def test_sanity_mfsdp_thd(tmp_path, monkeypatch, recipe_path):
451435

452436

453437
@requires_fp8
454-
def test_sanity_ddp_thd_fp8(tmp_path, monkeypatch, recipe_path):
455-
if torch.cuda.get_device_capability() == (12, 0):
456-
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
457-
# but it's missing this THD implementation.
458-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
438+
def test_sanity_ddp_thd_fp8(tmp_path, recipe_path):
459439

460440
# For DDP, we only check that the script can run successfully with THD, not convergence.
461441
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
@@ -474,11 +454,7 @@ def test_sanity_ddp_thd_fp8(tmp_path, monkeypatch, recipe_path):
474454

475455

476456
@requires_fp8
477-
def test_sanity_mfsdp_thd_fp8(tmp_path, monkeypatch, recipe_path):
478-
if torch.cuda.get_device_capability() == (12, 0):
479-
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
480-
# but it's missing this THD implementation.
481-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
457+
def test_sanity_mfsdp_thd_fp8(tmp_path, recipe_path):
482458

483459
# For MFSDP, we only check that the script can run successfully with THD, not convergence.
484460
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
@@ -571,11 +547,7 @@ def test_sanity_convergence_fsdp2_huggingface_model(tmp_path, recipe_path):
571547
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
572548

573549

574-
def test_sanity_ddp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
575-
if torch.cuda.get_device_capability() == (12, 0):
576-
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
577-
# but it's missing this THD implementation.
578-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
550+
def test_sanity_ddp_thd_token_packing(tmp_path, recipe_path):
579551

580552
# For DDP, we only check that the script can run successfully with THD, not convergence.
581553
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
@@ -592,11 +564,7 @@ def test_sanity_ddp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
592564
main_ddp(sanity_config)
593565

594566

595-
def test_sanity_mfsdp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
596-
if torch.cuda.get_device_capability() == (12, 0):
597-
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
598-
# but it's missing this THD implementation.
599-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
567+
def test_sanity_mfsdp_thd_token_packing(tmp_path, recipe_path):
600568

601569
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
602570
sanity_config = compose(
@@ -612,11 +580,7 @@ def test_sanity_mfsdp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
612580
main_mfsdp(sanity_config)
613581

614582

615-
def test_sanity_fsdp2_thd_token_packing(tmp_path, monkeypatch, recipe_path):
616-
if torch.cuda.get_device_capability() == (12, 0):
617-
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
618-
# but it's missing this THD implementation.
619-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
583+
def test_sanity_fsdp2_thd_token_packing(tmp_path, recipe_path):
620584

621585
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
622586
sanity_config = compose(

bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import torch
1716
from hydra import compose, initialize_config_dir
1817

1918
from train_lora_ddp import main as main_ddp
@@ -54,12 +53,7 @@ def test_sanity_convergence_ddp_non_streaming_dataset(tmp_path, recipe_path):
5453
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
5554

5655

57-
def test_sanity_ddp_thd(tmp_path, monkeypatch, recipe_path):
58-
if torch.cuda.get_device_capability() == (12, 0):
59-
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
60-
# but it's missing this THD implementation.
61-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
62-
56+
def test_sanity_ddp_thd(tmp_path, recipe_path):
6357
# For DDP, we only check that the script can run successfully with THD, not convergence.
6458
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
6559
sanity_config = compose(

0 commit comments

Comments
 (0)