|
14 | 14 | import cuda.core.experimental |
15 | 15 | from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system |
16 | 16 |
|
17 | | -SAXPY_KERNEL = """ |
| 17 | +SAXPY_KERNEL = r""" |
18 | 18 | template<typename T> |
19 | 19 | __global__ void saxpy(const T a, |
20 | 20 | const T* x, |
@@ -162,3 +162,50 @@ def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path): |
162 | 162 | def test_object_code_handle(get_saxpy_object_code): |
163 | 163 | mod = get_saxpy_object_code |
164 | 164 | assert mod.handle is not None |
| 165 | + |
| 166 | + |
| 167 | +def test_saxpy_arguments(get_saxpy_kernel): |
| 168 | + import ctypes |
| 169 | + krn, _ = get_saxpy_kernel |
| 170 | + |
| 171 | + assert krn.num_arguments == 5 |
| 172 | + |
| 173 | + arg_info = krn.arguments_info |
| 174 | + n_args = len(arg_info) |
| 175 | + assert n_args == krn.num_arguments |
| 176 | + class ExpectedStruct(ctypes.Structure): |
| 177 | + _fields_ = [ |
| 178 | + ('a', ctypes.c_float), |
| 179 | + ('x', ctypes.POINTER(ctypes.c_float)), |
| 180 | + ('y', ctypes.POINTER(ctypes.c_float)), |
| 181 | + ('out', ctypes.POINTER(ctypes.c_float)), |
| 182 | + ('N', ctypes.c_size_t) |
| 183 | + ] |
| 184 | + offsets, sizes = zip(*arg_info) |
| 185 | + members = [getattr(ExpectedStruct, name) for name, _ in ExpectedStruct._fields_] |
| 186 | + expected_offsets = tuple(m.offset for m in members) |
| 187 | + assert all(actual == expected for actual, expected in zip(offsets, expected_offsets)) |
| 188 | + expected_sizes = tuple(m.size for m in members) |
| 189 | + assert all(actual == expected for actual, expected in zip(sizes, expected_sizes)) |
| 190 | + |
| 191 | + |
| 192 | +@pytest.mark.parametrize("nargs", [0, 1, 2, 3, 16]) |
| 193 | +def test_num_arguments(init_cuda, nargs): |
| 194 | + args_str = ", ".join([f"int p_{i}" for i in range(nargs)]) |
| 195 | + src = f"__global__ void foo{nargs}({args_str}) {{ }}" |
| 196 | + prog = Program(src, code_type="c++") |
| 197 | + mod = prog.compile( |
| 198 | + "cubin", |
| 199 | + name_expressions=(f"foo{nargs}", ), |
| 200 | + ) |
| 201 | + krn = mod.get_kernel(f"foo{nargs}") |
| 202 | + assert krn.num_arguments == nargs |
| 203 | + |
| 204 | + import ctypes |
| 205 | + class ExpectedStruct(ctypes.Structure): |
| 206 | + _fields_ = [ |
| 207 | + (f'arg_{i}', ctypes.c_int) for i in range(nargs) |
| 208 | + ] |
| 209 | + members = tuple(getattr(ExpectedStruct, f"arg_{i}") for i in range(nargs)) |
| 210 | + expected = [(m.offset, m.size) for m in members] |
| 211 | + assert krn.arguments_info == expected |
0 commit comments