Skip to content

Commit 548f99c

Browse files
committed
Add support for closures (nested functions and lambdas)
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 8b6d006 commit 548f99c

13 files changed

Lines changed: 845 additions & 152 deletions

File tree

changelog.d/closures.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Added support for nested functions and lambdas.

src/cuda/tile/_compile.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def _get_final_ir(pyfunc, args, tile_context) -> ir.Function:
7777
ir_ctx = ir.IRContext(tile_context)
7878
func_hir: hir.Function = get_function_hir(pyfunc, entry_point=True)
7979

80-
ir_args = _bind_kernel_arguments(func_hir.param_names, args, get_constant_annotations(pyfunc))
80+
ir_args = _bind_kernel_arguments(tuple(func_hir.signature.parameters),
81+
args, get_constant_annotations(pyfunc))
8182
func_body = hir2ir(func_hir, ir_args, ir_ctx)
8283
eliminate_assign_ops(func_body)
8384
dead_code_elimination_pass(func_body)
@@ -162,7 +163,7 @@ def _compiler_crash_dump(func_ir: ir.Function,
162163

163164
artifacts = {
164165
f"{func_ir.name}.bytecode": bytes(bytecode_buf),
165-
f"{func_ir.name}.cutileir": f"{func_ir.to_string(include_loc=False)}\n",
166+
f"{func_ir.name}.cutileir": f"{func_ir.body.to_string(include_loc=False)}\n",
166167
"debug_info.txt": debug_info,
167168
}
168169

@@ -184,7 +185,7 @@ def compile_tile(pyfunc,
184185

185186
if 'CUTILEIR' in context.config.log_keys:
186187
code = (f"==== CuTile IR for {func_ir.name}==== \n\n"
187-
f"{func_ir.to_string(include_loc=False)}\n\n")
188+
f"{func_ir.body.to_string(include_loc=False)}\n\n")
188189
print(f'\n{code}', file=sys.stderr)
189190

190191
sm_arch = get_sm_arch()

src/cuda/tile/_exception.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44
import dataclasses
55
import linecache
6+
import os.path
67
import re
78
from dataclasses import dataclass
89
from typing import Optional
@@ -11,10 +12,20 @@
1112

1213
@dataclass(eq=False, frozen=True)
1314
class FunctionDesc:
14-
name: str
15+
name: str | None
1516
filename: str
1617
line: int
1718

19+
def __str__(self):
20+
return f"'{self.name}' @{self.filename}:{self.line}"
21+
22+
def short_str(self):
23+
if self.name is None:
24+
base_name = os.path.basename(self.filename)
25+
return f"<lambda at {base_name}:{self.line}>"
26+
else:
27+
return f"<function {self.name}>"
28+
1829

1930
@dataclass(slots=True, frozen=True)
2031
class Loc:

src/cuda/tile/_ir/hir.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616

1717
import enum
18+
import inspect
1819
import threading
1920
from dataclasses import dataclass
2021
from textwrap import indent
21-
from typing import Any, Set, Mapping
22+
from typing import Any, Set, OrderedDict
2223

2324
from cuda.tile._exception import Loc, FunctionDesc
2425

@@ -119,15 +120,81 @@ def jump_str(self):
119120
return f"{self.jump._value_}{results_str} # Line {self.jump_loc.line}"
120121

121122

122-
@dataclass
123+
# Resolved position of a named variable.
124+
#
125+
# If `depth` is -1, then this is a global variable, and the `index`
126+
# points into `hir.Function.frozen_global_names` / `hir.Function.frozen_global_values`.
127+
# If both `depth` and `index` are -1, then this is an indication that no variable
128+
# with the given name has been found.
129+
#
130+
# If `depth` is non-negative but is less than the depth of the current function,
131+
# then this is a captured local variable of an enclosing function.
132+
# In this case, the local variable name can be found in
133+
# `hir.Function.enclosing_funcs[depth].local_names[index].
134+
#
135+
# If `depth` is equal to the depth of the current function, then this is a local variable.
136+
# In this case, the local variable name can be found in `hir.Function.local_names[index]`.
137+
@dataclass(frozen=True)
138+
class ResolvedName:
139+
depth: int
140+
index: int
141+
142+
143+
UNKNOWN_NAME = ResolvedName(-1, -1)
144+
145+
146+
@dataclass(eq=False, repr=False)
123147
class Function:
124148
desc: FunctionDesc
125149
body: Block
126-
param_names: tuple[str, ...]
150+
151+
# For nested functions, this signature is synthesized from AST. In this case,
152+
# default values of parameters are represented with ClosureDefaultPlaceholder objects.
153+
signature: inspect.Signature
154+
155+
# Names of all local variables defined in this function.
156+
local_names: tuple[str, ...]
157+
158+
# For each function parameter, an index into `local_names`
159+
param_local_indices: tuple[int, ...]
160+
161+
# For each function parameter, its location (same length as `param_local_indices`)
127162
param_locs: tuple[Loc, ...]
128-
frozen_globals: Mapping[str, Any]
163+
164+
# Names of all global-like variables accessible by this function.
165+
frozen_global_names: tuple[str, ...]
166+
167+
# Values of all global-like variables access by this function.
168+
# We assume that they don't change after compilation, hence the name "frozen".
169+
# Same length as `frozen_global_names`.
170+
frozen_global_values: tuple[Any, ...]
171+
172+
# Strict upper bound on hir.Value.id inside this function definition,
173+
# excluding any nested functions. Can be used to pre-allocate lists etc.
129174
value_id_upper_bound: int
130175

176+
# List of all function definitions nested directly inside this function.
177+
nested_functions: "tuple[Function, ...]"
178+
179+
# Names of all variables that are ever loaded by this function
180+
loaded_names: tuple[str, ...]
181+
182+
# For each variable name used in this function (i.e. loaded from or stored to),
183+
# its resolved position in the scope.
184+
used_names: OrderedDict[str, ResolvedName]
185+
186+
# Pre-computed view of `used_names`: for each non-negative `depth` strictly less than this
187+
# function's depth, the list of all local variable indices captured by this function.
188+
# Empty when this is a top-level function.
189+
captures_by_depth: tuple[tuple[int, ...], ...]
190+
191+
# Sequence of enclosing function definitions, outermost first.
192+
# Empty when this is a top-level function.
193+
enclosing_funcs: "tuple[Function, ...]"
194+
195+
def __repr__(self):
196+
return f"<HIR for function {self.desc}>"
197+
131198

132199
@dataclass
133200
class _OperandFormatter:
@@ -157,3 +224,4 @@ def build_tuple(*items): ... # Makes a tuple (i.e. returns `items`)
157224
def identity(x): ... # Identity function (i.e. returns `x`)
158225
def store_var(name, value, /): ... # Store into a named variable
159226
def load_var(name, /): ... # Load from a named variable
227+
def make_closure(func_hir: Function, /, *default_values): ...

src/cuda/tile/_ir/ir.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,31 @@ def as_tuple(self) -> tuple[Var, ...]:
286286
return self.base_ptr, self.length
287287

288288

289+
@dataclass
290+
class ClosureValue(AggregateValue):
291+
# Default values of parameters. These need to be carried by the closure's value
292+
# because default expressions are evaluated at definition time, not when the closure is called.
293+
# Should have the same length as the corresponding `ClosureTy.default_value_types`.
294+
default_values: tuple[Var, ...]
295+
296+
# Tuple of the same length as `ty.func_hir.enclosing_functions`
297+
# and `ty.frozen_capture_types_by_depth`, where `ty` is the `ClosureTy` of this closure.
298+
#
299+
# For each depth `i`, `frozen_captures_by_depth[i]` is either:
300+
# - None: means the enclosing function's LocalScope is still live;
301+
# - tuple[Var, ...]: means the enclosing function's LocalScope is no longer live.
302+
# The tuple contains the final values of the variables captured from the enclosing
303+
# function's scope. Its length should be the same as `ty.func_hir.captures_by_depth`.
304+
frozen_captures_by_depth: tuple[tuple[Var, ...] | None, ...]
305+
306+
def as_tuple(self) -> tuple["Var", ...]:
307+
return (
308+
*self.default_values,
309+
*(v for values in self.frozen_captures_by_depth
310+
if values is not None for v in values)
311+
)
312+
313+
289314
def terminator(cls):
290315
cls._is_terminator = True
291316
return cls

0 commit comments

Comments
 (0)