@@ -11,10 +11,10 @@ from cuda.core._graph._utils cimport _attach_host_callback_to_graph
1111from cuda.core._resource_handles cimport as_cu
1212from cuda.core._stream cimport Stream
1313from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
14+ from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
15+
1416from cuda.core._utils.cuda_utils import (
1517 driver,
16- get_binding_version,
17- get_driver_version,
1818 handle_return,
1919)
2020
@@ -169,7 +169,7 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) ->
169169 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
170170 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
171171 elif (
172- get_binding_version () >= (12, 8)
172+ cy_binding_version () >= (12, 8, 0 )
173173 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
174174 ):
175175 raise RuntimeError (" One or more conditional handles are not associated with conditional builders." )
@@ -449,10 +449,10 @@ class GraphBuilder:
449449 The newly created conditional handle.
450450
451451 """
452- if get_driver_version () < 12030 :
453- raise RuntimeError(f"Driver version {get_driver_version( )} does not support conditional handles")
454- if get_binding_version () < (12, 3):
455- raise RuntimeError (f" Binding version {get_binding_version( )} does not support conditional handles" )
452+ if cy_driver_version () < (12, 3, 0) :
453+ raise RuntimeError (f" Driver version {'.'.join(map(str, cy_driver_version()) )} does not support conditional handles" )
454+ if cy_binding_version () < (12 , 3 , 0 ):
455+ raise RuntimeError (f" Binding version {'.'.join(map(str, cy_binding_version()) )} does not support conditional handles" )
456456 if default_value is not None :
457457 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT
458458 else :
@@ -522,10 +522,10 @@ class GraphBuilder:
522522 The newly created conditional graph builder.
523523
524524 """
525- if get_driver_version () < 12030 :
526- raise RuntimeError(f"Driver version {get_driver_version( )} does not support conditional if")
527- if get_binding_version () < (12, 3):
528- raise RuntimeError (f" Binding version {get_binding_version( )} does not support conditional if" )
525+ if cy_driver_version () < (12, 3, 0) :
526+ raise RuntimeError (f" Driver version {'.'.join(map(str, cy_driver_version()) )} does not support conditional if" )
527+ if cy_binding_version () < (12 , 3 , 0 ):
528+ raise RuntimeError (f" Binding version {'.'.join(map(str, cy_binding_version()) )} does not support conditional if" )
529529 node_params = driver.CUgraphNodeParams()
530530 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
531531 node_params.conditional.handle = handle
@@ -553,10 +553,10 @@ class GraphBuilder:
553553 A tuple of two new graph builders , one for the if branch and one for the else branch.
554554
555555 """
556- if get_driver_version () < 12080 :
557- raise RuntimeError(f"Driver version {get_driver_version( )} does not support conditional if-else")
558- if get_binding_version () < (12, 8):
559- raise RuntimeError (f" Binding version {get_binding_version( )} does not support conditional if-else" )
556+ if cy_driver_version () < (12, 8, 0) :
557+ raise RuntimeError (f" Driver version {'.'.join(map(str, cy_driver_version()) )} does not support conditional if-else" )
558+ if cy_binding_version () < (12 , 8 , 0 ):
559+ raise RuntimeError (f" Binding version {'.'.join(map(str, cy_binding_version()) )} does not support conditional if-else" )
560560 node_params = driver.CUgraphNodeParams()
561561 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
562562 node_params.conditional.handle = handle
@@ -587,10 +587,10 @@ class GraphBuilder:
587587 A tuple of new graph builders , one for each branch.
588588
589589 """
590- if get_driver_version () < 12080 :
591- raise RuntimeError(f"Driver version {get_driver_version( )} does not support conditional switch")
592- if get_binding_version () < (12, 8):
593- raise RuntimeError (f" Binding version {get_binding_version( )} does not support conditional switch" )
590+ if cy_driver_version () < (12, 8, 0) :
591+ raise RuntimeError (f" Driver version {'.'.join(map(str, cy_driver_version()) )} does not support conditional switch" )
592+ if cy_binding_version () < (12 , 8 , 0 ):
593+ raise RuntimeError (f" Binding version {'.'.join(map(str, cy_binding_version()) )} does not support conditional switch" )
594594 node_params = driver.CUgraphNodeParams()
595595 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
596596 node_params.conditional.handle = handle
@@ -618,10 +618,10 @@ class GraphBuilder:
618618 The newly created while loop graph builder.
619619
620620 """
621- if get_driver_version () < 12030 :
622- raise RuntimeError(f"Driver version {get_driver_version( )} does not support conditional while loop")
623- if get_binding_version () < (12, 3):
624- raise RuntimeError (f" Binding version {get_binding_version( )} does not support conditional while loop" )
621+ if cy_driver_version () < (12, 3, 0) :
622+ raise RuntimeError (f" Driver version {'.'.join(map(str, cy_driver_version()) )} does not support conditional while loop" )
623+ if cy_binding_version () < (12 , 3 , 0 ):
624+ raise RuntimeError (f" Binding version {'.'.join(map(str, cy_binding_version()) )} does not support conditional while loop" )
625625 node_params = driver.CUgraphNodeParams()
626626 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
627627 node_params.conditional.handle = handle
@@ -649,12 +649,6 @@ class GraphBuilder:
649649 child_graph : :obj:`~_graph.GraphBuilder`
650650 The child graph builder. Must have finished building.
651651 """
652- if (get_driver_version() < 12000 ) or (get_binding_version() < (12 , 0 )):
653- raise NotImplementedError (
654- f" Launching child graphs is not implemented for versions older than CUDA 12."
655- f" Found driver version is {get_driver_version()} and binding version is {get_binding_version()}"
656- )
657-
658652 if not child_graph._building_ended:
659653 raise ValueError (" Child graph has not finished building." )
660654
0 commit comments