Skip to content

Commit 330b30e

Browse files
rluo8rwgk
andauthored
tests: add coverage tests for cuda core (#1923)
* tests: add coverage tests for cuda core * tests: add launch-level coverage for ctypes/numpy subclass fallback --------- Co-authored-by: Ralf W. Grosse-Kunstleve <rwgkio@gmail.com>
1 parent 190df10 commit 330b30e

File tree

5 files changed

+501
-0
lines changed

5 files changed

+501
-0
lines changed

cuda_core/tests/test_event.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,54 @@ def test_event_type_safety(init_cuda):
195195
assert (event is None) is False
196196

197197

198+
def test_event_isub_not_implemented(init_cuda):
199+
"""Event.__isub__ returns NotImplemented for non-Event types."""
200+
device = Device()
201+
stream = device.create_stream()
202+
event = stream.record()
203+
result = event.__isub__(42)
204+
assert result is NotImplemented
205+
206+
207+
def test_event_rsub_not_implemented(init_cuda):
208+
"""Event.__rsub__ returns NotImplemented for non-Event types."""
209+
device = Device()
210+
stream = device.create_stream()
211+
event = stream.record()
212+
result = event.__rsub__(42)
213+
assert result is NotImplemented
214+
215+
216+
def test_event_get_ipc_descriptor_non_ipc(init_cuda):
217+
"""get_ipc_descriptor raises RuntimeError on a non-IPC event."""
218+
device = Device()
219+
stream = device.create_stream()
220+
event = stream.record()
221+
with pytest.raises(RuntimeError, match="not IPC-enabled"):
222+
event.get_ipc_descriptor()
223+
224+
225+
def test_event_is_done_false(init_cuda):
226+
"""Event.is_done returns False when captured work has not yet completed."""
227+
device = Device()
228+
latch = LatchKernel(device)
229+
stream = device.create_stream()
230+
latch.launch(stream)
231+
event = stream.record()
232+
# The latch holds the kernel; the event cannot be done yet.
233+
assert event.is_done is False
234+
latch.release()
235+
event.sync()
236+
237+
238+
def test_ipc_event_descriptor_direct_init():
239+
"""IPCEventDescriptor cannot be instantiated directly."""
240+
import cuda.core._event as _event_module
241+
242+
with pytest.raises(RuntimeError, match="cannot be instantiated directly"):
243+
_event_module.IPCEventDescriptor()
244+
245+
198246
# ============================================================================
199247
# Event Hash Tests
200248
# ============================================================================

cuda_core/tests/test_launcher.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,132 @@ def test_kernel_arg_unsupported_type():
387387

388388
with pytest.raises(TypeError, match="unsupported type"):
389389
ParamHolder(["not_a_valid_kernel_arg"])
390+
391+
392+
def test_kernel_arg_ctypes_subclass_isinstance_fallback():
393+
"""Subclassed ctypes types hit the isinstance fallback in prepare_ctypes_arg."""
394+
from cuda.core._kernel_arg_handler import ParamHolder
395+
396+
class MyInt32(ctypes.c_int32):
397+
pass
398+
399+
class MyFloat(ctypes.c_float):
400+
pass
401+
402+
class MyBool(ctypes.c_bool):
403+
pass
404+
405+
# These should NOT raise — they should be handled via isinstance fallback
406+
holder = ParamHolder([MyInt32(42), MyFloat(3.14), MyBool(True)])
407+
assert holder.ptr != 0
408+
409+
410+
@requires_module(np, "2.1")
411+
def test_launch_scalar_argument_ctypes_subclass_fallback():
412+
"""Subclassed ctypes scalars survive the launch path and reach the kernel correctly."""
413+
414+
class MyInt32(ctypes.c_int32):
415+
pass
416+
417+
dev = Device()
418+
dev.set_current()
419+
420+
mr = LegacyPinnedMemoryResource()
421+
b = mr.allocate(np.dtype(np.int32).itemsize)
422+
arr = np.from_dlpack(b).view(np.int32)
423+
arr[:] = 0
424+
425+
scalar = MyInt32(-123456)
426+
427+
code = r"""
428+
template <typename T>
429+
__global__ void write_scalar(T* arr, T val) {
430+
arr[0] = val;
431+
}
432+
"""
433+
434+
arch = "".join(f"{i}" for i in dev.compute_capability)
435+
pro_opts = ProgramOptions(std="c++17", arch=f"sm_{arch}")
436+
prog = Program(code, code_type="c++", options=pro_opts)
437+
ker_name = "write_scalar<signed int>"
438+
mod = prog.compile("cubin", name_expressions=(ker_name,))
439+
ker = mod.get_kernel(ker_name)
440+
441+
# This exercises the prepare_ctypes_arg isinstance fallback through a real launch.
442+
stream = dev.default_stream
443+
config = LaunchConfig(grid=1, block=1)
444+
launch(stream, config, ker, arr.ctypes.data, scalar)
445+
stream.sync()
446+
447+
assert arr[0] == scalar.value
448+
449+
450+
def test_kernel_arg_numpy_subclass_isinstance_fallback():
451+
"""Subclassed numpy scalars hit the isinstance fallback in prepare_numpy_arg."""
452+
from cuda.core._kernel_arg_handler import ParamHolder
453+
454+
class MyInt32(np.int32):
455+
pass
456+
457+
class MyFloat32(np.float32):
458+
pass
459+
460+
holder = ParamHolder([MyInt32(7), MyFloat32(2.5)])
461+
assert holder.ptr != 0
462+
463+
464+
@requires_module(np, "2.1")
465+
def test_launch_scalar_argument_numpy_subclass_fallback():
466+
"""Subclassed numpy scalars survive the launch path and reach the kernel correctly."""
467+
468+
class MyFloat32(np.float32):
469+
pass
470+
471+
dev = Device()
472+
dev.set_current()
473+
474+
mr = LegacyPinnedMemoryResource()
475+
b = mr.allocate(np.dtype(np.float32).itemsize)
476+
arr = np.from_dlpack(b).view(np.float32)
477+
arr[:] = 0.0
478+
479+
scalar = MyFloat32(3.14)
480+
481+
code = r"""
482+
template <typename T>
483+
__global__ void write_scalar(T* arr, T val) {
484+
arr[0] = val;
485+
}
486+
"""
487+
488+
arch = "".join(f"{i}" for i in dev.compute_capability)
489+
pro_opts = ProgramOptions(std="c++17", arch=f"sm_{arch}")
490+
prog = Program(code, code_type="c++", options=pro_opts)
491+
ker_name = "write_scalar<float>"
492+
mod = prog.compile("cubin", name_expressions=(ker_name,))
493+
ker = mod.get_kernel(ker_name)
494+
495+
# This exercises the prepare_numpy_arg isinstance fallback through a real launch.
496+
stream = dev.default_stream
497+
config = LaunchConfig(grid=1, block=1)
498+
launch(stream, config, ker, arr.ctypes.data, scalar)
499+
stream.sync()
500+
501+
assert arr[0] == scalar
502+
503+
504+
def test_kernel_arg_python_isinstance_fallbacks():
505+
"""Subclassed Python builtins hit the isinstance fallback in ParamHolder."""
506+
from cuda.core._kernel_arg_handler import ParamHolder
507+
508+
class MyBool(int):
509+
"""type(x) is not int, so fast path skips; isinstance(x, int) catches it."""
510+
511+
class MyFloat(float):
512+
pass
513+
514+
class MyComplex(complex):
515+
pass
516+
517+
holder = ParamHolder([MyBool(1), MyFloat(1.5), MyComplex(1 + 2j)])
518+
assert holder.ptr != 0

cuda_core/tests/test_linker.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,24 @@ def test_linker_logs_cached_after_link(compile_ptx_functions):
221221
# Calling again should return the same observable values.
222222
assert linker.get_error_log() == err_log
223223
assert linker.get_info_log() == info_log
224+
225+
226+
def test_linker_handle(compile_ptx_functions):
227+
"""Linker.handle returns a non-null handle object."""
228+
options = LinkerOptions(arch=ARCH)
229+
linker = Linker(*compile_ptx_functions, options=options)
230+
handle = linker.handle
231+
assert handle is not None
232+
assert int(handle) != 0
233+
234+
235+
@pytest.mark.skipif(is_culink_backend, reason="nvjitlink options only tested with nvjitlink backend")
236+
def test_linker_options_nvjitlink_options_as_str():
237+
"""_prepare_nvjitlink_options(as_bytes=False) returns plain strings."""
238+
opts = LinkerOptions(arch=ARCH, debug=True, lineinfo=True)
239+
options = opts._prepare_nvjitlink_options(as_bytes=False)
240+
assert isinstance(options, list)
241+
assert all(isinstance(o, str) for o in options)
242+
assert f"-arch={ARCH}" in options
243+
assert "-g" in options
244+
assert "-lineinfo" in options

cuda_core/tests/test_program.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,3 +773,107 @@ def test_program_options_as_bytes_nvvm_unsupported_option():
773773
options = ProgramOptions(arch="sm_80", lineinfo=True)
774774
with pytest.raises(CUDAError, match="not supported by NVVM backend"):
775775
options.as_bytes("nvvm")
776+
777+
778+
def test_program_options_repr():
779+
"""ProgramOptions.__repr__ returns a human-readable string."""
780+
opts = ProgramOptions(name="mykernel", arch="sm_80")
781+
r = repr(opts)
782+
assert "ProgramOptions" in r
783+
assert "mykernel" in r
784+
assert "sm_80" in r
785+
786+
787+
def test_program_options_bad_define_macro_short_tuple():
788+
"""define_macro with a 1-element tuple raises RuntimeError."""
789+
opts = ProgramOptions(name="test", arch="sm_80", define_macro=("ONLY_NAME",))
790+
with pytest.raises(RuntimeError, match="Expected define_macro tuple"):
791+
opts.as_bytes("nvrtc")
792+
793+
794+
def test_program_options_bad_define_macro_non_str_value():
795+
"""define_macro tuple with a non-string value raises RuntimeError."""
796+
opts = ProgramOptions(name="test", arch="sm_80", define_macro=("MY_MACRO", 99))
797+
with pytest.raises(RuntimeError, match="Expected define_macro tuple"):
798+
opts.as_bytes("nvrtc")
799+
800+
801+
def test_program_options_bad_define_macro_list_non_str():
802+
"""define_macro list containing a non-str/non-tuple item raises RuntimeError."""
803+
opts = ProgramOptions(name="test", arch="sm_80", define_macro=[42])
804+
with pytest.raises(RuntimeError, match="Expected define_macro"):
805+
opts.as_bytes("nvrtc")
806+
807+
808+
def test_program_options_bad_define_macro_list_bad_tuple():
809+
"""define_macro list with a malformed tuple inside raises RuntimeError."""
810+
opts = ProgramOptions(name="test", arch="sm_80", define_macro=[("ONLY_NAME",)])
811+
with pytest.raises(RuntimeError, match="Expected define_macro"):
812+
opts.as_bytes("nvrtc")
813+
814+
815+
def test_ptx_program_extra_sources_unsupported(ptx_code_object):
816+
"""PTX backend raises ValueError when extra_sources is specified."""
817+
options = ProgramOptions(extra_sources=[("module1", b"data")])
818+
with pytest.raises(ValueError, match="extra_sources is not supported by the PTX backend"):
819+
Program(ptx_code_object.code.decode(), "ptx", options)
820+
821+
822+
def test_ptx_program_handle_is_linker_handle(init_cuda, ptx_code_object):
823+
"""Program.handle for the PTX backend delegates to the linker handle."""
824+
program = Program(ptx_code_object.code.decode(), "ptx")
825+
handle = program.handle
826+
assert handle is not None
827+
assert int(handle) != 0
828+
program.close()
829+
830+
831+
@nvvm_available
832+
def test_nvvm_program_wrong_code_type():
833+
"""NVVM backend raises TypeError when code is not str/bytes/bytearray."""
834+
with pytest.raises(TypeError, match="NVVM IR code must be provided as str, bytes, or bytearray"):
835+
Program(42, "nvvm")
836+
837+
838+
def test_extra_sources_not_sequence():
839+
"""extra_sources must be a sequence; non-sequence raises TypeError."""
840+
with pytest.raises(TypeError, match="extra_sources must be a sequence of 2-tuples"):
841+
ProgramOptions(name="test", arch="sm_80", extra_sources=42)
842+
843+
844+
def test_extra_sources_bad_module_not_tuple():
845+
"""extra_sources items must be 2-tuples; non-tuple item raises TypeError."""
846+
with pytest.raises(TypeError, match="Each extra module must be a 2-tuple"):
847+
ProgramOptions(name="test", arch="sm_80", extra_sources=["not_a_tuple"])
848+
849+
850+
def test_extra_sources_bad_module_name_not_str():
851+
"""extra_sources module name must be a string; non-str raises TypeError."""
852+
with pytest.raises(TypeError, match="Module name at index 0 must be a string"):
853+
ProgramOptions(name="test", arch="sm_80", extra_sources=[(42, b"source")])
854+
855+
856+
def test_extra_sources_bad_module_source_wrong_type():
857+
"""extra_sources module source must be str/bytes/bytearray."""
858+
with pytest.raises(TypeError, match="Module source at index 0 must be str"):
859+
ProgramOptions(name="test", arch="sm_80", extra_sources=[("mod", 42)])
860+
861+
862+
def test_extra_sources_empty_source():
863+
"""extra_sources module source cannot be empty bytes."""
864+
with pytest.raises(ValueError, match="Module source for 'mod'.*cannot be empty"):
865+
ProgramOptions(name="test", arch="sm_80", extra_sources=[("mod", b"")])
866+
867+
868+
def test_nvrtc_compile_with_logs_capture(init_cuda):
869+
"""Program.compile with logs= exercises the NVRTC program-log reading path."""
870+
import io
871+
872+
# #warning generates a non-empty NVRTC program log, ensuring logsize > 1.
873+
code = '#warning "test log capture"\nextern "C" __global__ void my_kernel() {}'
874+
program = Program(code, "c++")
875+
logs = io.StringIO()
876+
result = program.compile("ptx", logs=logs)
877+
assert isinstance(result, ObjectCode)
878+
assert logs.getvalue(), "Expected non-empty compilation log from #warning directive"
879+
program.close()

0 commit comments

Comments
 (0)