Skip to content

Commit bd01457

Browse files
abhilash1910pre-commit-ci[bot]rwgk
authored
[NVVM] Refactor program_init (#1653)
* refactor program init * [pre-commit.ci] auto code formatting * fix CI * refresh to perform str to bytes before * [pre-commit.ci] auto code formatting * Update pathfinder descriptor catalogs for cusparseLt release 0.9.0 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf W. Grosse-Kunstleve <rwgkio@gmail.com> Co-authored-by: Ralf W. Grosse-Kunstleve <rgrossekunst@nvidia.com>
1 parent 346afc8 commit bd01457

File tree

4 files changed

+48
-36
lines changed

4 files changed

+48
-36
lines changed

cuda_core/cuda/core/_program.pyx

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,31 @@ class ProgramOptions:
403403
# Set arch to default if not provided
404404
if self.arch is None:
405405
self.arch = f"sm_{Device().arch}"
406+
if self.extra_sources is not None:
407+
if not is_sequence(self.extra_sources):
408+
raise TypeError(
409+
"extra_sources must be a sequence of 2-tuples: ((name1, source1), (name2, source2), ...)"
410+
)
411+
for i, module in enumerate(self.extra_sources):
412+
if not isinstance(module, tuple) or len(module) != 2:
413+
raise TypeError(
414+
f"Each extra module must be a 2-tuple (name, source)"
415+
f", got {type(module).__name__} at index {i}"
416+
)
417+
418+
module_name, module_source = module
419+
420+
if not isinstance(module_name, str):
421+
raise TypeError(f"Module name at index {i} must be a string, got {type(module_name).__name__}")
422+
423+
if not isinstance(module_source, (str, bytes, bytearray)):
424+
raise TypeError(
425+
f"Module source at index {i} must be str (textual LLVM IR), bytes (textual LLVM IR or bitcode), "
426+
f"or bytearray, got {type(module_source).__name__}"
427+
)
428+
429+
if len(module_source) == 0:
430+
raise ValueError(f"Module source for '{module_name}' (index {i}) cannot be empty")
406431

407432
def _prepare_nvrtc_options(self) -> list[bytes]:
408433
return _prepare_nvrtc_options_impl(self)
@@ -456,6 +481,23 @@ class ProgramOptions:
456481
def __repr__(self):
457482
return f"ProgramOptions(name={self.name!r}, arch={self.arch!r})"
458483

484+
def _prepare_extra_sources_bytes(self) -> list[tuple[bytes, bytes]] | None:
485+
"""Convert extra_sources to bytes format for NVVM."""
486+
if self.extra_sources is None:
487+
return None
488+
489+
result = []
490+
for module_name, module_source in self.extra_sources:
491+
name_bytes = module_name.encode("utf-8")
492+
if isinstance(module_source, str):
493+
source_bytes = module_source.encode("utf-8")
494+
elif isinstance(module_source, bytearray):
495+
source_bytes = bytes(module_source)
496+
else:
497+
source_bytes = module_source
498+
result.append((name_bytes, source_bytes))
499+
return result
500+
459501

460502
# =============================================================================
461503
# Private Classes and Helper Functions
@@ -628,41 +670,11 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
628670

629671
# Add extra modules if provided
630672
if options.extra_sources is not None:
631-
if not is_sequence(options.extra_sources):
632-
raise TypeError(
633-
"extra_sources must be a sequence of 2-tuples: ((name1, source1), (name2, source2), ...)"
634-
)
635-
for i, module in enumerate(options.extra_sources):
636-
if not isinstance(module, tuple) or len(module) != 2:
637-
raise TypeError(
638-
f"Each extra module must be a 2-tuple (name, source)"
639-
f", got {type(module).__name__} at index {i}"
640-
)
641-
642-
module_name, module_source = module
643-
644-
if not isinstance(module_name, str):
645-
raise TypeError(f"Module name at index {i} must be a string, got {type(module_name).__name__}")
646-
647-
if isinstance(module_source, str):
648-
# Textual LLVM IR - encode to UTF-8 bytes
649-
module_source = module_source.encode("utf-8")
650-
elif not isinstance(module_source, (bytes, bytearray)):
651-
raise TypeError(
652-
f"Module source at index {i} must be str (textual LLVM IR), bytes (textual LLVM IR or bitcode), "
653-
f"or bytearray, got {type(module_source).__name__}"
654-
)
655-
656-
if len(module_source) == 0:
657-
raise ValueError(f"Module source for '{module_name}' (index {i}) cannot be empty")
658-
659-
# Add the module using NVVM API
660-
module_bytes = module_source if isinstance(module_source, bytes) else bytes(module_source)
673+
extra_sources_bytes = options._prepare_extra_sources_bytes()
674+
for module_name_bytes, module_bytes in extra_sources_bytes:
661675
module_ptr = <const char*>module_bytes
662676
module_len = len(module_bytes)
663-
module_name_bytes = module_name.encode()
664677
module_name_ptr = <const char*>module_name_bytes
665-
666678
with nogil:
667679
HANDLE_RETURN_NVVM(nvvm_prog, cynvvm.nvvmAddModuleToProgram(
668680
nvvm_prog, module_ptr, module_len, module_name_ptr))

cuda_core/tests/test_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ def test_cpp_program_with_extra_sources():
724724
# negative test with NVRTC with multiple sources
725725
code = 'extern "C" __global__ void my_kernel(){}'
726726
helper = 'extern "C" __global__ void helper(){}'
727-
options = ProgramOptions(extra_sources=helper)
727+
options = ProgramOptions(extra_sources=[("helper", helper)])
728728
with pytest.raises(ValueError, match="extra_sources is not supported by the NVRTC backend"):
729729
Program(code, "c++", options)
730730

cuda_pathfinder/cuda/pathfinder/_dynamic_libs/descriptor_catalog.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,8 @@ class DescriptorSpec:
331331
packaged_with="other",
332332
linux_sonames=("libcusparseLt.so.0",),
333333
windows_dlls=("cusparseLt.dll",),
334-
site_packages_linux=("nvidia/cusparselt/lib",),
335-
site_packages_windows=("nvidia/cusparselt/bin",),
334+
site_packages_linux=("nvidia/cu13/lib", "nvidia/cusparselt/lib"),
335+
site_packages_windows=("nvidia/cu13/bin/x64", "nvidia/cusparselt/bin"),
336336
),
337337
DescriptorSpec(
338338
name="cutensor",

cuda_pathfinder/cuda/pathfinder/_headers/header_descriptor_catalog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class HeaderDescriptorSpec:
141141
name="cusparseLt",
142142
packaged_with="other",
143143
header_basename="cusparseLt.h",
144-
site_packages_dirs=("nvidia/cusparselt/include",),
144+
site_packages_dirs=("nvidia/cu13/include", "nvidia/cusparselt/include"),
145145
conda_targets_layout=False,
146146
use_ctk_root_canary=False,
147147
),

0 commit comments

Comments
 (0)