1818 MutableSequence ,
1919 Protocol ,
2020 Sequence ,
21- Tuple ,
2221 TypeVar ,
2322 Union ,
2423)
@@ -511,7 +510,7 @@ def __init__(
511510 if isinstance (op , str ) and isinstance (domain , StringConstantPattern ):
512511 # TODO(rama): support overloaded operators.
513512 overload = ""
514- self ._op_identifier : tuple [ str , str , str ] | None = (
513+ self ._op_identifier : ir . OperatorIdentifier | None = (
515514 domain .value (),
516515 op ,
517516 overload ,
@@ -535,7 +534,7 @@ def __str__(self) -> str:
535534 inputs_and_attributes = f"{ inputs } , { attributes } " if attributes else inputs
536535 return f"{ outputs } = { qualified_op } ({ inputs_and_attributes } )"
537536
538- def op_identifier (self ) -> Tuple [ str , str , str ] | None :
537+ def op_identifier (self ) -> ir . OperatorIdentifier | None :
539538 return self ._op_identifier
540539
541540 @property
@@ -629,11 +628,6 @@ def producer(self) -> NodePattern:
629628Var = ValuePattern
630629
631630
632- def _is_pattern_variable (x : Any ) -> bool :
633- # The derived classes of ValuePattern represent constant patterns and node-output patterns.
634- return type (x ) is ValuePattern
635-
636-
637631class AnyValue (ValuePattern ):
638632 """Represents a pattern that matches against any value."""
639633
@@ -718,6 +712,92 @@ def __str__(self) -> str:
718712 return str (self ._value )
719713
720714
715+ class OrValue (ValuePattern ):
716+ """Represents a (restricted) form of value pattern disjunction."""
717+
718+ def __init__ (
719+ self ,
720+ values : Sequence [ValuePattern ],
721+ name : str | None = None ,
722+ tag_var : str | None = None ,
723+ tag_values : Sequence [Any ] | None = None ,
724+ ) -> None :
725+ """
726+ Initialize an OrValue pattern.
727+
728+ Args:
729+ values: A sequence of value patterns to match against.
730+ Must contain at least two alternatives. All value patterns except the last one
731+ must have a unique producer id. This allows the pattern-matching to be deterministic,
732+ without the need for backtracking.
733+ name: An optional variable name for the pattern. Defaults to None. If present,
734+ this name will be bound to the value matched by the pattern.
735+ tag_var: An optional variable name for the tag. Defaults to None. If present,
736+ it will be bound to a value (from tag_values) indicating which alternative was matched.
737+ tag_values: An optional sequence of values to bind to the tag_var. Defaults to None.
738+ If present, the length of tag_values must match the number of alternatives in values.
739+ In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th
740+ alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used.
741+ """
742+ super ().__init__ (name )
743+ if len (values ) < 2 :
744+ raise ValueError ("OrValue must have at least two alternatives." )
745+ if tag_values is not None :
746+ if tag_var is None :
747+ raise ValueError ("tag_var must be specified if tag_values is provided." )
748+ if len (tag_values ) != len (values ):
749+ raise ValueError (
750+ "tag_values must have the same length as the number of alternatives."
751+ )
752+ else :
753+ tag_values = tuple (range (len (values )))
754+ self ._tag_var = tag_var
755+ self ._tag_values = tag_values
756+ self ._values = values
757+
758+ mapping : dict [ir .OperatorIdentifier , tuple [Any , NodeOutputPattern ]] = {}
759+ for i , alternative in enumerate (values [:- 1 ]):
760+ if not isinstance (alternative , NodeOutputPattern ):
761+ raise TypeError (
762+ f"Invalid type { type (alternative )} for OrValue. Expected NodeOutputPattern."
763+ )
764+ producer = alternative .producer ()
765+ id = producer .op_identifier ()
766+ if id is None :
767+ raise ValueError (
768+ f"Invalid producer { producer } for OrValue. Expected a NodePattern with op identifier."
769+ )
770+ if id in mapping :
771+ raise ValueError (
772+ f"Invalid producer { producer } for OrValue. Expected a unique producer id for each alternative."
773+ )
774+ mapping [id ] = (tag_values [i ], alternative )
775+ self ._op_to_pattern = mapping
776+ self ._default_pattern = (tag_values [- 1 ], values [- 1 ])
777+
778+ @property
779+ def tag_var (self ) -> str | None :
780+ """Returns the tag variable associated with the OrValue pattern."""
781+ return self ._tag_var
782+
783+ def clone (self , node_map : dict [NodePattern , NodePattern ]) -> OrValue :
784+ return OrValue (
785+ [v .clone (node_map ) for v in self ._values ],
786+ self .name ,
787+ self ._tag_var ,
788+ self ._tag_values ,
789+ )
790+
791+ def get_pattern (self , value : ir .Value ) -> tuple [Any , ValuePattern ]:
792+ """Returns the pattern that should be tried for the given value."""
793+ producer = value .producer ()
794+ if producer is not None :
795+ id = producer .op_identifier ()
796+ if id is not None and id in self ._op_to_pattern :
797+ return self ._op_to_pattern [id ]
798+ return self ._default_pattern
799+
800+
721801def _nodes_in_pattern (outputs : Sequence [ValuePattern ]) -> list [NodePattern ]:
722802 """Returns all nodes used in a pattern, given the outputs of the pattern."""
723803 node_patterns : list [NodePattern ] = []
@@ -1136,6 +1216,15 @@ def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> b
11361216 if value is None :
11371217 return self .fail ("Mismatch: Constant pattern does not match None." )
11381218 return self ._match_constant (pattern_value , value )
1219+ if isinstance (pattern_value , OrValue ):
1220+ if value is None :
1221+ return self .fail ("Mismatch: OrValue pattern does not match None." )
1222+ i , pattern_choice = pattern_value .get_pattern (value )
1223+ result = self ._match_value (pattern_choice , value )
1224+ if result :
1225+ if pattern_value .tag_var is not None :
1226+ self ._match .bind (pattern_value .tag_var , i )
1227+ return result
11391228 return True
11401229
11411230 def _match_node_output (self , pattern_value : NodeOutputPattern , value : ir .Value ) -> bool :
0 commit comments