@@ -403,6 +403,31 @@ class ProgramOptions:
403403 # Set arch to default if not provided
404404 if self .arch is None :
405405 self .arch = f" sm_{Device().arch}"
406+ if self .extra_sources is not None :
407+ if not is_sequence(self .extra_sources):
408+ raise TypeError (
409+ " extra_sources must be a sequence of 2-tuples: ((name1, source1), (name2, source2), ...)"
410+ )
411+ for i, module in enumerate (self .extra_sources):
412+ if not isinstance (module, tuple ) or len (module) != 2 :
413+ raise TypeError (
414+ f" Each extra module must be a 2-tuple (name, source)"
415+ f" , got {type(module).__name__} at index {i}"
416+ )
417+
418+ module_name, module_source = module
419+
420+ if not isinstance (module_name, str ):
421+ raise TypeError (f" Module name at index {i} must be a string, got {type(module_name).__name__}" )
422+
423+ if not isinstance (module_source, (str , bytes, bytearray)):
424+ raise TypeError (
425+ f" Module source at index {i} must be str (textual LLVM IR), bytes (textual LLVM IR or bitcode), "
426+ f" or bytearray, got {type(module_source).__name__}"
427+ )
428+
429+ if len (module_source) == 0 :
430+ raise ValueError (f" Module source for '{module_name}' (index {i}) cannot be empty" )
406431
407432 def _prepare_nvrtc_options (self ) -> list[bytes]:
408433 return _prepare_nvrtc_options_impl(self )
@@ -456,6 +481,23 @@ class ProgramOptions:
456481 def __repr__(self ):
457482 return f" ProgramOptions(name={self.name!r}, arch={self.arch!r})"
458483
484+ def _prepare_extra_sources_bytes (self ) -> list[tuple[bytes , bytes]] | None:
485+ """Convert extra_sources to bytes format for NVVM."""
486+ if self.extra_sources is None:
487+ return None
488+
489+ result = []
490+ for module_name , module_source in self.extra_sources:
491+ name_bytes = module_name.encode(" utf-8" )
492+ if isinstance(module_source , str ):
493+ source_bytes = module_source.encode(" utf-8" )
494+ elif isinstance (module_source, bytearray):
495+ source_bytes = bytes(module_source)
496+ else :
497+ source_bytes = module_source
498+ result.append((name_bytes, source_bytes))
499+ return result
500+
459501
460502# =============================================================================
461503# Private Classes and Helper Functions
@@ -628,41 +670,11 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
628670
629671 # Add extra modules if provided
630672 if options.extra_sources is not None :
631- if not is_sequence(options.extra_sources):
632- raise TypeError (
633- " extra_sources must be a sequence of 2-tuples: ((name1, source1), (name2, source2), ...)"
634- )
635- for i, module in enumerate (options.extra_sources):
636- if not isinstance (module, tuple ) or len (module) != 2 :
637- raise TypeError (
638- f" Each extra module must be a 2-tuple (name, source)"
639- f" , got {type(module).__name__} at index {i}"
640- )
641-
642- module_name, module_source = module
643-
644- if not isinstance (module_name, str ):
645- raise TypeError (f" Module name at index {i} must be a string, got {type(module_name).__name__}" )
646-
647- if isinstance (module_source, str ):
648- # Textual LLVM IR - encode to UTF-8 bytes
649- module_source = module_source.encode(" utf-8" )
650- elif not isinstance (module_source, (bytes, bytearray)):
651- raise TypeError (
652- f" Module source at index {i} must be str (textual LLVM IR), bytes (textual LLVM IR or bitcode), "
653- f" or bytearray, got {type(module_source).__name__}"
654- )
655-
656- if len (module_source) == 0 :
657- raise ValueError (f" Module source for '{module_name}' (index {i}) cannot be empty" )
658-
659- # Add the module using NVVM API
660- module_bytes = module_source if isinstance (module_source, bytes) else bytes(module_source)
673+ extra_sources_bytes = options._prepare_extra_sources_bytes()
674+ for module_name_bytes, module_bytes in extra_sources_bytes:
661675 module_ptr = < const char * > module_bytes
662676 module_len = len (module_bytes)
663- module_name_bytes = module_name.encode()
664677 module_name_ptr = < const char * > module_name_bytes
665-
666678 with nogil:
667679 HANDLE_RETURN_NVVM(nvvm_prog, cynvvm.nvvmAddModuleToProgram(
668680 nvvm_prog, module_ptr, module_len, module_name_ptr))
0 commit comments