Skip to content

Commit 2783ab8

Browse files
walter-erquinigogbonik
authored andcommitted
[debugging] Use different linkage names for different functions
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 91a80ad commit 2783ab8

8 files changed

Lines changed: 379 additions & 32 deletions

File tree

src/cuda/tile/_exception.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,25 @@
1212

1313
@dataclass(eq=False, frozen=True)
1414
class FunctionDesc:
15-
name: str | None
15+
name: str | None # None for lambdas
1616
filename: str
17-
line: int
17+
line: int # 1-based
18+
column: int # 1-based
19+
# If this FunctionDesc represents a concrete specialization of a source
20+
# function other than the kernel entry point, this value will hold a
21+
# unique identifier, which is used to distinguish distinct specialized
22+
# functions in debug info.
23+
specialization_id: str | None = None
24+
# True for the FunctionDesc that represents the kernel entry point.
25+
is_entry: bool = False
1826

1927
def __str__(self):
20-
return f"'{self.name}' @{self.filename}:{self.line}"
28+
return f"'{self.name}' @{self.filename}:{self.line}:{self.column}"
2129

2230
def short_str(self):
2331
if self.name is None:
2432
base_name = os.path.basename(self.filename)
25-
return f"<lambda at {base_name}:{self.line}>"
33+
return f"<lambda at {base_name}:{self.line}:{self.column}>"
2634
else:
2735
return f"<function {self.name}>"
2836

src/cuda/tile/_ir/ir.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def __init__(self, log_ir_on_error: bool, tileiras_version: BytecodeVersion):
4040
self.log_ir_on_error = log_ir_on_error
4141
self._aggregate_values: Dict[str, Any] = dict()
4242
self.tileiras_version: BytecodeVersion = tileiras_version
43+
self._function_specialization_id_counter = itertools.count()
44+
45+
def next_function_specialization_id(self) -> str:
46+
# Monotonic counter used as a unique id when creating concrete FunctionDescs
47+
# to distinguish them in debug info. The monotonicity ensures that the
48+
# emitted bytecode is deterministic for cache hits.
49+
return f"s{next(self._function_specialization_id_counter)}"
4350

4451
# Make a Var with a unique name based on `name`.
4552
def make_var(self, name: str, loc: Loc) -> Var:

src/cuda/tile/_ir/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def format_var(var):
167167

168168
@impl(hir_stubs.loop)
169169
async def loop_impl(body: hir.Block, iterable: Var):
170-
from .._passes.hir2ir import dispatch_hir_block
170+
from .._passes.hir2ir import dispatch_hir_block, retarget_loc
171171

172172
scope = Scope.get_current()
173173
range_ty = require_optional_range_type(iterable)
@@ -202,7 +202,7 @@ async def loop_impl(body: hir.Block, iterable: Var):
202202

203203
# Process the loop body
204204
loop_info = ControlFlowInfo(stored_locals)
205-
body_loc = body.loc.with_call_site(scope.call_site)
205+
body_loc = retarget_loc(body.loc, scope)
206206
with enter_nested_block(body_loc) as new_body, scope.change_loop_info(loop_info), \
207207
scope.local.enter_branch():
208208
# Define body variables. Not all of them will eventually be kept,
@@ -396,7 +396,7 @@ async def _flatten_branch(branch: hir.Block) -> Var | None:
396396

397397
@impl(hir_stubs.if_else)
398398
async def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block) -> Var | None:
399-
from .._passes.hir2ir import dispatch_hir_block
399+
from .._passes.hir2ir import dispatch_hir_block, retarget_loc
400400

401401
require_bool(cond)
402402
if cond.is_constant():
@@ -415,7 +415,7 @@ async def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block)
415415

416416
# Convert the "then" branch from HIR to IR
417417
info = ControlFlowInfo(stored_locals)
418-
then_loc = then_block.loc.with_call_site(scope.call_site)
418+
then_loc = retarget_loc(then_block.loc, scope)
419419
with enter_nested_block(then_loc) as new_then_block, scope.change_if_else_info(info), \
420420
scope.local.enter_branch():
421421
await dispatch_hir_block(then_block)
@@ -427,7 +427,7 @@ async def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block)
427427
# EndBranch
428428
# <else_block>
429429
# This is to avoid the situation where none of the branches yield.
430-
else_loc = else_block.loc.with_call_site(scope.call_site)
430+
else_loc = retarget_loc(else_block.loc, scope)
431431
if len(info.jumps) == 0:
432432
info = ControlFlowInfo(())
433433
with enter_nested_block(else_loc) as new_else_block, scope.change_if_else_info(info), \

