diff --git a/onnxscript/rewriter/rules/common/_basic_rules.py b/onnxscript/rewriter/rules/common/_basic_rules.py index f4b519bf03..d6846ec25b 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules.py +++ b/onnxscript/rewriter/rules/common/_basic_rules.py @@ -125,7 +125,7 @@ def pattern(self, op, x, shape_ignored, shape): return op.Reshape(op.Reshape(x, shape_ignored), shape) def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value): - new_shape = op.initializer(ir.Tensor(self._new_shape, name=shape.name)) + new_shape = op.initializer(ir.Tensor(self._new_shape, name=self._new_shape_name)) return op.Reshape(x, new_shape, allowzero=self._allowzero) def check(self, context, x, shape_ignored, shape) -> MatchResult: @@ -145,6 +145,7 @@ def check(self, context, x, shape_ignored, shape) -> MatchResult: # Constraints for shape. self._allowzero = context.nodes[0].attributes.get_int("allowzero", 0) + self._new_shape_name = f"{context.nodes[1].name or shape.name}/shape" if self._allowzero == 1 and any(self._new_shape == 0): return check_result if any(self._new_shape == 0) and any(self._new_shape < 0): diff --git a/onnxscript/rewriter/rules/common/_basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py index 67ebdaa495..3eabf1f9b9 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -12,8 +12,9 @@ import onnxscript import onnxscript.onnx_types as ot -from onnxscript import ir +from onnxscript import ir, rewriter from onnxscript.onnx_opset import opset18 +from onnxscript.optimizer import _constant_folding, common_passes from onnxscript.rewriter import MatchingTracer, testing from onnxscript.rewriter import pattern as orp from onnxscript.rewriter.rules.common import _basic_rules @@ -506,6 +507,57 @@ def test_reshape_dynamic_reshape_rule(self, shape1, allowzero1=0): } testing.assert_numerically_equal(model, updated_model, feeds, atol=0, rtol=0) + def test_reshape_reshape_rule_with_shared_negative_one_shape(self): + input1 = ir.val("input1", ir.DataType.FLOAT, ir.Shape((2, 3))) + input2 = ir.val("input2", ir.DataType.FLOAT, ir.Shape((2, 6))) + output1 = ir.val("out1", ir.DataType.FLOAT, ir.Shape((2, 3))) + output2 = ir.val("out2", ir.DataType.FLOAT, ir.Shape((2, 6))) + tape = ir.tape.Tape( + ir.Graph( + [input1, input2], + [output1, output2], + nodes=[], + opset_imports={"": 21}, + name="test_reshape_reshape_rule_with_shared_negative_one_shape", + ) + ) + + shape_mid_a = tape.initializer( + ir.Tensor(np.array([6], dtype=np.int64), name="shape_mid_a") + ) + shape_mid_b = tape.initializer( + ir.Tensor(np.array([12], dtype=np.int64), name="shape_mid_b") + ) + shared_shape = tape.initializer( + ir.Tensor(np.array([2, -1], dtype=np.int64), name="shared_shape") + ) + + mid1 = tape.op("Reshape", inputs=[input1, shape_mid_a]) + mid2 = tape.op("Reshape", inputs=[input2, shape_mid_b]) + tape.op("Reshape", inputs=[mid1, shared_shape], output=output1) + tape.op("Reshape", inputs=[mid2, shared_shape], output=output2) + model = ir.Model(tape.graph_like, ir_version=10) + + _constant_folding.FoldConstantsPass( + shape_inference=True, input_size_limit=1024, output_size_limit=1024 + )(model) + rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES)(model) + common_passes.RemoveUnusedNodesPass()(model) + common_passes.LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=0)( + model + ) + common_passes.DeduplicateInitializersPass()(model) + + reshape_shape_inputs = [ + node.inputs[1] for node in model.graph if node.op_type == "Reshape" + ] + self.assertEqual(len(reshape_shape_inputs), 2) + self.assertEqual(len({shape.name for shape in reshape_shape_inputs}), 2) + for shape in reshape_shape_inputs: + self.assertIn(shape.name, model.graph.initializers) + + onnx.checker.check_model(ir.to_proto(model), full_check=True) + @parameterized.parameterized.expand( [((3, 6, 9), [0, 3, 2, -1]), ((0, 6, 2), [0, 0, 3], 1)] )