Skip to content

Commit 6de62d0

Browse files
Factor out common logic between num_arguments and arguments_info properties
Also parametrize test to check arguments_info to check with all int, and all short arguments.
1 parent e22b891 commit 6de62d0

File tree

2 files changed

+18
-21
lines changed

2 files changed

+18
-21
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -216,34 +216,31 @@ def attributes(self) -> KernelAttributes:
216216
self._attributes = KernelAttributes._init(self._handle)
217217
return self._attributes
218218

219-
@property
220-
def num_arguments(self) -> int:
221-
"""int : The number of arguments of this function"""
219+
def _get_arguments_info(self, param_info=False) -> tuple[int, list[tuple[int, int]]]:
222220
attr_impl = self.attributes
223221
if attr_impl._backend_version != "new":
224222
raise NotImplementedError("New backend is required")
225223
arg_pos = 0
224+
param_info_data = []
226225
while True:
227226
result = attr_impl._loader["paraminfo"](self._handle, arg_pos)
228227
if result[0] != driver.CUresult.CUDA_SUCCESS:
229228
break
229+
if param_info:
230+
param_info_data.append(result[1:3])
230231
arg_pos = arg_pos + 1
231-
return arg_pos
232+
return arg_pos, param_info_data
233+
234+
@property
235+
def num_arguments(self) -> int:
236+
"""int : The number of arguments of this function"""
237+
num_args, _ = self._get_arguments_info()
238+
return num_args
232239

233240
@property
234241
def arguments_info(self) -> list[tuple[int, int]]:
235-
"""tuple[tuple[int, int]]: (offset, size) for each argument of this function"""
236-
attr_impl = self.attributes
237-
if attr_impl._backend_version != "new":
238-
raise NotImplementedError("New backend is required")
239-
arg_pos = 0
240-
param_info = []
241-
while True:
242-
result = attr_impl._loader["paraminfo"](self._handle, arg_pos)
243-
if result[0] != driver.CUresult.CUDA_SUCCESS:
244-
break
245-
param_info.append((result[1], result[2]))
246-
arg_pos = arg_pos + 1
242+
"""list[tuple[int, int]]: (offset, size) for each argument of this function"""
243+
_, param_info = self._get_arguments_info(param_info=True)
247244
return param_info
248245

249246
# TODO: implement from_handle()

cuda_core/tests/test_module.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
import warnings
11+
import ctypes
1112

1213
import pytest
1314

@@ -165,7 +166,6 @@ def test_object_code_handle(get_saxpy_object_code):
165166

166167

167168
def test_saxpy_arguments(get_saxpy_kernel):
168-
import ctypes
169169
krn, _ = get_saxpy_kernel
170170

171171
assert krn.num_arguments == 5
@@ -190,8 +190,9 @@ class ExpectedStruct(ctypes.Structure):
190190

191191

192192
@pytest.mark.parametrize("nargs", [0, 1, 2, 3, 16])
193-
def test_num_arguments(init_cuda, nargs):
194-
args_str = ", ".join([f"int p_{i}" for i in range(nargs)])
193+
@pytest.mark.parametrize("c_type_name,c_type", [("int", ctypes.c_int), ("short", ctypes.c_short)], ids=["int", "short"])
194+
def test_num_arguments(init_cuda, nargs, c_type_name, c_type):
195+
args_str = ", ".join([f"{c_type_name} p_{i}" for i in range(nargs)])
195196
src = f"__global__ void foo{nargs}({args_str}) {{ }}"
196197
prog = Program(src, code_type="c++")
197198
mod = prog.compile(
@@ -201,10 +202,9 @@ def test_num_arguments(init_cuda, nargs):
201202
krn = mod.get_kernel(f"foo{nargs}")
202203
assert krn.num_arguments == nargs
203204

204-
import ctypes
205205
class ExpectedStruct(ctypes.Structure):
206206
_fields_ = [
207-
(f'arg_{i}', ctypes.c_int) for i in range(nargs)
207+
(f'arg_{i}', c_type) for i in range(nargs)
208208
]
209209
members = tuple(getattr(ExpectedStruct, f"arg_{i}") for i in range(nargs))
210210
expected = [(m.offset, m.size) for m in members]

0 commit comments

Comments
 (0)