src/cuda/tile/_ir/scope.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dataclasses import dataclass
1010
from typing import TypeVar, Generic
1111

12-
from cuda.tile._exception import Loc, TileSyntaxError
12+
from cuda.tile._exception import Loc, FunctionDesc, TileSyntaxError
1313
from cuda.tile._ir import hir
1414
from cuda.tile._ir.hir import ResolvedName
1515
from cuda.tile._ir.ir import Operation, Var, IRContext
@@ -142,6 +142,9 @@ class Scope:
142142
call_site: Loc | None
143143
hir2ir_varmap: IntMap[Var]
144144
func_hir: hir.Function
145+
# FunctionDesc that represents the concrete specialization of `func_hir`
146+
# representing this scope.
147+
concrete_func_desc: FunctionDesc
145148
context_stack: list[ContextManagerState] = dataclasses.field(default_factory=list)
146149
loop_context_stack_depth: int | None = None
147150

src/cuda/tile/_ir2bytecode.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import functools
66
import os
7+
import re
78
from contextlib import contextmanager
89
from typing import Dict, Tuple, Any, Optional
910
import warnings
@@ -232,14 +233,56 @@ def encode_comparison(builder: bc.CodeBuilder, fn: str, lhs: bc.Value, rhs: bc.V
232233
raise TileInternalError(f'Unexpected dtype: {dtype}')
233234

234235

236+
def create_synthetic_linkage_name(func_desc: FunctionDesc) -> str:
237+
# Build a synthetic linkage name for a helper function or lambda. Format:
238+
#
239+
# <name>@<basename>:<line>:<column>_<specialization_id>
240+
#
241+
# By construction every FunctionDesc reaching this point has been
242+
# concretized by hir2ir — only the kernel entry skips this path (it uses
243+
# the externally-visible symbol instead), so the specialization_id is
244+
# required here.
245+
assert func_desc.specialization_id is not None, (
246+
f"create_synthetic_linkage_name called on a FunctionDesc without a "
247+
f"specialization_id: {func_desc}. hir2ir must concretize every "
248+
f"non-entry function before bytecode generation."
249+
)
250+
base_name = os.path.basename(func_desc.filename) or "unknown"
251+
stem = os.path.splitext(base_name)[0]
252+
# Convert any non-alphanumeric chars to _.
253+
stem = re.sub(r"[^A-Za-z0-9_]", "_", stem) or "anonymous"
254+
func_part = func_desc.name if func_desc.name is not None else "lambda"
255+
return (f"{func_part}@{stem}:{func_desc.line}:{func_desc.column}"
256+
f"_{func_desc.specialization_id}")
257+
258+
235259
class DebugAttrMap:
236-
def __init__(self, debug_attr_table: bc.DebugAttrTable, linkage_name: str, anonymize: bool):
260+
def __init__(self,
261+
debug_attr_table: bc.DebugAttrTable,
262+
entry_symbol: str,
263+
anonymize: bool):
237264
self._subprogram_cache = {}
238265
self._debug_attr_table = debug_attr_table
239-
self._linkage_name = linkage_name
266+
self._entry_symbol = entry_symbol
240267
self._anonymize = anonymize
241268

269+
def _linkage_for(self, func_desc: FunctionDesc) -> str:
270+
# The kernel entry point keeps the externally-visible symbol so it can
271+
# be looked up by the loader. Every other function gets a per-function
272+
# artificial linkage name.
273+
if func_desc.is_entry:
274+
return self._entry_symbol
275+
return create_synthetic_linkage_name(func_desc)
276+
242277
def get_subprogram(self, func_desc: FunctionDesc) -> bc.DebugAttrId:
278+
# Every FunctionDesc reaching DI emission must satisfy: a function has
279+
# no specialization_id iff it is the kernel entry. hir2ir leaves the
280+
# entry's abstract desc as-is and concretizes everyone else; anything
281+
# else is a bug.
282+
assert func_desc.is_entry == (func_desc.specialization_id is None), (
283+
f"FunctionDesc invariant violated: is_entry={func_desc.is_entry} "
284+
f"but specialization_id={func_desc.specialization_id!r}: {func_desc}"
285+
)
243286
try:
244287
return self._subprogram_cache[func_desc]
245288
except KeyError:
@@ -252,7 +295,7 @@ def get_subprogram(self, func_desc: FunctionDesc) -> bc.DebugAttrId:
252295
file=file_attr,
253296
line=func_desc.line,
254297
name="<lambda>" if func_desc.name is None else func_desc.name,
255-
linkage_name=self._linkage_name,
298+
linkage_name=self._linkage_for(func_desc),
256299
compile_unit=compile_unit_attr,
257300
scope_line=func_desc.line,
258301
)
@@ -435,7 +478,8 @@ def generate_bytecode_for_kernel(func_body: Block,
435478
num_worker_warps_per_cta=num_worker_warps)
436479

