Skip to content

Commit 9bb7531

Browse files
Copilotleofangpre-commit-ci[bot]
authored
Fix linker file path support, add ObjectCode constructor tests for ltoir inputs, and expose code_type (#890)
* Initial plan * Add comprehensive tests for ObjectCode from_ltoir, from_fatbin, from_object, and from_library constructors Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * Enhance ObjectCode constructor tests with error handling and default value validation Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * Fix linker to handle file paths for ObjectCode constructors Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * Flatten nested if statements in linker _add_code_object method Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * Address review comments: reorganize tests, add NVCC-based testing, fix fixture naming Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * Remove self-explanatory comments and clarify if conditions in linker Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * [pre-commit.ci] auto code formatting * nit: reorder test to make it easier to follow * rename fixture for clarity + purge object/library/fatbin tests for now * minor fixes * make code_type public --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> Co-authored-by: Leo Fang <leof@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5669118 commit 9bb7531

File tree

4 files changed

+111
-44
lines changed

4 files changed

+111
-44
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,18 +395,23 @@ def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):
395395

396396
def _add_code_object(self, object_code: ObjectCode):
397397
data = object_code._module
398-
assert_type(data, bytes)
399398
with _exception_manager(self):
400399
name_str = f"{object_code.name}"
401-
if _nvjitlink:
400+
if _nvjitlink and isinstance(data, bytes):
402401
_nvjitlink.add_data(
403402
self._mnff.handle,
404403
self._input_type_from_code_type(object_code._code_type),
405404
data,
406405
len(data),
407406
name_str,
408407
)
409-
else:
408+
elif _nvjitlink and isinstance(data, str):
409+
_nvjitlink.add_file(
410+
self._mnff.handle,
411+
self._input_type_from_code_type(object_code._code_type),
412+
data,
413+
)
414+
elif (not _nvjitlink) and isinstance(data, bytes):
410415
name_bytes = name_str.encode()
411416
handle_return(
412417
_driver.cuLinkAddData(
@@ -421,6 +426,21 @@ def _add_code_object(self, object_code: ObjectCode):
421426
)
422427
)
423428
self._mnff.const_char_keep_alive.append(name_bytes)
429+
elif (not _nvjitlink) and isinstance(data, str):
430+
name_bytes = name_str.encode()
431+
handle_return(
432+
_driver.cuLinkAddFile(
433+
self._mnff.handle,
434+
self._input_type_from_code_type(object_code._code_type),
435+
data.encode(),
436+
0,
437+
None,
438+
None,
439+
)
440+
)
441+
self._mnff.const_char_keep_alive.append(name_bytes)
442+
else:
443+
raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
424444

