|
7 | 7 | import pytest |
8 | 8 |
|
9 | 9 | import cuda.tile as ct |
| 10 | +import cuda.tile._bytecode as bc |
10 | 11 | from cuda.tile._bytecode.basic import StringTable |
11 | 12 | from cuda.tile._bytecode.debug_info import DebugAttrTable |
12 | 13 | from cuda.tile._cext import CallingConvention |
13 | 14 | from cuda.tile._compile import compile_tile |
14 | | -from cuda.tile._exception import FunctionDesc |
| 15 | +from cuda.tile._exception import FunctionDesc, Loc |
| 16 | +from cuda.tile._ir.ops import Return |
15 | 17 | from cuda.tile._ir2bytecode import DebugAttrMap, create_synthetic_linkage_name |
16 | 18 | from cuda.tile.compilation import ArrayConstraint, KernelSignature |
17 | 19 |
|
@@ -202,6 +204,34 @@ def kernel(x): |
202 | 204 | assert entry.line == kernel_line and entry.line != body_line |
203 | 205 |
|
204 | 206 |
|
| 207 | +def test_dummy_return_sets_unknown_loc(): |
| 208 | + def kernel(x): |
| 209 | + t = ct.load(x, (0,), (1,)) |
| 210 | + ct.store(x, (1,), t) |
| 211 | + |
| 212 | + body = _compile(kernel) |
| 213 | + return_ops = [op for op in body.traverse() if isinstance(op, Return)] |
| 214 | + assert len(return_ops) == 1, "expected one Return op" |
| 215 | + assert return_ops[0].loc.is_unknown() |
| 216 | + |
| 217 | + |
| 218 | +def test_explicit_return_keeps_loc(): |
| 219 | + def kernel(x): |
| 220 | + t = ct.load(x, (0,), (1,)) |
| 221 | + ct.store(x, (1,), t) |
| 222 | + return None |
| 223 | + |
| 224 | + body = _compile(kernel) |
| 225 | + return_line, _ = _line_col(kernel, "return None") |
| 226 | + return_ops = [op for op in body.traverse() if isinstance(op, Return)] |
| 227 | + assert len(return_ops) == 1, "expected one Return op" |
| 228 | + return_op = return_ops[0] |
| 229 | + assert return_op, "expected a Return op" |
| 230 | + assert return_op.loc.line == return_line, (return_op.loc.line, return_line) |
| 231 | + assert return_op.loc.function is not None |
| 232 | + assert return_op.loc.function.name == "kernel" |
| 233 | + |
| 234 | + |
205 | 235 | def test_linkage_names_appear_in_emitted_bytecode(): |
206 | 236 | # End-to-end smoke check: every linkage name we compute should also be |
207 | 237 | # serialized into the bytecode's string table. This catches regressions |
@@ -258,3 +288,8 @@ def compile_once(): |
258 | 288 | ) |
259 | 289 |
|
260 | 290 | assert compile_once() == compile_once() |
| 291 | + |
| 292 | + |
| 293 | +def test_unknown_loc_maps_to_missing_debug_attr(): |
| 294 | + m = DebugAttrMap(DebugAttrTable(StringTable()), entry_symbol="kern", anonymize=False) |
| 295 | + assert m.get_debugattr(Loc.unknown()) == bc.MISSING_DEBUG_ATTR_ID |
0 commit comments