|
15 | 15 |
|
16 | 16 |
|
17 | 17 | import enum |
| 18 | +import inspect |
18 | 19 | import threading |
19 | 20 | from dataclasses import dataclass |
20 | 21 | from textwrap import indent |
21 | | -from typing import Any, Set, Mapping |
| 22 | +from typing import Any, Set, OrderedDict |
22 | 23 |
|
23 | 24 | from cuda.tile._exception import Loc, FunctionDesc |
24 | 25 |
|
@@ -119,15 +120,81 @@ def jump_str(self): |
119 | 120 | return f"{self.jump._value_}{results_str} # Line {self.jump_loc.line}" |
120 | 121 |
|
121 | 122 |
|
122 | | -@dataclass |
| 123 | +# Resolved position of a named variable. |
| 124 | +# |
| 125 | +# If `depth` is -1, then this is a global variable, and the `index` |
| 126 | +# points into `hir.Function.frozen_global_names` / `hir.Function.frozen_global_values`. |
| 127 | +# If both `depth` and `index` are -1, then this is an indication that no variable |
| 128 | +# with the given name has been found. |
| 129 | +# |
| 130 | +# If `depth` is non-negative but is less than the depth of the current function, |
| 131 | +# then this is a captured local variable of an enclosing function. |
| 132 | +# In this case, the local variable name can be found in |
| 133 | +# `hir.Function.enclosing_funcs[depth].local_names[index]. |
| 134 | +# |
| 135 | +# If `depth` is equal to the depth of the current function, then this is a local variable. |
| 136 | +# In this case, the local variable name can be found in `hir.Function.local_names[index]`. |
| 137 | +@dataclass(frozen=True) |
| 138 | +class ResolvedName: |
| 139 | + depth: int |
| 140 | + index: int |
| 141 | + |
| 142 | + |
| 143 | +UNKNOWN_NAME = ResolvedName(-1, -1) |
| 144 | + |
| 145 | + |
| 146 | +@dataclass(eq=False, repr=False) |
123 | 147 | class Function: |
124 | 148 | desc: FunctionDesc |
125 | 149 | body: Block |
126 | | - param_names: tuple[str, ...] |
| 150 | + |
| 151 | + # For nested functions, this signature is synthesized from AST. In this case, |
| 152 | + # default values of parameters are represented with ClosureDefaultPlaceholder objects. |
| 153 | + signature: inspect.Signature |
| 154 | + |
| 155 | + # Names of all local variables defined in this function. |
| 156 | + local_names: tuple[str, ...] |
| 157 | + |
| 158 | + # For each function parameter, an index into `local_names` |
| 159 | + param_local_indices: tuple[int, ...] |
| 160 | + |
| 161 | + # For each function parameter, its location (same length as `param_local_indices`) |
127 | 162 | param_locs: tuple[Loc, ...] |
128 | | - frozen_globals: Mapping[str, Any] |
| 163 | + |
| 164 | + # Names of all global-like variables accessible by this function. |
| 165 | + frozen_global_names: tuple[str, ...] |
| 166 | + |
| 167 | + # Values of all global-like variables access by this function. |
| 168 | + # We assume that they don't change after compilation, hence the name "frozen". |
| 169 | + # Same length as `frozen_global_names`. |
| 170 | + frozen_global_values: tuple[Any, ...] |
| 171 | + |
| 172 | + # Strict upper bound on hir.Value.id inside this function definition, |
| 173 | + # excluding any nested functions. Can be used to pre-allocate lists etc. |
129 | 174 | value_id_upper_bound: int |
130 | 175 |
|
| 176 | + # List of all function definitions nested directly inside this function. |
| 177 | + nested_functions: "tuple[Function, ...]" |
| 178 | + |
| 179 | + # Names of all variables that are ever loaded by this function |
| 180 | + loaded_names: tuple[str, ...] |
| 181 | + |
| 182 | + # For each variable name used in this function (i.e. loaded from or stored to), |
| 183 | + # its resolved position in the scope. |
| 184 | + used_names: OrderedDict[str, ResolvedName] |
| 185 | + |
| 186 | + # Pre-computed view of `used_names`: for each non-negative `depth` strictly less than this |
| 187 | + # function's depth, the list of all local variable indices captured by this function. |
| 188 | + # Empty when this is a top-level function. |
| 189 | + captures_by_depth: tuple[tuple[int, ...], ...] |
| 190 | + |
| 191 | + # Sequence of enclosing function definitions, outermost first. |
| 192 | + # Empty when this is a top-level function. |
| 193 | + enclosing_funcs: "tuple[Function, ...]" |
| 194 | + |
| 195 | + def __repr__(self): |
| 196 | + return f"<HIR for function {self.desc}>" |
| 197 | + |
131 | 198 |
|
132 | 199 | @dataclass |
133 | 200 | class _OperandFormatter: |
@@ -157,3 +224,4 @@ def build_tuple(*items): ... # Makes a tuple (i.e. returns `items`) |
157 | 224 | def identity(x): ... # Identity function (i.e. returns `x`) |
158 | 225 | def store_var(name, value, /): ... # Store into a named variable |
159 | 226 | def load_var(name, /): ... # Load from a named variable |
| 227 | +def make_closure(func_hir: Function, /, *default_values): ... |
0 commit comments