66import sys
77from contextlib import contextmanager
88from dataclasses import dataclass
9- from typing import Any , Callable , Sequence
9+ from typing import Any , Sequence
1010
1111from .ast2hir import get_function_hir
1212from .. import TileTypeError
13- from .._exception import Loc , TileSyntaxError , TileInternalError , TileError
13+ from .._exception import Loc , TileSyntaxError , TileInternalError , TileError , TileRecursionError
1414from .._ir import hir , ir
1515from .._ir .ir import Var , IRContext , Argument , Scope , LocalScope
1616from .._ir .op_impl import op_implementations , impl
2020from .._ir .typing_support import get_signature
2121
2222
23+ MAX_RECURSION_DEPTH = 50
24+
25+
2326def hir2ir (func_hir : hir .Function ,
2427 args : tuple [Argument , ...],
2528 ir_ctx : IRContext ) -> ir .Block :
26- scope = _create_scope (func_hir , ir_ctx )
29+ scope = _create_scope (func_hir , ir_ctx , call_site = None )
2730 ir_params = tuple (scope .local .redefine (name , loc )
2831 for name , loc in zip (func_hir .param_names , func_hir .param_locs , strict = True ))
2932 preamble = []
@@ -50,9 +53,9 @@ def hir2ir(func_hir: hir.Function,
5053 return ret
5154
5255
53- def _create_scope (func_hir : hir .Function , ir_ctx : IRContext ) :
56+ def _create_scope (func_hir : hir .Function , ir_ctx : IRContext , call_site : Loc | None ) -> Scope :
5457 local_scope = LocalScope (func_hir .body .stored_names , ir_ctx )
55- return Scope (local_scope , func_hir .frozen_globals )
58+ return Scope (local_scope , func_hir .frozen_globals , call_site )
5659
5760
5861def dispatch_hir_block (block : hir .Block ):
@@ -83,7 +86,8 @@ def _dispatch_hir_block_inner(preamble: Sequence[hir.Call],
8386 if not _dispatch_hir_calls (state , builder ):
8487 return ()
8588 result_vars = tuple (_resolve_operand (x ) for x in block .results )
86- with _wrap_exceptions (block .jump_loc ), builder .change_loc (block .jump_loc ):
89+ loc = _add_call_site (block .jump_loc , builder )
90+ with _wrap_exceptions (loc ), builder .change_loc (loc ):
8791 _dispatch_hir_jump (block .jump , result_vars )
8892 except Exception :
8993 if 'CUTILEIR' in builder .ir_ctx .tile_ctx .config .log_keys :
@@ -120,7 +124,8 @@ def _dispatch_hir_jump(jump: hir.Jump,
120124def _dispatch_hir_calls (state : _State , cur_builder : ir .Builder ) -> bool :
121125 while len (state .todo_stack ) > 0 :
122126 with state .next_call () as call :
123- with _wrap_exceptions (call .loc ), cur_builder .change_loc (call .loc ):
127+ loc = _add_call_site (call .loc , cur_builder )
128+ with _wrap_exceptions (loc ), cur_builder .change_loc (loc ):
124129 _dispatch_call (call , cur_builder , state .todo_stack )
125130 if cur_builder .is_terminated :
126131 # The current block has been terminated, e.g. by flattening an if-else
@@ -130,6 +135,10 @@ def _dispatch_hir_calls(state: _State, cur_builder: ir.Builder) -> bool:
130135 return True
131136
132137
138+ def _add_call_site (loc : Loc , builder : ir .Builder ) -> Loc :
139+ return loc .with_call_site (builder .scope .call_site )
140+
141+
133142@contextmanager
134143def _wrap_exceptions (loc : Loc ):
135144 with loc :
@@ -183,14 +192,14 @@ def _dispatch_call(call: hir.Call, builder: ir.Builder, todo_stack: list[hir.Cal
183192 builder .ops [i ] = builder .ops [i ].clone (mapper )
184193 else :
185194 # Callee is a user-defined function.
186- _check_recursive_call (call .loc , callee )
195+ _check_recursive_call (builder .loc )
187196 sig = get_signature (callee )
188197 for param_name , param in sig .parameters .items ():
189198 if param .kind in (inspect .Parameter .VAR_POSITIONAL ,
190199 inspect .Parameter .VAR_KEYWORD ):
191200 raise TileSyntaxError ("Variadic parameters in user-defined"
192201 " functions are not supported" )
193- callee_hir = get_function_hir (callee , builder .ir_ctx , call_site = call . loc )
202+ callee_hir = get_function_hir (callee , builder .ir_ctx , entry_point = False )
194203
195204 # Since `todo_stack` is a stack, we push things backwards. First, we push identity()
196205 # calls to assign the temporary return values back to the original result variables.
@@ -202,7 +211,7 @@ def _dispatch_call(call: hir.Call, builder: ir.Builder, todo_stack: list[hir.Cal
202211 # For this purpose, we push a call to the special _set_scope stub.
203212 old_scope = builder .scope
204213 todo_stack .append (hir .Call ((), _set_scope , (old_scope ,), (), call .loc ))
205- builder .scope = _create_scope (callee_hir , builder .ir_ctx )
214+ builder .scope = _create_scope (callee_hir , builder .ir_ctx , call_site = builder . loc )
206215
207216 # Now push the function body.
208217 todo_stack .extend (reversed (callee_hir .body .calls ))
@@ -219,11 +228,14 @@ def _is_freshly_defined(var: Var, builder: ir.Builder, first_idx: int):
219228 for r in builder .ops [i ].result_vars )
220229
221230
222- def _check_recursive_call (call_loc : Loc , callee : Callable ):
231+ def _check_recursive_call (call_loc : Loc ):
232+ depth = 1
223233 while call_loc is not None :
224- if call_loc .function is callee :
225- raise TileTypeError ("Recursive function call detected" )
234+ depth += 1
226235 call_loc = call_loc .call_site
236+ if depth > MAX_RECURSION_DEPTH :
237+ raise TileRecursionError (f"Maximum recursion depth ({ MAX_RECURSION_DEPTH } ) reached"
238+ f" while inlining a function call" )
227239
228240
229241def _get_callee_and_self (callee_var : Var ) -> tuple [Any , tuple [()] | tuple [Var ]]:
0 commit comments