diff --git a/docs/api/rewriter_pattern.md b/docs/api/rewriter_pattern.md index a3f1dcbe4b..033f65bb5c 100644 --- a/docs/api/rewriter_pattern.md +++ b/docs/api/rewriter_pattern.md @@ -25,6 +25,7 @@ rewriter.pattern.NodeOutputPattern rewriter.pattern.AnyValue rewriter.pattern.Constant + rewriter.pattern.OrValue rewriter.pattern.GraphPattern rewriter.pattern.ReplacementSubgraph rewriter.pattern.ReplacementPatternFunction diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index cfca31125f..115593fff0 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -18,7 +18,6 @@ MutableSequence, Protocol, Sequence, - Tuple, TypeVar, Union, ) @@ -511,7 +510,7 @@ def __init__( if isinstance(op, str) and isinstance(domain, StringConstantPattern): # TODO(rama): support overloaded operators. overload = "" - self._op_identifier: tuple[str, str, str] | None = ( + self._op_identifier: ir.OperatorIdentifier | None = ( domain.value(), op, overload, @@ -535,7 +534,7 @@ def __str__(self) -> str: inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs return f"{outputs} = {qualified_op} ({inputs_and_attributes})" - def op_identifier(self) -> Tuple[str, str, str] | None: + def op_identifier(self) -> ir.OperatorIdentifier | None: return self._op_identifier @property @@ -629,11 +628,6 @@ def producer(self) -> NodePattern: Var = ValuePattern -def _is_pattern_variable(x: Any) -> bool: - # The derived classes of ValuePattern represent constant patterns and node-output patterns. - return type(x) is ValuePattern - - class AnyValue(ValuePattern): """Represents a pattern that matches against any value.""" @@ -718,6 +712,92 @@ def __str__(self) -> str: return str(self._value) +class OrValue(ValuePattern): + """Represents a (restricted) form of value pattern disjunction.""" + + def __init__( + self, + values: Sequence[ValuePattern], + name: str | None = None, + tag_var: str | None = None, + tag_values: Sequence[Any] | None = None, + ) -> None: + """ + Initialize an OrValue pattern. + + Args: + values: A sequence of value patterns to match against. + Must contain at least two alternatives. All value patterns except the last one + must have a unique producer id. This allows the pattern-matching to be deterministic, + without the need for backtracking. + name: An optional variable name for the pattern. Defaults to None. If present, + this name will be bound to the value matched by the pattern. + tag_var: An optional variable name for the tag. Defaults to None. If present, + it will be bound to a value (from tag_values) indicating which alternative was matched. + tag_values: An optional sequence of values to bind to the tag_var. Defaults to None. + If present, the length of tag_values must match the number of alternatives in values. + In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th + alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used. + """ + super().__init__(name) + if len(values) < 2: + raise ValueError("OrValue must have at least two alternatives.") + if tag_values is not None: + if tag_var is None: + raise ValueError("tag_var must be specified if tag_values is provided.") + if len(tag_values) != len(values): + raise ValueError( + "tag_values must have the same length as the number of alternatives." + ) + else: + tag_values = tuple(range(len(values))) + self._tag_var = tag_var + self._tag_values = tag_values + self._values = values + + mapping: dict[ir.OperatorIdentifier, tuple[Any, NodeOutputPattern]] = {} + for i, alternative in enumerate(values[:-1]): + if not isinstance(alternative, NodeOutputPattern): + raise TypeError( + f"Invalid type {type(alternative)} for OrValue. Expected NodeOutputPattern." + ) + producer = alternative.producer() + id = producer.op_identifier() + if id is None: + raise ValueError( + f"Invalid producer {producer} for OrValue. Expected a NodePattern with op identifier." + ) + if id in mapping: + raise ValueError( + f"Invalid producer {producer} for OrValue. Expected a unique producer id for each alternative." + ) + mapping[id] = (tag_values[i], alternative) + self._op_to_pattern = mapping + self._default_pattern = (tag_values[-1], values[-1]) + + @property + def tag_var(self) -> str | None: + """Returns the tag variable associated with the OrValue pattern.""" + return self._tag_var + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> OrValue: + return OrValue( + [v.clone(node_map) for v in self._values], + self.name, + self._tag_var, + self._tag_values, + ) + + def get_pattern(self, value: ir.Value) -> tuple[Any, ValuePattern]: + """Returns the pattern that should be tried for the given value.""" + producer = value.producer() + if producer is not None: + id = producer.op_identifier() + if id is not None and id in self._op_to_pattern: + return self._op_to_pattern[id] + return self._default_pattern + + def _nodes_in_pattern(outputs: Sequence[ValuePattern]) -> list[NodePattern]: """Returns all nodes used in a pattern, given the outputs of the pattern.""" node_patterns: list[NodePattern] = [] @@ -1136,6 +1216,15 @@ def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> b if value is None: return self.fail("Mismatch: Constant pattern does not match None.") return self._match_constant(pattern_value, value) + if isinstance(pattern_value, OrValue): + if value is None: + return self.fail("Mismatch: OrValue pattern does not match None.") + i, pattern_choice = pattern_value.get_pattern(value) + result = self._match_value(pattern_choice, value) + if result: + if pattern_value.tag_var is not None: + self._match.bind(pattern_value.tag_var, i) + return result return True def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value) -> bool: diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index ce11e23c19..ca39d6c9ab 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -688,6 +688,39 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]: self.assertEqual(len(model.graph), 2) self.assertEqual([x.op_type for x in model.graph], ["Constant", "Identity"]) + def test_or_pattern(self): + def source_pattern(op, x, y, bias): + t1 = op.MatMul(x, y) + t2 = op.Add(t1, bias) + t1_or_t2 = pattern.OrValue([t1, t2], tag_var="has_bias", tag_values=[False, True]) + return op.Relu(t1_or_t2) + + def replacement(op, x, y, bias, has_bias): + if has_bias: + return op.WithBias(x, y, bias) + else: + return op.WithoutBias(x, y) + + rule = pattern.RewriteRule(source_pattern, replacement) + + @script() + def test_model1(x: FLOAT[16, 32], y: FLOAT[32, 16]) -> FLOAT[16, 16]: + return op.Relu(op.MatMul(x, y)) + + model_proto = test_model1.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual([x.op_type for x in model.graph], ["WithoutBias"]) + + @script() + def test_model2(x: FLOAT[16, 32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16, 16]: + return op.Relu(op.Add(op.MatMul(x, y), bias)) + + model_proto = test_model2.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual([x.op_type for x in model.graph], ["WithBias"]) + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self):