Skip to content

Commit 9ac34df

Browse files
committed
Remove _lazy_init from _graph_builder; add cached get_driver_version
Replace the per-module _lazy_init / _inited / _driver_ver / _py_major_minor pattern in _graph_builder.pyx with direct calls to centralized cached functions in cuda_utils: - Add get_driver_version() with @functools.cache alongside get_binding_version - Switch get_binding_version from @functools.lru_cache to @functools.cache (cleaner for nullary functions) - Fix split() to return tuple(result) — Cython enforces return type annotations unlike pure Python - Fix _cond_with_params annotation from -> GraphBuilder to -> tuple to match actual return value Made-with: Cursor
1 parent 631a74c commit 9ac34df

2 files changed

Lines changed: 31 additions & 43 deletions

File tree

cuda_core/cuda/core/_graph/_graph_builder.pyx

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,10 @@ from cuda.core._stream cimport Stream
99
from cuda.core._utils.cuda_utils import (
1010
driver,
1111
get_binding_version,
12+
get_driver_version,
1213
handle_return,
1314
)
1415

15-
_inited = False
16-
_driver_ver = None
17-
18-
19-
def _lazy_init():
20-
global _inited
21-
if _inited:
22-
return
23-
24-
global _py_major_minor, _driver_ver
25-
# binding availability depends on cuda-python version
26-
_py_major_minor = get_binding_version()
27-
_driver_ver = handle_return(driver.cuDriverGetVersion())
28-
_inited = True
29-
30-
3116
@dataclass
3217
class GraphDebugPrintOptions:
3318
"""Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()`
@@ -179,7 +164,7 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) ->
179164
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
180165
raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
181166
elif (
182-
_py_major_minor >= (12, 8)
167+
get_binding_version() >= (12, 8)
183168
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
184169
):
185170
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
@@ -242,7 +227,6 @@ class GraphBuilder:
242227
@classmethod
243228
def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False):
244229
self = cls.__new__(cls)
245-
_lazy_init()
246230
self._mnff = GraphBuilder._MembersNeededForFinalize(
247231
self, stream, is_stream_owner, conditional_graph, is_join_required
248232
)
@@ -398,7 +382,7 @@ class GraphBuilder:
398382
GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True)
399383
)
400384
event.close()
401-
return result
385+
return tuple(result)
402386

403387
@staticmethod
404388
def join(*graph_builders) -> GraphBuilder:
@@ -460,10 +444,10 @@ class GraphBuilder:
460444
The newly created conditional handle.
461445

462446
"""
463-
if _driver_ver < 12030:
464-
raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles")
465-
if _py_major_minor < (12, 3):
466-
raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles")
447+
if get_driver_version() < 12030:
448+
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional handles")
449+
if get_binding_version() < (12, 3):
450+
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional handles")
467451
if default_value is not None:
468452
flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT
469453
else:
@@ -478,7 +462,7 @@ class GraphBuilder:
478462
driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags)
479463
)
480464

481-
def _cond_with_params(self, node_params) -> GraphBuilder:
465+
def _cond_with_params(self, node_params) -> tuple:
482466
# Get current capture info to ensure we're in a valid state
483467
status, _, graph, *deps_info, num_dependencies = handle_return(
484468
driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)
@@ -533,10 +517,10 @@ class GraphBuilder:
533517
The newly created conditional graph builder.
534518

535519
"""
536-
if _driver_ver < 12030:
537-
raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if")
538-
if _py_major_minor < (12, 3):
539-
raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if")
520+
if get_driver_version() < 12030:
521+
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional if")
522+
if get_binding_version() < (12, 3):
523+
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional if")
540524
node_params = driver.CUgraphNodeParams()
541525
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
542526
node_params.conditional.handle = handle
@@ -564,10 +548,10 @@ class GraphBuilder:
564548
A tuple of two new graph builders, one for the if branch and one for the else branch.
565549

566550
"""
567-
if _driver_ver < 12080:
568-
raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else")
569-
if _py_major_minor < (12, 8):
570-
raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else")
551+
if get_driver_version() < 12080:
552+
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional if-else")
553+
if get_binding_version() < (12, 8):
554+
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional if-else")
571555
node_params = driver.CUgraphNodeParams()
572556
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
573557
node_params.conditional.handle = handle
@@ -598,10 +582,10 @@ class GraphBuilder:
598582
A tuple of new graph builders, one for each branch.
599583

600584
"""
601-
if _driver_ver < 12080:
602-
raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch")
603-
if _py_major_minor < (12, 8):
604-
raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch")
585+
if get_driver_version() < 12080:
586+
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional switch")
587+
if get_binding_version() < (12, 8):
588+
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional switch")
605589
node_params = driver.CUgraphNodeParams()
606590
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
607591
node_params.conditional.handle = handle
@@ -629,10 +613,10 @@ class GraphBuilder:
629613
The newly created while loop graph builder.
630614

631615
"""
632-
if _driver_ver < 12030:
633-
raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop")
634-
if _py_major_minor < (12, 3):
635-
raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop")
616+
if get_driver_version() < 12030:
617+
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional while loop")
618+
if get_binding_version() < (12, 3):
619+
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional while loop")
636620
node_params = driver.CUgraphNodeParams()
637621
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
638622
node_params.conditional.handle = handle
@@ -660,10 +644,10 @@ class GraphBuilder:
660644
child_graph : :obj:`~_graph.GraphBuilder`
661645
The child graph builder. Must have finished building.
662646
"""
663-
if (_driver_ver < 12000) or (_py_major_minor < (12, 0)):
647+
if (get_driver_version() < 12000) or (get_binding_version() < (12, 0)):
664648
raise NotImplementedError(
665649
f"Launching child graphs is not implemented for versions older than CUDA 12."
666-
f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}"
650+
f"Found driver version is {get_driver_version()} and binding version is {get_binding_version()}"
667651
)
668652

669653
if not child_graph._building_ended:

cuda_core/cuda/core/_utils/cuda_utils.pyx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,18 @@ def is_nested_sequence(obj):
298298
return is_sequence(obj) and any(is_sequence(elem) for elem in obj)
299299

300300

301-
@functools.lru_cache
301+
@functools.cache
302302
def get_binding_version():
303303
try:
304304
major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2]
305305
except importlib.metadata.PackageNotFoundError:
306306
major_minor = importlib.metadata.version("cuda-python").split(".")[:2]
307307
return tuple(int(v) for v in major_minor)
308308

309+
@functools.cache
310+
def get_driver_version():
311+
return handle_return(driver.cuDriverGetVersion())
312+
309313

310314
class Transaction:
311315
"""

0 commit comments

Comments
 (0)