Skip to content

Commit 2d39bec

Browse files
leofangCopilotpre-commit-ci[bot]kkraus14
authored
Make {Program,Linker}Options.as_bytes a public API with backend-specific option handling (NVIDIA#1355)
* Initial plan * Add as_bytes() public API with backend-specific option preparation Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * Address PR feedback: remove docstrings from private methods, make as_bytes always return list[bytes], remove _as_bytes Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * Remove wrapper methods and use prepare methods directly Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * Restore _translate_program_options method for PTX code path Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * remove redundant docstrings * Remove blank lines in private methods to make code blocks more compact Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * nits * [pre-commit.ci] auto code formatting * Refactor: move NVRTC option building to lazy evaluation in _prepare_nvrtc_options Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> * fix & expand tests * cover new NVRTC options * fix linker options handling * fix two NVRTC bugs - the program name is used for pch filename, but on Windows it is problematic - trace.json could not be properly created with NVRTC 12.9 * Apply suggestions from code review Co-authored-by: Keith Kraus <keith.j.kraus@gmail.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: leofang <5534781+leofang@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Keith Kraus <keith.j.kraus@gmail.com>
1 parent d8e9317 commit 2d39bec

File tree

4 files changed

+521
-151
lines changed

4 files changed

+521
-151
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 93 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -202,74 +202,78 @@ class LinkerOptions:
202202
def __post_init__(self):
203203
_lazy_init()
204204
self._name = self.name.encode()
205-
self.formatted_options = []
206-
if _nvjitlink:
207-
self._init_nvjitlink()
208-
else:
209-
self._init_driver()
210205

211-
def _init_nvjitlink(self):
206+
def _prepare_nvjitlink_options(self, as_bytes: bool = False) -> Union[list[bytes], list[str]]:
207+
options = []
208+
212209
if self.arch is not None:
213-
self.formatted_options.append(f"-arch={self.arch}")
210+
options.append(f"-arch={self.arch}")
214211
else:
215-
self.formatted_options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability))
212+
options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability))
216213
if self.max_register_count is not None:
217-
self.formatted_options.append(f"-maxrregcount={self.max_register_count}")
214+
options.append(f"-maxrregcount={self.max_register_count}")
218215
if self.time is not None:
219-
self.formatted_options.append("-time")
216+
options.append("-time")
220217
if self.verbose:
221-
self.formatted_options.append("-verbose")
218+
options.append("-verbose")
222219
if self.link_time_optimization:
223-
self.formatted_options.append("-lto")
220+
options.append("-lto")
224221
if self.ptx:
225-
self.formatted_options.append("-ptx")
222+
options.append("-ptx")
226223
if self.optimization_level is not None:
227-
self.formatted_options.append(f"-O{self.optimization_level}")
224+
options.append(f"-O{self.optimization_level}")
228225
if self.debug:
229-
self.formatted_options.append("-g")
226+
options.append("-g")
230227
if self.lineinfo:
231-
self.formatted_options.append("-lineinfo")
228+
options.append("-lineinfo")
232229
if self.ftz is not None:
233-
self.formatted_options.append(f"-ftz={'true' if self.ftz else 'false'}")
230+
options.append(f"-ftz={'true' if self.ftz else 'false'}")
234231
if self.prec_div is not None:
235-
self.formatted_options.append(f"-prec-div={'true' if self.prec_div else 'false'}")
232+
options.append(f"-prec-div={'true' if self.prec_div else 'false'}")
236233
if self.prec_sqrt is not None:
237-
self.formatted_options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}")
234+
options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}")
238235
if self.fma is not None:
239-
self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}")
236+
options.append(f"-fma={'true' if self.fma else 'false'}")
240237
if self.kernels_used is not None:
241238
if isinstance(self.kernels_used, str):
242-
self.formatted_options.append(f"-kernels-used={self.kernels_used}")
239+
options.append(f"-kernels-used={self.kernels_used}")
243240
elif isinstance(self.kernels_used, list):
244241
for kernel in self.kernels_used:
245-
self.formatted_options.append(f"-kernels-used={kernel}")
242+
options.append(f"-kernels-used={kernel}")
246243
if self.variables_used is not None:
247244
if isinstance(self.variables_used, str):
248-
self.formatted_options.append(f"-variables-used={self.variables_used}")
245+
options.append(f"-variables-used={self.variables_used}")
249246
elif isinstance(self.variables_used, list):
250247
for variable in self.variables_used:
251-
self.formatted_options.append(f"-variables-used={variable}")
248+
options.append(f"-variables-used={variable}")
252249
if self.optimize_unused_variables is not None:
253-
self.formatted_options.append("-optimize-unused-variables")
250+
options.append("-optimize-unused-variables")
254251
if self.ptxas_options is not None:
255252
if isinstance(self.ptxas_options, str):
256-
self.formatted_options.append(f"-Xptxas={self.ptxas_options}")
253+
options.append(f"-Xptxas={self.ptxas_options}")
257254
elif is_sequence(self.ptxas_options):
258255
for opt in self.ptxas_options:
259-
self.formatted_options.append(f"-Xptxas={opt}")
256+
options.append(f"-Xptxas={opt}")
260257
if self.split_compile is not None:
261-
self.formatted_options.append(f"-split-compile={self.split_compile}")
258+
options.append(f"-split-compile={self.split_compile}")
262259
if self.split_compile_extended is not None:
263-
self.formatted_options.append(f"-split-compile-extended={self.split_compile_extended}")
260+
options.append(f"-split-compile-extended={self.split_compile_extended}")
264261
if self.no_cache is True:
265-
self.formatted_options.append("-no-cache")
262+
options.append("-no-cache")
263+
264+
if as_bytes:
265+
return [o.encode() for o in options]
266+
else:
267+
return options
268+
269+
def _prepare_driver_options(self) -> tuple[list, list]:
270+
formatted_options = []
271+
option_keys = []
266272

