Skip to content

Commit 9db6438

Browse files
committed
Helper function call improvements
- Allow recursive function calls. Add a recursion limit instead. - Include cuTile stack trace in error messages, rather than a single immediate location. - Don't set Loc.call_site in ast2hir, delay until hir2ir. This way, a single HIR object could in theory be reused for multiple calls (handy for implementing closures/lambdas, as well as potentially caching the HIR for faster kernel specialization) Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent c1278ba commit 9db6438

10 files changed

Lines changed: 146 additions & 57 deletions

File tree

changelog.d/allow-recursion.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Lift the ban on recursive helper function calls. Instead, add a limit on recursion depth.
5+
- Add a new exception class `TileRecursionError`, thrown at compile time when the recursion limit
6+
is reached during function call inlining.
7+
- Include a full cuTile traceback in error messages. Improve formatting of code locations:
8+
include function names, remove unnecessary characters to reduce line lengths.
9+
- Expose the `TileError` base class in the public API.
10+

src/cuda/tile/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242
from cuda.tile._exception import (
4343
TileCompilerExecutionError,
4444
TileCompilerTimeoutError,
45+
TileError,
4546
TileInternalError,
47+
TileRecursionError,
4648
TileSyntaxError,
4749
TileTypeError,
4850
TileValueError,
@@ -172,7 +174,9 @@
172174

173175
"TileCompilerExecutionError",
174176
"TileCompilerTimeoutError",
177+
"TileError",
175178
"TileInternalError",
179+
"TileRecursionError",
176180
"TileSyntaxError",
177181
"TileTypeError",
178182
"TileValueError",

src/cuda/tile/_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def wrapper(*args, **kwargs):
7575

7676
def _get_final_ir(pyfunc, args, tile_context) -> ir.Block:
7777
ir_ctx = ir.IRContext(tile_context)
78-
func_hir: hir.Function = get_function_hir(pyfunc, ir_ctx, call_site=None)
78+
func_hir: hir.Function = get_function_hir(pyfunc, ir_ctx, entry_point=True)
7979

8080
ir_args = _bind_kernel_arguments(func_hir.param_names, args, get_constant_annotations(pyfunc))
8181
func_body = hir2ir(func_hir, ir_args, ir_ctx)

src/cuda/tile/_exception.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
11
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
4+
import dataclasses
55
import linecache
66
import re
77
from dataclasses import dataclass
8-
from typing import Optional, Callable
8+
from typing import Optional
99
from unicodedata import east_asian_width
1010

1111

12+
@dataclass(eq=False, frozen=True)
13+
class FunctionDesc:
14+
name: str
15+
filename: str
16+
line: int
17+
18+
1219
@dataclass(slots=True)
1320
class Loc:
1421
line: int
1522
col: int
1623
filename: Optional[str] = None
1724
last_line: Optional[int] = None
1825
end_col: Optional[int] = None
19-
function: Optional[Callable] = None
26+
function: Optional[FunctionDesc] = None
2027
call_site: Optional["Loc"] = None
2128

29+
def with_call_site(self, call_site) -> "Loc":
30+
return dataclasses.replace(self, call_site=call_site)
31+
2232
def __str__(self) -> str:
2333
if self.filename:
2434
return f"{self.filename}:{self.line}:{self.col}"
@@ -49,6 +59,14 @@ def _wcwidth(s: str) -> int:
4959

5060

5161
def format_location(loc: Loc):
62+
frames = []
63+
while loc is not None:
64+
frames.append(loc)
65+
loc = loc.call_site
66+
return "".join(_format_location_frame(x) for x in reversed(frames))
67+
68+
69+
def _format_location_frame(loc: Loc) -> str:
5270
if loc.is_unknown():
5371
return "Unknown location"
5472

@@ -74,12 +92,14 @@ def format_location(loc: Loc):
7492
cols_str = f"col {visual_col + 1}"
7593
else:
7694
end_visual_col = _wcwidth(line_bytes[:end_col].decode())
77-
cols_str = f"col {visual_col + 1}--{end_visual_col}"
95+
cols_str = f"col {visual_col + 1}-{end_visual_col}"
7896

7997
spaces = " " * visual_col
8098
carets = "^" * (end_visual_col - visual_col)
8199

82-
return (f' In file "{loc.filename}", {lines_str}, {cols_str}:\n'
100+
func_str = "" if loc.function is None else f", in {loc.function.name}"
101+
102+
return (f' "{loc.filename}", {lines_str}, {cols_str}{func_str}:\n'
83103
f" {line_text}\n"
84104
f" {spaces}{carets}\n")
85105

@@ -103,6 +123,12 @@ class TileTypeError(TileError):
103123
pass
104124

105125

126+
class TileRecursionError(TileError):
127+
"""Thrown at compile time to indicate that the recursion limit has been reached
128+
when inlining a function call.
129+
"""
130+
131+
106132
class TileValueError(TileError):
107133
"""Exception when an unexpected python value is encountered."""
108134
pass

src/cuda/tile/_ir/hir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from textwrap import indent
2020
from typing import Any, Set, Mapping
2121

22-
from cuda.tile._exception import Loc
22+
from cuda.tile._exception import Loc, FunctionDesc
2323
from cuda.tile._ir.ir import Var
2424

2525

@@ -97,6 +97,7 @@ def jump_str(self):
9797

9898
@dataclass
9999
class Function:
100+
desc: FunctionDesc
100101
body: Block
101102
param_names: tuple[str, ...]
102103
param_locs: tuple[Loc, ...]

src/cuda/tile/_ir/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def nested_block(name: str, loc: Loc, params: Sequence[Var] = (),
327327
new_loop_info = loop_info or prev_builder.loop_info
328328
new_if_else_info = if_else_info or prev_builder.if_else_info
329329
scope = prev_builder.scope
330+
loc = loc.with_call_site(scope.call_site)
330331
with Builder(prev_builder.ir_ctx, loc, scope, new_loop_info, new_if_else_info) as builder, \
331332
scope.local.enter_branch():
332333
yield block
@@ -410,6 +411,7 @@ def __setitem__(self, key, value):
410411
class Scope:
411412
local: LocalScope
412413
frozen_globals: Mapping[str, Any]
414+
call_site: Loc | None
413415

414416

415417
class Builder:

src/cuda/tile/_ir2bytecode.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import functools
6-
import inspect
76
import os
87
from contextlib import contextmanager
98
from typing import Dict, Tuple, Sequence, Any, List, Optional, Iterator, Set
@@ -14,7 +13,8 @@
1413
import cuda.tile._bytecode as bc
1514
from cuda.tile._compiler_options import CompilerOptions
1615
from cuda.tile._debug import CUDA_TILE_TESTING_DISABLE_DIV
17-
from cuda.tile._exception import TileInternalError, TileError, ConstFoldNotImplementedError
16+
from cuda.tile._exception import TileInternalError, TileError, ConstFoldNotImplementedError, \
17+
FunctionDesc
1818
from cuda.tile._ir.ir import Block, Loc, Var, IRContext
1919
from cuda.tile._ir.ops_utils import (
2020
padding_mode_to_bytecode, rounding_mode_to_bytecode,
@@ -472,27 +472,24 @@ def __init__(self, debug_attr_table: bc.DebugAttrTable, linkage_name: str, anony
472472
self._linkage_name = linkage_name
473473
self._anonymize = anonymize
474474

475-
def get_subprogram(self, pyfunc) -> bc.DebugAttrId:
475+
def get_subprogram(self, func_desc: FunctionDesc) -> bc.DebugAttrId:
476476
try:
477-
return self._subprogram_cache[pyfunc]
477+
return self._subprogram_cache[func_desc]
478478
except KeyError:
479479
pass
480480

481-
func_name = pyfunc.__name__
482-
func_filename = inspect.getfile(pyfunc)
483-
_, func_line = inspect.findsource(pyfunc)
484-
func_dirname, func_basename = os.path.split(func_filename)
481+
func_dirname, func_basename = os.path.split(func_desc.filename)
485482
file_attr = self._debug_attr_table.file(func_basename, func_dirname)
486483
compile_unit_attr = self._debug_attr_table.compile_unit(file_attr)
487484
ret = self._debug_attr_table.subprogram(
488485
file=file_attr,
489-
line=func_line,
490-
name=func_name,
486+
line=func_desc.line,
487+
name=func_desc.name,
491488
linkage_name=self._linkage_name,
492489
compile_unit=compile_unit_attr,
493-
scope_line=func_line,
490+
scope_line=func_desc.line,
494491
)
495-
self._subprogram_cache[pyfunc] = ret
492+
self._subprogram_cache[func_desc] = ret
496493
return ret
497494

498495
def get_debugattr(self, loc: Loc) -> bc.DebugAttrId:

src/cuda/tile/_passes/ast2hir.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
from typing import List, Sequence, Optional, Any, Dict, Type, Callable
1111

1212
from cuda.tile import _datatype as datatype
13-
from cuda.tile._exception import TileSyntaxError, Loc
13+
from cuda.tile._exception import TileSyntaxError, Loc, FunctionDesc
1414
from cuda.tile._ir.ir import IRContext, Var
1515
from cuda.tile._ir import hir
1616

1717

1818
def get_function_hir(pyfunc: Callable,
1919
ir_ctx: IRContext,
20-
call_site: Optional[Loc]) -> hir.Function:
20+
entry_point: bool) -> hir.Function:
2121
# Get the original function from the decorated function if it exists.
2222
pyfunc = getattr(pyfunc, "__wrapped__", pyfunc)
2323

@@ -56,14 +56,17 @@ def get_function_hir(pyfunc: Callable,
5656
for name, cell in zip(pyfunc.__code__.co_freevars, pyfunc.__closure__):
5757
func_globals[name] = cell.cell_contents
5858

59-
ctx = _Context(inspect.getfile(pyfunc), first_line, call_site, pyfunc, ir_ctx)
59+
filename = inspect.getfile(pyfunc)
60+
desc = FunctionDesc(func_def.name, filename, first_line)
61+
ctx = _Context(filename, first_line, desc, entry_point, ir_ctx)
6062
assert isinstance(func_def, ast.FunctionDef)
6163
body = _ast2hir(func_def, ctx)
6264
all_ast_args = _get_all_parameters(func_def, ctx)
6365
param_names = tuple(p.arg for p in all_ast_args)
6466
param_locs = tuple(ctx.get_loc(p) for p in all_ast_args)
6567
body.stored_names.update(param_names)
66-
return hir.Function(body, param_names, param_locs, func_globals)
68+
69+
return hir.Function(desc, body, param_names, param_locs, func_globals)
6770

6871

6972
# Translate the 1-based line number of the chunk we passed to the AST parser
@@ -81,13 +84,12 @@ class LoopKind(Enum):
8184

8285

8386
class _Context:
84-
def __init__(self, filename: str, first_line: int, call_site: Optional[Loc],
85-
function: Callable, ir_ctx: IRContext):
87+
def __init__(self, filename: str, first_line: int, function_desc: FunctionDesc,
88+
entry_point: bool, ir_ctx: IRContext):
8689
self.filename = filename
8790
self.first_line = first_line
88-
self.entry_point = call_site is None
89-
self.call_site = call_site
90-
self.function = function
91+
self.function_desc = function_desc
92+
self.entry_point = entry_point
9193
self.parent_loops: List[LoopKind] = []
9294
self.current_loc = Loc.unknown()
9395
self.current_block: Optional[hir.Block] = None
@@ -145,8 +147,7 @@ def get_loc(self, node: ast.AST) -> Loc:
145147
# Subtract 1 from the column offset to correct for an extra level
146148
# of indentation we inserted for the dummy "if True" block.
147149
return Loc(line_no, node.col_offset - 1, self.filename,
148-
last_line_no, node.end_col_offset - 1, self.function,
149-
self.call_site)
150+
last_line_no, node.end_col_offset - 1, self.function_desc)
150151

151152
def syntax_error(self, message: str, loc=None) -> TileSyntaxError:
152153
if loc is None:

src/cuda/tile/_passes/hir2ir.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import sys
77
from contextlib import contextmanager
88
from dataclasses import dataclass
9-
from typing import Any, Callable, Sequence
9+
from typing import Any, Sequence
1010

1111
from .ast2hir import get_function_hir
1212
from .. import TileTypeError
13-
from .._exception import Loc, TileSyntaxError, TileInternalError, TileError
13+
from .._exception import Loc, TileSyntaxError, TileInternalError, TileError, TileRecursionError
1414
from .._ir import hir, ir
1515
from .._ir.ir import Var, IRContext, Argument, Scope, LocalScope
1616
from .._ir.op_impl import op_implementations, impl
@@ -20,10 +20,13 @@
2020
from .._ir.typing_support import get_signature
2121

2222

23+
MAX_RECURSION_DEPTH = 50
24+
25+
2326
def hir2ir(func_hir: hir.Function,
2427
args: tuple[Argument, ...],
2528
ir_ctx: IRContext) -> ir.Block:
26-
scope = _create_scope(func_hir, ir_ctx)
29+
scope = _create_scope(func_hir, ir_ctx, call_site=None)
2730
ir_params = tuple(scope.local.redefine(name, loc)
2831
for name, loc in zip(func_hir.param_names, func_hir.param_locs, strict=True))
2932
preamble = []
@@ -50,9 +53,9 @@ def hir2ir(func_hir: hir.Function,
5053
return ret
5154

5255

53-
def _create_scope(func_hir: hir.Function, ir_ctx: IRContext):
56+
def _create_scope(func_hir: hir.Function, ir_ctx: IRContext, call_site: Loc | None) -> Scope:
5457
local_scope = LocalScope(func_hir.body.stored_names, ir_ctx)
55-
return Scope(local_scope, func_hir.frozen_globals)
58+
return Scope(local_scope, func_hir.frozen_globals, call_site)
5659

5760

5861
def dispatch_hir_block(block: hir.Block):
@@ -83,7 +86,8 @@ def _dispatch_hir_block_inner(preamble: Sequence[hir.Call],
8386
if not _dispatch_hir_calls(state, builder):
8487
return ()
8588
result_vars = tuple(_resolve_operand(x) for x in block.results)
86-
with _wrap_exceptions(block.jump_loc), builder.change_loc(block.jump_loc):
89+
loc = _add_call_site(block.jump_loc, builder)
90+
with _wrap_exceptions(loc), builder.change_loc(loc):
8791
_dispatch_hir_jump(block.jump, result_vars)
8892
except Exception:
8993
if 'CUTILEIR' in builder.ir_ctx.tile_ctx.config.log_keys:
@@ -120,7 +124,8 @@ def _dispatch_hir_jump(jump: hir.Jump,
120124
def _dispatch_hir_calls(state: _State, cur_builder: ir.Builder) -> bool:
121125
while len(state.todo_stack) > 0:
122126
with state.next_call() as call:
123-
with _wrap_exceptions(call.loc), cur_builder.change_loc(call.loc):
127+
loc = _add_call_site(call.loc, cur_builder)
128+
with _wrap_exceptions(loc), cur_builder.change_loc(loc):
124129
_dispatch_call(call, cur_builder, state.todo_stack)
125130
if cur_builder.is_terminated:
126131
# The current block has been terminated, e.g. by flattening an if-else
@@ -130,6 +135,10 @@ def _dispatch_hir_calls(state: _State, cur_builder: ir.Builder) -> bool:
130135
return True
131136

132137

138+
def _add_call_site(loc: Loc, builder: ir.Builder) -> Loc:
139+
return loc.with_call_site(builder.scope.call_site)
140+
141+
133142
@contextmanager
134143
def _wrap_exceptions(loc: Loc):
135144
with loc:
@@ -183,14 +192,14 @@ def _dispatch_call(call: hir.Call, builder: ir.Builder, todo_stack: list[hir.Cal
183192
builder.ops[i] = builder.ops[i].clone(mapper)
184193
else:
185194
# Callee is a user-defined function.
186-
_check_recursive_call(call.loc, callee)
195+
_check_recursive_call(builder.loc)
187196
sig = get_signature(callee)
188197
for param_name, param in sig.parameters.items():
189198
if param.kind in (inspect.Parameter.VAR_POSITIONAL,
190199
inspect.Parameter.VAR_KEYWORD):
191200
raise TileSyntaxError("Variadic parameters in user-defined"
192201
" functions are not supported")
193-
callee_hir = get_function_hir(callee, builder.ir_ctx, call_site=call.loc)
202+
callee_hir = get_function_hir(callee, builder.ir_ctx, entry_point=False)
194203

195204
# Since `todo_stack` is a stack, we push things backwards. First, we push identity()
196205
# calls to assign the temporary return values back to the original result variables.
@@ -202,7 +211,7 @@ def _dispatch_call(call: hir.Call, builder: ir.Builder, todo_stack: list[hir.Cal
202211
# For this purpose, we push a call to the special _set_scope stub.
203212
old_scope = builder.scope
204213
todo_stack.append(hir.Call((), _set_scope, (old_scope,), (), call.loc))
205-
builder.scope = _create_scope(callee_hir, builder.ir_ctx)
214+
builder.scope = _create_scope(callee_hir, builder.ir_ctx, call_site=builder.loc)
206215

207216
# Now push the function body.
208217
todo_stack.extend(reversed(callee_hir.body.calls))
@@ -219,11 +228,14 @@ def _is_freshly_defined(var: Var, builder: ir.Builder, first_idx: int):
219228
for r in builder.ops[i].result_vars)
220229

221230

222-
def _check_recursive_call(call_loc: Loc, callee: Callable):
231+
def _check_recursive_call(call_loc: Loc):
232+
depth = 1
223233
while call_loc is not None:
224-
if call_loc.function is callee:
225-
raise TileTypeError("Recursive function call detected")
234+
depth += 1
226235
call_loc = call_loc.call_site
236+
if depth > MAX_RECURSION_DEPTH:
237+
raise TileRecursionError(f"Maximum recursion depth ({MAX_RECURSION_DEPTH}) reached"
238+
f" while inlining a function call")
227239

228240

229241
def _get_callee_and_self(callee_var: Var) -> tuple[Any, tuple[()] | tuple[Var]]:

0 commit comments

Comments
 (0)