|
10 | 10 | import onnx_ir as ir |
11 | 11 |
|
12 | 12 | import onnxscript._internal.builder as builder |
| 13 | +import onnxscript.testing |
13 | 14 | from onnxscript import script |
14 | 15 | from onnxscript.onnx_types import DOUBLE, FLOAT |
15 | 16 |
|
@@ -713,6 +714,31 @@ def add_mul(X, Y): |
713 | 714 | self.assertEqual(nodes[0].op_type, "Add") |
714 | 715 | self.assertEqual(nodes[1].op_type, "Mul") |
715 | 716 |
|
| 717 | + def test_call_with_outer_scope_value(self): |
| 718 | + """Test that script supports references to pre-existing values.""" |
| 719 | + # Create a GraphBuilder first |
| 720 | + op, x, y = _create_builder_with_inputs() |
| 721 | + product = op.Mul(x, y) |
| 722 | + |
| 723 | + @script() |
| 724 | + def add_product(X): |
| 725 | + return op.Add(X, product) # Reference to 'product' from outer scope |
| 726 | + |
| 727 | + x_plus = op.call(add_product, x, _outputs=["x_plus"]) |
| 728 | + y_plus = op.call(add_product, y, _outputs=["y_plus"]) |
| 729 | + |
| 730 | + op.builder.graph.outputs.extend([x_plus, y_plus]) |
| 731 | + |
| 732 | + # Now, create the same graph directly: |
| 733 | + op2, x2, y2 = _create_builder_with_inputs() |
| 734 | + product2 = op2.Mul(x2, y2) |
| 735 | + x2_plus = op2.Add(x2, product2, _outputs=["x_plus"]) |
| 736 | + y2_plus = op2.Add(y2, product2, _outputs=["y_plus"]) |
| 737 | + op2.builder.graph.outputs.extend([x2_plus, y2_plus]) |
| 738 | + |
| 739 | + # Verify that the two graphs are structurally equivalent |
| 740 | + onnxscript.testing.assert_isomorphic_graph(op.builder.graph, op2.builder.graph) |
| 741 | + |
716 | 742 | def test_call_with_prefix_option(self): |
717 | 743 | """Test that GraphBuilder.call respects the _prefix option for hierarchical naming.""" |
718 | 744 | # Create a GraphBuilder first |
|
0 commit comments