Skip to content

Commit 5004f7a

Browse files
committed
fixup! feat(core.utils): per-field linker gates + use_libdevice truthy gate
1 parent f28e7fa commit 5004f7a

File tree

2 files changed

+108
-5
lines changed

2 files changed

+108
-5
lines changed

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,62 @@ def _backend_for_code_type(code_type: str) -> str:
215215
)
216216

217217

218+
# Map each linker-relevant ProgramOptions field to the gate the Linker uses
219+
# to turn it into a flag (see ``_prepare_nvjitlink_options`` and
220+
# ``_prepare_driver_options`` in _linker.pyx). Collapsing inputs through
221+
# these gates means semantically-equivalent configurations
222+
# (``debug=False`` vs ``None``, ``time=True`` vs ``time="path"``) hash to
223+
# the same cache key instead of forcing spurious misses.
224+
def _gate_presence(v):
225+
return v is not None
226+
227+
228+
def _gate_truthy(v):
229+
return bool(v)
230+
231+
232+
def _gate_is_true(v):
233+
return v is True
234+
235+
236+
def _gate_tristate_bool(v):
237+
return None if v is None else bool(v)
238+
239+
240+
def _gate_identity(v):
241+
return v
242+
243+
244+
_LINKER_FIELD_GATES = {
245+
"name": _gate_identity,
246+
"arch": _gate_identity,
247+
"max_register_count": _gate_identity,
248+
"time": _gate_presence, # linker emits ``-time`` iff value is not None
249+
"link_time_optimization": _gate_truthy,
250+
"debug": _gate_truthy,
251+
"lineinfo": _gate_truthy,
252+
"ftz": _gate_tristate_bool,
253+
"prec_div": _gate_tristate_bool,
254+
"prec_sqrt": _gate_tristate_bool,
255+
"fma": _gate_tristate_bool,
256+
"split_compile": _gate_identity,
257+
"ptxas_options": _gate_identity,
258+
"no_cache": _gate_is_true,
259+
}
260+
261+
218262
def _linker_option_fingerprint(options: ProgramOptions) -> list[bytes]:
219-
"""Stable byte fingerprint of ProgramOptions fields consumed by the Linker."""
220-
return [f"{name}={getattr(options, name, None)!r}".encode() for name in _LINKER_RELEVANT_FIELDS]
263+
"""Stable byte fingerprint of ProgramOptions fields consumed by the Linker.
264+
265+
Each field is first passed through the gate the Linker itself uses so
266+
that equivalent inputs (e.g. ``debug=False`` / ``None``) produce the
267+
same bytes. Without this, callers hitting the same linker behavior
268+
would get different cache keys and needlessly re-compile.
269+
"""
270+
return [
271+
f"{name}={_LINKER_FIELD_GATES[name](getattr(options, name, None))!r}".encode()
272+
for name in _LINKER_RELEVANT_FIELDS
273+
]
221274

222275

223276
# ProgramOptions fields that map to LinkerOptions fields the cuLink (driver)
@@ -585,9 +638,11 @@ def _probe(label: str, fn):
585638
else:
586639
# Fallback for unexpected format.
587640
_update("extra_source", str(item).encode("utf-8"))
588-
use_libdevice = getattr(options, "use_libdevice", None)
589-
if use_libdevice is not None:
590-
_update("use_libdevice", str(use_libdevice).encode("ascii"))
641+
# Program_init gates ``use_libdevice`` on truthiness, not ``is not None``
642+
# (see _program.pyx), so False and None produce identical ObjectCode and
643+
# must hash the same way here.
644+
if getattr(options, "use_libdevice", None):
645+
_update("use_libdevice", b"1")
591646

592647
# Program.compile() propagates options.name onto the returned ObjectCode,
593648
# so two compiles identical in everything but name produce ObjectCodes

cuda_core/tests/test_program_cache.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,54 @@ def test_make_program_cache_key_ignores_name_expressions_for_non_nvrtc(code_type
340340
assert k_none == k_with
341341

342342

343+
@pytest.mark.parametrize(
344+
"a, b",
345+
[
346+
# ``debug`` / ``lineinfo`` / ``link_time_optimization`` are truthy-only
347+
# gates in the linker; False and None produce identical output.
348+
pytest.param({"debug": False}, {"debug": None}, id="debug_false_eq_none"),
349+
pytest.param({"lineinfo": False}, {"lineinfo": None}, id="lineinfo_false_eq_none"),
350+
pytest.param(
351+
{"link_time_optimization": False},
352+
{"link_time_optimization": None},
353+
id="lto_false_eq_none",
354+
),
355+
# ``time`` is a presence gate: the linker emits ``-time`` for any
356+
# non-None value, so True / "path" produce the same flag.
357+
pytest.param({"time": True}, {"time": "timing.csv"}, id="time_true_eq_path"),
358+
# ``no_cache`` has an ``is True`` gate; False and None equivalent.
359+
pytest.param({"no_cache": False}, {"no_cache": None}, id="no_cache_false_eq_none"),
360+
],
361+
)
362+
def test_make_program_cache_key_ptx_linker_equivalent_options_hash_same(a, b, monkeypatch):
363+
"""The linker folds several PTX-relevant fields through simple gates:
364+
truthy-only (``debug``, ``lineinfo``, ``link_time_optimization``),
365+
presence-only (``time``), ``is True`` (``no_cache``). Semantically
366+
equivalent inputs under those gates must hash to the same key."""
367+
# Pin the linker probe so the only variable is the options gate.
368+
from cuda.core.utils import _program_cache
369+
370+
monkeypatch.setattr(_program_cache, "_linker_backend_and_version", lambda: ("nvJitLink", "12030"))
371+
k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**a))
372+
k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**b))
373+
assert k_a == k_b
374+
375+
376+
def test_make_program_cache_key_nvvm_use_libdevice_false_equals_none():
377+
"""Program_init gates ``use_libdevice`` on truthiness, so False and None
378+
compile identically and must hash the same way."""
379+
# With the NVVM probe mocked so the NVVM portion is stable.
380+
from cuda.core.utils import _program_cache
381+
382+
# Use real probe; the key differences come only from use_libdevice here.
383+
k_none = _make_key(code="abc", code_type="nvvm", target_type="ptx", options=_opts(use_libdevice=None))
384+
k_false = _make_key(code="abc", code_type="nvvm", target_type="ptx", options=_opts(use_libdevice=False))
385+
k_true = _make_key(code="abc", code_type="nvvm", target_type="ptx", options=_opts(use_libdevice=True))
386+
assert k_none == k_false
387+
assert k_true != k_none
388+
_ = _program_cache # silence unused warning
389+
390+
343391
def test_make_program_cache_key_nvvm_probe_changes_key(monkeypatch):
344392
"""NVVM keys must reflect the NVVM toolchain identity (IR version)
345393
so an upgraded libNVVM does not silently reuse pre-upgrade entries."""

0 commit comments

Comments
 (0)