Skip to content

Commit 9edceaf

Browse files
committed
Remove aggregate types (tuples etc.) from the IR
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 9af1b63 commit 9edceaf

14 files changed

Lines changed: 850 additions & 807 deletions

src/cuda/tile/_exception.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class FunctionDesc:
1616
line: int
1717

1818

19-
@dataclass(slots=True)
19+
@dataclass(slots=True, frozen=True)
2020
class Loc:
2121
line: int
2222
col: int
@@ -138,10 +138,6 @@ class TileInternalError(TileError):
138138
pass
139139

140140

141-
class ConstFoldNotImplementedError(Exception):
142-
pass
143-
144-
145141
class ConstantNotFoundError(Exception):
146142
pass
147143

src/cuda/tile/_ir/ir.py

Lines changed: 180 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@
2828
from cuda.tile._ir2bytecode import BytecodeContext
2929

3030

31-
@dataclass
32-
class RangeInfo:
33-
known_step: int
34-
35-
3631
class IRContext:
3732
def __init__(self, tile_ctx: TileContext):
3833
self._all_vars: Dict[str, str] = {}
@@ -41,8 +36,8 @@ def __init__(self, tile_ctx: TileContext):
4136
self.typemap: Dict[str, Type] = dict()
4237
self.constants: Dict[str, Any] = dict()
4338
self._loose_typemap: Dict[str, Type] = dict()
44-
self.range_infos: Dict[str, RangeInfo] = dict()
4539
self.tile_ctx: TileContext = tile_ctx
40+
self._aggregate_values: Dict[str, Any] = dict()
4641

4742
# Make a Var with a unique name based on `name`.
4843
def make_var(self, name: str, loc: Loc, undefined: bool = False) -> Var:
@@ -68,8 +63,8 @@ def copy_type_information(self, src: Var, dst: Var):
6863
self._loose_typemap[dst.name] = self._loose_typemap[src.name]
6964
if src.name in self.constants:
7065
self.constants[dst.name] = self.constants[src.name]
71-
if src.name in self.range_infos:
72-
self.range_infos[dst.name] = self.range_infos[src.name]
66+
if src.name in self._aggregate_values:
67+
self._aggregate_values[dst.name] = self._aggregate_values[src.name]
7368

7469

7570
class ConstantState(enum.Enum):
@@ -82,26 +77,15 @@ class ConstantState(enum.Enum):
8277
class PhiState:
8378
ty: Type | None = None
8479
loose_ty: Type | None = None
85-
last_loc: Loc | None = None
86-
constant_state: ConstantState = ConstantState.UNSET
87-
constant_value: Any = None
80+
last_loc: Loc = Loc.unknown()
81+
initial_constant_state: ConstantState = ConstantState.UNSET
8882

89-
def set_nonconstant(self):
90-
self.constant_state = ConstantState.NONCONSTANT
83+
# Constant propagation state, per aggregate item.
84+
# We initialize it to None because we don't know yet how many items we have.
85+
constant_state: list[ConstantState] | None = None
86+
constant_value: list[Any] | None = None
9187

