Skip to content

Commit 41555d7

Browse files
authored
Fix missing binding version checks alongside driver version checks (#2054)
Two source locations and one test helper only checked the driver version when gating features that also require the corresponding cuda-bindings version. When bindings are older than the driver, the driver check passes but the binding attribute/symbol is missing, causing AttributeError or similar runtime failures. - graph/_subclasses.pyx: _check_node_get_params() now also checks binding_version() >= (13, 2, 0) - _module.pyx: _get_arguments_info() now also checks cy_binding_version() >= (12, 4, 0) - tests/graph/test_graph_definition.py: _driver_has_node_get_params() renamed to _has_node_get_params() and checks both versions Closes #2052
1 parent 9cc3420 commit 41555d7

3 files changed

Lines changed: 14 additions & 7 deletions

File tree

cuda_core/cuda/core/_module.pyx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ from cuda.core._utils.clear_error_support import (
3232
raise_code_path_meant_to_be_unreachable,
3333
)
3434
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
35-
from cuda.core._utils.version cimport cy_driver_version
35+
from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
3636
from cuda.core._utils.cuda_utils import driver
3737
from cuda.bindings cimport cydriver
3838

@@ -463,6 +463,11 @@ cdef class Kernel:
463463
"Driver version 12.4 or newer is required for this function. "
464464
f"Using driver version {'.'.join(map(str, cy_driver_version()))}"
465465
)
466+
if cy_binding_version() < (12, 4, 0):
467+
raise NotImplementedError(
468+
"cuda.bindings 12.4 or newer is required for this function. "
469+
f"Using binding version {'.'.join(map(str, cy_binding_version()))}"
470+
)
466471
cdef size_t arg_pos = 0
467472
cdef list param_info_data = []
468473
cdef cydriver.CUkernel cu_kernel = as_cu(self._h_kernel)

cuda_core/cuda/core/graph/_subclasses.pyx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ cdef bint _version_checked = False
6060
cdef bint _check_node_get_params():
6161
global _has_cuGraphNodeGetParams, _version_checked
6262
if not _version_checked:
63-
from cuda.core._utils.version import driver_version
64-
_has_cuGraphNodeGetParams = driver_version() >= (13, 2, 0)
63+
from cuda.core._utils.version import binding_version, driver_version
64+
_has_cuGraphNodeGetParams = (
65+
driver_version() >= (13, 2, 0) and binding_version() >= (13, 2, 0)
66+
)
6567
_version_checked = True
6668
return _has_cuGraphNodeGetParams
6769

cuda_core/tests/graph/test_graph_definition.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ def _skip_if_no_managed_mempool():
4848
pytest.skip("Device does not support managed memory pool operations")
4949

5050

51-
def _driver_has_node_get_params():
52-
from cuda.core._utils.version import driver_version
51+
def _has_node_get_params():
52+
from cuda.core._utils.version import binding_version, driver_version
5353

54-
return driver_version() >= (13, 2, 0)
54+
return driver_version() >= (13, 2, 0) and binding_version() >= (13, 2, 0)
5555

5656

57-
_HAS_NODE_GET_PARAMS = _driver_has_node_get_params()
57+
_HAS_NODE_GET_PARAMS = _has_node_get_params()
5858

5959

6060
def _bindings_major_version():

0 commit comments

Comments
 (0)