Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions cuda_core/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,177 @@ def test_object_code_handle(get_saxpy_object_code):
assert mod.handle is not None


@pytest.fixture(scope="function")
def get_ltoir_object_code(init_cuda):
Comment thread
leofang marked this conversation as resolved.
Outdated
# Create LTOIR code using link-time optimization
prog = Program(SAXPY_KERNEL, code_type="c++", options=ProgramOptions(link_time_optimization=True))
mod = prog.compile("ltoir", name_expressions=("saxpy<float>", "saxpy<double>"))
return mod


def test_object_code_load_ltoir(get_ltoir_object_code):
mod = get_ltoir_object_code
ltoir = mod._module
sym_map = mod._sym_map
assert isinstance(ltoir, bytes)
mod_obj = ObjectCode.from_ltoir(ltoir, symbol_mapping=sym_map)
assert mod_obj.code == ltoir
assert mod_obj._code_type == "ltoir"
# ltoir doesn't support kernel retrieval directly as it's used for linking
assert mod_obj._handle is None # Should only be loaded when needed
# Test that get_kernel fails for unsupported code type
with pytest.raises(RuntimeError, match=r'Unsupported code type "ltoir"'):
mod_obj.get_kernel("saxpy<float>")


def test_object_code_load_ltoir_from_file(get_ltoir_object_code, tmp_path):
mod = get_ltoir_object_code
ltoir = mod._module
sym_map = mod._sym_map
assert isinstance(ltoir, bytes)
ltoir_file = tmp_path / "test.ltoir"
ltoir_file.write_bytes(ltoir)
mod_obj = ObjectCode.from_ltoir(str(ltoir_file), symbol_mapping=sym_map)
assert mod_obj.code == str(ltoir_file)
assert mod_obj._code_type == "ltoir"
assert mod_obj._handle is None # Should only be loaded when needed


def test_object_code_load_fatbin(get_saxpy_kernel):
Comment thread
leofang marked this conversation as resolved.
Outdated
# Use cubin as a substitute for fatbin since they have similar structure
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
mod_obj = ObjectCode.from_fatbin(cubin, symbol_mapping=sym_map)
assert mod_obj.code == cubin
assert mod_obj._code_type == "fatbin"
# fatbin supports kernel retrieval
mod_obj.get_kernel("saxpy<double>") # force loading


def test_object_code_load_fatbin_from_file(get_saxpy_kernel, tmp_path):
# Use cubin as a substitute for fatbin since they have similar structure
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
fatbin_file = tmp_path / "test.fatbin"
fatbin_file.write_bytes(cubin)
mod_obj = ObjectCode.from_fatbin(str(fatbin_file), symbol_mapping=sym_map)
assert mod_obj.code == str(fatbin_file)
assert mod_obj._code_type == "fatbin"
mod_obj.get_kernel("saxpy<double>") # force loading


def test_object_code_load_object(get_saxpy_kernel):
Comment thread
leofang marked this conversation as resolved.
Outdated
# Use cubin as a substitute for object code since they're binary formats
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
mod_obj = ObjectCode.from_object(cubin, symbol_mapping=sym_map)
assert mod_obj.code == cubin
assert mod_obj._code_type == "object"
# object code doesn't support direct kernel retrieval
assert mod_obj._handle is None # Should only be loaded when needed
# Test that get_kernel fails for unsupported code type
with pytest.raises(RuntimeError, match=r'Unsupported code type "object"'):
mod_obj.get_kernel("saxpy<float>")


def test_object_code_load_object_from_file(get_saxpy_kernel, tmp_path):
# Use cubin as a substitute for object code since they're binary formats
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
object_file = tmp_path / "test.o"
object_file.write_bytes(cubin)
mod_obj = ObjectCode.from_object(str(object_file), symbol_mapping=sym_map)
assert mod_obj.code == str(object_file)
assert mod_obj._code_type == "object"
assert mod_obj._handle is None # Should only be loaded when needed


def test_object_code_load_library(get_saxpy_kernel):
# Use cubin as a substitute for library since they're binary formats
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
mod_obj = ObjectCode.from_library(cubin, symbol_mapping=sym_map)
assert mod_obj.code == cubin
assert mod_obj._code_type == "library"
# library code doesn't support direct kernel retrieval
assert mod_obj._handle is None # Should only be loaded when needed
# Test that get_kernel fails for unsupported code type
with pytest.raises(RuntimeError, match=r'Unsupported code type "library"'):
mod_obj.get_kernel("saxpy<float>")


def test_object_code_load_library_from_file(get_saxpy_kernel, tmp_path):
# Use cubin as a substitute for library since they're binary formats
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
library_file = tmp_path / "test.a"
library_file.write_bytes(cubin)
mod_obj = ObjectCode.from_library(str(library_file), symbol_mapping=sym_map)
assert mod_obj.code == str(library_file)
assert mod_obj._code_type == "library"
assert mod_obj._handle is None # Should only be loaded when needed


def test_object_code_constructors_with_name_and_symbol_mapping():
Comment thread
leofang marked this conversation as resolved.
Outdated
"""Test that all from_* constructors properly set name and symbol_mapping"""
# Dummy data for testing
dummy_bytes = b"dummy_code_data"
test_name = "test_object"
test_sym_map = {"kernel1": "mangled_kernel1", "kernel2": "mangled_kernel2"}

# Test all constructors
constructors = [
(ObjectCode.from_cubin, "cubin"),
(ObjectCode.from_ptx, "ptx"),
(ObjectCode.from_ltoir, "ltoir"),
(ObjectCode.from_fatbin, "fatbin"),
(ObjectCode.from_object, "object"),
(ObjectCode.from_library, "library"),
]

for constructor, code_type in constructors:
obj = constructor(dummy_bytes, name=test_name, symbol_mapping=test_sym_map)
assert obj.name == test_name
assert obj._sym_map == test_sym_map
assert obj._code_type == code_type
assert obj.code == dummy_bytes


def test_object_code_constructors_default_values():
Comment thread
leofang marked this conversation as resolved.
Outdated
"""Test that all from_* constructors handle default values correctly"""
# Dummy data for testing
dummy_bytes = b"dummy_code_data"

# Test all constructors with defaults
constructors = [
(ObjectCode.from_cubin, "cubin"),
(ObjectCode.from_ptx, "ptx"),
(ObjectCode.from_ltoir, "ltoir"),
(ObjectCode.from_fatbin, "fatbin"),
(ObjectCode.from_object, "object"),
(ObjectCode.from_library, "library"),
]

for constructor, code_type in constructors:
obj = constructor(dummy_bytes) # Use defaults
assert obj.name == "" # Default name should be empty string
assert obj._sym_map == {} # Default symbol mapping should be empty dict
assert obj._code_type == code_type
assert obj.code == dummy_bytes


def test_saxpy_arguments(get_saxpy_kernel, cuda12_4_prerequisite_check):
krn, _ = get_saxpy_kernel

Expand Down