Skip to content

Commit 409a24a

Browse files
committed
only xfail thd tests if cuda arch is unsupported
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent e213d40 commit 409a24a

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

recipes/esm2_native_te_nvfsdp_thd/test_train.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616
from pathlib import Path
1717

1818
import pytest
19+
import torch
1920
from hydra import compose, initialize_config_dir
2021

2122
from train import main
2223

2324

24-
@pytest.mark.xfail(reason="CUDNN padded packed sequences not supported on all hardware currently.")
25+
@pytest.mark.xfail(
26+
torch.cuda.get_device_capability() != (10, 0),
27+
reason="CUDNN padded packed sequences not supported on all hardware currently (nvbugs/5458694).",
28+
)
2529
def test_main_invocation(monkeypatch, tmp_path):
2630
"""Test that the main function can be invoked with the correct arguments."""
2731

@@ -43,7 +47,10 @@ def test_main_invocation(monkeypatch, tmp_path):
4347
main(sanity_config)
4448

4549

46-
@pytest.mark.xfail(reason="CUDNN padded packed sequences not supported on all hardware currently.")
50+
@pytest.mark.xfail(
51+
torch.cuda.get_device_capability() != (10, 0),
52+
reason="CUDNN padded packed sequences not supported on all hardware currently (nvbugs/5458694).",
53+
)
4754
def test_main_invocation_ddp(monkeypatch, tmp_path):
4855
"""Test that the main function can be invoked wrapping the model in DDP."""
4956

0 commit comments

Comments
 (0)