44
55from __future__ import annotations
66
7+ import dataclasses
78import enum
89import itertools
910import threading
10- from collections import OrderedDict
1111from collections import defaultdict
1212from collections .abc import Mapping
1313from contextlib import contextmanager
14- from copy import copy
1514from dataclasses import dataclass
1615from types import MappingProxyType
1716from typing import (
@@ -316,11 +315,6 @@ def as_tuple(self) -> tuple["Var", ...]:
316315 )
317316
318317
319- def terminator (cls ):
320- cls ._is_terminator = True
321- return cls
322-
323-
324318class MemoryEffect (enum .IntEnum ):
325319 # Int value assigned here is meaningful.
326320 # It implies the relative strength of memory effects.
@@ -330,18 +324,6 @@ class MemoryEffect(enum.IntEnum):
330324 STORE = 2
331325
332326
333- def memory_effect (eff : MemoryEffect ):
334- def decorate (cls ):
335- cls .memory_effect = eff
336- return cls
337- return decorate
338-
339-
340- def has_multiple_results (cls ):
341- cls ._multiple_results = True
342- return cls
343-
344-
345327class Mapper :
346328 def __init__ (self , ctx : IRContext , preserve_vars : bool = False ):
347329 self ._ctx = ctx
@@ -450,8 +432,8 @@ def add_operation(self, op_class,
450432 for var , ty in zip (result , result_ty , strict = True ):
451433 if ty is not None :
452434 var .set_type (ty , force = force_type )
453- if len ( result ) > 0 or op_class . _multiple_results :
454- attrs_and_operands [ " result_vars" ] = result
435+
436+ result_vars = result
455437 else :
456438 if result is None :
457439 result = self .ir_ctx .make_temp (self ._loc )
@@ -460,9 +442,10 @@ def add_operation(self, op_class,
460442 force_type = True
461443 if result_ty is not None :
462444 result .set_type (result_ty , force = force_type )
463- attrs_and_operands ["result_var" ] = result
464445
465- new_op = op_class (** attrs_and_operands , loc = self ._loc )
446+ result_vars = (result ,)
447+
448+ new_op = op_class (** attrs_and_operands , loc = self ._loc , result_vars = result_vars )
466449 self ._ops .append (new_op )
467450 if new_op .is_terminator :
468451 self .is_terminated = True
@@ -529,7 +512,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
529512
530513
531514@contextmanager
532- def nested_block (loc : Loc , reduction_body : bool = False , scan_body : bool = False ):
515+ def enter_nested_block (loc : Loc , reduction_body : bool = False , scan_body : bool = False ):
533516 prev_builder = Builder .get_current ()
534517 block = Block (prev_builder .ir_ctx , loc = loc )
535518 with Builder (prev_builder .ir_ctx , loc ,
@@ -546,66 +529,132 @@ class _CurrentBuilder(threading.local):
546529_current_builder = _CurrentBuilder ()
547530
548531
532+ class _FieldKind (enum .IntEnum ):
533+ OPERAND = 0
534+ ATTRIBUTE = 1
535+ NESTED_BLOCK = 2
536+
537+
538+ _FIELD_METADATA_KEY = "operation_field_kind"
539+
540+
541+ def attribute (* , default = dataclasses .MISSING ) -> dataclasses .Field :
542+ return dataclasses .field (default = default , metadata = {_FIELD_METADATA_KEY : _FieldKind .ATTRIBUTE },
543+ kw_only = True )
544+
545+
546+ def operand (* , default = dataclasses .MISSING ) -> dataclasses .Field :
547+ return dataclasses .field (default = default , metadata = {_FIELD_METADATA_KEY : _FieldKind .OPERAND },
548+ kw_only = True )
549+
550+
551+ def nested_block () -> dataclasses .Field :
552+ return dataclasses .field (metadata = {_FIELD_METADATA_KEY : _FieldKind .NESTED_BLOCK },
553+ kw_only = True )
554+
555+
556+ def _get_result_vars_tuple_for_single_result_op (self ):
557+ return (self .result_var ,)
558+
559+
560+ @dataclass (eq = False )
549561class Operation :
550- memory_effect = MemoryEffect .NONE
551- _multiple_results = False
552-
553- def __init__ (
554- self ,
555- op : str ,
556- operands : dict [str , Optional [Var | Tuple [Var , ...]]],
557- result_vars : Sequence [Var ],
558- attributes : Optional [Dict [str , Any ]] = None ,
559- nested_blocks : Optional [Sequence [Block ]] = None ,
560- loc : Loc = Loc .unknown (),
561- ):
562- self .op = op
563- self .result_vars = result_vars or []
564- self .attributes = attributes or {}
565- self .nested_blocks = nested_blocks or []
566- self .loc = loc
562+ result_vars : tuple [Var , ...]
563+ loc : Loc
564+
565+ def __init_subclass__ (cls ,
566+ opcode : str ,
567+ terminator : bool = False ,
568+ memory_effect : MemoryEffect = MemoryEffect .NONE ):
569+ cls ._opcode = opcode
570+ cls ._is_terminator = terminator
571+ cls .memory_effect = memory_effect
572+
573+ operand_names = []
574+ attribute_names = []
575+ nested_block_names = []
576+ for field_name in cls .__annotations__ .keys ():
577+ f = getattr (cls , field_name , None )
578+ kind = f .metadata .get (_FIELD_METADATA_KEY ) if isinstance (f , dataclasses .Field ) else None
579+ if kind == _FieldKind .OPERAND :
580+ operand_names .append (field_name )
581+ elif kind == _FieldKind .ATTRIBUTE :
582+ attribute_names .append (field_name )
583+ elif kind == _FieldKind .NESTED_BLOCK :
584+ nested_block_names .append (field_name )
585+ else :
586+ raise TypeError (f"Field { field_name } of { cls } must be annotated with either"
587+ f" operand(), attribute() or nested_block()" )
567588
568- self ._operands = OrderedDict ()
569- for k , v in operands .items ():
570- self ._add_operand (k , v )
571- self ._is_terminator = getattr (self .__class__ , "_is_terminator" , False )
572- self ._parent_block = None
589+ cls ._operand_names = tuple (operand_names )
590+ cls ._attribute_names = tuple (attribute_names )
591+ cls ._nested_block_names = tuple (nested_block_names )
592+
593+ def __post_init__ (self ):
594+ for var in self .all_inputs ():
595+ assert isinstance (var , Var | tuple ) or var is None
596+ if isinstance (var , tuple ):
597+ assert all (isinstance (x , Var ) for x in var )
598+
599+ if isinstance (var , Var ) and var .is_aggregate () and self .op != "assign" :
600+ # Don't allow aggregate values as operands, except for arrays and lists.
601+ # All other aggregates should only exist in the HIR level.
602+ # Also make an exception for the Assign op, until we find a better way to handle it.
603+ agg_val = var .get_aggregate ()
604+ assert isinstance (agg_val , ArrayValue | ListValue )
605+
606+ for nb in self .nested_blocks :
607+ assert isinstance (nb , Block )
573608
574609 def clone (self , mapper : Mapper ) -> Operation :
575610 result_vars = mapper .clone_vars (self .result_vars )
576611 return self ._clone_impl (mapper , result_vars )
577612
578613 def _clone_impl (self , mapper : Mapper , result_vars : Sequence [Var ]) -> Operation :
579- new_nested_blocks = []
580- for old_block in self .nested_blocks :
614+ new_fields = {}
615+
616+ for name in self ._attribute_names :
617+ new_fields [name ] = getattr (self , name )
618+
619+ for name in self ._operand_names :
620+ var = getattr (self , name )
621+ if isinstance (var , Var ):
622+ new_var = mapper .get_var (var )
623+ elif var is None :
624+ new_var = None
625+ else :
626+ new_var = tuple (mapper .get_var (v ) for v in var )
627+ new_fields [name ] = new_var
628+
629+ for name in self ._nested_block_names :
630+ old_block = getattr (self , name )
581631 new_block = Block (old_block .ctx , old_block .loc )
582632 new_block .params = mapper .clone_vars (old_block .params )
583633 for old_op in old_block :
584634 new_block .append (old_op .clone (mapper ))
585- new_nested_blocks . append ( new_block )
635+ new_fields [ name ] = new_block
586636
587- ret = copy (self )
588- ret ._operands = OrderedDict ()
589- for name , var in self ._operands .items ():
590- if isinstance (var , Var ):
591- ret ._operands [name ] = mapper .get_var (var )
592- elif var is None :
593- ret ._operands [name ] = None
594- else :
595- ret ._operands [name ] = tuple (mapper .get_var (v ) for v in var )
637+ return type (self )(result_vars = tuple (result_vars ), loc = self .loc , ** new_fields )
596638
597- ret .attributes = dict (ret .attributes )
598- ret .result_vars = result_vars
599- ret .parent_block = None
600- ret .nested_blocks = new_nested_blocks
601- return ret
639+ @property
640+ def op (self ) -> str :
641+ return self ._opcode
602642
603643 @property
604644 def operands (self ) -> Mapping [str , Var | Tuple [Var , ...]]:
605- return MappingProxyType (self ._operands )
645+ return MappingProxyType ({name : getattr (self , name ) for name in self ._operand_names })
646+
647+ @property
648+ def attributes (self ):
649+ return MappingProxyType ({name : getattr (self , name ) for name in self ._attribute_names })
650+
651+ @property
652+ def nested_blocks (self ):
653+ return tuple (getattr (self , name ) for name in self ._nested_block_names )
606654
607655 def all_inputs (self ) -> Iterator [Var ]:
608- for x in self ._operands .values ():
656+ for name in self ._operand_names :
657+ x = getattr (self , name )
609658 if isinstance (x , tuple ):
610659 yield from iter (x )
611660 elif x is not None :
@@ -615,28 +664,6 @@ def all_inputs(self) -> Iterator[Var]:
615664 def is_terminator (self ) -> bool :
616665 return self ._is_terminator
617666
618- def _add_operand (self , name : str , var : Var | Tuple [Var , ...]):
619- if isinstance (var , Var ) and var .is_aggregate () and self .op != "assign" :
620- # Don't allow aggregate values as operands, except for arrays and lists.
621- # All other aggregates should only exist in the HIR level.
622- # Also make an exception for the Assign op, until we find a better way to handle it.
623- agg_val = var .get_aggregate ()
624- assert isinstance (agg_val , ArrayValue | ListValue )
625- self ._operands [name ] = var
626-
627- def update_operand (self , name : str , var : Var | Tuple [Var , ...]):
628- self ._add_operand (name , var )
629-
630- def __getattr__ (self , name : str ) -> Any :
631- if name == "__setstate__" :
632- raise AttributeError (name )
633-
634- if name in self .operands :
635- return self .operands [name ]
636- if name in self .attributes :
637- return self .attributes [name ]
638- raise AttributeError (f"{ self .__class__ .__name__ } has no operand or attribute { name } " )
639-
640667 @property
641668 def result_var (self ) -> Var :
642669 if len (self .result_vars ) != 1 :
@@ -740,11 +767,6 @@ def format_var(var: Var) -> str:
740767 return ret
741768
742769
743- # TODO: no longer needed, remove by inheriting from Operation instead
744- class TypedOperation (Operation ):
745- pass
746-
747-
748770class Block :
749771 def __init__ (self , ctx : IRContext , loc : Loc ):
750772 self .ctx = ctx
0 commit comments