Skip to content

Commit 08ca973

Browse files
committed
fixup! feat(core.utils): backend-aware linker fingerprint collapses driver-ignored fields
1 parent 5004f7a commit 08ca973

File tree

2 files changed

+61
-12
lines changed

2 files changed

+61
-12
lines changed

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,18 +259,34 @@ def _gate_identity(v):
259259
}
260260

261261

262-
def _linker_option_fingerprint(options: ProgramOptions) -> list[bytes]:
263-
"""Stable byte fingerprint of ProgramOptions fields consumed by the Linker.
264-
265-
Each field is first passed through the gate the Linker itself uses so
266-
that equivalent inputs (e.g. ``debug=False`` / ``None``) produce the
267-
same bytes. Without this, callers hitting the same linker behavior
268-
would get different cache keys and needlessly re-compile.
262+
# LinkerOptions fields the ``cuLink`` driver backend silently ignores
263+
# (emits only a DeprecationWarning; no actual flag reaches the compiler).
264+
# When the driver backend is active, collapse them to a single sentinel in
265+
# the fingerprint so nvJitLink<->driver parity of ``ObjectCode`` doesn't
266+
# cause cache misses from otherwise-equivalent configurations.
267+
_DRIVER_IGNORED_LINKER_FIELDS = frozenset({"ftz", "prec_div", "prec_sqrt", "fma"})
268+
269+
270+
def _linker_option_fingerprint(options: ProgramOptions, *, use_driver_linker: bool | None) -> list[bytes]:
271+
"""Backend-aware fingerprint of ProgramOptions fields consumed by the Linker.
272+
273+
Each field passes through the gate the Linker itself uses so equivalent
274+
inputs (e.g. ``debug=False`` / ``None``) hash to the same bytes. When
275+
the driver (cuLink) linker backend is in use, fields it silently
276+
ignores collapse to one sentinel so those options don't perturb the
277+
key on driver-backed hosts either. ``use_driver_linker=None`` means we
278+
couldn't probe the backend; we don't collapse driver-ignored fields in
279+
that case, to stay conservative.
269280
"""
270-
return [
271-
f"{name}={_LINKER_FIELD_GATES[name](getattr(options, name, None))!r}".encode()
272-
for name in _LINKER_RELEVANT_FIELDS
273-
]
281+
parts = []
282+
driver_ignored = use_driver_linker is True
283+
for name in _LINKER_RELEVANT_FIELDS:
284+
if driver_ignored and name in _DRIVER_IGNORED_LINKER_FIELDS:
285+
parts.append(f"{name}=<driver-ignored>".encode())
286+
continue
287+
gated = _LINKER_FIELD_GATES[name](getattr(options, name, None))
288+
parts.append(f"{name}={gated!r}".encode())
289+
return parts
274290

275291

276292
# ProgramOptions fields that map to LinkerOptions fields the cuLink (driver)
@@ -499,6 +515,7 @@ def make_program_cache_key(
499515
# store a key for a compilation that can't succeed in this environment.
500516
# If the probe fails we can't tell which backend will run, so skip -- the
501517
# failed-probe branch below already taints the key.
518+
use_driver_linker: bool | None = None
502519
if backend == "linker":
503520
try:
504521
from cuda.core._linker import _decide_nvjitlink_or_driver
@@ -535,7 +552,7 @@ def make_program_cache_key(
535552
# misses on PTX. For nvrtc/nvvm backends, ProgramOptions.as_bytes gives
536553
# the real compile-time flag surface.
537554
if backend == "linker":
538-
option_bytes = _linker_option_fingerprint(options)
555+
option_bytes = _linker_option_fingerprint(options, use_driver_linker=use_driver_linker)
539556
else:
540557
try:
541558
option_bytes = options.as_bytes(backend, target_type)

cuda_core/tests/test_program_cache.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,38 @@ def test_make_program_cache_key_ptx_linker_equivalent_options_hash_same(a, b, mo
373373
assert k_a == k_b
374374

375375

376+
@pytest.mark.parametrize(
377+
"field, a, b",
378+
[
379+
pytest.param("ftz", True, False, id="ftz"),
380+
pytest.param("prec_div", True, False, id="prec_div"),
381+
pytest.param("prec_sqrt", True, False, id="prec_sqrt"),
382+
pytest.param("fma", True, False, id="fma"),
383+
],
384+
)
385+
def test_make_program_cache_key_ptx_driver_ignored_fields_collapse(field, a, b, monkeypatch):
386+
"""The driver (cuLink) linker silently ignores ftz/prec_div/prec_sqrt/fma
387+
(only emits a DeprecationWarning). Under the driver backend, those
388+
fields must not perturb the PTX cache key -- two otherwise-equivalent
389+
compiles differing only in these flags produce identical ObjectCode."""
390+
from cuda.core import _linker
391+
392+
monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: True) # driver
393+
k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**{field: a}))
394+
k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**{field: b}))
395+
assert k_a == k_b
396+
397+
398+
def test_make_program_cache_key_ptx_driver_ignored_fields_still_matter_under_nvjitlink(monkeypatch):
399+
"""nvJitLink does honour those fields; they must still differentiate keys there."""
400+
from cuda.core import _linker
401+
402+
monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: False) # nvJitLink
403+
k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ftz=True))
404+
k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ftz=False))
405+
assert k_a != k_b
406+
407+
376408
def test_make_program_cache_key_nvvm_use_libdevice_false_equals_none():
377409
"""Program_init gates ``use_libdevice`` on truthiness, so False and None
378410
compile identically and must hash the same way."""

0 commit comments

Comments
 (0)