Skip to content

Commit 1b513d5

Browse files
committed
cufile tests: snapshot parameter baselines after driver_open
Open driver once to read size_t/bool/string originals, then close before set/get/restore round-trips so pending does not restore invalid pre-open values (e.g. per-buffer cache 0). Aligns with review feedback.
1 parent 39c9cad commit 1b513d5

1 file changed

Lines changed: 23 additions & 3 deletions

File tree

cuda_bindings/tests/test_cufile.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,8 +1394,16 @@ def test_set_get_parameter_size_t():
13941394
(cufile.SizeTConfigParameter.EXECUTION_MAX_REQUEST_PARALLELISM, 4), # Max 4 parallel requests
13951395
)
13961396

1397+
# Snapshot baselines after driver_open so getters reflect merged config (defaults + JSON),
1398+
# not pre-open pending state that could restore invalid values (e.g. 0 for per-buffer cache).
1399+
cufile.driver_open()
1400+
try:
1401+
originals = {param: cufile.get_parameter_size_t(param) for param, _ in param_val_pairs}
1402+
finally:
1403+
cufile.driver_close()
1404+
13971405
def test_param(param, val):
1398-
orig_val = cufile.get_parameter_size_t(param)
1406+
orig_val = originals[param]
13991407
cufile.set_parameter_size_t(param, val)
14001408
retrieved_val = cufile.get_parameter_size_t(param)
14011409
assert retrieved_val == val
@@ -1436,8 +1444,14 @@ def test_set_get_parameter_bool():
14361444
)
14371445
param_val_pairs = tuple((p, v) for p, v in param_val_pairs if p not in _COMPAT_PARAMS)
14381446

1447+
cufile.driver_open()
1448+
try:
1449+
originals = {param: cufile.get_parameter_bool(param) for param, _ in param_val_pairs}
1450+
finally:
1451+
cufile.driver_close()
1452+
14391453
def test_param(param, val):
1440-
orig_val = cufile.get_parameter_bool(param)
1454+
orig_val = originals[param]
14411455
cufile.set_parameter_bool(param, val)
14421456
retrieved_val = cufile.get_parameter_bool(param)
14431457
assert retrieved_val is val
@@ -1477,8 +1491,14 @@ def test_set_get_parameter_string(tmp_path):
14771491
), # Test log directory
14781492
)
14791493

1494+
cufile.driver_open()
1495+
try:
1496+
originals = {param: cufile.get_parameter_string(param, 256) for param, _, _ in param_val_pairs}
1497+
finally:
1498+
cufile.driver_close()
1499+
14801500
def test_param(param, val, default_val):
1481-
orig_val = cufile.get_parameter_string(param, 256)
1501+
orig_val = originals[param]
14821502

14831503
val_b = val.encode("utf-8")
14841504
val_buf = ctypes.create_string_buffer(val_b)

0 commit comments

Comments
 (0)