Skip to content

Commit a4eeeec

Browse files
committed
fixup! feat(core.utils): narrow key inputs by backend
- name_expressions now only hashed for NVRTC (NVVM/PTX ignore them in Program.compile) - Validate code_type/target_type combination against SUPPORTED_TARGETS matrix
1 parent d5f8c19 commit a4eeeec

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,14 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
173173
_VALID_CODE_TYPES = frozenset({"c++", "ptx", "nvvm"})
174174
_VALID_TARGET_TYPES = frozenset({"ptx", "cubin", "ltoir"})
175175

176+
# code_type -> allowed target_type set, mirroring Program.compile's
177+
# SUPPORTED_TARGETS matrix in _program.pyx.
178+
_SUPPORTED_TARGETS_BY_CODE_TYPE = {
179+
"c++": frozenset({"ptx", "cubin", "ltoir"}),
180+
"ptx": frozenset({"cubin", "ptx"}),
181+
"nvvm": frozenset({"ptx", "ltoir"}),
182+
}
183+
176184

177185
def _backend_for_code_type(code_type: str) -> str:
178186
if code_type == "nvvm":
@@ -341,6 +349,13 @@ def make_program_cache_key(
341349
raise ValueError(f"code_type={code_type!r} is not supported (must be one of {sorted(_VALID_CODE_TYPES)})")
342350
if target_type not in _VALID_TARGET_TYPES:
343351
raise ValueError(f"target_type={target_type!r} is not supported (must be one of {sorted(_VALID_TARGET_TYPES)})")
352+
supported_for_code = _SUPPORTED_TARGETS_BY_CODE_TYPE[code_type]
353+
if target_type not in supported_for_code:
354+
raise ValueError(
355+
f"target_type={target_type!r} is not valid for code_type={code_type!r}"
356+
f" (supported: {sorted(supported_for_code)}). Program.compile() rejects"
357+
f" this combination, so caching a key for it is meaningless."
358+
)
344359

345360
backend = _backend_for_code_type(code_type)
346361

@@ -456,9 +471,13 @@ def _probe(label: str, fn):
456471
_update("option_count", str(len(option_bytes)).encode("ascii"))
457472
for opt in option_bytes:
458473
_update("option", bytes(opt))
459-
_update("names_count", str(len(names)).encode("ascii"))
460-
for n in names:
461-
_update("name", n)
474+
# Only NVRTC consumes ``name_expressions``; Program.compile ignores them
475+
# on the NVVM and PTX/linker paths, so folding them into the key there
476+
# would force spurious cache misses.
477+
if backend == "nvrtc":
478+
_update("names_count", str(len(names)).encode("ascii"))
479+
for n in names:
480+
_update("name", n)
462481

463482
# Hash fields that affect compilation output but are not captured by
464483
# options.as_bytes() (which only emits CLI-style flags). Without these,

cuda_core/tests/test_program_cache.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,43 @@ def test_make_program_cache_key_name_expressions_str_bytes_distinct():
247247
pytest.param({"code_type": "fortran"}, ValueError, "code_type", id="unknown_code_type"),
248248
pytest.param({"target_type": "exe"}, ValueError, "target_type", id="unknown_target_type"),
249249
pytest.param({"code": 12345}, TypeError, "code", id="non_str_bytes_code"),
250+
# Backend-specific target matrix -- Program.compile rejects these
251+
# combinations, so caching a key for them would be a lie.
252+
pytest.param(
253+
{"code_type": "ptx", "target_type": "ltoir"},
254+
ValueError,
255+
"not valid for code_type",
256+
id="ptx_cannot_ltoir",
257+
),
258+
pytest.param(
259+
{"code_type": "nvvm", "target_type": "cubin"},
260+
ValueError,
261+
"not valid for code_type",
262+
id="nvvm_cannot_cubin",
263+
),
250264
],
251265
)
252266
def test_make_program_cache_key_rejects(kwargs, exc_type, match):
253267
with pytest.raises(exc_type, match=match):
254268
_make_key(**kwargs)
255269

256270

271+
@pytest.mark.parametrize(
272+
"code_type, code, target_type",
273+
[
274+
pytest.param("nvvm", "abc", "ptx", id="nvvm"),
275+
pytest.param("ptx", ".version 7.0", "cubin", id="ptx"),
276+
],
277+
)
278+
def test_make_program_cache_key_ignores_name_expressions_for_non_nvrtc(code_type, code, target_type):
279+
"""Program.compile only forwards ``name_expressions`` on the NVRTC path
280+
(_program.pyx). Folding them into the key for NVVM/PTX compiles would
281+
cause identical compiles to miss the cache for no behavioural reason."""
282+
k_none = _make_key(code=code, code_type=code_type, target_type=target_type)
283+
k_with = _make_key(code=code, code_type=code_type, target_type=target_type, name_expressions=("foo", "bar"))
284+
assert k_none == k_with
285+
286+
257287
@pytest.mark.parametrize(
258288
"option_kw",
259289
[

0 commit comments

Comments
 (0)