@@ -17,6 +17,7 @@ import numpy
1717
1818from cuda.core.experimental._memory import Buffer
1919from cuda.core.experimental._utils.cuda_utils import driver
20+ from cuda.bindings cimport cydriver
2021
2122
2223ctypedef cpp_complex.complex[float ] cpp_single_complex
@@ -128,67 +129,123 @@ cdef inline int prepare_ctypes_arg(
128129 vector.vector[void * ]& data_addresses,
129130 arg,
130131 const size_t idx) except - 1 :
131- if isinstance (arg, ctypes_bool):
132+ cdef object arg_type = type (arg)
133+ if arg_type is ctypes_bool:
132134 return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
133- elif isinstance (arg, ctypes_int8) :
135+ elif arg_type is ctypes_int8:
134136 return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
135- elif isinstance (arg, ctypes_int16) :
137+ elif arg_type is ctypes_int16:
136138 return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
137- elif isinstance (arg, ctypes_int32) :
139+ elif arg_type is ctypes_int32:
138140 return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
139- elif isinstance (arg, ctypes_int64) :
141+ elif arg_type is ctypes_int64:
140142 return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
141- elif isinstance (arg, ctypes_uint8) :
143+ elif arg_type is ctypes_uint8:
142144 return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
143- elif isinstance (arg, ctypes_uint16) :
145+ elif arg_type is ctypes_uint16:
144146 return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
145- elif isinstance (arg, ctypes_uint32) :
147+ elif arg_type is ctypes_uint32:
146148 return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
147- elif isinstance (arg, ctypes_uint64) :
149+ elif arg_type is ctypes_uint64:
148150 return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
149- elif isinstance (arg, ctypes_float) :
151+ elif arg_type is ctypes_float:
150152 return prepare_arg[float ](data, data_addresses, arg.value, idx)
151- elif isinstance (arg, ctypes_double) :
153+ elif arg_type is ctypes_double:
152154 return prepare_arg[double ](data, data_addresses, arg.value, idx)
153155 else :
154- return 1
156+ # If no exact types are found, fallback to slower `isinstance` check
157+ if isinstance (arg, ctypes_bool):
158+ return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
159+ elif isinstance (arg, ctypes_int8):
160+ return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
161+ elif isinstance (arg, ctypes_int16):
162+ return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
163+ elif isinstance (arg, ctypes_int32):
164+ return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
165+ elif isinstance (arg, ctypes_int64):
166+ return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
167+ elif isinstance (arg, ctypes_uint8):
168+ return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
169+ elif isinstance (arg, ctypes_uint16):
170+ return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
171+ elif isinstance (arg, ctypes_uint32):
172+ return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
173+ elif isinstance (arg, ctypes_uint64):
174+ return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
175+ elif isinstance (arg, ctypes_float):
176+ return prepare_arg[float ](data, data_addresses, arg.value, idx)
177+ elif isinstance (arg, ctypes_double):
178+ return prepare_arg[double ](data, data_addresses, arg.value, idx)
179+ else :
180+ return 1
155181
156182
157183cdef inline int prepare_numpy_arg(
158184 vector.vector[void * ]& data,
159185 vector.vector[void * ]& data_addresses,
160186 arg,
161187 const size_t idx) except - 1 :
162- if isinstance (arg, numpy_bool):
188+ cdef object arg_type = type (arg)
189+ if arg_type is numpy_bool:
163190 return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
164- elif isinstance (arg, numpy_int8) :
191+ elif arg_type is numpy_int8:
165192 return prepare_arg[int8_t](data, data_addresses, arg, idx)
166- elif isinstance (arg, numpy_int16) :
193+ elif arg_type is numpy_int16:
167194 return prepare_arg[int16_t](data, data_addresses, arg, idx)
168- elif isinstance (arg, numpy_int32) :
195+ elif arg_type is numpy_int32:
169196 return prepare_arg[int32_t](data, data_addresses, arg, idx)
170- elif isinstance (arg, numpy_int64) :
197+ elif arg_type is numpy_int64:
171198 return prepare_arg[int64_t](data, data_addresses, arg, idx)
172- elif isinstance (arg, numpy_uint8) :
199+ elif arg_type is numpy_uint8:
173200 return prepare_arg[uint8_t](data, data_addresses, arg, idx)
174- elif isinstance (arg, numpy_uint16) :
201+ elif arg_type is numpy_uint16:
175202 return prepare_arg[uint16_t](data, data_addresses, arg, idx)
176- elif isinstance (arg, numpy_uint32) :
203+ elif arg_type is numpy_uint32:
177204 return prepare_arg[uint32_t](data, data_addresses, arg, idx)
178- elif isinstance (arg, numpy_uint64) :
205+ elif arg_type is numpy_uint64:
179206 return prepare_arg[uint64_t](data, data_addresses, arg, idx)
180- elif isinstance (arg, numpy_float16) :
207+ elif arg_type is numpy_float16:
181208 return prepare_arg[__half_raw](data, data_addresses, arg, idx)
182- elif isinstance (arg, numpy_float32) :
209+ elif arg_type is numpy_float32:
183210 return prepare_arg[float ](data, data_addresses, arg, idx)
184- elif isinstance (arg, numpy_float64) :
211+ elif arg_type is numpy_float64:
185212 return prepare_arg[double ](data, data_addresses, arg, idx)
186- elif isinstance (arg, numpy_complex64) :
213+ elif arg_type is numpy_complex64:
187214 return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
188- elif isinstance (arg, numpy_complex128) :
215+ elif arg_type is numpy_complex128:
189216 return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
190217 else :
191- return 1
218+ # If no exact types are found, fallback to slower `isinstance` check
219+ if isinstance (arg, numpy_bool):
220+ return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
221+ elif isinstance (arg, numpy_int8):
222+ return prepare_arg[int8_t](data, data_addresses, arg, idx)
223+ elif isinstance (arg, numpy_int16):
224+ return prepare_arg[int16_t](data, data_addresses, arg, idx)
225+ elif isinstance (arg, numpy_int32):
226+ return prepare_arg[int32_t](data, data_addresses, arg, idx)
227+ elif isinstance (arg, numpy_int64):
228+ return prepare_arg[int64_t](data, data_addresses, arg, idx)
229+ elif isinstance (arg, numpy_uint8):
230+ return prepare_arg[uint8_t](data, data_addresses, arg, idx)
231+ elif isinstance (arg, numpy_uint16):
232+ return prepare_arg[uint16_t](data, data_addresses, arg, idx)
233+ elif isinstance (arg, numpy_uint32):
234+ return prepare_arg[uint32_t](data, data_addresses, arg, idx)
235+ elif isinstance (arg, numpy_uint64):
236+ return prepare_arg[uint64_t](data, data_addresses, arg, idx)
237+ elif isinstance (arg, numpy_float16):
238+ return prepare_arg[__half_raw](data, data_addresses, arg, idx)
239+ elif isinstance (arg, numpy_float32):
240+ return prepare_arg[float ](data, data_addresses, arg, idx)
241+ elif isinstance (arg, numpy_float64):
242+ return prepare_arg[double ](data, data_addresses, arg, idx)
243+ elif isinstance (arg, numpy_complex64):
244+ return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
245+ elif isinstance (arg, numpy_complex128):
246+ return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
247+ else :
248+ return 1
192249
193250
194251cdef class ParamHolder:
@@ -207,44 +264,69 @@ cdef class ParamHolder:
207264 cdef size_t n_args = len (kernel_args)
208265 cdef size_t i
209266 cdef int not_prepared
267+ cdef object arg_type
210268 self .data = vector.vector[voidptr](n_args, nullptr)
211269 self .data_addresses = vector.vector[voidptr](n_args)
212270 for i, arg in enumerate (kernel_args):
213- if isinstance (arg, Buffer):
271+ arg_type = type (arg)
272+ if arg_type is Buffer:
214273 # we need the address of where the actual buffer address is stored
215- if isinstance (arg.handle, int ) :
274+ if type (arg.handle) is int :
216275 # see note below on handling int arguments
217276 prepare_arg[intptr_t](self .data, self .data_addresses, arg.handle, i)
218277 continue
219278 else :
220279 # it's a CUdeviceptr:
221280 self .data_addresses[i] = < void * >< intptr_t> (arg.handle.getPtr())
222281 continue
223- elif isinstance (arg, int ):
282+ elif arg_type is bool :
283+ prepare_arg[cpp_bool](self .data, self .data_addresses, arg, i)
284+ continue
285+ elif arg_type is int :
224286 # Here's the dilemma: We want to have a fast path to pass in Python
225287 # integers as pointer addresses, but one could also (mistakenly) pass
226288 # it with the intention of passing a scalar integer. It's a mistake
227289 # bacause a Python int is ambiguous (arbitrary width). Our judgement
228290 # call here is to treat it as a pointer address, without any warning!
229291 prepare_arg[intptr_t](self .data, self .data_addresses, arg, i)
230292 continue
231- elif isinstance (arg, float ) :
293+ elif arg_type is float :
232294 prepare_arg[double ](self .data, self .data_addresses, arg, i)
233295 continue
234- elif isinstance (arg, complex ) :
296+ elif arg_type is complex :
235297 prepare_arg[cpp_double_complex](self .data, self .data_addresses, arg, i)
236298 continue
237- elif isinstance (arg, bool ):
238- prepare_arg[cpp_bool](self .data, self .data_addresses, arg, i)
239- continue
240299
241300 not_prepared = prepare_numpy_arg(self .data, self .data_addresses, arg, i)
242301 if not_prepared:
243302 not_prepared = prepare_ctypes_arg(self .data, self .data_addresses, arg, i)
244303 if not_prepared:
245304 # TODO: revisit this treatment if we decide to cythonize cuda.core
246- if isinstance (arg, driver.CUgraphConditionalHandle):
247- prepare_arg[intptr_t](self .data, self .data_addresses, < intptr_t> int (arg), i)
305+ if arg_type is driver.CUgraphConditionalHandle:
306+ prepare_arg[cydriver.CUgraphConditionalHandle](self .data, self .data_addresses, < intptr_t> int (arg), i)
307+ continue
308+ # If no exact types are found, fallback to slower `isinstance` check
309+ elif isinstance (arg, Buffer):
310+ if isinstance (arg.handle, int ):
311+ prepare_arg[intptr_t](self .data, self .data_addresses, arg.handle, i)
312+ continue
313+ else :
314+ self .data_addresses[i] = < void * >< intptr_t> (arg.handle.getPtr())
315+ continue
316+ elif isinstance (arg, bool ):
317+ prepare_arg[cpp_bool](self .data, self .data_addresses, arg, i)
318+ continue
319+ elif isinstance (arg, int ):
320+ prepare_arg[intptr_t](self .data, self .data_addresses, arg, i)
321+ continue
322+ elif isinstance (arg, float ):
323+ prepare_arg[double ](self .data, self .data_addresses, arg, i)
324+ continue
325+ elif isinstance (arg, complex ):
326+ prepare_arg[cpp_double_complex](self .data, self .data_addresses, arg, i)
327+ continue
328+ elif isinstance (arg, driver.CUgraphConditionalHandle):
329+ prepare_arg[cydriver.CUgraphConditionalHandle](self .data, self .data_addresses, arg, i)
248330 continue
249331 # TODO: support ctypes/numpy struct
250332 raise TypeError (" the argument is of unsupported type: " + str (type (arg)))
0 commit comments