Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cuda_bindings/tests/cufile.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
// e.g : export CUFILE_ENV_PATH_JSON="/home/<xxx>/cufile.json"


"properties" : {
"allow_compat_mode" : true
},

"execution" : {
// max number of workitems in the queue;
"max_io_queue_depth": 128,
Expand Down
35 changes: 32 additions & 3 deletions cuda_bindings/tests/test_cufile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,16 @@ def test_set_get_parameter_size_t():
(cufile.SizeTConfigParameter.EXECUTION_MAX_REQUEST_PARALLELISM, 4), # Max 4 parallel requests
)

# Snapshot baselines after driver_open so getters reflect merged config (defaults + JSON),
# not pre-open pending state that could restore invalid values (e.g. 0 for per-buffer cache).
cufile.driver_open()
try:
originals = {param: cufile.get_parameter_size_t(param) for param, _ in param_val_pairs}
finally:
cufile.driver_close()

Comment thread
rsarpangalav marked this conversation as resolved.
def test_param(param, val):
orig_val = cufile.get_parameter_size_t(param)
orig_val = originals[param]
cufile.set_parameter_size_t(param, val)
retrieved_val = cufile.get_parameter_size_t(param)
assert retrieved_val == val
Expand All @@ -1412,6 +1420,14 @@ def test_param(param, val):
@pytest.mark.usefixtures("ctx")
def test_set_get_parameter_bool():
"""Test setting and getting boolean parameters with cuFile validation."""
# Do not exercise allow/force compat via set_parameter_bool before any driver_open:
# pending API values are applied after JSON load on first open and can overwrite
# cufile.json (e.g. allow_compat_mode: true), causing DRIVER_NOT_INITIALIZED when
# nvidia-fs is not loaded. Other tests cover compat behavior where appropriate.
_COMPAT_PARAMS = (
cufile.BoolConfigParameter.PROPERTIES_ALLOW_COMPAT_MODE,
cufile.BoolConfigParameter.FORCE_COMPAT_MODE,
)
param_val_pairs = (
(cufile.BoolConfigParameter.PROPERTIES_USE_POLL_MODE, True),
(cufile.BoolConfigParameter.PROPERTIES_ALLOW_COMPAT_MODE, False),
Expand All @@ -1426,9 +1442,16 @@ def test_set_get_parameter_bool():
(cufile.BoolConfigParameter.SKIP_TOPOLOGY_DETECTION, False),
(cufile.BoolConfigParameter.STREAM_MEMOPS_BYPASS, True),
)
param_val_pairs = tuple((p, v) for p, v in param_val_pairs if p not in _COMPAT_PARAMS)

cufile.driver_open()
try:
originals = {param: cufile.get_parameter_bool(param) for param, _ in param_val_pairs}
finally:
cufile.driver_close()

def test_param(param, val):
orig_val = cufile.get_parameter_bool(param)
orig_val = originals[param]
cufile.set_parameter_bool(param, val)
retrieved_val = cufile.get_parameter_bool(param)
assert retrieved_val is val
Expand Down Expand Up @@ -1468,8 +1491,14 @@ def test_set_get_parameter_string(tmp_path):
), # Test log directory
)

cufile.driver_open()
try:
originals = {param: cufile.get_parameter_string(param, 256) for param, _, _ in param_val_pairs}
finally:
cufile.driver_close()

def test_param(param, val, default_val):
orig_val = cufile.get_parameter_string(param, 256)
orig_val = originals[param]

val_b = val.encode("utf-8")
val_buf = ctypes.create_string_buffer(val_b)
Expand Down