Skip to content

Commit acaccca

Browse files
committed
fixup! feat(core.utils): canonicalize ptxas_options; scope use_libdevice to NVVM
1 parent 08ca973 commit acaccca

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,19 @@ def _gate_identity(v):
241241
return v
242242

243243

244+
def _gate_ptxas_options(v):
245+
# ``_prepare_nvjitlink_options`` emits one ``-Xptxas=<s>`` per element, and
246+
# treats ``str`` as a single-element sequence. Canonicalize to a tuple so
247+
# ``"-v"`` / ``["-v"]`` / ``("-v",)`` all hash the same.
248+
if v is None:
249+
return None
250+
if isinstance(v, str):
251+
return ("-Xptxas=" + v,)
252+
if isinstance(v, collections.abc.Sequence):
253+
return tuple(f"-Xptxas={s}" for s in v)
254+
return v
255+
256+
244257
_LINKER_FIELD_GATES = {
245258
"name": _gate_identity,
246259
"arch": _gate_identity,
@@ -254,7 +267,7 @@ def _gate_identity(v):
254267
"prec_sqrt": _gate_tristate_bool,
255268
"fma": _gate_tristate_bool,
256269
"split_compile": _gate_identity,
257-
"ptxas_options": _gate_identity,
270+
"ptxas_options": _gate_ptxas_options,
258271
"no_cache": _gate_is_true,
259272
}
260273

@@ -655,10 +668,11 @@ def _probe(label: str, fn):
655668
else:
656669
# Fallback for unexpected format.
657670
_update("extra_source", str(item).encode("utf-8"))
658-
# Program_init gates ``use_libdevice`` on truthiness, not ``is not None``
659-
# (see _program.pyx), so False and None produce identical ObjectCode and
660-
# must hash the same way here.
661-
if getattr(options, "use_libdevice", None):
671+
# ``use_libdevice`` is only consumed on the NVVM compile path
672+
# (_program.pyx Program_init); NVRTC and PTX/linker ignore it, so
673+
# folding it into the key there would force spurious misses. On NVVM,
674+
# Program_init gates it on truthiness -- False and None match.
675+
if backend == "nvvm" and getattr(options, "use_libdevice", None):
662676
_update("use_libdevice", b"1")
663677

664678
# Program.compile() propagates options.name onto the returned ObjectCode,

cuda_core/tests/test_program_cache.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,26 @@ def test_make_program_cache_key_ptx_driver_ignored_fields_collapse(field, a, b,
395395
assert k_a == k_b
396396

397397

398+
@pytest.mark.parametrize(
399+
"a, b",
400+
[
401+
pytest.param("-v", ["-v"], id="str_vs_list"),
402+
pytest.param("-v", ("-v",), id="str_vs_tuple"),
403+
pytest.param(["-v"], ("-v",), id="list_vs_tuple"),
404+
],
405+
)
406+
def test_make_program_cache_key_ptx_ptxas_options_canonicalized(a, b, monkeypatch):
407+
"""_prepare_nvjitlink_options emits the same -Xptxas= flags for str,
408+
list, and tuple shapes of ptxas_options. The cache key must treat them
409+
as equivalent so equivalent compiles don't miss the cache."""
410+
from cuda.core import _linker
411+
412+
monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: False) # nvJitLink
413+
k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ptxas_options=a))
414+
k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ptxas_options=b))
415+
assert k_a == k_b
416+
417+
398418
def test_make_program_cache_key_ptx_driver_ignored_fields_still_matter_under_nvjitlink(monkeypatch):
399419
"""nvJitLink does honour those fields; they must still differentiate keys there."""
400420
from cuda.core import _linker
@@ -405,6 +425,22 @@ def test_make_program_cache_key_ptx_driver_ignored_fields_still_matter_under_nvj
405425
assert k_a != k_b
406426

407427

428+
@pytest.mark.parametrize(
429+
"code_type, code, target_type",
430+
[
431+
pytest.param("c++", "void k(){}", "cubin", id="nvrtc"),
432+
pytest.param("ptx", ".version 7.0", "cubin", id="ptx"),
433+
],
434+
)
435+
def test_make_program_cache_key_use_libdevice_ignored_for_non_nvvm(code_type, code, target_type):
436+
"""``use_libdevice`` is only consumed on the NVVM path; NVRTC and PTX
437+
ignore it, so toggling it must not perturb the cache key elsewhere."""
438+
k_off = _make_key(code=code, code_type=code_type, target_type=target_type, options=_opts(use_libdevice=False))
439+
k_on = _make_key(code=code, code_type=code_type, target_type=target_type, options=_opts(use_libdevice=True))
440+
k_none = _make_key(code=code, code_type=code_type, target_type=target_type, options=_opts(use_libdevice=None))
441+
assert k_off == k_on == k_none
442+
443+
408444
def test_make_program_cache_key_nvvm_use_libdevice_false_equals_none():
409445
"""Program_init gates ``use_libdevice`` on truthiness, so False and None
410446
compile identically and must hash the same way."""

0 commit comments

Comments
 (0)