Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions cuda_core/cuda/core/experimental/_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,18 +395,26 @@ def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):

def _add_code_object(self, object_code: ObjectCode):
data = object_code._module
assert_type(data, bytes)
with _exception_manager(self):
name_str = f"{object_code.name}"
if _nvjitlink:
if _nvjitlink and isinstance(data, bytes):
# Handle bytes input with nvjitlink
Comment thread
leofang marked this conversation as resolved.
Outdated
_nvjitlink.add_data(
self._mnff.handle,
self._input_type_from_code_type(object_code._code_type),
data,
len(data),
name_str,
)
else:
elif _nvjitlink and isinstance(data, str):
Comment thread
leofang marked this conversation as resolved.
# Handle file path input with nvjitlink
Comment thread
leofang marked this conversation as resolved.
Outdated
_nvjitlink.add_file(
self._mnff.handle,
self._input_type_from_code_type(object_code._code_type),
data,
)
elif isinstance(data, bytes):
# Handle bytes input with driver API
Comment thread
leofang marked this conversation as resolved.
Outdated
name_bytes = name_str.encode()
handle_return(
_driver.cuLinkAddData(
Expand All @@ -421,6 +429,22 @@ def _add_code_object(self, object_code: ObjectCode):
)
)
self._mnff.const_char_keep_alive.append(name_bytes)
elif isinstance(data, str):
# Handle file path input with driver API
Comment thread
leofang marked this conversation as resolved.
Outdated
name_bytes = name_str.encode()
handle_return(
_driver.cuLinkAddFile(
self._mnff.handle,
self._input_type_from_code_type(object_code._code_type),
data.encode(),
0,
None,
None,
)
)
self._mnff.const_char_keep_alive.append(name_bytes)
else:
raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")

