Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/rewriter_pattern.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
rewriter.pattern.NodeOutputPattern
rewriter.pattern.AnyValue
rewriter.pattern.Constant
rewriter.pattern.OrValue
rewriter.pattern.GraphPattern
rewriter.pattern.ReplacementSubgraph
rewriter.pattern.ReplacementPatternFunction
Expand Down
105 changes: 97 additions & 8 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
MutableSequence,
Protocol,
Sequence,
Tuple,
TypeVar,
Union,
)
Expand Down Expand Up @@ -511,7 +510,7 @@ def __init__(
if isinstance(op, str) and isinstance(domain, StringConstantPattern):
# TODO(rama): support overloaded operators.
overload = ""
self._op_identifier: tuple[str, str, str] | None = (
self._op_identifier: ir.OperatorIdentifier | None = (
domain.value(),
op,
overload,
Expand All @@ -535,7 +534,7 @@ def __str__(self) -> str:
inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs
return f"{outputs} = {qualified_op} ({inputs_and_attributes})"

def op_identifier(self) -> Tuple[str, str, str] | None:
def op_identifier(self) -> ir.OperatorIdentifier | None:
return self._op_identifier

@property
Expand Down Expand Up @@ -629,11 +628,6 @@ def producer(self) -> NodePattern:
Var = ValuePattern


def _is_pattern_variable(x: Any) -> bool:
# The derived classes of ValuePattern represent constant patterns and node-output patterns.
return type(x) is ValuePattern


class AnyValue(ValuePattern):
"""Represents a pattern that matches against any value."""

Expand Down Expand Up @@ -718,6 +712,92 @@ def __str__(self) -> str:
return str(self._value)


class OrValue(ValuePattern):
"""Represents a (restricted) form of value pattern disjunction."""

def __init__(
self,
values: Sequence[ValuePattern],
name: str | None = None,
tag_var: str | None = None,
tag_values: Sequence[Any] | None = None,
) -> None:
"""
Initialize an OrValue pattern.

Args:
values: A sequence of value patterns to match against.
Must contain at least two alternatives. All value patterns except the last one
must have a unique producer id. This allows the pattern-matching to be deterministic,
without the need for backtracking.
name: An optional variable name for the pattern. Defaults to None. If present,
this name will be bound to the value matched by the pattern.
tag_var: An optional variable name for the tag. Defaults to None. If present,
it will be bound to a value (from tag_values) indicating which alternative was matched.
tag_values: An optional sequence of values to bind to the tag_var. Defaults to None.
If present, the length of tag_values must match the number of alternatives in values.
In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th
alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used.
"""
super().__init__(name)
if len(values) < 2:
raise ValueError("OrValue must have at least two alternatives.")
if tag_values is not None:
if tag_var is None:
raise ValueError("tag_var must be specified if tag_values is provided.")
if len(tag_values) != len(values):
raise ValueError(
"tag_values must have the same length as the number of alternatives."
)
else:
tag_values = tuple(range(len(values)))
self._tag_var = tag_var
self._tag_values = tag_values
self._values = values

mapping: dict[ir.OperatorIdentifier, tuple[Any, NodeOutputPattern]] = {}
for i, alternative in enumerate(values[:-1]):
if not isinstance(alternative, NodeOutputPattern):
raise TypeError(
f"Invalid type {type(alternative)} for OrValue. Expected NodeOutputPattern."
)
producer = alternative.producer()
id = producer.op_identifier()
if id is None:
raise ValueError(
f"Invalid producer {producer} for OrValue. Expected a NodePattern with op identifier."
)
if id in mapping:
raise ValueError(
f"Invalid producer {producer} for OrValue. Expected a unique producer id for each alternative."
)
mapping[id] = (tag_values[i], alternative)
Comment thread Fixed
self._op_to_pattern = mapping
self._default_pattern = (tag_values[-1], values[-1])
Comment thread Fixed

@property
def tag_var(self) -> str | None:
"""Returns the tag variable associated with the OrValue pattern."""
return self._tag_var

def clone(self, node_map: dict[NodePattern, NodePattern]) -> OrValue:
return OrValue(
[v.clone(node_map) for v in self._values],
self.name,
self._tag_var,
self._tag_values,
)

def get_pattern(self, value: ir.Value) -> tuple[Any, ValuePattern]:
"""Returns the pattern that should be tried for the given value."""
producer = value.producer()
if producer is not None:
id = producer.op_identifier()
if id is not None and id in self._op_to_pattern:
return self._op_to_pattern[id]
return self._default_pattern


def _nodes_in_pattern(outputs: Sequence[ValuePattern]) -> list[NodePattern]:
"""Returns all nodes used in a pattern, given the outputs of the pattern."""
node_patterns: list[NodePattern] = []
Expand Down Expand Up @@ -1136,6 +1216,15 @@ def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> b
if value is None:
return self.fail("Mismatch: Constant pattern does not match None.")
return self._match_constant(pattern_value, value)
if isinstance(pattern_value, OrValue):
if value is None:
return self.fail("Mismatch: OrValue pattern does not match None.")
i, pattern_choice = pattern_value.get_pattern(value)
result = self._match_value(pattern_choice, value)
if result:
if pattern_value.tag_var is not None:
self._match.bind(pattern_value.tag_var, i)
return result
return True

def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value) -> bool:
Expand Down
33 changes: 33 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,39 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]:
self.assertEqual(len(model.graph), 2)
self.assertEqual([x.op_type for x in model.graph], ["Constant", "Identity"])

def test_or_pattern(self):
def source_pattern(op, x, y, bias):
t1 = op.MatMul(x, y)
t2 = op.Add(t1, bias)
t1_or_t2 = pattern.OrValue([t1, t2], tag_var="has_bias", tag_values=[False, True])
Comment thread
justinchuby marked this conversation as resolved.
return op.Relu(t1_or_t2)

def replacement(op, x, y, bias, has_bias):
Comment thread
shubhambhokare1 marked this conversation as resolved.
if has_bias:
return op.WithBias(x, y, bias)
else:
return op.WithoutBias(x, y)

rule = pattern.RewriteRule(source_pattern, replacement)

@script()
def test_model1(x: FLOAT[16, 32], y: FLOAT[32, 16]) -> FLOAT[16, 16]:
return op.Relu(op.MatMul(x, y))

model_proto = test_model1.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
Comment thread
gramalingam marked this conversation as resolved.
rule.apply_to_model(model)
self.assertEqual([x.op_type for x in model.graph], ["WithoutBias"])

@script()
def test_model2(x: FLOAT[16, 32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16, 16]:
Comment thread
gramalingam marked this conversation as resolved.
return op.Relu(op.Add(op.MatMul(x, y), bias))

model_proto = test_model2.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
rule.apply_to_model(model)
self.assertEqual([x.op_type for x in model.graph], ["WithBias"])


class PatternBuilderTest(unittest.TestCase):
def test_pattern_builder_context(self):
Expand Down
Loading