Skip to content

Commit c89283c

Browse files
authored
chore: clean up cufile fixtures (NVIDIA#1313)
* chore: clean up cufile fixtures * chore: remove somewhat useless `raw_driver`
1 parent 1b72df0 commit c89283c

1 file changed

Lines changed: 51 additions & 88 deletions

File tree

cuda_bindings/tests/test_cufile.py

Lines changed: 51 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -24,46 +24,18 @@
2424
force=True, # Override any existing logging configuration
2525
)
2626

27-
28-
def platform_is_tegra_linux():
29-
return pathlib.Path("/etc/nv_tegra_release").exists()
30-
31-
32-
if platform_is_tegra_linux():
33-
pytest.skip("skipping cuFile tests on Tegra Linux", allow_module_level=True)
34-
35-
36-
def platform_is_wsl():
37-
"""Check if running on Windows Subsystem for Linux (WSL)."""
38-
return platform.system() == "Linux" and "microsoft" in pathlib.Path("/proc/version").read_text().lower()
39-
40-
41-
if platform_is_wsl():
42-
pytest.skip("skipping cuFile tests on WSL", allow_module_level=True)
43-
44-
45-
from cuda.bindings.cufile import cuFileError
27+
cufile = pytest.importorskip("cuda.bindings.cufile", reason="skipping tests on Windows")
4628

4729

4830
@pytest.fixture
49-
def cufile_env_json():
31+
def cufile_env_json(monkeypatch):
5032
"""Set CUFILE_ENV_PATH_JSON environment variable for async tests."""
51-
original_value = os.environ.get("CUFILE_ENV_PATH_JSON")
52-
5333
# Get absolute path to cufile.json in the same directory as this test file
5434
test_dir = os.path.dirname(os.path.abspath(__file__))
5535
config_path = os.path.join(test_dir, "cufile.json")
56-
logging.info(f"Using cuFile config: {config_path}")
5736
assert os.path.isfile(config_path)
58-
os.environ["CUFILE_ENV_PATH_JSON"] = config_path
59-
60-
yield
61-
62-
# Restore original value or remove if it wasn't set
63-
if original_value is not None:
64-
os.environ["CUFILE_ENV_PATH_JSON"] = original_value
65-
else:
66-
del os.environ["CUFILE_ENV_PATH_JSON"]
37+
monkeypatch.setenv("CUFILE_ENV_PATH_JSON", config_path)
38+
logging.info(f"Using cuFile config: {config_path}")
6739

6840

6941
@cache
@@ -108,11 +80,18 @@ def isSupportedFilesystem():
10880

10981

11082
# Global skip condition for all tests if cuFile library is not available
111-
pytestmark = pytest.mark.skipif(not cufileLibraryAvailable(), reason="cuFile library not available on this system")
83+
pytestmark = [
84+
pytest.mark.skipif(not cufileLibraryAvailable(), reason="cuFile library not available on this system"),
85+
pytest.mark.skipif(
86+
platform.system() == "Linux" and "microsoft" in pathlib.Path("/proc/version").read_text().lower(),
87+
reason="skipping cuFile tests on WSL",
88+
),
89+
pytest.mark.skipif(pathlib.Path("/etc/nv_tegra_release").exists(), reason="skipping cuFile tests on Tegra Linux"),
90+
]
11291

11392
xfail_handle_register = pytest.mark.xfail(
11493
condition=isSupportedFilesystem() and os.environ.get("CI") is not None,
115-
raises=cuFileError,
94+
raises=cufile.cuFileError,
11695
reason="handle_register call fails in CI for unknown reasons",
11796
)
11897

@@ -615,13 +594,10 @@ def test_cufile_read_write_large():
615594

616595

