Skip to content

Commit b2a912a

Browse files
bharatr21mdboomleofang
authored
chore: Replace isinstance(obj, T) with type(obj) is T comparisons (NVIDIA#1292)
* chore: Replace isinstance(obj, T) with type(obj) is T comparisons * Add isinstance fallback to maintain backward compat for subclasses Signed-off-by: Bharat Raghunathan <bharatrgatech@gmail.com> * Fix test failures (attempt 1/n) Signed-off-by: Bharat Raghunathan <bharatrgatech@gmail.com> * Fix test failures (attempt 2/n) Signed-off-by: Bharat Raghunathan <bharatrgatech@gmail.com> * Apply suggestions from code review by @mdboom and @leofang Explicit cast not needed since `prepare_arg` does it automatically Co-authored-by: Leo Fang <leo80042@gmail.com> * Fix test_graph.py to use bool correctly in cudaGraphSetConditional call. * More explicit typing of CUgraphConditionalHandle * Improve conditional kernel tests by testing both bools and ints * fix * Move breaking change to bugfix --------- Signed-off-by: Bharat Raghunathan <bharatrgatech@gmail.com> Co-authored-by: Michael Droettboom <mdboom@gmail.com> Co-authored-by: Leo Fang <leof@nvidia.com> Co-authored-by: Leo Fang <leo80042@gmail.com>
1 parent df394be commit b2a912a

3 files changed

Lines changed: 184 additions & 49 deletions

File tree

cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx

Lines changed: 119 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import numpy
1717

1818
from cuda.core.experimental._memory import Buffer
1919
from cuda.core.experimental._utils.cuda_utils import driver
20+
from cuda.bindings cimport cydriver
2021

2122

2223
ctypedef cpp_complex.complex[float] cpp_single_complex
@@ -128,67 +129,123 @@ cdef inline int prepare_ctypes_arg(
128129
vector.vector[void*]& data_addresses,
129130
arg,
130131
const size_t idx) except -1:
131-
if isinstance(arg, ctypes_bool):
132+
cdef object arg_type = type(arg)
133+
if arg_type is ctypes_bool:
132134
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
133-
elif isinstance(arg, ctypes_int8):
135+
elif arg_type is ctypes_int8:
134136
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
135-
elif isinstance(arg, ctypes_int16):
137+
elif arg_type is ctypes_int16:
136138
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
137-
elif isinstance(arg, ctypes_int32):
139+
elif arg_type is ctypes_int32:
138140
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
139-
elif isinstance(arg, ctypes_int64):
141+
elif arg_type is ctypes_int64:
140142
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
141-
elif isinstance(arg, ctypes_uint8):
143+
elif arg_type is ctypes_uint8:
142144
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
143-
elif isinstance(arg, ctypes_uint16):
145+
elif arg_type is ctypes_uint16:
144146
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
145-
elif isinstance(arg, ctypes_uint32):
147+
elif arg_type is ctypes_uint32:
146148
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
147-
elif isinstance(arg, ctypes_uint64):
149+
elif arg_type is ctypes_uint64:
148150
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
149-
elif isinstance(arg, ctypes_float):
151+
elif arg_type is ctypes_float:
150152
return prepare_arg[float](data, data_addresses, arg.value, idx)
151-
elif isinstance(arg, ctypes_double):
153+
elif arg_type is ctypes_double:
152154
return prepare_arg[double](data, data_addresses, arg.value, idx)
153155
else:
154-
return 1
156+
# If no exact types are found, fallback to slower `isinstance` check
157+
if isinstance(arg, ctypes_bool):
158+
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
159+
elif isinstance(arg, ctypes_int8):
160+
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
161+
elif isinstance(arg, ctypes_int16):
162+
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
163+
elif isinstance(arg, ctypes_int32):
164+
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
165+
elif isinstance(arg, ctypes_int64):
166+
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
167+
elif isinstance(arg, ctypes_uint8):
168+
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
169+
elif isinstance(arg, ctypes_uint16):
170+
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
171+
elif isinstance(arg, ctypes_uint32):
172+
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
173+
elif isinstance(arg, ctypes_uint64):
174+
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
175+
elif isinstance(arg, ctypes_float):
176+
return prepare_arg[float](data, data_addresses, arg.value, idx)
177+
elif isinstance(arg, ctypes_double):
178+
return prepare_arg[double](data, data_addresses, arg.value, idx)
179+
else:
180+
return 1
155181

156182

157183
cdef inline int prepare_numpy_arg(
158184
vector.vector[void*]& data,
159185
vector.vector[void*]& data_addresses,
160186
arg,
161187
const size_t idx) except -1:
162-
if isinstance(arg, numpy_bool):
188+
cdef object arg_type = type(arg)
189+
if arg_type is numpy_bool:
163190
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
164-
elif isinstance(arg, numpy_int8):
191+
elif arg_type is numpy_int8:
165192
return prepare_arg[int8_t](data, data_addresses, arg, idx)
166-
elif isinstance(arg, numpy_int16):
193+
elif arg_type is numpy_int16:
167194
return prepare_arg[int16_t](data, data_addresses, arg, idx)
168-
elif isinstance(arg, numpy_int32):
195+
elif arg_type is numpy_int32:
169196
return prepare_arg[int32_t](data, data_addresses, arg, idx)
170-
elif isinstance(arg, numpy_int64):
197+
elif arg_type is numpy_int64:
171198
return prepare_arg[int64_t](data, data_addresses, arg, idx)
172-
elif isinstance(arg, numpy_uint8):
199+
elif arg_type is numpy_uint8:
173200
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
174-
elif isinstance(arg, numpy_uint16):
201+
elif arg_type is numpy_uint16:
175202
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
176-
elif isinstance(arg, numpy_uint32):
203+
elif arg_type is numpy_uint32:
177204
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
178-
elif isinstance(arg, numpy_uint64):
205+
elif arg_type is numpy_uint64:
179206
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
180-
elif isinstance(arg, numpy_float16):
207+
elif arg_type is numpy_float16:
181208
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
182-
elif isinstance(arg, numpy_float32):
209+
elif arg_type is numpy_float32:
183210
return prepare_arg[float](data, data_addresses, arg, idx)
184-
elif isinstance(arg, numpy_float64):
211+
elif arg_type is numpy_float64:
185212
return prepare_arg[double](data, data_addresses, arg, idx)
186-
elif isinstance(arg, numpy_complex64):
213+
elif arg_type is numpy_complex64:
187214
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
188-
elif isinstance(arg, numpy_complex128):
215+
elif arg_type is numpy_complex128:
189216
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
190217
else:
191-
return 1
218+
# If no exact types are found, fallback to slower `isinstance` check
219+
if isinstance(arg, numpy_bool):
220+
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
221+
elif isinstance(arg, numpy_int8):
222+
return prepare_arg[int8_t](data, data_addresses, arg, idx)
223+
elif isinstance(arg, numpy_int16):
224+
return prepare_arg[int16_t](data, data_addresses, arg, idx)
225+
elif isinstance(arg, numpy_int32):
226+
return prepare_arg[int32_t](data, data_addresses, arg, idx)
227+
elif isinstance(arg, numpy_int64):
228+
return prepare_arg[int64_t](data, data_addresses, arg, idx)
229+
elif isinstance(arg, numpy_uint8):
230+
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
231+
elif isinstance(arg, numpy_uint16):
232+
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
233+
elif isinstance(arg, numpy_uint32):
234+
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
235+
elif isinstance(arg, numpy_uint64):
236+
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
237+
elif isinstance(arg, numpy_float16):
238+
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
239+
elif isinstance(arg, numpy_float32):
240+
return prepare_arg[float](data, data_addresses, arg, idx)
241+
elif isinstance(arg, numpy_float64):
242+
return prepare_arg[double](data, data_addresses, arg, idx)
243+
elif isinstance(arg, numpy_complex64):
244+
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
245+
elif isinstance(arg, numpy_complex128):
246+
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
247+
else:
248+
return 1
192249

193250

194251
cdef class ParamHolder:
@@ -207,44 +264,69 @@ cdef class ParamHolder:
207264
cdef size_t n_args = len(kernel_args)
208265
cdef size_t i
209266
cdef int not_prepared
267+
cdef object arg_type
210268
self.data = vector.vector[voidptr](n_args, nullptr)
211269
self.data_addresses = vector.vector[voidptr](n_args)
212270
for i, arg in enumerate(kernel_args):
213-
if isinstance(arg, Buffer):
271+
arg_type = type(arg)
272+
if arg_type is Buffer:
214273
# we need the address of where the actual buffer address is stored
215-
if isinstance(arg.handle, int):
274+
if type(arg.handle) is int:
216275
# see note below on handling int arguments
217276
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
218277
continue
219278
else:
220279
# it's a CUdeviceptr:
221280
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
222281
continue
223-
elif isinstance(arg, int):
282+
elif arg_type is bool:
283+
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
284+
continue
285+
elif arg_type is int:
224286
# Here's the dilemma: We want to have a fast path to pass in Python
225287
# integers as pointer addresses, but one could also (mistakenly) pass
226288
# it with the intention of passing a scalar integer. It's a mistake
227289
# bacause a Python int is ambiguous (arbitrary width). Our judgement
228290
# call here is to treat it as a pointer address, without any warning!
229291
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
230292
continue
231-
elif isinstance(arg, float):
293+
elif arg_type is float:
232294
prepare_arg[double](self.data, self.data_addresses, arg, i)
233295
continue
234-
elif isinstance(arg, complex):
296+
elif arg_type is complex:
235297
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
236298
continue
237-
elif isinstance(arg, bool):
238-
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
239-
continue
240299

241300
not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
242301
if not_prepared:
243302
not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i)
244303
if not_prepared:
245304
# TODO: revisit this treatment if we decide to cythonize cuda.core
246-
if isinstance(arg, driver.CUgraphConditionalHandle):
247-
prepare_arg[intptr_t](self.data, self.data_addresses, <intptr_t>int(arg), i)
305+
if arg_type is driver.CUgraphConditionalHandle:
306+
prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, <intptr_t>int(arg), i)
307+
continue
308+
# If no exact types are found, fallback to slower `isinstance` check
309+
elif isinstance(arg, Buffer):
310+
if isinstance(arg.handle, int):
311+
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
312+
continue
313+
else:
314+
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
315+
continue
316+
elif isinstance(arg, bool):
317+
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
318+
continue
319+
elif isinstance(arg, int):
320+
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
321+
continue
322+
elif isinstance(arg, float):
323+
prepare_arg[double](self.data, self.data_addresses, arg, i)
324+
continue
325+
elif isinstance(arg, complex):
326+
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
327+
continue
328+
elif isinstance(arg, driver.CUgraphConditionalHandle):
329+
prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg, i)
248330
continue
249331
# TODO: support ctypes/numpy struct
250332
raise TypeError("the argument is of unsupported type: " + str(type(arg)))
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
.. SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
.. SPDX-License-Identifier: Apache-2.0
3+
4+
.. currentmodule:: cuda.core.experimental
5+
6+
``cuda.core`` 0.5.x Release Notes
7+
=================================
8+
9+
10+
Highlights
11+
----------
12+
13+
None.
14+
15+
16+
Breaking Changes
17+
----------------
18+
19+
None.
20+
21+
New features
22+
------------
23+
24+
None.
25+
26+
27+
New examples
28+
------------
29+
30+
None.
31+
32+
33+
Fixes and enhancements
34+
----------------------
35+
36+
- Python ``bool`` objects are now converted to C++ ``bool`` type when passed as kernel
37+
arguments. Previously, they were converted to ``int``. This brings them inline
38+
with ``ctypes.c_bool`` and ``numpy.bool_``.

