2828 from cuda .tile ._ir2bytecode import BytecodeContext
2929
3030
31- @dataclass
32- class RangeInfo :
33- known_step : int
34-
35-
3631class IRContext :
3732 def __init__ (self , tile_ctx : TileContext ):
3833 self ._all_vars : Dict [str , str ] = {}
@@ -41,8 +36,8 @@ def __init__(self, tile_ctx: TileContext):
4136 self .typemap : Dict [str , Type ] = dict ()
4237 self .constants : Dict [str , Any ] = dict ()
4338 self ._loose_typemap : Dict [str , Type ] = dict ()
44- self .range_infos : Dict [str , RangeInfo ] = dict ()
4539 self .tile_ctx : TileContext = tile_ctx
40+ self ._aggregate_values : Dict [str , Any ] = dict ()
4641
4742 # Make a Var with a unique name based on `name`.
4843 def make_var (self , name : str , loc : Loc , undefined : bool = False ) -> Var :
@@ -68,8 +63,8 @@ def copy_type_information(self, src: Var, dst: Var):
6863 self ._loose_typemap [dst .name ] = self ._loose_typemap [src .name ]
6964 if src .name in self .constants :
7065 self .constants [dst .name ] = self .constants [src .name ]
71- if src .name in self .range_infos :
72- self .range_infos [dst .name ] = self .range_infos [src .name ]
66+ if src .name in self ._aggregate_values :
67+ self ._aggregate_values [dst .name ] = self ._aggregate_values [src .name ]
7368
7469
7570class ConstantState (enum .Enum ):
@@ -82,26 +77,15 @@ class ConstantState(enum.Enum):
8277class PhiState :
8378 ty : Type | None = None
8479 loose_ty : Type | None = None
85- last_loc : Loc | None = None
86- constant_state : ConstantState = ConstantState .UNSET
87- constant_value : Any = None
80+ last_loc : Loc = Loc .unknown ()
81+ initial_constant_state : ConstantState = ConstantState .UNSET
8882
89- def set_nonconstant (self ):
90- self .constant_state = ConstantState .NONCONSTANT
83+ # Constant propagation state, per aggregate item.
84+ # We initialize it to None because we don't know yet how many items we have.
85+ constant_state : list [ConstantState ] | None = None
86+ constant_value : list [Any ] | None = None
9187
9288 def propagate (self , src : Var , fail_eagerly : bool = False , allow_loose_typing : bool = True ):
93- # Constant propagation
94- if src .is_constant ():
95- new_const = src .get_constant ()
96- if self .constant_state == ConstantState .UNSET :
97- self .constant_state = ConstantState .MAY_BE_CONSTANT
98- self .constant_value = new_const
99- elif (self .constant_state == ConstantState .MAY_BE_CONSTANT
100- and new_const != self .constant_value ):
101- self .constant_state = ConstantState .NONCONSTANT
102- else :
103- self .set_nonconstant ()
104-
10589 # Type & loose type propagation
10690 src_ty = src .get_type_allow_invalid ()
10791 src_loose_ty = src .get_loose_type_allow_invalid () if allow_loose_typing else src_ty
@@ -128,12 +112,46 @@ def propagate(self, src: Var, fail_eagerly: bool = False, allow_loose_typing: bo
128112 if self .loose_ty != src_loose_ty :
129113 self .loose_ty = self .ty
130114
115+ # Constant propagation
116+ if isinstance (src_ty , InvalidType ):
117+ self .constant_state = None
118+ self .initial_constant_state = ConstantState .NONCONSTANT
119+ else :
120+ agg_items = tuple (src .flatten_aggregate ())
121+ if self .constant_state is None :
122+ self .constant_state = [self .initial_constant_state for _ in range (len (agg_items ))]
123+ self .constant_value = [None for _ in range (len (agg_items ))]
124+ else :
125+ # This should be true because we already checked the type.
126+ # If the type matches, it must have the same aggregate length.
127+ assert len (self .constant_state ) == len (agg_items )
128+
129+ for i , item in enumerate (agg_items ):
130+ if item .is_constant ():
131+ new_const = item .get_constant ()
132+ if self .constant_state [i ] == ConstantState .UNSET :
133+ self .constant_state [i ] = ConstantState .MAY_BE_CONSTANT
134+ self .constant_value [i ] = new_const
135+ elif (self .constant_state [i ] == ConstantState .MAY_BE_CONSTANT
136+ and new_const != self .constant_value [i ]):
137+ self .constant_state [i ] = ConstantState .NONCONSTANT
138+ else :
139+ self .constant_state [i ] = ConstantState .NONCONSTANT
140+
131141 def finalize_constant_and_loose_type (self , dst : Var ):
132- if self .constant_state == ConstantState .MAY_BE_CONSTANT :
133- dst .set_constant (self .constant_value )
142+ assert self .constant_state is not None
143+ for item , state , val in zip (dst .flatten_aggregate (),
144+ self .constant_state , self .constant_value , strict = True ):
145+ if state == ConstantState .MAY_BE_CONSTANT :
146+ item .set_constant (val )
134147 dst .set_loose_type (self .loose_ty )
135148
136149
150+ class AggregateValue :
151+ def as_tuple (self ) -> tuple ["Var" , ...]:
152+ raise NotImplementedError ()
153+
154+
137155class Var :
138156 def __init__ (self , name : str , loc : Loc , ctx : IRContext , undefined : bool = False ):
139157 self .name = name
@@ -192,16 +210,6 @@ def set_loose_type(self, ty: Type, force: bool = False):
192210 assert self .name not in self .ctx ._loose_typemap
193211 self .ctx ._loose_typemap [self .name ] = ty
194212
195- def has_range_info (self ) -> bool :
196- return self .name in self .ctx .range_infos
197-
198- def get_range_info (self ) -> RangeInfo :
199- return self .ctx .range_infos [self .name ]
200-
201- def set_range_info (self , range_info : RangeInfo ):
202- assert self .name not in self .ctx .range_infos
203- self .ctx .range_infos [self .name ] = range_info
204-
205213 def is_undefined (self ) -> bool :
206214 return self ._undefined
207215
@@ -211,14 +219,72 @@ def set_undefined(self):
211219 def get_original_name (self ) -> str :
212220 return self .ctx .get_original_name (self .name )
213221
222+ def is_aggregate (self ) -> bool :
223+ return self .name in self .ctx ._aggregate_values
224+
225+ def get_aggregate (self ) -> AggregateValue :
226+ return self .ctx ._aggregate_values [self .name ]
227+
228+ def set_aggregate (self , agg_value : AggregateValue ):
229+ self .ctx ._aggregate_values [self .name ] = agg_value
230+
231+ def flatten_aggregate (self ) -> Iterator ["Var" ]:
232+ if self .is_aggregate ():
233+ for x in self .get_aggregate ().as_tuple ():
234+ yield from x .flatten_aggregate ()
235+ else :
236+ yield self
237+
214238 def __repr__ (self ):
215239 return f"Var<{ self .name } @{ self .loc } >"
216240
217241 def __str__ (self ) -> str :
218242 return self .name
219243
220244
221- TypeResult = list [Type ] | Type
245+ @dataclass
246+ class TupleValue (AggregateValue ):
247+ items : tuple [Var , ...]
248+
249+ def as_tuple (self ) -> tuple ["Var" , ...]:
250+ return self .items
251+
252+
253+ @dataclass
254+ class RangeValue (AggregateValue ):
255+ start : Var
256+ stop : Var
257+ step : Var
258+
259+ def as_tuple (self ) -> tuple [Var , ...]:
260+ return self .start , self .stop , self .step
261+
262+
263+ @dataclass
264+ class BoundMethodValue (AggregateValue ):
265+ bound_self : Var
266+
267+ def as_tuple (self ) -> tuple [Var , ...]:
268+ return (self .bound_self ,)
269+
270+
271+ @dataclass
272+ class ArrayValue (AggregateValue ):
273+ base_ptr : Var
274+ shape : tuple [Var , ...]
275+ strides : tuple [Var , ...]
276+
277+ def as_tuple (self ) -> tuple [Var , ...]:
278+ return self .base_ptr , * self .shape , * self .strides
279+
280+
281+ @dataclass
282+ class ListValue (AggregateValue ):
283+ base_ptr : Var
284+ length : Var
285+
286+ def as_tuple (self ) -> tuple [Var , ...]:
287+ return self .base_ptr , self .length
222288
223289
224290def terminator (cls ):
@@ -268,7 +334,13 @@ def set_var(self, old_var: Var, new_var: Var):
268334def add_operation (op_class ,
269335 result_ty : Type | None | Tuple [Type | None , ...],
270336 ** attrs_and_operands ) -> Var | Tuple [Var , ...]:
271- return Builder .get_current ().add_operation (op_class , result_ty , ** attrs_and_operands )
337+ return Builder .get_current ().add_operation (op_class , result_ty , attrs_and_operands )
338+
339+
340+ def make_aggregate (value : AggregateValue ,
341+ ty : Type | None ,
342+ loose_ty : Type | None = None ):
343+ return Builder .get_current ().make_aggregate (value , ty , loose_ty )
272344
273345
274346@dataclass
@@ -427,26 +499,55 @@ def __init__(self, ctx: IRContext, loc: Loc, scope: Scope,
427499
428500 def add_operation (self , op_class ,
429501 result_ty : Type | None | Tuple [Type | None , ...],
430- ** attrs_and_operands ) -> Var | Tuple [Var , ...]:
502+ attrs_and_operands ,
503+ result : Var | Sequence [Var ] | None = None ) -> Var | Tuple [Var , ...]:
431504 assert not self .is_terminated
505+ force_type = False
432506 if isinstance (result_ty , tuple ):
433- ret = tuple (self .ir_ctx .make_temp (self ._loc ) for _ in result_ty )
434- for var , ty in zip (ret , result_ty , strict = True ):
507+ if result is None :
508+ result = tuple (self .ir_ctx .make_temp (self ._loc ) for _ in result_ty )
509+ else :
510+ result = tuple (result )
511+ assert all (isinstance (v , Var ) for v in result )
512+ force_type = True
513+
514+ for var , ty in zip (result , result_ty , strict = True ):
435515 if ty is not None :
436- var .set_type (ty )
437- if len (ret ) > 0 or op_class ._multiple_results :
438- attrs_and_operands ["result_vars" ] = ret
516+ var .set_type (ty , force = force_type )
517+ if len (result ) > 0 or op_class ._multiple_results :
518+ attrs_and_operands ["result_vars" ] = result
439519 else :
440- ret = self .ir_ctx .make_temp (self ._loc )
520+ if result is None :
521+ result = self .ir_ctx .make_temp (self ._loc )
522+ else :
523+ assert isinstance (result , Var )
524+ force_type = True
441525 if result_ty is not None :
442- ret .set_type (result_ty )
443- attrs_and_operands ["result_var" ] = ret
526+ result .set_type (result_ty , force = force_type )
527+ attrs_and_operands ["result_var" ] = result
444528
445529 new_op = op_class (** attrs_and_operands , loc = self ._loc )
446530 self ._ops .append (new_op )
447531 if new_op .is_terminator :
448532 self .is_terminated = True
449- return ret
533+ return result
534+
535+ def make_aggregate (self ,
536+ value : AggregateValue ,
537+ ty : Type | None ,
538+ loose_ty : Type | None = None ,
539+ result_var : Var | None = None ) -> Var :
540+ force_type = True
541+ if result_var is None :
542+ result_var = self .ir_ctx .make_temp (self ._loc )
543+ force_type = False
544+
545+ if ty is not None :
546+ result_var .set_type (ty , force = force_type )
547+ if loose_ty is not None :
548+ result_var .set_loose_type (ty , force = force_type )
549+ result_var .set_aggregate (value )
550+ return result_var
450551
451552 @property
452553 def ops (self ) -> list [Operation ]:
@@ -591,6 +692,12 @@ def has_side_effects(self) -> bool:
591692 return self ._has_side_effects
592693
593694 def _add_operand (self , name : str , var : Var | Tuple [Var , ...]):
695+ if isinstance (var , Var ) and var .is_aggregate () and self .op != "assign" :
696+ # Don't allow aggregate values as operands, except for arrays and lists.
697+ # All other aggregates should only exist in the HIR level.
698+ # Also make an exception for the Assign op, until we find a better way to handle it.
699+ agg_val = var .get_aggregate ()
700+ assert isinstance (agg_val , ArrayValue | ListValue )
594701 self ._operands [name ] = var
595702
596703 def update_operand (self , name : str , var : Var | Tuple [Var , ...]):
@@ -630,9 +737,10 @@ def _to_string_rhs(self) -> str:
630737 operands_str_list = []
631738 for name , val in self .operands .items ():
632739 if isinstance (val , Var ):
633- operands_str_list .append (f"{ name } ={ str (val )} " )
740+ operands_str_list .append (f"{ name } ={ var_aggregate_name (val )} " )
634741 elif isinstance (val , tuple ) and all (isinstance (v , Var ) for v in val ):
635- operands_str_list .append (f"{ name } =({ ', ' .join (str (v ) for v in val )} )" )
742+ tup_str = ', ' .join (var_aggregate_name (v ) for v in val )
743+ operands_str_list .append (f"{ name } =({ tup_str } )" )
636744 elif val is None :
637745 operands_str_list .append (f"{ name } =None" )
638746 else :
@@ -655,14 +763,6 @@ def to_string(self,
655763 indent : int = 0 ,
656764 highlight_loc : Optional [Loc ] = None ,
657765 include_loc : bool = False ) -> str :
658- def format_var (var : Var ):
659- ty = var .try_get_type ()
660- if ty is None :
661- return var .name
662- else :
663- const_prefix = "const " if var .is_constant () else ""
664- return f"{ var .name } : { const_prefix } { ty } "
665-
666766 indent_str = " " * indent
667767 lhs = (
668768 ", " .join (format_var (var ) for var in self .result_vars )
@@ -698,6 +798,24 @@ def __str__(self) -> str:
698798 return self .to_string ()
699799
700800
801+ def var_aggregate_name (var : Var ) -> str :
802+ ret = var .name
803+ if var .is_aggregate ():
804+ ret += "{" + ", " .join (x .name for x in var .flatten_aggregate ()) + "}"
805+ return ret
806+
807+
808+ def format_var (var : Var ) -> str :
809+ ret = var_aggregate_name (var )
810+
811+ ty = var .try_get_type ()
812+ if ty is not None :
813+ const_prefix = "const " if var .is_constant () else ""
814+ ret += f": { const_prefix } { ty } "
815+
816+ return ret
817+
818+
701819# TODO: no longer needed, remove by inheriting from Operation instead
702820class TypedOperation (Operation ):
703821 pass
@@ -777,14 +895,15 @@ def to_string(self,
777895 indent : int = 0 ,
778896 highlight_loc : Optional [Loc ] = None ,
779897 include_loc : bool = False ) -> str :
780- op_strings = (
898+ params = ", " .join (format_var (p ) for p in self .params )
899+ ops = "\n " .join (
781900 op .to_string (
782901 indent ,
783902 highlight_loc ,
784903 include_loc
785904 ) for op in self .operations
786905 )
787- return " \n " . join ( op_strings )
906+ return f" { ' ' * indent } ( { params } ): \n { ops } "
788907
789908 def traverse (self ) -> Iterator [Operation ]:
790909 for op in self .operations :
0 commit comments