267-
def _init_driver(self):
268-
self.option_keys = []
269273
# allocate 4 KiB each for info/error logs
270274
size = 4194304
271-
self.formatted_options.extend((bytearray(size), size, bytearray(size), size))
272-
self.option_keys.extend(
275+
formatted_options.extend((bytearray(size), size, bytearray(size), size))
276+
option_keys.extend(
273277
(
274278
_driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
275279
_driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
@@ -280,30 +284,30 @@ def _init_driver(self):
280284

281285
if self.arch is not None:
282286
arch = self.arch.split("_")[-1].upper()
283-
self.formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
284-
self.option_keys.append(_driver.CUjit_option.CU_JIT_TARGET)
287+
formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
288+
option_keys.append(_driver.CUjit_option.CU_JIT_TARGET)
285289
if self.max_register_count is not None:
286-
self.formatted_options.append(self.max_register_count)
287-
self.option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
290+
formatted_options.append(self.max_register_count)
291+
option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
288292
if self.time is not None:
289293
raise ValueError("time option is not supported by the driver API")
290294
if self.verbose:
291-
self.formatted_options.append(1)
292-
self.option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
295+
formatted_options.append(1)
296+
option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
293297
if self.link_time_optimization:
294-
self.formatted_options.append(1)
295-
self.option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
298+
formatted_options.append(1)
299+
option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
296300
if self.ptx:
297301
raise ValueError("ptx option is not supported by the driver API")
298302
if self.optimization_level is not None:
299-
self.formatted_options.append(self.optimization_level)
300-
self.option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
303+
formatted_options.append(self.optimization_level)
304+
option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
301305
if self.debug:
302-
self.formatted_options.append(1)
303-
self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
306+
formatted_options.append(1)
307+
option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
304308
if self.lineinfo:
305-
self.formatted_options.append(1)
306-
self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
309+
formatted_options.append(1)
310+
option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
307311
if self.ftz is not None:
308312
warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
309313
if self.prec_div is not None:
@@ -325,8 +329,37 @@ def _init_driver(self):
325329
if self.split_compile_extended is not None:
326330
raise ValueError("split_compile_extended option is not supported by the driver API")
327331
if self.no_cache is True:
328-
self.formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
329-
self.option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)
332+
formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
333+
option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)
334+
335+
return formatted_options, option_keys
336+
337+
def as_bytes(self, backend: str = "nvjitlink") -> list[bytes]:
338+
"""Convert linker options to bytes format for the nvjitlink backend.
339+
340+
Parameters
341+
----------
342+
backend : str, optional
343+
The linker backend. Only "nvjitlink" is supported. Default is "nvjitlink".
344+
345+
Returns
346+
-------
347+
list[bytes]
348+
List of option strings encoded as bytes.
349+
350+
Raises
351+
------
352+
ValueError
353+
If an unsupported backend is specified.
354+
RuntimeError
355+
If nvJitLink backend is not available.
356+
"""
357+
backend = backend.lower()
358+
if backend != "nvjitlink":
359+
raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'")
360+
if not _nvjitlink:
361+
raise RuntimeError("nvJitLink backend is not available")
362+
return self._prepare_nvjitlink_options(as_bytes=True)
330363

331364

332365
# This needs to be a free function not a method, as it's disallowed by contextmanager.
@@ -369,7 +402,7 @@ class Linker:
369402
"""
370403

371404
class _MembersNeededForFinalize:
372-
__slots__ = ("handle", "use_nvjitlink", "const_char_keep_alive")
405+
__slots__ = ("handle", "use_nvjitlink", "const_char_keep_alive", "formatted_options", "option_keys")
373406

374407
def __init__(self, program_obj, handle, use_nvjitlink):
375408
self.handle = handle
@@ -394,14 +427,17 @@ def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):
394427
self._options = options = check_or_create_options(LinkerOptions, options, "Linker options")
395428
with _exception_manager(self):
396429
if _nvjitlink:
397-
handle = _nvjitlink.create(len(options.formatted_options), options.formatted_options)
430+
formatted_options = options._prepare_nvjitlink_options(as_bytes=False)
431+
handle = _nvjitlink.create(len(formatted_options), formatted_options)
398432
use_nvjitlink = True
399433
else:
400-
handle = handle_return(
401-
_driver.cuLinkCreate(len(options.formatted_options), options.option_keys, options.formatted_options)
402-
)
434+
formatted_options, option_keys = options._prepare_driver_options()
435+
handle = handle_return(_driver.cuLinkCreate(len(formatted_options), option_keys, formatted_options))
403436
use_nvjitlink = False
404437
self._mnff = Linker._MembersNeededForFinalize(self, handle, use_nvjitlink)
438+
self._mnff.formatted_options = formatted_options # Store for log access
439+
if not _nvjitlink:
440+
self._mnff.option_keys = option_keys
405441

406442
for code in object_codes:
407443
assert_type(code, ObjectCode)
@@ -508,7 +544,7 @@ def get_error_log(self) -> str:
508544
log = bytearray(log_size)
509545
_nvjitlink.get_error_log(self._mnff.handle, log)
510546
else:
511-
log = self._options.formatted_options[2]
547+
log = self._mnff.formatted_options[2]
512548
return log.decode("utf-8", errors="backslashreplace")
513549

514550
def get_info_log(self) -> str:
@@ -524,7 +560,7 @@ def get_info_log(self) -> str:
524560
log = bytearray(log_size)
525561
_nvjitlink.get_info_log(self._mnff.handle, log)
526562
else:
527-
log = self._options.formatted_options[0]
563+
log = self._mnff.formatted_options[0]
528564
return log.decode("utf-8", errors="backslashreplace")
529565

530566
def _input_type_from_code_type(self, code_type: str):

0 commit comments

Comments
 (0)