617596
@pytest.mark.skipif(not isSupportedFilesystem(), reason="cuFile handle_register requires ext4 or xfs filesystem")
618-
@pytest.mark.usefixtures("ctx")
597+
@pytest.mark.usefixtures("ctx", "cufile_env_json", "driver")
619598
@xfail_handle_register
620-
def test_cufile_write_async(cufile_env_json):
599+
def test_cufile_write_async():
621600
"""Test cuFile asynchronous write operations."""
622-
# Open cuFile driver
623-
cufile.driver_open()
624-
625601
# Create test file
626602
file_path = "test_cufile_write_async.bin"
627603
fd = os.open(file_path, os.O_CREAT | os.O_RDWR | os.O_DIRECT, 0o600)
@@ -693,17 +669,13 @@ def test_cufile_write_async(cufile_env_json):
693669
os.close(fd)
694670
with suppress(OSError):
695671
os.unlink(file_path)
696-
cufile.driver_close()
697672

698673

699674
@pytest.mark.skipif(not isSupportedFilesystem(), reason="cuFile handle_register requires ext4 or xfs filesystem")
700-
@pytest.mark.usefixtures("ctx")
675+
@pytest.mark.usefixtures("ctx", "cufile_env_json", "driver")
701676
@xfail_handle_register
702-
def test_cufile_read_async(cufile_env_json):
677+
def test_cufile_read_async():
703678
"""Test cuFile asynchronous read operations."""
704-
# Open cuFile driver
705-
cufile.driver_open()
706-
707679
# Create test file
708680
file_path = "test_cufile_read_async.bin"
709681

@@ -788,17 +760,13 @@ def test_cufile_read_async(cufile_env_json):
788760
os.close(fd)
789761
with suppress(OSError):
790762
os.unlink(file_path)
791-
cufile.driver_close()
792763

793764

794765
@pytest.mark.skipif(not isSupportedFilesystem(), reason="cuFile handle_register requires ext4 or xfs filesystem")
795-
@pytest.mark.usefixtures("ctx")
796766
@xfail_handle_register
797-
def test_cufile_async_read_write(cufile_env_json):
767+
@pytest.mark.usefixtures("ctx", "cufile_env_json", "driver")
768+
def test_cufile_async_read_write():
798769
"""Test cuFile asynchronous read and write operations in sequence."""
799-
# Open cuFile driver
800-
cufile.driver_open()
801-
802770
# Create test file
803771
file_path = "test_cufile_async_rw.bin"
804772
fd = os.open(file_path, os.O_CREAT | os.O_RDWR | os.O_DIRECT, 0o600)
@@ -906,7 +874,6 @@ def test_cufile_async_read_write(cufile_env_json):
906874
os.close(fd)
907875
with suppress(OSError):
908876
os.unlink(file_path)
909-
cufile.driver_close()
910877

911878

912879
@pytest.mark.skipif(not isSupportedFilesystem(), reason="cuFile handle_register requires ext4 or xfs filesystem")
@@ -1864,30 +1831,30 @@ def test_get_bar_size_in_kb():
18641831
logging.info(f"GPU BAR size: {bar_size_kb} KB ({bar_size_kb / 1024 / 1024:.2f} GB)")
18651832

18661833