cuda_core/tests/test_graph.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
import ctypes
6+
57
import numpy as np
68
import pytest
79

@@ -35,19 +37,28 @@ def _common_kernels():
3537
return mod
3638

3739

38-
def _common_kernels_conditional():
40+
def _common_kernels_conditional(cond_type):
41+
if cond_type in (bool, np.bool_, ctypes.c_bool):
42+
cond_type_str = "bool"
43+
elif cond_type is int:
44+
cond_type_str = "unsigned int"
45+
else:
46+
raise ValueError("Unsupported cond_type")
47+
3948
code = """
4049
extern "C" __device__ __cudart_builtin__ void CUDARTAPI cudaGraphSetConditional(cudaGraphConditionalHandle handle,
4150
unsigned int value);
4251
__global__ void empty_kernel() {}
4352
__global__ void add_one(int *a) { *a += 1; }
44-
__global__ void set_handle(cudaGraphConditionalHandle handle, int value) { cudaGraphSetConditional(handle, value); }
53+
__global__ void set_handle(cudaGraphConditionalHandle handle, $cond_type_str value) {
54+
cudaGraphSetConditional(handle, value);
55+
}
4556
__global__ void loop_kernel(cudaGraphConditionalHandle handle)
4657
{
4758
static int count = 10;
4859
cudaGraphSetConditional(handle, --count ? 1 : 0);
4960
}
50-
"""
61+
""".replace("$cond_type_str", cond_type_str)
5162
arch = "".join(f"{i}" for i in Device().compute_capability)
5263
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
5364
prog = Program(code, code_type="c++", options=program_options)
@@ -216,10 +227,12 @@ def test_graph_capture_errors(init_cuda):
216227
gb.end_building().complete()
217228

218229

219-
@pytest.mark.parametrize("condition_value", [True, False])
230+
@pytest.mark.parametrize(
231+
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0]
232+
)
220233
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
221234
def test_graph_conditional_if(init_cuda, condition_value):
222-
mod = _common_kernels_conditional()
235+
mod = _common_kernels_conditional(type(condition_value))
223236
add_one = mod.get_kernel("add_one")
224237
set_handle = mod.get_kernel("set_handle")
225238

@@ -278,10 +291,12 @@ def test_graph_conditional_if(init_cuda, condition_value):
278291
b.close()
279292

280293

281-
@pytest.mark.parametrize("condition_value", [True, False])
294+
@pytest.mark.parametrize(
295+
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0]
296+
)
282297
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
283298
def test_graph_conditional_if_else(init_cuda, condition_value):
284-
mod = _common_kernels_conditional()
299+
mod = _common_kernels_conditional(type(condition_value))
285300
add_one = mod.get_kernel("add_one")
286301
set_handle = mod.get_kernel("set_handle")
287302

@@ -353,7 +368,7 @@ def test_graph_conditional_if_else(init_cuda, condition_value):
353368
@pytest.mark.parametrize("condition_value", [0, 1, 2, 3])
354369
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
355370
def test_graph_conditional_switch(init_cuda, condition_value):
356-
mod = _common_kernels_conditional()
371+
mod = _common_kernels_conditional(type(condition_value))
357372
add_one = mod.get_kernel("add_one")
358373
set_handle = mod.get_kernel("set_handle")
359374

@@ -441,10 +456,10 @@ def test_graph_conditional_switch(init_cuda, condition_value):
441456
b.close()
442457

443458

444-
@pytest.mark.parametrize("condition_value", [True, False])
459+
@pytest.mark.parametrize("condition_value", [True, False, 1, 0])
445460
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
446461
def test_graph_conditional_while(init_cuda, condition_value):
447-
mod = _common_kernels_conditional()
462+
mod = _common_kernels_conditional(type(condition_value))
448463
add_one = mod.get_kernel("add_one")
449464
loop_kernel = mod.get_kernel("loop_kernel")
450465
empty_kernel = mod.get_kernel("empty_kernel")
@@ -545,7 +560,7 @@ def test_graph_child_graph(init_cuda):
545560

546561
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
547562
def test_graph_update(init_cuda):
548-
mod = _common_kernels_conditional()
563+
mod = _common_kernels_conditional(int)
549564
add_one = mod.get_kernel("add_one")
550565

551566
# Allocate memory
@@ -668,7 +683,7 @@ def test_graph_stream_lifetime(init_cuda):
668683

669684

670685
def test_graph_dot_print_options(init_cuda, tmp_path):
671-
mod = _common_kernels_conditional()
686+
mod = _common_kernels_conditional(bool)
672687
set_handle = mod.get_kernel("set_handle")
673688
empty_kernel = mod.get_kernel("empty_kernel")
674689

0 commit comments

Comments
 (0)