22#
33# SPDX-License-Identifier: Apache-2.0
44
5- from dataclasses import dataclass
6- from typing import Optional , Union
75
8- from cuda .core .experimental ._device import Device
96from cuda .core .experimental ._kernel_arg_handler import ParamHolder
7+ from cuda .core .experimental ._launch_config import LaunchConfig , _to_native_launch_config
108from cuda .core .experimental ._module import Kernel
119from cuda .core .experimental ._stream import Stream
1210from cuda .core .experimental ._utils .clear_error_support import assert_type
1311from cuda .core .experimental ._utils .cuda_utils import (
14- CUDAError ,
15- cast_to_3_tuple ,
1612 check_or_create_options ,
1713 driver ,
1814 get_binding_version ,
@@ -37,54 +33,6 @@ def _lazy_init():
3733 _inited = True
3834
3935
40- @dataclass
41- class LaunchConfig :
42- """Customizable launch options.
43-
44- Attributes
45- ----------
46- grid : Union[tuple, int]
47- Collection of threads that will execute a kernel function.
48- cluster : Union[tuple, int]
49- Group of blocks (Thread Block Cluster) that will execute on the same
50- GPU Processing Cluster (GPC). Blocks within a cluster have access to
51- distributed shared memory and can be explicitly synchronized.
52- block : Union[tuple, int]
53- Group of threads (Thread Block) that will execute on the same
54- streaming multiprocessor (SM). Threads within a thread blocks have
55- access to shared memory and can be explicitly synchronized.
56- shmem_size : int, optional
57- Dynamic shared-memory size per thread block in bytes.
58- (Default to size 0)
59-
60- """
61-
62- # TODO: expand LaunchConfig to include other attributes
63- grid : Union [tuple , int ] = None
64- cluster : Union [tuple , int ] = None
65- block : Union [tuple , int ] = None
66- shmem_size : Optional [int ] = None
67-
68- def __post_init__ (self ):
69- _lazy_init ()
70- self .grid = cast_to_3_tuple ("LaunchConfig.grid" , self .grid )
71- self .block = cast_to_3_tuple ("LaunchConfig.block" , self .block )
72- # thread block clusters are supported starting H100
73- if self .cluster is not None :
74- if not _use_ex :
75- err , drvers = driver .cuDriverGetVersion ()
76- drvers_fmt = f" (got driver version { drvers } )" if err == driver .CUresult .CUDA_SUCCESS else ""
77- raise CUDAError (f"thread block clusters require cuda.bindings & driver 11.8+{ drvers_fmt } " )
78- cc = Device ().compute_capability
79- if cc < (9 , 0 ):
80- raise CUDAError (
81- f"thread block clusters are not supported on devices with compute capability < 9.0 (got { cc } )"
82- )
83- self .cluster = cast_to_3_tuple ("LaunchConfig.cluster" , self .cluster )
84- if self .shmem_size is None :
85- self .shmem_size = 0
86-
87-
8836def launch (stream , config , kernel , * kernel_args ):
8937 """Launches a :obj:`~_module.Kernel`
9038 object with launch-time configuration.
@@ -114,6 +62,7 @@ def launch(stream, config, kernel, *kernel_args):
11462 f"stream must either be a Stream object or support __cuda_stream__ (got { type (stream )} )"
11563 ) from e
11664 assert_type (kernel , Kernel )
65+ _lazy_init ()
11766 config = check_or_create_options (LaunchConfig , config , "launch config" )
11867
11968 # TODO: can we ensure kernel_args is valid/safe to use here?
@@ -127,25 +76,13 @@ def launch(stream, config, kernel, *kernel_args):
12776 # mainly to see if the "Ex" API is available and if so we use it, as it's more feature
12877 # rich.
12978 if _use_ex :
130- drv_cfg = driver .CUlaunchConfig ()
131- drv_cfg .gridDimX , drv_cfg .gridDimY , drv_cfg .gridDimZ = config .grid
132- drv_cfg .blockDimX , drv_cfg .blockDimY , drv_cfg .blockDimZ = config .block
79+ drv_cfg = _to_native_launch_config (config )
13380 drv_cfg .hStream = stream .handle
134- drv_cfg .sharedMemBytes = config .shmem_size
135- attrs = [] # TODO: support more attributes
136- if config .cluster :
137- attr = driver .CUlaunchAttribute ()
138- attr .id = driver .CUlaunchAttributeID .CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
139- dim = attr .value .clusterDim
140- dim .x , dim .y , dim .z = config .cluster
141- attrs .append (attr )
142- drv_cfg .numAttrs = len (attrs )
143- drv_cfg .attrs = attrs
14481 handle_return (driver .cuLaunchKernelEx (drv_cfg , int (kernel ._handle ), args_ptr , 0 ))
14582 else :
14683 # TODO: check if config has any unsupported attrs
14784 handle_return (
14885 driver .cuLaunchKernel (
149- int (kernel ._handle ), * config .grid , * config .block , config .shmem_size , stream ._handle , args_ptr , 0
86+ int (kernel ._handle ), * config .grid , * config .block , config .shmem_size , stream .handle , args_ptr , 0
15087 )
15188 )
0 commit comments