1867-
@pytest.mark.skipif(
1868-
cufileVersionLessThan(1150), reason="cuFile parameter APIs require cuFile library version 13.0 or later"
1869-
)
1870-
@pytest.mark.usefixtures("ctx")
1871-
def test_set_parameter_posix_pool_slab_array():
1872-
"""Test cuFile POSIX pool slab array configuration."""
1873-
# Define slab sizes for POSIX I/O pool (common I/O buffer sizes) - BEFORE driver open
1874-
import ctypes
1875-
1876-
slab_sizes = [
1834+
@pytest.fixture(scope="module")
1835+
def slab_sizes():
1836+
"""Define slab sizes for POSIX I/O pool (common I/O buffer sizes) - BEFORE driver open"""
1837+
return [
18771838
4096, # 4KB - small files
18781839
65536, # 64KB - medium files
18791840
1048576, # 1MB - large files
18801841
16777216, # 16MB - very large files
18811842
]
18821843

1883-
# Define counts for each slab size (number of buffers)
1884-
slab_counts = [
1844+
1845+
@pytest.fixture(scope="module")
1846+
def slab_counts():
1847+
"""Define counts for each slab size (number of buffers)"""
1848+
return [
18851849
10, # 10 buffers of 4KB
18861850
5, # 5 buffers of 64KB
18871851
3, # 3 buffers of 1MB
18881852
2, # 2 buffers of 16MB
18891853
]
18901854

1855+
1856+
@pytest.fixture
1857+
def driver_config(slab_sizes, slab_counts):
18911858
# Convert to ctypes arrays
18921859
size_array_type = ctypes.c_size_t * len(slab_sizes)
18931860
count_array_type = ctypes.c_size_t * len(slab_counts)
@@ -1899,32 +1866,28 @@ def test_set_parameter_posix_pool_slab_array():
18991866
ctypes.addressof(size_array), ctypes.addressof(count_array), len(slab_sizes)
19001867
)
19011868

1902-
# Open cuFile driver AFTER setting parameters
1903-
cufile.driver_open()
1904-
1905-
try:
1906-
# After setting parameters, retrieve them back to verify
1907-
retrieved_sizes = (ctypes.c_size_t * len(slab_sizes))()
1908-
retrieved_counts = (ctypes.c_size_t * len(slab_counts))()
1909-
1910-
cufile.get_parameter_posix_pool_slab_array(
1911-
ctypes.addressof(retrieved_sizes), ctypes.addressof(retrieved_counts), len(slab_sizes)
1912-
)
19131869

1914-
# Verify they match what we set
1915-
for i in range(len(slab_sizes)):
1916-
assert retrieved_sizes[i] == slab_sizes[i], (
1917-
f"Size mismatch at index {i}: expected {slab_sizes[i]}, got {retrieved_sizes[i]}"
1918-
)
1919-
assert retrieved_counts[i] == slab_counts[i], (
1920-
f"Count mismatch at index {i}: expected {slab_counts[i]}, got {retrieved_counts[i]}"
1921-
)
1870+
@pytest.mark.skipif(
1871+
cufileVersionLessThan(1150), reason="cuFile parameter APIs require cuFile library version 13.0 or later"
1872+
)
1873+
@pytest.mark.usefixtures("ctx")
1874+
def test_set_parameter_posix_pool_slab_array(slab_sizes, slab_counts, driver_config):
1875+
"""Test cuFile POSIX pool slab array configuration."""
1876+
# After setting parameters, retrieve them back to verify
1877+
n_slab_sizes = len(slab_sizes)
1878+
retrieved_sizes = (ctypes.c_size_t * n_slab_sizes)()
1879+
retrieved_counts = (ctypes.c_size_t * len(slab_counts))()
19221880

1923-
# Verify configuration was accepted successfully
1924-
logging.info(f"POSIX pool slab array configured with {len(slab_sizes)} slab sizes")
1925-
logging.info(f"Slab sizes: {[f'{size // 1024}KB' for size in slab_sizes]}")
1926-
logging.info("Round-trip verification successful: set and retrieved values match")
1881+
retrieved_sizes_addr = ctypes.addressof(retrieved_sizes)
1882+
retrieved_counts_addr = ctypes.addressof(retrieved_counts)
19271883

1884+
# Open cuFile driver AFTER setting parameters
1885+
cufile.driver_open()
1886+
try:
1887+
cufile.get_parameter_posix_pool_slab_array(retrieved_sizes_addr, retrieved_counts_addr, n_slab_sizes)
19281888
finally:
1929-
# Close cuFile driver
19301889
cufile.driver_close()
1890+
1891+
# Verify they match what we set
1892+
assert list(retrieved_sizes) == slab_sizes
1893+
assert list(retrieved_counts) == slab_counts

0 commit comments

Comments
 (0)