def link(self, target_type) -> ObjectCode:
"""
Expand Down
236 changes: 236 additions & 0 deletions cuda_core/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,242 @@ def test_object_code_handle(get_saxpy_object_code):
assert mod.handle is not None


@pytest.fixture(scope="function")
def get_ltoir_object_code(init_cuda):
Comment thread
leofang marked this conversation as resolved.
Outdated
# Create LTOIR code using link-time optimization
prog = Program(SAXPY_KERNEL, code_type="c++", options=ProgramOptions(link_time_optimization=True))
mod = prog.compile("ltoir", name_expressions=("saxpy<float>", "saxpy<double>"))
return mod


def test_object_code_load_ltoir(get_ltoir_object_code):
mod = get_ltoir_object_code
ltoir = mod._module
sym_map = mod._sym_map
assert isinstance(ltoir, bytes)
mod_obj = ObjectCode.from_ltoir(ltoir, symbol_mapping=sym_map)
assert mod_obj.code == ltoir
assert mod_obj._code_type == "ltoir"
# ltoir doesn't support kernel retrieval directly as it's used for linking
assert mod_obj._handle is None # Should only be loaded when needed
# Test that get_kernel fails for unsupported code type
with pytest.raises(RuntimeError, match=r'Unsupported code type "ltoir"'):
mod_obj.get_kernel("saxpy<float>")


def test_object_code_load_ltoir_from_file(get_ltoir_object_code, tmp_path):
mod = get_ltoir_object_code
ltoir = mod._module
sym_map = mod._sym_map
assert isinstance(ltoir, bytes)
ltoir_file = tmp_path / "test.ltoir"
ltoir_file.write_bytes(ltoir)
mod_obj = ObjectCode.from_ltoir(str(ltoir_file), symbol_mapping=sym_map)
assert mod_obj.code == str(ltoir_file)
assert mod_obj._code_type == "ltoir"
assert mod_obj._handle is None # Should only be loaded when needed


def test_object_code_load_fatbin(get_saxpy_kernel):
Comment thread
leofang marked this conversation as resolved.
Outdated
# Use cubin as a substitute for fatbin since they have similar structure
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
mod_obj = ObjectCode.from_fatbin(cubin, symbol_mapping=sym_map)
assert mod_obj.code == cubin
assert mod_obj._code_type == "fatbin"
# fatbin supports kernel retrieval
mod_obj.get_kernel("saxpy<double>") # force loading


def test_object_code_load_fatbin_from_file(get_saxpy_kernel, tmp_path):
# Use cubin as a substitute for fatbin since they have similar structure
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
fatbin_file = tmp_path / "test.fatbin"
fatbin_file.write_bytes(cubin)
mod_obj = ObjectCode.from_fatbin(str(fatbin_file), symbol_mapping=sym_map)
assert mod_obj.code == str(fatbin_file)
assert mod_obj._code_type == "fatbin"
mod_obj.get_kernel("saxpy<double>") # force loading


def test_object_code_load_object(get_saxpy_kernel):
Comment thread
leofang marked this conversation as resolved.
Outdated
# Use cubin as a substitute for object code since they're binary formats
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
mod_obj = ObjectCode.from_object(cubin, symbol_mapping=sym_map)
assert mod_obj.code == cubin
assert mod_obj._code_type == "object"
# object code doesn't support direct kernel retrieval
assert mod_obj._handle is None # Should only be loaded when needed
# Test that get_kernel fails for unsupported code type
with pytest.raises(RuntimeError, match=r'Unsupported code type "object"'):
mod_obj.get_kernel("saxpy<float>")


def test_object_code_load_object_from_file(get_saxpy_kernel, tmp_path):
# Use cubin as a substitute for object code since they're binary formats
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
object_file = tmp_path / "test.o"
object_file.write_bytes(cubin)
mod_obj = ObjectCode.from_object(str(object_file), symbol_mapping=sym_map)
assert mod_obj.code == str(object_file)
assert mod_obj._code_type == "object"
assert mod_obj._handle is None # Should only be loaded when needed


def test_object_code_load_library(get_saxpy_kernel):
# Use cubin as a substitute for library since they're binary formats
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
mod_obj = ObjectCode.from_library(cubin, symbol_mapping=sym_map)
assert mod_obj.code == cubin
assert mod_obj._code_type == "library"
# library code doesn't support direct kernel retrieval
assert mod_obj._handle is None # Should only be loaded when needed
# Test that get_kernel fails for unsupported code type
with pytest.raises(RuntimeError, match=r'Unsupported code type "library"'):
mod_obj.get_kernel("saxpy<float>")


def test_object_code_load_library_from_file(get_saxpy_kernel, tmp_path):
# Use cubin as a substitute for library since they're binary formats
_, mod = get_saxpy_kernel
cubin = mod._module
sym_map = mod._sym_map
assert isinstance(cubin, bytes)
library_file = tmp_path / "test.a"
library_file.write_bytes(cubin)
mod_obj = ObjectCode.from_library(str(library_file), symbol_mapping=sym_map)
assert mod_obj.code == str(library_file)
assert mod_obj._code_type == "library"
assert mod_obj._handle is None # Should only be loaded when needed


def test_object_code_constructors_with_name_and_symbol_mapping():
Comment thread
leofang marked this conversation as resolved.
Outdated
"""Test that all from_* constructors properly set name and symbol_mapping"""
# Dummy data for testing
dummy_bytes = b"dummy_code_data"
test_name = "test_object"
test_sym_map = {"kernel1": "mangled_kernel1", "kernel2": "mangled_kernel2"}

# Test all constructors
constructors = [
(ObjectCode.from_cubin, "cubin"),
(ObjectCode.from_ptx, "ptx"),
(ObjectCode.from_ltoir, "ltoir"),
(ObjectCode.from_fatbin, "fatbin"),
(ObjectCode.from_object, "object"),
(ObjectCode.from_library, "library"),
]

for constructor, code_type in constructors:
obj = constructor(dummy_bytes, name=test_name, symbol_mapping=test_sym_map)
assert obj.name == test_name
assert obj._sym_map == test_sym_map
assert obj._code_type == code_type
assert obj.code == dummy_bytes


def test_object_code_constructors_default_values():
Comment thread
leofang marked this conversation as resolved.
Outdated
"""Test that all from_* constructors handle default values correctly"""
# Dummy data for testing
dummy_bytes = b"dummy_code_data"

# Test all constructors with defaults
constructors = [
(ObjectCode.from_cubin, "cubin"),
(ObjectCode.from_ptx, "ptx"),
(ObjectCode.from_ltoir, "ltoir"),
(ObjectCode.from_fatbin, "fatbin"),
(ObjectCode.from_object, "object"),
(ObjectCode.from_library, "library"),
]

for constructor, code_type in constructors:
obj = constructor(dummy_bytes) # Use defaults
assert obj.name == "" # Default name should be empty string
assert obj._sym_map == {} # Default symbol mapping should be empty dict
assert obj._code_type == code_type
assert obj.code == dummy_bytes


def test_object_code_file_path_linker_integration(get_saxpy_kernel, tmp_path):
"""Test that ObjectCode created from file paths works with the Linker"""
_, mod = get_saxpy_kernel
cubin = mod._module
assert isinstance(cubin, bytes)

# Create temporary files for different code types
test_files = {}
for code_type in ["cubin", "ptx", "ltoir", "fatbin", "object", "library"]:
file_path = tmp_path / f"test.{code_type}"
file_path.write_bytes(cubin) # Use cubin bytes as proxy for all types
test_files[code_type] = str(file_path)

# Create ObjectCode instances from file paths
file_based_objects = []
for code_type, file_path in test_files.items():
if code_type == "cubin":
obj = ObjectCode.from_cubin(file_path, name=f"file_{code_type}")
elif code_type == "ptx":
obj = ObjectCode.from_ptx(file_path, name=f"file_{code_type}")
elif code_type == "ltoir":
obj = ObjectCode.from_ltoir(file_path, name=f"file_{code_type}")
elif code_type == "fatbin":
obj = ObjectCode.from_fatbin(file_path, name=f"file_{code_type}")
elif code_type == "object":
obj = ObjectCode.from_object(file_path, name=f"file_{code_type}")
elif code_type == "library":
obj = ObjectCode.from_library(file_path, name=f"file_{code_type}")

# Verify the ObjectCode was created correctly
assert obj.code == file_path
assert obj._code_type == code_type
assert obj.name == f"file_{code_type}"
assert isinstance(obj._module, str) # Should store the file path
file_based_objects.append(obj)

# Test that these ObjectCode instances can be used with Linker
# Note: We can't actually link most of these types together in practice,
# but we can verify the linker accepts them and handles the file path correctly
from cuda.core.experimental import Linker, LinkerOptions

# Test with ptx which should be linkable (use only PTX for actual linking)
ptx_obj = None
for obj in file_based_objects:
if obj._code_type == "ptx":
ptx_obj = obj
break

if ptx_obj is not None:
# Create a simple linker test - this will test that _add_code_object
# handles file paths correctly by not crashing on the file path
try:
arch = "sm_" + "".join(f"{i}" for i in Device().compute_capability)
options = LinkerOptions(arch=arch)
# This should not crash - it should handle the file path in _add_code_object
linker = Linker(ptx_obj, options=options)
# We don't need to actually link since that might fail due to content,
# but creating the linker tests our file path handling
assert linker is not None
except Exception as e:
# If it fails, it should be due to content issues, not file path handling
# The key is that it should not fail with "Expected type bytes, but got str"
assert "Expected type bytes, but got str" not in str(e), f"File path handling failed: {e}"


def test_saxpy_arguments(get_saxpy_kernel, cuda12_4_prerequisite_check):
krn, _ = get_saxpy_kernel

Expand Down