Skip to content

Commit e22b891

Browse files
Implement Kernel.num_arguments, and Kernel.arguments_info
1 parent bd770e1 commit e22b891

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def _lazy_init():
4343
"data": driver.cuLibraryLoadData,
4444
"kernel": driver.cuLibraryGetKernel,
4545
"attribute": driver.cuKernelGetAttribute,
46+
"paraminfo": driver.cuKernelGetParamInfo,
4647
}
4748
_kernel_ctypes = (driver.CUfunction, driver.CUkernel)
4849
else:
@@ -215,6 +216,36 @@ def attributes(self) -> KernelAttributes:
215216
self._attributes = KernelAttributes._init(self._handle)
216217
return self._attributes
217218

219+
@property
220+
def num_arguments(self) -> int:
221+
"""int : The number of arguments of this function"""
222+
attr_impl = self.attributes
223+
if attr_impl._backend_version != "new":
224+
raise NotImplementedError("New backend is required")
225+
arg_pos = 0
226+
while True:
227+
result = attr_impl._loader["paraminfo"](self._handle, arg_pos)
228+
if result[0] != driver.CUresult.CUDA_SUCCESS:
229+
break
230+
arg_pos = arg_pos + 1
231+
return arg_pos
232+
233+
@property
234+
def arguments_info(self) -> list[tuple[int, int]]:
235+
"""tuple[tuple[int, int]]: (offset, size) for each argument of this function"""
236+
attr_impl = self.attributes
237+
if attr_impl._backend_version != "new":
238+
raise NotImplementedError("New backend is required")
239+
arg_pos = 0
240+
param_info = []
241+
while True:
242+
result = attr_impl._loader["paraminfo"](self._handle, arg_pos)
243+
if result[0] != driver.CUresult.CUDA_SUCCESS:
244+
break
245+
param_info.append((result[1], result[2]))
246+
arg_pos = arg_pos + 1
247+
return param_info
248+
218249
# TODO: implement from_handle()
219250

220251

cuda_core/tests/test_module.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import cuda.core.experimental
1515
from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system
1616

17-
SAXPY_KERNEL = """
17+
SAXPY_KERNEL = r"""
1818
template<typename T>
1919
__global__ void saxpy(const T a,
2020
const T* x,
@@ -162,3 +162,50 @@ def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path):
162162
def test_object_code_handle(get_saxpy_object_code):
163163
mod = get_saxpy_object_code
164164
assert mod.handle is not None
165+
166+
167+
def test_saxpy_arguments(get_saxpy_kernel):
168+
import ctypes
169+
krn, _ = get_saxpy_kernel
170+
171+
assert krn.num_arguments == 5
172+
173+
arg_info = krn.arguments_info
174+
n_args = len(arg_info)
175+
assert n_args == krn.num_arguments
176+
class ExpectedStruct(ctypes.Structure):
177+
_fields_ = [
178+
('a', ctypes.c_float),
179+
('x', ctypes.POINTER(ctypes.c_float)),
180+
('y', ctypes.POINTER(ctypes.c_float)),
181+
('out', ctypes.POINTER(ctypes.c_float)),
182+
('N', ctypes.c_size_t)
183+
]
184+
offsets, sizes = zip(*arg_info)
185+
members = [getattr(ExpectedStruct, name) for name, _ in ExpectedStruct._fields_]
186+
expected_offsets = tuple(m.offset for m in members)
187+
assert all(actual == expected for actual, expected in zip(offsets, expected_offsets))
188+
expected_sizes = tuple(m.size for m in members)
189+
assert all(actual == expected for actual, expected in zip(sizes, expected_sizes))
190+
191+
192+
@pytest.mark.parametrize("nargs", [0, 1, 2, 3, 16])
193+
def test_num_arguments(init_cuda, nargs):
194+
args_str = ", ".join([f"int p_{i}" for i in range(nargs)])
195+
src = f"__global__ void foo{nargs}({args_str}) {{ }}"
196+
prog = Program(src, code_type="c++")
197+
mod = prog.compile(
198+
"cubin",
199+
name_expressions=(f"foo{nargs}", ),
200+
)
201+
krn = mod.get_kernel(f"foo{nargs}")
202+
assert krn.num_arguments == nargs
203+
204+
import ctypes
205+
class ExpectedStruct(ctypes.Structure):
206+
_fields_ = [
207+
(f'arg_{i}', ctypes.c_int) for i in range(nargs)
208+
]
209+
members = tuple(getattr(ExpectedStruct, f"arg_{i}") for i in range(nargs))
210+
expected = [(m.offset, m.size) for m in members]
211+
assert krn.arguments_info == expected

0 commit comments

Comments
 (0)