425445
def link(self, target_type) -> ObjectCode:
426446
"""

cuda_core/cuda/core/experimental/_module.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,11 @@ def name(self) -> str:
666666
"""Return a human-readable name of this code object."""
667667
return self._name
668668

669+
@property
670+
def code_type(self) -> str:
671+
"""Return the type of the underlying code object."""
672+
return self._code_type
673+
669674
@property
670675
@precondition(_lazy_load_module)
671676
def handle(self):

cuda_core/docs/source/release/0.X.Y-notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ New features
3131
- CUDA 13.x testing support through new ``test-cu13`` dependency group.
3232
- Stream-ordered memory allocation can now be shared on Linux via :class:`DeviceMemoryResource`.
3333
- Added NVVM IR support to :class:`Program`. NVVM IR is now understood with ``code_type="nvvm"``.
34+
- Added an :attr:`ObjectCode.code_type` attribute for querying the code type.
3435

3536

3637
New examples

cuda_core/tests/test_module.py

Lines changed: 82 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,20 @@ def test_object_code_init_disabled():
6060

6161

6262
@pytest.fixture(scope="function")
63-
def get_saxpy_kernel(init_cuda):
63+
def get_saxpy_kernel_cubin(init_cuda):
6464
# prepare program
6565
prog = Program(SAXPY_KERNEL, code_type="c++")
6666
mod = prog.compile(
6767
"cubin",
6868
name_expressions=("saxpy<float>", "saxpy<double>"),
6969
)
70-
7170
# run in single precision
7271
return mod.get_kernel("saxpy<float>"), mod
7372

7473

7574
@pytest.fixture(scope="function")
7675
def get_saxpy_kernel_ptx(init_cuda):
76+
# prepare program
7777
prog = Program(SAXPY_KERNEL, code_type="c++")
7878
mod = prog.compile(
7979
"ptx",
@@ -84,12 +84,10 @@ def get_saxpy_kernel_ptx(init_cuda):
8484

8585

8686
@pytest.fixture(scope="function")
87-
def get_saxpy_object_code(init_cuda):
88-
prog = Program(SAXPY_KERNEL, code_type="c++")
89-
mod = prog.compile(
90-
"cubin",
91-
name_expressions=("saxpy<float>", "saxpy<double>"),
92-
)
87+
def get_saxpy_kernel_ltoir(init_cuda):
88+
# Create LTOIR code using link-time optimization
89+
prog = Program(SAXPY_KERNEL, code_type="c++", options=ProgramOptions(link_time_optimization=True))
90+
mod = prog.compile("ltoir", name_expressions=("saxpy<float>", "saxpy<double>"))
9391
return mod
9492

9593

@@ -129,8 +127,8 @@ def test_get_kernel(init_cuda):
129127
("cluster_scheduling_policy_preference", int),
130128
],
131129
)
132-
def test_read_only_kernel_attributes(get_saxpy_kernel, attr, expected_type):
133-
kernel, _ = get_saxpy_kernel
130+
def test_read_only_kernel_attributes(get_saxpy_kernel_cubin, attr, expected_type):
131+
kernel, _ = get_saxpy_kernel_cubin
134132
method = getattr(kernel.attributes, attr)
135133
# get the value without providing a device ordinal
136134
value = method()
@@ -142,16 +140,6 @@ def test_read_only_kernel_attributes(get_saxpy_kernel, attr, expected_type):
142140
assert isinstance(value, expected_type), f"Expected {attr} to be of type {expected_type}, but got {type(value)}"
143141

144142

145-
def test_object_code_load_cubin(get_saxpy_kernel):
146-
_, mod = get_saxpy_kernel
147-
cubin = mod._module
148-
sym_map = mod._sym_map
149-
assert isinstance(cubin, bytes)
150-
mod = ObjectCode.from_cubin(cubin, symbol_mapping=sym_map)
151-
assert mod.code == cubin
152-
mod.get_kernel("saxpy<double>") # force loading
153-
154-
155143
def test_object_code_load_ptx(get_saxpy_kernel_ptx):
156144
ptx, mod = get_saxpy_kernel_ptx
157145
sym_map = mod._sym_map
@@ -162,8 +150,32 @@ def test_object_code_load_ptx(get_saxpy_kernel_ptx):
162150
mod_obj.get_kernel("saxpy<double>") # force loading
163151

164152

165-
def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path):
166-
_, mod = get_saxpy_kernel
153+
def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path):
154+
ptx, mod = get_saxpy_kernel_ptx
155+
sym_map = mod._sym_map
156+
assert isinstance(ptx, bytes)
157+
ptx_file = tmp_path / "test.ptx"
158+
ptx_file.write_bytes(ptx)
159+
mod_obj = ObjectCode.from_ptx(str(ptx_file), symbol_mapping=sym_map)
160+
assert mod_obj.code == str(ptx_file)
161+
assert mod_obj.code_type == "ptx"
162+
if not Program._can_load_generated_ptx():
163+
pytest.skip("PTX version too new for current driver")
164+
mod_obj.get_kernel("saxpy<double>") # force loading
165+
166+
167+
def test_object_code_load_cubin(get_saxpy_kernel_cubin):
168+
_, mod = get_saxpy_kernel_cubin
169+
cubin = mod._module
170+
sym_map = mod._sym_map
171+
assert isinstance(cubin, bytes)
172+
mod = ObjectCode.from_cubin(cubin, symbol_mapping=sym_map)
173+
assert mod.code == cubin
174+
mod.get_kernel("saxpy<double>") # force loading
175+
176+
177+
def test_object_code_load_cubin_from_file(get_saxpy_kernel_cubin, tmp_path):
178+
_, mod = get_saxpy_kernel_cubin
167179
cubin = mod._module
168180
sym_map = mod._sym_map
169181
assert isinstance(cubin, bytes)
@@ -174,13 +186,42 @@ def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path):
174186
mod.get_kernel("saxpy<double>") # force loading
175187

176188

177-
def test_object_code_handle(get_saxpy_object_code):
178-
mod = get_saxpy_object_code
189+
def test_object_code_handle(get_saxpy_kernel_cubin):
190+
_, mod = get_saxpy_kernel_cubin
179191
assert mod.handle is not None
180192

181193

182-
def test_saxpy_arguments(get_saxpy_kernel, cuda12_4_prerequisite_check):
183-
krn, _ = get_saxpy_kernel
194+
def test_object_code_load_ltoir(get_saxpy_kernel_ltoir):
195+
mod = get_saxpy_kernel_ltoir
196+
ltoir = mod._module
197+
sym_map = mod._sym_map
198+
assert isinstance(ltoir, bytes)
199+
mod_obj = ObjectCode.from_ltoir(ltoir, symbol_mapping=sym_map)
200+
assert mod_obj.code == ltoir
201+
assert mod_obj.code_type == "ltoir"
202+
# ltoir doesn't support kernel retrieval directly as it's used for linking
203+
assert mod_obj._handle is None
204+
# Test that get_kernel fails for unsupported code type
205+
with pytest.raises(RuntimeError, match=r'Unsupported code type "ltoir"'):
206+
mod_obj.get_kernel("saxpy<float>")
207+
208+
209+
def test_object_code_load_ltoir_from_file(get_saxpy_kernel_ltoir, tmp_path):
210+
mod = get_saxpy_kernel_ltoir
211+
ltoir = mod._module
212+
sym_map = mod._sym_map
213+
assert isinstance(ltoir, bytes)
214+
ltoir_file = tmp_path / "test.ltoir"
215+
ltoir_file.write_bytes(ltoir)
216+
mod_obj = ObjectCode.from_ltoir(str(ltoir_file), symbol_mapping=sym_map)
217+
assert mod_obj.code == str(ltoir_file)
218+
assert mod_obj.code_type == "ltoir"
219+
# ltoir doesn't support kernel retrieval directly as it's used for linking
220+
assert mod_obj._handle is None
221+
222+
223+
def test_saxpy_arguments(get_saxpy_kernel_cubin, cuda12_4_prerequisite_check):
224+
krn, _ = get_saxpy_kernel_cubin
184225

185226
if cuda12_4_prerequisite_check:
186227
assert krn.num_arguments == 5
@@ -258,8 +299,8 @@ def test_num_args_error_handling(deinit_all_contexts_function, cuda12_4_prerequi
258299

259300
@pytest.mark.parametrize("block_size", [32, 64, 96, 120, 128, 256])
260301
@pytest.mark.parametrize("smem_size_per_block", [0, 32, 4096])
261-
def test_occupancy_max_active_block_per_multiprocessor(get_saxpy_kernel, block_size, smem_size_per_block):
262-
kernel, _ = get_saxpy_kernel
302+
def test_occupancy_max_active_block_per_multiprocessor(get_saxpy_kernel_cubin, block_size, smem_size_per_block):
303+
kernel, _ = get_saxpy_kernel_cubin
263304
dev_props = Device().properties
264305
assert block_size <= dev_props.max_threads_per_block
265306
assert smem_size_per_block <= dev_props.max_shared_memory_per_block
@@ -275,9 +316,9 @@ def test_occupancy_max_active_block_per_multiprocessor(get_saxpy_kernel, block_s
275316

276317
@pytest.mark.parametrize("block_size_limit", [32, 64, 96, 120, 128, 256, 0])
277318
@pytest.mark.parametrize("smem_size_per_block", [0, 32, 4096])
278-
def test_occupancy_max_potential_block_size_constant(get_saxpy_kernel, block_size_limit, smem_size_per_block):
319+
def test_occupancy_max_potential_block_size_constant(get_saxpy_kernel_cubin, block_size_limit, smem_size_per_block):
279320
"""Tests use case when shared memory needed is independent on the block size"""
280-
kernel, _ = get_saxpy_kernel
321+
kernel, _ = get_saxpy_kernel_cubin
281322
dev_props = Device().properties
282323
assert block_size_limit <= dev_props.max_threads_per_block
283324
assert smem_size_per_block <= dev_props.max_shared_memory_per_block
@@ -302,9 +343,9 @@ def test_occupancy_max_potential_block_size_constant(get_saxpy_kernel, block_siz
302343

303344
@pytest.mark.skipif(numba is None, reason="Test requires numba to be installed")
304345
@pytest.mark.parametrize("block_size_limit", [32, 64, 96, 120, 128, 277, 0])
305-
def test_occupancy_max_potential_block_size_b2dsize(get_saxpy_kernel, block_size_limit):
346+
def test_occupancy_max_potential_block_size_b2dsize(get_saxpy_kernel_cubin, block_size_limit):
306347
"""Tests use case when shared memory needed depends on the block size"""
307-
kernel, _ = get_saxpy_kernel
348+
kernel, _ = get_saxpy_kernel_cubin
308349

309350
def shared_memory_needed(block_size: numba.intc) -> numba.size_t:
310351
"Size of dynamic shared memory needed by kernel of this block size"
@@ -329,8 +370,8 @@ def shared_memory_needed(block_size: numba.intc) -> numba.size_t:
329370

330371

331372
@pytest.mark.parametrize("num_blocks_per_sm, block_size", [(4, 32), (2, 64), (2, 96), (3, 120), (2, 128), (1, 256)])
332-
def test_occupancy_available_dynamic_shared_memory_per_block(get_saxpy_kernel, num_blocks_per_sm, block_size):
333-
kernel, _ = get_saxpy_kernel
373+
def test_occupancy_available_dynamic_shared_memory_per_block(get_saxpy_kernel_cubin, num_blocks_per_sm, block_size):
374+
kernel, _ = get_saxpy_kernel_cubin
334375
dev_props = Device().properties
335376
assert block_size <= dev_props.max_threads_per_block
336377
assert num_blocks_per_sm * block_size <= dev_props.max_threads_per_multiprocessor
@@ -340,8 +381,8 @@ def test_occupancy_available_dynamic_shared_memory_per_block(get_saxpy_kernel, n
340381

341382

342383
@pytest.mark.parametrize("cluster", [None, 2])
343-
def test_occupancy_max_active_clusters(get_saxpy_kernel, cluster):
344-
kernel, _ = get_saxpy_kernel
384+
def test_occupancy_max_active_clusters(get_saxpy_kernel_cubin, cluster):
385+
kernel, _ = get_saxpy_kernel_cubin
345386
dev = Device()
346387
if dev.compute_capability < (9, 0):
347388
pytest.skip("Device with compute capability 90 or higher is required for cluster support")
@@ -355,8 +396,8 @@ def test_occupancy_max_active_clusters(get_saxpy_kernel, cluster):
355396
assert max_active_clusters >= 0
356397

357398

358-
def test_occupancy_max_potential_cluster_size(get_saxpy_kernel):
359-
kernel, _ = get_saxpy_kernel
399+
def test_occupancy_max_potential_cluster_size(get_saxpy_kernel_cubin):
400+
kernel, _ = get_saxpy_kernel_cubin
360401
dev = Device()
361402
if dev.compute_capability < (9, 0):
362403
pytest.skip("Device with compute capability 90 or higher is required for cluster support")
@@ -370,11 +411,11 @@ def test_occupancy_max_potential_cluster_size(get_saxpy_kernel):
370411
assert max_potential_cluster_size >= 0
371412

372413

373-
def test_module_serialization_roundtrip(get_saxpy_kernel):
374-
_, objcode = get_saxpy_kernel
414+
def test_module_serialization_roundtrip(get_saxpy_kernel_cubin):
415+
_, objcode = get_saxpy_kernel_cubin
375416
result = pickle.loads(pickle.dumps(objcode)) # noqa: S403, S301
376417

377418
assert isinstance(result, ObjectCode)
378419
assert objcode.code == result.code
379420
assert objcode._sym_map == result._sym_map
380-
assert objcode._code_type == result._code_type
421+
assert objcode.code_type == result.code_type

0 commit comments

Comments
 (0)