Skip to content

Commit 47b47da

Browse files
committed
fixup! test(core.utils): use tokenize to parse SUPPORTED_TARGETS (ignore string/comment braces)
1 parent 531444a commit 47b47da

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

cuda_core/tests/test_program_cache.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -270,44 +270,55 @@ def test_make_program_cache_key_rejects(kwargs, exc_type, match):
270270

271271
def test_make_program_cache_key_supported_targets_matches_program_compile():
272272
"""``_SUPPORTED_TARGETS_BY_CODE_TYPE`` duplicates the backend target
273-
matrix in ``_program.pyx``. Guard against drift: parse the pyx source,
274-
extract its SUPPORTED_TARGETS dict, and assert the two views agree for
275-
every code_type."""
273+
matrix in ``_program.pyx``. Guard against drift: parse the pyx source
274+
with :mod:`tokenize` (which skips string literals and comments) to
275+
extract ``SUPPORTED_TARGETS`` and assert the two views agree."""
276276
import ast
277+
import io
278+
import tokenize
277279
from pathlib import Path
278280

279281
from cuda.core.utils._program_cache import _SUPPORTED_TARGETS_BY_CODE_TYPE
280282

281-
# Map Program._backend strings to the public code_type values.
282283
backend_to_code_type = {"NVRTC": "c++", "NVVM": "nvvm"}
283284
linker_backends = ("nvJitLink", "driver")
284285

285286
pyx = Path(__file__).parent.parent / "cuda" / "core" / "_program.pyx"
286287
text = pyx.read_text()
287-
# Find the assignment: ``cdef dict SUPPORTED_TARGETS = { ... }``.
288-
marker = "cdef dict SUPPORTED_TARGETS"
289-
start = text.index(marker)
290-
brace = text.index("{", start)
291-
# Find the matching closing brace.
288+
marker_idx = text.index("cdef dict SUPPORTED_TARGETS")
289+
tokens = tokenize.generate_tokens(io.StringIO(text[marker_idx:]).readline)
290+
292291
depth = 0
293-
for idx in range(brace, len(text)):
294-
if text[idx] == "{":
292+
start_offset = None
293+
end_offset = None
294+
lines = text[marker_idx:].splitlines(keepends=True)
295+
line_starts = [0]
296+
for line in lines[:-1]:
297+
line_starts.append(line_starts[-1] + len(line))
298+
299+
def _offset(row, col):
300+
return line_starts[row - 1] + col
301+
302+
for tok in tokens:
303+
if tok.type != tokenize.OP:
304+
continue
305+
if tok.string == "{":
306+
if depth == 0:
307+
start_offset = _offset(tok.start[0], tok.start[1])
295308
depth += 1
296-
elif text[idx] == "}":
309+
elif tok.string == "}":
297310
depth -= 1
298311
if depth == 0:
299-
end = idx + 1
312+
end_offset = _offset(tok.end[0], tok.end[1])
300313
break
301-
pyx_targets = ast.literal_eval(text[brace:end])
314+
assert start_offset is not None and end_offset is not None, "could not locate SUPPORTED_TARGETS literal"
315+
pyx_targets = ast.literal_eval(text[marker_idx + start_offset : marker_idx + end_offset])
302316

303-
# NVRTC and NVVM map 1:1 to their code_types.
304317
for backend, code_type in backend_to_code_type.items():
305318
assert frozenset(pyx_targets[backend]) == _SUPPORTED_TARGETS_BY_CODE_TYPE[code_type], (
306319
backend,
307320
code_type,
308321
)
309-
# Both linker backends ultimately serve code_type="ptx"; they should
310-
# agree with each other and with the cache table.
311322
linker_sets = [frozenset(pyx_targets[b]) for b in linker_backends]
312323
assert all(s == linker_sets[0] for s in linker_sets)
313324
assert linker_sets[0] == _SUPPORTED_TARGETS_BY_CODE_TYPE["ptx"]

0 commit comments

Comments
 (0)