@@ -9,25 +9,10 @@ from cuda.core._stream cimport Stream
99from cuda.core._utils.cuda_utils import (
1010 driver,
1111 get_binding_version,
12+ get_driver_version,
1213 handle_return,
1314)
1415
15- _inited = False
16- _driver_ver = None
17-
18-
19- def _lazy_init ():
20- global _inited
21- if _inited:
22- return
23-
24- global _py_major_minor, _driver_ver
25- # binding availability depends on cuda-python version
26- _py_major_minor = get_binding_version()
27- _driver_ver = handle_return(driver.cuDriverGetVersion())
28- _inited = True
29-
30-
3116@dataclass
3217class GraphDebugPrintOptions :
3318 """ Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()`
@@ -179,7 +164,7 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) ->
179164 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
180165 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
181166 elif (
182- _py_major_minor >= (12, 8)
167+ get_binding_version() >= (12, 8)
183168 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
184169 ):
185170 raise RuntimeError (" One or more conditional handles are not associated with conditional builders." )
@@ -242,7 +227,6 @@ class GraphBuilder:
242227 @classmethod
243228 def _init (cls , stream , is_stream_owner , conditional_graph = None , is_join_required = False ):
244229 self = cls .__new__ (cls )
245- _lazy_init()
246230 self ._mnff = GraphBuilder._MembersNeededForFinalize(
247231 self , stream, is_stream_owner, conditional_graph, is_join_required
248232 )
@@ -398,7 +382,7 @@ class GraphBuilder:
398382 GraphBuilder._init(stream = stream, is_stream_owner = True , conditional_graph = None , is_join_required = True )
399383 )
400384 event.close()
401- return result
385+ return tuple ( result)
402386
403387 @staticmethod
404388 def join (*graph_builders ) -> GraphBuilder:
@@ -460,10 +444,10 @@ class GraphBuilder:
460444 The newly created conditional handle.
461445
462446 """
463- if _driver_ver < 12030:
464- raise RuntimeError(f"Driver version {_driver_ver } does not support conditional handles")
465- if _py_major_minor < (12, 3):
466- raise RuntimeError (f" Binding version {_py_major_minor } does not support conditional handles" )
447+ if get_driver_version() < 12030:
448+ raise RuntimeError(f"Driver version {get_driver_version() } does not support conditional handles")
449+ if get_binding_version() < (12, 3):
450+ raise RuntimeError (f" Binding version {get_binding_version() } does not support conditional handles" )
467451 if default_value is not None :
468452 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT
469453 else :
@@ -478,7 +462,7 @@ class GraphBuilder:
478462 driver.cuGraphConditionalHandleCreate(graph, self ._get_conditional_context(), default_value, flags)
479463 )
480464
481- def _cond_with_params (self , node_params ) -> GraphBuilder :
465+ def _cond_with_params (self , node_params ) -> tuple :
482466 # Get current capture info to ensure we're in a valid state
483467 status , _ , graph , *deps_info , num_dependencies = handle_return(
484468 driver.cuStreamGetCaptureInfo(self ._mnff.stream.handle)
@@ -533,10 +517,10 @@ class GraphBuilder:
533517 The newly created conditional graph builder.
534518
535519 """
536- if _driver_ver < 12030:
537- raise RuntimeError(f"Driver version {_driver_ver } does not support conditional if")
538- if _py_major_minor < (12, 3):
539- raise RuntimeError (f" Binding version {_py_major_minor } does not support conditional if" )
520+ if get_driver_version() < 12030:
521+ raise RuntimeError(f"Driver version {get_driver_version() } does not support conditional if")
522+ if get_binding_version() < (12, 3):
523+ raise RuntimeError (f" Binding version {get_binding_version() } does not support conditional if" )
540524 node_params = driver.CUgraphNodeParams()
541525 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
542526 node_params.conditional.handle = handle
@@ -564,10 +548,10 @@ class GraphBuilder:
564548 A tuple of two new graph builders , one for the if branch and one for the else branch.
565549
566550 """
567- if _driver_ver < 12080:
568- raise RuntimeError(f"Driver version {_driver_ver } does not support conditional if-else")
569- if _py_major_minor < (12, 8):
570- raise RuntimeError (f" Binding version {_py_major_minor } does not support conditional if-else" )
551+ if get_driver_version() < 12080:
552+ raise RuntimeError(f"Driver version {get_driver_version() } does not support conditional if-else")
553+ if get_binding_version() < (12, 8):
554+ raise RuntimeError (f" Binding version {get_binding_version() } does not support conditional if-else" )
571555 node_params = driver.CUgraphNodeParams()
572556 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
573557 node_params.conditional.handle = handle
@@ -598,10 +582,10 @@ class GraphBuilder:
598582 A tuple of new graph builders , one for each branch.
599583
600584 """
601- if _driver_ver < 12080:
602- raise RuntimeError(f"Driver version {_driver_ver } does not support conditional switch")
603- if _py_major_minor < (12, 8):
604- raise RuntimeError (f" Binding version {_py_major_minor } does not support conditional switch" )
585+ if get_driver_version() < 12080:
586+ raise RuntimeError(f"Driver version {get_driver_version() } does not support conditional switch")
587+ if get_binding_version() < (12, 8):
588+ raise RuntimeError (f" Binding version {get_binding_version() } does not support conditional switch" )
605589 node_params = driver.CUgraphNodeParams()
606590 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
607591 node_params.conditional.handle = handle
@@ -629,10 +613,10 @@ class GraphBuilder:
629613 The newly created while loop graph builder.
630614
631615 """
632- if _driver_ver < 12030:
633- raise RuntimeError(f"Driver version {_driver_ver } does not support conditional while loop")
634- if _py_major_minor < (12, 3):
635- raise RuntimeError (f" Binding version {_py_major_minor } does not support conditional while loop" )
616+ if get_driver_version() < 12030:
617+ raise RuntimeError(f"Driver version {get_driver_version() } does not support conditional while loop")
618+ if get_binding_version() < (12, 3):
619+ raise RuntimeError (f" Binding version {get_binding_version() } does not support conditional while loop" )
636620 node_params = driver.CUgraphNodeParams()
637621 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
638622 node_params.conditional.handle = handle
@@ -660,10 +644,10 @@ class GraphBuilder:
660644 child_graph : :obj:`~_graph.GraphBuilder`
661645 The child graph builder. Must have finished building.
662646 """
663- if (_driver_ver < 12000 ) or (_py_major_minor < (12 , 0 )):
647+ if (get_driver_version() < 12000 ) or (get_binding_version() < (12 , 0 )):
664648 raise NotImplementedError (
665649 f" Launching child graphs is not implemented for versions older than CUDA 12."
666- f" Found driver version is {_driver_ver } and binding version is {_py_major_minor }"
650+ f" Found driver version is {get_driver_version() } and binding version is {get_binding_version() }"
667651 )
668652
669653 if not child_graph._building_ended:
0 commit comments