Skip to content

Commit 531444a

Browse files
committed
fixup! test(core.utils): guard against drift between cache + Program.compile target matrices
1 parent a4eeeec commit 531444a

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

cuda_core/tests/test_program_cache.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,51 @@ def test_make_program_cache_key_rejects(kwargs, exc_type, match):
268268
_make_key(**kwargs)
269269

270270

271+
def test_make_program_cache_key_supported_targets_matches_program_compile():
272+
"""``_SUPPORTED_TARGETS_BY_CODE_TYPE`` duplicates the backend target
273+
matrix in ``_program.pyx``. Guard against drift: parse the pyx source,
274+
extract its SUPPORTED_TARGETS dict, and assert the two views agree for
275+
every code_type."""
276+
import ast
277+
from pathlib import Path
278+
279+
from cuda.core.utils._program_cache import _SUPPORTED_TARGETS_BY_CODE_TYPE
280+
281+
# Map Program._backend strings to the public code_type values.
282+
backend_to_code_type = {"NVRTC": "c++", "NVVM": "nvvm"}
283+
linker_backends = ("nvJitLink", "driver")
284+
285+
pyx = Path(__file__).parent.parent / "cuda" / "core" / "_program.pyx"
286+
text = pyx.read_text()
287+
# Find the assignment: ``cdef dict SUPPORTED_TARGETS = { ... }``.
288+
marker = "cdef dict SUPPORTED_TARGETS"
289+
start = text.index(marker)
290+
brace = text.index("{", start)
291+
# Find the matching closing brace.
292+
depth = 0
293+
for idx in range(brace, len(text)):
294+
if text[idx] == "{":
295+
depth += 1
296+
elif text[idx] == "}":
297+
depth -= 1
298+
if depth == 0:
299+
end = idx + 1
300+
break
301+
pyx_targets = ast.literal_eval(text[brace:end])
302+
303+
# NVRTC and NVVM map 1:1 to their code_types.
304+
for backend, code_type in backend_to_code_type.items():
305+
assert frozenset(pyx_targets[backend]) == _SUPPORTED_TARGETS_BY_CODE_TYPE[code_type], (
306+
backend,
307+
code_type,
308+
)
309+
# Both linker backends ultimately serve code_type="ptx"; they should
310+
# agree with each other and with the cache table.
311+
linker_sets = [frozenset(pyx_targets[b]) for b in linker_backends]
312+
assert all(s == linker_sets[0] for s in linker_sets)
313+
assert linker_sets[0] == _SUPPORTED_TARGETS_BY_CODE_TYPE["ptx"]
314+
315+
271316
@pytest.mark.parametrize(
272317
"code_type, code, target_type",
273318
[

0 commit comments

Comments
 (0)