From 3a1e94df0c2002b324533fb7674ce2f524e0ae1a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 28 Apr 2025 20:16:23 -0700 Subject: [PATCH 1/6] Add Or pattern --- onnxscript/rewriter/pattern.py | 68 ++++++++++++++++++++++++++--- onnxscript/rewriter/pattern_test.py | 33 ++++++++++++++ 2 files changed, 94 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index cfca31125f..d70146aac8 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -466,6 +466,7 @@ def __pow__(self, other): def __str__(self) -> str: return self._name if self._name is not None else "anonymous:" + str(id(self)) +OpIdentifier = Tuple[str, str, str] class NodePattern: """Represents a pattern that matches against a Node. @@ -511,7 +512,7 @@ def __init__( if isinstance(op, str) and isinstance(domain, StringConstantPattern): # TODO(rama): support overloaded operators. overload = "" - self._op_identifier: tuple[str, str, str] | None = ( + self._op_identifier: OpIdentifier | None = ( domain.value(), op, overload, @@ -535,7 +536,7 @@ def __str__(self) -> str: inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs return f"{outputs} = {qualified_op} ({inputs_and_attributes})" - def op_identifier(self) -> Tuple[str, str, str] | None: + def op_identifier(self) -> OpIdentifier | None: return self._op_identifier @property @@ -629,11 +630,6 @@ def producer(self) -> NodePattern: Var = ValuePattern -def _is_pattern_variable(x: Any) -> bool: - # The derived classes of ValuePattern represent constant patterns and node-output patterns. - return type(x) is ValuePattern - - class AnyValue(ValuePattern): """Represents a pattern that matches against any value.""" @@ -717,6 +713,59 @@ def matches(self, value: ir.Value, match: MatchResult) -> MatchResult: def __str__(self) -> str: return str(self._value) +class OrValue(ValuePattern): + """Represents a (restricted) form of value pattern disjunction.""" + + def __init__(self, values: Sequence[ValuePattern], name: str | None = None) -> None: + """ + Initialize an OrValue pattern. + + Args: + values (Sequence[ValuePattern]): A sequence of value patterns to match against. + Must contain at least two alternatives. All value patterns except the last one + must have a unique producer id. This allows the pattern-matching to be deterministic, + without the need for backtracking. + name (str | None, optional): An optional variable name for the pattern. Defaults to None. + """ + super().__init__(name) + if len(values) < 2: + raise ValueError("OrValue must have at least two alternatives.") + + mapping: dict[OpIdentifier, NodeOutputPattern] = {} + for alternative in values[:-1]: + if not isinstance(alternative, NodeOutputPattern): + raise TypeError( + f"Invalid type {type(alternative)} for OrValue. Expected NodeOutputPattern." + ) + producer = alternative.producer() + id = producer.op_identifier() + if id is None: + raise ValueError( + f"Invalid producer {producer} for OrValue. Expected a NodePattern with op identifier." + ) + if id in mapping: + raise ValueError( + f"Invalid producer {producer} for OrValue. Expected a unique producer id for each alternative." + ) + mapping[id] = alternative + self._op_to_pattern = mapping + self._default_pattern = values[-1] + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> OrValue: + return OrValue([v.clone(node_map) for v in self._values], self.name) + + def get_pattern(self, value: ir.Value) -> ValuePattern: + """Returns the pattern that should be tried for the given value.""" + producer = value.producer() + if producer is not None: + id = producer.op_identifier() + if id is not None and id in self._op_to_pattern: + return self._op_to_pattern[id] + return self._default_pattern + + def __str__(self) -> str: + return f"OrValue({self._values})" + def _nodes_in_pattern(outputs: Sequence[ValuePattern]) -> list[NodePattern]: """Returns all nodes used in a pattern, given the outputs of the pattern.""" @@ -1136,6 +1185,11 @@ def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> b if value is None: return self.fail("Mismatch: Constant pattern does not match None.") return self._match_constant(pattern_value, value) + if isinstance(pattern_value, OrValue): + if value is None: + return self.fail("Mismatch: OrValue pattern does not match None.") + pattern_choice = pattern_value.get_pattern(value) + return self._match_value(pattern_choice, value) return True def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value) -> bool: diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index ce11e23c19..715c93a655 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -688,6 +688,39 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]: self.assertEqual(len(model.graph), 2) self.assertEqual([x.op_type for x in model.graph], ["Constant", "Identity"]) + def test_or_pattern(self): + def source_pattern(op, x, y, bias): + t1 = op.MatMul(x, y) + t2 = op.Add(t1, bias) + t1_or_t2 = pattern.OrValue([t1, t2]) + return op.Relu(t1_or_t2) + + def replacement(op, x, y, bias): + if bias is None: + return op.WithoutBias(x, y) + else: + return op.WithBias(x, y, bias) + + rule = pattern.RewriteRule(source_pattern, replacement) + + @script() + def test_model1(x: FLOAT[16,32], y: FLOAT[32, 16]) -> FLOAT[16, 16]: + return op.Relu(op.MatMul(x, y)) + + model_proto = test_model1.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual([x.op_type for x in model.graph], ["WithoutBias"]) + + @script() + def test_model2(x: FLOAT[16,32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16, 16]: + return op.Relu(op.Add(op.MatMul(x, y), bias)) + + model_proto = test_model2.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual([x.op_type for x in model.graph], ["WithBias"]) + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): From 158ab4ccdff5f56cd308db2b14872ea7906afade Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 28 Apr 2025 21:08:00 -0700 Subject: [PATCH 2/6] Add tag var and values --- onnxscript/rewriter/pattern.py | 53 +++++++++++++++++++++++------ onnxscript/rewriter/pattern_test.py | 16 ++++----- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index d70146aac8..f774d983a2 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -466,8 +466,10 @@ def __pow__(self, other): def __str__(self) -> str: return self._name if self._name is not None else "anonymous:" + str(id(self)) + OpIdentifier = Tuple[str, str, str] + class NodePattern: """Represents a pattern that matches against a Node. @@ -713,26 +715,51 @@ def matches(self, value: ir.Value, match: MatchResult) -> MatchResult: def __str__(self) -> str: return str(self._value) + class OrValue(ValuePattern): """Represents a (restricted) form of value pattern disjunction.""" - def __init__(self, values: Sequence[ValuePattern], name: str | None = None) -> None: + def __init__( + self, + values: Sequence[ValuePattern], + name: str | None = None, + tag_var: str | None = None, + tag_values: Sequence[Any] | None = None, + ) -> None: """ Initialize an OrValue pattern. Args: - values (Sequence[ValuePattern]): A sequence of value patterns to match against. + values: A sequence of value patterns to match against. Must contain at least two alternatives. All value patterns except the last one must have a unique producer id. This allows the pattern-matching to be deterministic, without the need for backtracking. - name (str | None, optional): An optional variable name for the pattern. Defaults to None. + name: An optional variable name for the pattern. Defaults to None. If present, + this name will be bound to the value matched by the pattern. + tag_var: An optional variable name for the tag. Defaults to None. If present, + it will be bound to a value (from tag_values) indicating which alternative was matched. + tag_values: An optional sequence of values to bind to the tag_var. Defaults to None. + If present, the length of tag_values must match the number of alternatives in values. + In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th + alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used. """ super().__init__(name) if len(values) < 2: raise ValueError("OrValue must have at least two alternatives.") + if tag_values is not None: + if tag_var is None: + raise ValueError("tag_var must be specified if tag_values is provided.") + if len(tag_values) != len(values): + raise ValueError( + "tag_values must have the same length as the number of alternatives." + ) + elif tag_var is not None: + tag_values = tuple(range(len(values))) + self._tag_var = tag_var + self._tag_values = tag_values - mapping: dict[OpIdentifier, NodeOutputPattern] = {} - for alternative in values[:-1]: + mapping: dict[OpIdentifier, tuple[Any, NodeOutputPattern]] = {} + for i, alternative in enumerate(values[:-1]): if not isinstance(alternative, NodeOutputPattern): raise TypeError( f"Invalid type {type(alternative)} for OrValue. Expected NodeOutputPattern." @@ -747,14 +774,14 @@ def __init__(self, values: Sequence[ValuePattern], name: str | None = None) -> N raise ValueError( f"Invalid producer {producer} for OrValue. Expected a unique producer id for each alternative." ) - mapping[id] = alternative + mapping[id] = (tag_values[i], alternative) self._op_to_pattern = mapping - self._default_pattern = values[-1] + self._default_pattern = (tag_values[-1], values[-1]) def clone(self, node_map: dict[NodePattern, NodePattern]) -> OrValue: return OrValue([v.clone(node_map) for v in self._values], self.name) - def get_pattern(self, value: ir.Value) -> ValuePattern: + def get_pattern(self, value: ir.Value) -> tuple[Any, ValuePattern]: """Returns the pattern that should be tried for the given value.""" producer = value.producer() if producer is not None: @@ -762,7 +789,7 @@ def get_pattern(self, value: ir.Value) -> ValuePattern: if id is not None and id in self._op_to_pattern: return self._op_to_pattern[id] return self._default_pattern - + def __str__(self) -> str: return f"OrValue({self._values})" @@ -1188,8 +1215,12 @@ def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> b if isinstance(pattern_value, OrValue): if value is None: return self.fail("Mismatch: OrValue pattern does not match None.") - pattern_choice = pattern_value.get_pattern(value) - return self._match_value(pattern_choice, value) + i, pattern_choice = pattern_value.get_pattern(value) + result = self._match_value(pattern_choice, value) + if result: + if pattern_value._tag_var is not None: + self._match.bind(pattern_value._tag_var, i) + return result return True def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value) -> bool: diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 715c93a655..ca39d6c9ab 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -692,19 +692,19 @@ def test_or_pattern(self): def source_pattern(op, x, y, bias): t1 = op.MatMul(x, y) t2 = op.Add(t1, bias) - t1_or_t2 = pattern.OrValue([t1, t2]) + t1_or_t2 = pattern.OrValue([t1, t2], tag_var="has_bias", tag_values=[False, True]) return op.Relu(t1_or_t2) - def replacement(op, x, y, bias): - if bias is None: - return op.WithoutBias(x, y) - else: + def replacement(op, x, y, bias, has_bias): + if has_bias: return op.WithBias(x, y, bias) + else: + return op.WithoutBias(x, y) rule = pattern.RewriteRule(source_pattern, replacement) @script() - def test_model1(x: FLOAT[16,32], y: FLOAT[32, 16]) -> FLOAT[16, 16]: + def test_model1(x: FLOAT[16, 32], y: FLOAT[32, 16]) -> FLOAT[16, 16]: return op.Relu(op.MatMul(x, y)) model_proto = test_model1.to_model_proto() @@ -713,13 +713,13 @@ def test_model1(x: FLOAT[16,32], y: FLOAT[32, 16]) -> FLOAT[16, 16]: self.assertEqual([x.op_type for x in model.graph], ["WithoutBias"]) @script() - def test_model2(x: FLOAT[16,32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16, 16]: + def test_model2(x: FLOAT[16, 32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16, 16]: return op.Relu(op.Add(op.MatMul(x, y), bias)) model_proto = test_model2.to_model_proto() model = ir.serde.deserialize_model(model_proto) rule.apply_to_model(model) - self.assertEqual([x.op_type for x in model.graph], ["WithBias"]) + self.assertEqual([x.op_type for x in model.graph], ["WithBias"]) class PatternBuilderTest(unittest.TestCase): From d26c750f371e0d03c18d250b4e7cc8763f881b73 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 28 Apr 2025 21:13:27 -0700 Subject: [PATCH 3/6] Cleanup property reference --- onnxscript/rewriter/pattern.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index f774d983a2..5edd18471b 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -778,6 +778,11 @@ def __init__( self._op_to_pattern = mapping self._default_pattern = (tag_values[-1], values[-1]) + @property + def tag_var(self) -> str | None: + """Returns the tag variable associated with the OrValue pattern.""" + return self._tag_var + def clone(self, node_map: dict[NodePattern, NodePattern]) -> OrValue: return OrValue([v.clone(node_map) for v in self._values], self.name) @@ -1218,8 +1223,8 @@ def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> b i, pattern_choice = pattern_value.get_pattern(value) result = self._match_value(pattern_choice, value) if result: - if pattern_value._tag_var is not None: - self._match.bind(pattern_value._tag_var, i) + if pattern_value.tag_var is not None: + self._match.bind(pattern_value.tag_var, i) return result return True From 7926bf0f883f542cde06f15a9ffb1c78c3cbca5a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 30 Apr 2025 11:00:35 -0700 Subject: [PATCH 4/6] Address lint issues --- onnxscript/rewriter/pattern.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 5edd18471b..2e94d938d1 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -753,10 +753,11 @@ def __init__( raise ValueError( "tag_values must have the same length as the number of alternatives." ) - elif tag_var is not None: + else: tag_values = tuple(range(len(values))) self._tag_var = tag_var self._tag_values = tag_values + self._values = values mapping: dict[OpIdentifier, tuple[Any, NodeOutputPattern]] = {} for i, alternative in enumerate(values[:-1]): @@ -784,7 +785,12 @@ def tag_var(self) -> str | None: return self._tag_var def clone(self, node_map: dict[NodePattern, NodePattern]) -> OrValue: - return OrValue([v.clone(node_map) for v in self._values], self.name) + return OrValue( + [v.clone(node_map) for v in self._values], + self.name, + self._tag_var, + self._tag_values, + ) def get_pattern(self, value: ir.Value) -> tuple[Any, ValuePattern]: """Returns the pattern that should be tried for the given value.""" @@ -795,9 +801,6 @@ def get_pattern(self, value: ir.Value) -> tuple[Any, ValuePattern]: return self._op_to_pattern[id] return self._default_pattern - def __str__(self) -> str: - return f"OrValue({self._values})" - def _nodes_in_pattern(outputs: Sequence[ValuePattern]) -> list[NodePattern]: """Returns all nodes used in a pattern, given the outputs of the pattern.""" From fd44167e689a1a32294a69b3d4400851ce6ddad1 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 30 Apr 2025 11:06:14 -0700 Subject: [PATCH 5/6] Use ir.OperatorIdentifier --- onnxscript/rewriter/pattern.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 2e94d938d1..115593fff0 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -18,7 +18,6 @@ MutableSequence, Protocol, Sequence, - Tuple, TypeVar, Union, ) @@ -467,9 +466,6 @@ def __str__(self) -> str: return self._name if self._name is not None else "anonymous:" + str(id(self)) -OpIdentifier = Tuple[str, str, str] - - class NodePattern: """Represents a pattern that matches against a Node. @@ -514,7 +510,7 @@ def __init__( if isinstance(op, str) and isinstance(domain, StringConstantPattern): # TODO(rama): support overloaded operators. overload = "" - self._op_identifier: OpIdentifier | None = ( + self._op_identifier: ir.OperatorIdentifier | None = ( domain.value(), op, overload, @@ -538,7 +534,7 @@ def __str__(self) -> str: inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs return f"{outputs} = {qualified_op} ({inputs_and_attributes})" - def op_identifier(self) -> OpIdentifier | None: + def op_identifier(self) -> ir.OperatorIdentifier | None: return self._op_identifier @property @@ -759,7 +755,7 @@ def __init__( self._tag_values = tag_values self._values = values - mapping: dict[OpIdentifier, tuple[Any, NodeOutputPattern]] = {} + mapping: dict[ir.OperatorIdentifier, tuple[Any, NodeOutputPattern]] = {} for i, alternative in enumerate(values[:-1]): if not isinstance(alternative, NodeOutputPattern): raise TypeError( From cbb75d727997c88ef7b215ad97972af771e547ea Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 30 Apr 2025 17:14:47 -0700 Subject: [PATCH 6/6] Update docs --- docs/api/rewriter_pattern.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api/rewriter_pattern.md b/docs/api/rewriter_pattern.md index a3f1dcbe4b..033f65bb5c 100644 --- a/docs/api/rewriter_pattern.md +++ b/docs/api/rewriter_pattern.md @@ -25,6 +25,7 @@ rewriter.pattern.NodeOutputPattern rewriter.pattern.AnyValue rewriter.pattern.Constant + rewriter.pattern.OrValue rewriter.pattern.GraphPattern rewriter.pattern.ReplacementSubgraph rewriter.pattern.ReplacementPatternFunction