437480
param_type_ids = [typeid(writer.type_table, p.get_type()) for p in func_body.params]
438-
debug_attr_map = DebugAttrMap(writer.debug_attr_table, symbol, anonymize_debug_attr)
481+
debug_attr_map = DebugAttrMap(writer.debug_attr_table, symbol,
482+
anonymize=anonymize_debug_attr)
439483
func_debug_attr = debug_attr_map.get_debugattr(func_body.loc)
440484

441485
with writer.function(name=symbol,

src/cuda/tile/_passes/ast2hir.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_function_hir(pyfunc: Callable, entry_point: bool) -> hir.Function:
5252
assert len(mod.body[0].body) == 1
5353
func_def = mod.body[0].body[0]
5454
assert isinstance(func_def, ast.FunctionDef)
55-
_fix_line_numbers(func_def, first_line)
55+
_fix_line_and_column_numbers(func_def, first_line)
5656

5757
func_globals = dict(pyfunc.__builtins__)
5858
func_globals.update(pyfunc.__globals__)
@@ -62,7 +62,8 @@ def get_function_hir(pyfunc: Callable, entry_point: bool) -> hir.Function:
6262
func_globals[name] = cell.cell_contents
6363

6464
filename = inspect.getfile(pyfunc)
65-
desc = FunctionDesc(func_def.name, filename, first_line)
65+
desc = FunctionDesc(func_def.name, filename, first_line, func_def.col_offset + 1,
66+
is_entry=entry_point)
6667
local_names, _, _ = ast_get_all_local_names(func_def)
6768
ctx = _Context(filename, first_line, desc, func_globals, local_names, entry_point)
6869
signature = inspect.signature(pyfunc)
@@ -73,9 +74,9 @@ def get_function_hir(pyfunc: Callable, entry_point: bool) -> hir.Function:
7374
return ret
7475

7576

76-
# Translate the 1-based line numbers of the chunk we passed to the AST parser
77-
# to the original 1-based line numbers in the file.
78-
def _fix_line_numbers(tree: ast.AST, first_line: int):
77+
# Translate the 1-based line and 0-based column numbers of the chunk we passed to the
78+
# AST parser to the original line and column numbers in the file.
79+
def _fix_line_and_column_numbers(tree: ast.AST, first_line: int):
7980
for node in ast.walk(tree):
8081
if hasattr(node, "lineno"):
8182
# Why -2?
@@ -846,9 +847,8 @@ def _pass_stmt(stmt: ast.Pass, ctx: _Context) -> None:
846847
def _make_closure(node: ast.FunctionDef | ast.Lambda, ctx: _Context) -> hir.Value:
847848
signature, default_exprs = _signature_from_ast_arguments(node.args)
848849
default_values = tuple(_expr(x, ctx) for x in default_exprs)
849-
line_no = node.lineno
850850
name = None if isinstance(node, ast.Lambda) else node.name
851-
desc = FunctionDesc(name, ctx.filename, line_no)
851+
desc = FunctionDesc(name, ctx.filename, node.lineno, node.col_offset + 1)
852852
new_locals, new_globals, _ = ast_get_all_local_names(node)
853853
local_names = (ctx.local_names - new_globals) | new_locals
854854
new_ctx = _Context(ctx.filename, ctx.first_line, desc, ctx.frozen_globals,

src/cuda/tile/_passes/hir2ir.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from .ast2hir import get_function_hir
1212
from .. import TileTypeError
1313
from .._coroutine_util import resume_after, run_coroutine
14-
from .._exception import Loc, TileSyntaxError, TileInternalError, TileError, TileRecursionError
14+
from .._exception import Loc, FunctionDesc, TileSyntaxError, TileInternalError, TileError, \
15+
TileRecursionError
1516
from .._execution import is_stub
1617
from .._ir import hir, ir
1718
from .._ir.ir import Var, IRContext, BoundMethodValue, ClosureValue, TupleValue
@@ -35,10 +36,11 @@ def hir2ir(func_hir: hir.Function,
3536
run_coroutine(_hir2ir_coroutine(func_hir, param_aggregate_vars, ir_ctx))
3637

3738

38-
async def _hir2ir_coroutine(func_hir: hir.Function,
39-
param_aggregate_vars: Sequence[ir.Var],
40-
ir_ctx: IRContext):
41-
scope = _create_scope(func_hir, ir_ctx, call_site=None, parent_scopes=())
39+
async def _hir2ir_coroutine(
40+
func_hir: hir.Function, param_aggregate_vars: Sequence[ir.Var], ir_ctx: IRContext
41+
):
42+
scope = _create_scope(func_hir, ir_ctx, call_site=None, parent_scopes=(),
43+
concrete_func_desc=func_hir.desc)
4244
for local_idx, var in zip(func_hir.param_local_indices, param_aggregate_vars, strict=True):
4345
scope.local[local_idx] = var
4446

@@ -56,9 +58,28 @@ async def _hir2ir_coroutine(func_hir: hir.Function,
5658

5759

5860
def _create_scope(func_hir: hir.Function, ir_ctx: IRContext, call_site: Loc | None,
59-
parent_scopes: tuple[LocalScope, ...]) -> Scope:
61+
parent_scopes: tuple[LocalScope, ...],
62+
concrete_func_desc: FunctionDesc) -> Scope:
6063
local_scope = LocalScope(func_hir.local_names, ir_ctx)
61-
return Scope(parent_scopes + (local_scope,), None, None, call_site, IntMap(), func_hir)
64+
return Scope(parent_scopes + (local_scope,), None, None, call_site, IntMap(), func_hir,
65+
concrete_func_desc=concrete_func_desc)
66+
67+
68+
def _concretize_func_desc(func_hir: hir.Function, ir_ctx: IRContext) -> FunctionDesc:
69+
# Mint a fresh FunctionDesc whose `specialization_id` makes its synthesized
70+
# linkage name unique across every inlining.
71+
return dataclasses.replace(func_hir.desc,
72+
specialization_id=ir_ctx.next_function_specialization_id())
73+
74+
75+
def retarget_loc(loc: Loc, scope: Scope) -> Loc:
76+
# Splice in the scope's call site and, if this loc belongs to the function
77+
# currently being inlined, swap in the per-specialization FunctionDesc so
78+
# emitted ops carry the right debug info.
79+
new_function = loc.function
80+
if loc.function is scope.func_hir.desc:
81+
new_function = scope.concrete_func_desc
82+
return dataclasses.replace(loc, function=new_function, call_site=scope.call_site)
6283

6384

6485
async def dispatch_hir_block(block: hir.Block, cur_builder: ir.Builder | None = None):
@@ -72,7 +93,7 @@ async def _dispatch_hir_block_inner(block: hir.Block, builder: ir.Builder):
7293
try:
7394
scope = Scope.get_current()
7495
for cursor, call in enumerate(block.calls):
75-
loc = call.loc.with_call_site(scope.call_site)
96+
loc = retarget_loc(call.loc, scope)
7697
with _wrap_exceptions(loc), builder.change_loc(loc):
7798
await _dispatch_call(call, scope)
7899
if builder.is_terminated:
@@ -81,7 +102,7 @@ async def _dispatch_hir_block_inner(block: hir.Block, builder: ir.Builder):
81102
return
82103
cursor = len(block.calls)
83104

84-
loc = block.jump_loc.with_call_site(scope.call_site)
105+
loc = retarget_loc(block.jump_loc, scope)
85106
with _wrap_exceptions(loc), builder.change_loc(loc):
86107
_dispatch_hir_jump(block, scope)
87108
except Exception:
@@ -155,9 +176,13 @@ async def _call_user_defined(callee_hir: hir.Function,
155176
raise TileSyntaxError("Variadic keyword parameters in user-defined"
156177
" functions are not supported")
157178

158-
# Activate a fresh Scope.
179+
# Activate a fresh Scope. Each inlining gets its own concretized
180+
# FunctionDesc so that DI never merges two specializations whose generated
181+
# IR might differ.
159182
new_scope = _create_scope(callee_hir, builder.ir_ctx, call_site=builder.loc,
160-
parent_scopes=parent_scopes)
183+
parent_scopes=parent_scopes,
184+
concrete_func_desc=_concretize_func_desc(callee_hir,
185+
builder.ir_ctx))
161186
with new_scope.make_current():
162187
# Call store_var() to bind arguments to parameters.
163188
for arg, local_idx, param_loc in zip(arg_list, callee_hir.param_local_indices,

0 commit comments

Comments
 (0)