Skip to content

Commit 09eadde

Browse files
committed
Set unknown loc for dummy return in main function
Signed-off-by: Ziheng Deng <zihengd@nvidia.com>
1 parent b9d077f commit 09eadde

4 files changed

Lines changed: 42 additions & 3 deletions

File tree

src/cuda/tile/_ir2bytecode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def get_subprogram(self, func_desc: FunctionDesc) -> bc.DebugAttrId:
330330
return ret
331331

332332
def get_debugattr(self, loc: Loc) -> bc.DebugAttrId:
333-
if self._anonymize:
333+
if self._anonymize or loc.is_unknown():
334334
return bc.MISSING_DEBUG_ATTR_ID
335335

336336
subprogram = self.get_subprogram(loc.function)

src/cuda/tile/_passes/ast2hir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,8 @@ def _ast2hir(func_def: ast.FunctionDef | ast.Lambda, ctx: _Context) -> hir.Block
10001000
_stmt_list(func_def.body, ctx)
10011001
# Add a Return jump to the root block if it doesn't have one
10021002
if root_block.jump is None:
1003-
ctx.set_block_jump(hir.Jump.RETURN)
1003+
with ctx.change_loc(Loc.unknown()):
1004+
ctx.set_block_jump(hir.Jump.RETURN)
10041005
elif isinstance(func_def, ast.FunctionDef):
10051006
# To enable early returns in a helper function, wrap the body in a loop.
10061007
# Thus, we can use "break" to implement the return statement.

src/cuda/tile/_passes/hir2ir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def _concretize_func_desc(func_hir: hir.Function, ir_ctx: IRContext) -> Function
7575

7676

7777
def retarget_loc(loc: Loc, scope: Scope) -> Loc:
78+
if loc.is_unknown():
79+
return loc
80+
7881
# Splice in the scope's call site and, if this loc belongs to the function
7982
# currently being inlined, swap in the per-specialization FunctionDesc so
8083
# emitted ops carry the right debug info.

test/test_function_debuginfo.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
import pytest
88

99
import cuda.tile as ct
10+
import cuda.tile._bytecode as bc
1011
from cuda.tile._bytecode.basic import StringTable
1112
from cuda.tile._bytecode.debug_info import DebugAttrTable
1213
from cuda.tile._cext import CallingConvention
1314
from cuda.tile._compile import compile_tile
14-
from cuda.tile._exception import FunctionDesc
15+
from cuda.tile._exception import FunctionDesc, Loc
16+
from cuda.tile._ir.ops import Return
1517
from cuda.tile._ir2bytecode import DebugAttrMap, create_synthetic_linkage_name
1618
from cuda.tile.compilation import ArrayConstraint, KernelSignature
1719

@@ -202,6 +204,34 @@ def kernel(x):
202204
assert entry.line == kernel_line and entry.line != body_line
203205

204206

207+
def test_dummy_return_sets_unknown_loc():
208+
def kernel(x):
209+
t = ct.load(x, (0,), (1,))
210+
ct.store(x, (1,), t)
211+
212+
body = _compile(kernel)
213+
return_ops = [op for op in body.traverse() if isinstance(op, Return)]
214+
assert len(return_ops) == 1, "expected one Return op"
215+
assert return_ops[0].loc.is_unknown()
216+
217+
218+
def test_explicit_return_keeps_loc():
219+
def kernel(x):
220+
t = ct.load(x, (0,), (1,))
221+
ct.store(x, (1,), t)
222+
return None
223+
224+
body = _compile(kernel)
225+
return_line, _ = _line_col(kernel, "return None")
226+
return_ops = [op for op in body.traverse() if isinstance(op, Return)]
227+
assert len(return_ops) == 1, "expected one Return op"
228+
return_op = return_ops[0]
229+
assert return_op, "expected a Return op"
230+
assert return_op.loc.line == return_line, (return_op.loc.line, return_line)
231+
assert return_op.loc.function is not None
232+
assert return_op.loc.function.name == "kernel"
233+
234+
205235
def test_linkage_names_appear_in_emitted_bytecode():
206236
# End-to-end smoke check: every linkage name we compute should also be
207237
# serialized into the bytecode's string table. This catches regressions
@@ -258,3 +288,8 @@ def compile_once():
258288
)
259289

260290
assert compile_once() == compile_once()
291+
292+
293+
def test_unknown_loc_maps_to_missing_debug_attr():
294+
m = DebugAttrMap(DebugAttrTable(StringTable()), entry_symbol="kern", anonymize=False)
295+
assert m.get_debugattr(Loc.unknown()) == bc.MISSING_DEBUG_ATTR_ID

0 commit comments

Comments
 (0)