9288
def propagate(self, src: Var, fail_eagerly: bool = False, allow_loose_typing: bool = True):
93-
# Constant propagation
94-
if src.is_constant():
95-
new_const = src.get_constant()
96-
if self.constant_state == ConstantState.UNSET:
97-
self.constant_state = ConstantState.MAY_BE_CONSTANT
98-
self.constant_value = new_const
99-
elif (self.constant_state == ConstantState.MAY_BE_CONSTANT
100-
and new_const != self.constant_value):
101-
self.constant_state = ConstantState.NONCONSTANT
102-
else:
103-
self.set_nonconstant()
104-
10589
# Type & loose type propagation
10690
src_ty = src.get_type_allow_invalid()
10791
src_loose_ty = src.get_loose_type_allow_invalid() if allow_loose_typing else src_ty
@@ -128,12 +112,46 @@ def propagate(self, src: Var, fail_eagerly: bool = False, allow_loose_typing: bo
128112
if self.loose_ty != src_loose_ty:
129113
self.loose_ty = self.ty
130114

115+
# Constant propagation
116+
if isinstance(src_ty, InvalidType):
117+
self.constant_state = None
118+
self.initial_constant_state = ConstantState.NONCONSTANT
119+
else:
120+
agg_items = tuple(src.flatten_aggregate())
121+
if self.constant_state is None:
122+
self.constant_state = [self.initial_constant_state for _ in range(len(agg_items))]
123+
self.constant_value = [None for _ in range(len(agg_items))]
124+
else:
125+
# This should be true because we already checked the type.
126+
# If the type matches, it must have the same aggregate length.
127+
assert len(self.constant_state) == len(agg_items)
128+
129+
for i, item in enumerate(agg_items):
130+
if item.is_constant():
131+
new_const = item.get_constant()
132+
if self.constant_state[i] == ConstantState.UNSET:
133+
self.constant_state[i] = ConstantState.MAY_BE_CONSTANT
134+
self.constant_value[i] = new_const
135+
elif (self.constant_state[i] == ConstantState.MAY_BE_CONSTANT
136+
and new_const != self.constant_value[i]):
137+
self.constant_state[i] = ConstantState.NONCONSTANT
138+
else:
139+
self.constant_state[i] = ConstantState.NONCONSTANT
140+
131141
def finalize_constant_and_loose_type(self, dst: Var):
132-
if self.constant_state == ConstantState.MAY_BE_CONSTANT:
133-
dst.set_constant(self.constant_value)
142+
assert self.constant_state is not None
143+
for item, state, val in zip(dst.flatten_aggregate(),
144+
self.constant_state, self.constant_value, strict=True):
145+
if state == ConstantState.MAY_BE_CONSTANT:
146+
item.set_constant(val)
134147
dst.set_loose_type(self.loose_ty)
135148

136149

150+
class AggregateValue:
151+
def as_tuple(self) -> tuple["Var", ...]:
152+
raise NotImplementedError()
153+
154+
137155
class Var:
138156
def __init__(self, name: str, loc: Loc, ctx: IRContext, undefined: bool = False):
139157
self.name = name
@@ -192,16 +210,6 @@ def set_loose_type(self, ty: Type, force: bool = False):
192210
assert self.name not in self.ctx._loose_typemap
193211
self.ctx._loose_typemap[self.name] = ty
194212

195-
def has_range_info(self) -> bool:
196-
return self.name in self.ctx.range_infos
197-
198-
def get_range_info(self) -> RangeInfo:
199-
return self.ctx.range_infos[self.name]
200-
201-
def set_range_info(self, range_info: RangeInfo):
202-
assert self.name not in self.ctx.range_infos
203-
self.ctx.range_infos[self.name] = range_info
204-
205213
def is_undefined(self) -> bool:
206214
return self._undefined
207215

@@ -211,14 +219,72 @@ def set_undefined(self):
211219
def get_original_name(self) -> str:
212220
return self.ctx.get_original_name(self.name)
213221

222+
def is_aggregate(self) -> bool:
223+
return self.name in self.ctx._aggregate_values
224+
225+
def get_aggregate(self) -> AggregateValue:
226+
return self.ctx._aggregate_values[self.name]
227+
228+
def set_aggregate(self, agg_value: AggregateValue):
229+
self.ctx._aggregate_values[self.name] = agg_value
230+
231+
def flatten_aggregate(self) -> Iterator["Var"]:
232+
if self.is_aggregate():
233+
for x in self.get_aggregate().as_tuple():
234+
yield from x.flatten_aggregate()
235+
else:
236+
yield self
237+
214238
def __repr__(self):
215239
return f"Var<{self.name} @{self.loc}>"
216240

217241
def __str__(self) -> str:
218242
return self.name
219243

220244

221-
TypeResult = list[Type] | Type
245+
@dataclass
246+
class TupleValue(AggregateValue):
247+
items: tuple[Var, ...]
248+
249+
def as_tuple(self) -> tuple["Var", ...]:
250+
return self.items
251+
252+
253+
@dataclass
254+
class RangeValue(AggregateValue):
255+
start: Var
256+
stop: Var
257+
step: Var
258+
259+
def as_tuple(self) -> tuple[Var, ...]:
260+
return self.start, self.stop, self.step
261+
262+
263+
@dataclass
264+
class BoundMethodValue(AggregateValue):
265+
bound_self: Var
266+
267+
def as_tuple(self) -> tuple[Var, ...]:
268+
return (self.bound_self,)
269+
270+
271+
@dataclass
272+
class ArrayValue(AggregateValue):
273+
base_ptr: Var
274+
shape: tuple[Var, ...]
275+
strides: tuple[Var, ...]
276+
277+
def as_tuple(self) -> tuple[Var, ...]:
278+
return self.base_ptr, *self.shape, *self.strides
279+
280+
281+
@dataclass
282+
class ListValue(AggregateValue):
283+
base_ptr: Var
284+
length: Var
285+
286+
def as_tuple(self) -> tuple[Var, ...]:
287+
return self.base_ptr, self.length
222288

223289

224290
def terminator(cls):
@@ -268,7 +334,13 @@ def set_var(self, old_var: Var, new_var: Var):
268334
def add_operation(op_class,
269335
result_ty: Type | None | Tuple[Type | None, ...],
270336
**attrs_and_operands) -> Var | Tuple[Var, ...]:
271-
return Builder.get_current().add_operation(op_class, result_ty, **attrs_and_operands)
337+
return Builder.get_current().add_operation(op_class, result_ty, attrs_and_operands)
338+
339+
340+
def make_aggregate(value: AggregateValue,
341+
ty: Type | None,
342+
loose_ty: Type | None = None):
343+
return Builder.get_current().make_aggregate(value, ty, loose_ty)
272344

273345

274346
@dataclass
@@ -427,26 +499,55 @@ def __init__(self, ctx: IRContext, loc: Loc, scope: Scope,
427499

428500
def add_operation(self, op_class,
429501
result_ty: Type | None | Tuple[Type | None, ...],
430-
**attrs_and_operands) -> Var | Tuple[Var, ...]:
502+
attrs_and_operands,
503+
result: Var | Sequence[Var] | None = None) -> Var | Tuple[Var, ...]:
431504
assert not self.is_terminated
505+
force_type = False
432506
if isinstance(result_ty, tuple):
433-
ret = tuple(self.ir_ctx.make_temp(self._loc) for _ in result_ty)
434-
for var, ty in zip(ret, result_ty, strict=True):
507+
if result is None:
508+
result = tuple(self.ir_ctx.make_temp(self._loc) for _ in result_ty)
509+
else:
510+
result = tuple(result)
511+
assert all(isinstance(v, Var) for v in result)
512+
force_type = True
513+
514+
for var, ty in zip(result, result_ty, strict=True):
435515
if ty is not None:
436-
var.set_type(ty)
437-
if len(ret) > 0 or op_class._multiple_results:
438-
attrs_and_operands["result_vars"] = ret
516+
var.set_type(ty, force=force_type)
517+
if len(result) > 0 or op_class._multiple_results:
518+
attrs_and_operands["result_vars"] = result
439519
else:
440-
ret = self.ir_ctx.make_temp(self._loc)
520+
if result is None:
521+
result = self.ir_ctx.make_temp(self._loc)
522+
else:
523+
assert isinstance(result, Var)
524+
force_type = True
441525
if result_ty is not None:
442-
ret.set_type(result_ty)
443-
attrs_and_operands["result_var"] = ret
526+
result.set_type(result_ty, force=force_type)
527+
attrs_and_operands["result_var"] = result
444528

445529
new_op = op_class(**attrs_and_operands, loc=self._loc)
446530
self._ops.append(new_op)
447531
if new_op.is_terminator:
448532
self.is_terminated = True
449-
return ret
533+
return result
534+
535+
def make_aggregate(self,
536+
value: AggregateValue,
537+
ty: Type | None,
538+
loose_ty: Type | None = None,
539+
result_var: Var | None = None) -> Var:
540+
force_type = True
541+
if result_var is None:
542+
result_var = self.ir_ctx.make_temp(self._loc)
543+
force_type = False
544+
545+
if ty is not None:
546+
result_var.set_type(ty, force=force_type)
547+
if loose_ty is not None:
548+
result_var.set_loose_type(ty, force=force_type)
549+
result_var.set_aggregate(value)
550+
return result_var
450551

451552
@property
452553
def ops(self) -> list[Operation]:
@@ -591,6 +692,12 @@ def has_side_effects(self) -> bool:
591692
return self._has_side_effects
592693

593694
def _add_operand(self, name: str, var: Var | Tuple[Var, ...]):
695+
if isinstance(var, Var) and var.is_aggregate() and self.op != "assign":
696+
# Don't allow aggregate values as operands, except for arrays and lists.
697+
# All other aggregates should only exist in the HIR level.
698+
# Also make an exception for the Assign op, until we find a better way to handle it.
699+
agg_val = var.get_aggregate()
700+
assert isinstance(agg_val, ArrayValue | ListValue)
594701
self._operands[name] = var
595702

596703
def update_operand(self, name: str, var: Var | Tuple[Var, ...]):
@@ -630,9 +737,10 @@ def _to_string_rhs(self) -> str:
630737
operands_str_list = []
631738
for name, val in self.operands.items():
632739
if isinstance(val, Var):
633-
operands_str_list.append(f"{name}={str(val)}")
740+
operands_str_list.append(f"{name}={var_aggregate_name(val)}")
634741
elif isinstance(val, tuple) and all(isinstance(v, Var) for v in val):
635-
operands_str_list.append(f"{name}=({', '.join(str(v) for v in val)})")
742+
tup_str = ', '.join(var_aggregate_name(v) for v in val)
743+
operands_str_list.append(f"{name}=({tup_str})")
636744
elif val is None:
637745
operands_str_list.append(f"{name}=None")
638746
else:
@@ -655,14 +763,6 @@ def to_string(self,
655763
indent: int = 0,
656764
highlight_loc: Optional[Loc] = None,
657765
include_loc: bool = False) -> str:
658-
def format_var(var: Var):
659-
ty = var.try_get_type()
660-
if ty is None:
661-
return var.name
662-
else:
663-
const_prefix = "const " if var.is_constant() else ""
664-
return f"{var.name}: {const_prefix}{ty}"
665-
666766
indent_str = " " * indent
667767
lhs = (
668768
", ".join(format_var(var) for var in self.result_vars)
@@ -698,6 +798,24 @@ def __str__(self) -> str:
698798
return self.to_string()
699799

700800

801+
def var_aggregate_name(var: Var) -> str:
802+
ret = var.name
803+
if var.is_aggregate():
804+
ret += "{" + ", ".join(x.name for x in var.flatten_aggregate()) + "}"
805+
return ret
806+
807+
808+
def format_var(var: Var) -> str:
809+
ret = var_aggregate_name(var)
810+
811+
ty = var.try_get_type()
812+
if ty is not None:
813+
const_prefix = "const " if var.is_constant() else ""
814+
ret += f": {const_prefix}{ty}"
815+
816+
return ret
817+
818+
701819
# TODO: no longer needed, remove by inheriting from Operation instead
702820
class TypedOperation(Operation):
703821
pass
@@ -777,14 +895,15 @@ def to_string(self,
777895
indent: int = 0,
778896
highlight_loc: Optional[Loc] = None,
779897
include_loc: bool = False) -> str:
780-
op_strings = (
898+
params = ", ".join(format_var(p) for p in self.params)
899+
ops = "\n".join(
781900
op.to_string(
782901
indent,
783902
highlight_loc,
784903
include_loc
785904
) for op in self.operations
786905
)
787-
return "\n".join(op_strings)
906+
return f"{' ' * indent}({params}):\n{ops}"
788907

789908
def traverse(self) -> Iterator[Operation]:
790909
for op in self.operations:

0 commit comments

Comments
 (0)