From b35d3de3f5d2d315003b5f94fda6be82733c46b1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 24 Apr 2025 09:33:27 -0700 Subject: [PATCH 1/4] [IR] Support specifying output value in Tape When a user wants to specify names for output values, they can initialize the value first then supply them to tape op() call. --- onnxscript/ir/_tape.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 0a63118d4f..29abe83681 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -85,17 +85,23 @@ def op( name: str | None = None, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, + output: ir.Value | None = None, ) -> ir.Value: if attributes is None: attrs: Sequence[ir.Attr | ir.RefAttr] = () else: attrs = _convenience.convert_attributes(attributes) + output_kwargs: dict[str, Any] + if output is None: + output_kwargs = dict(num_outputs=1) + else: + output_kwargs = dict(outputs=[output]) node = ir.Node( domain, op_type, inputs, attributes=attrs, - num_outputs=1, + **output_kwargs, overload=overload, version=version, graph=graph or self.graph_like, @@ -114,7 +120,8 @@ def op_multi_output( inputs: Sequence[ir.Value | None], attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, *, - num_outputs: int, + num_outputs: int | None = None, + outputs: Sequence[ir.Value] | None = None, domain: str = "", overload: str = "", version: int | None = None, @@ -123,6 +130,13 @@ def op_multi_output( doc_string: str | None = None, metadata_props: dict[str, str] | None = None, ) -> Sequence[ir.Value]: + if num_outputs is None and outputs is None: + raise ValueError("Either num_outputs or outputs must be provided.") + output_kwargs: dict[str, Any] + if outputs is None: + output_kwargs = dict(num_outputs=num_outputs) + else: + output_kwargs = dict(outputs=outputs) if attributes is None: attrs: Sequence[ir.Attr | ir.RefAttr] = () else: @@ -132,7 +146,7 @@ def op_multi_output( op_type, inputs, attributes=attrs, - num_outputs=num_outputs, + **output_kwargs, overload=overload, version=version, graph=graph or self.graph_like, From c1e64e118e07dcf4ce6b591c1086e59b4731560f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 25 Apr 2025 07:35:19 -0700 Subject: [PATCH 2/4] Update onnxscript/ir/_tape.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/ir/_tape.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 29abe83681..d19e13d214 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -132,6 +132,8 @@ def op_multi_output( ) -> Sequence[ir.Value]: if num_outputs is None and outputs is None: raise ValueError("Either num_outputs or outputs must be provided.") + if num_outputs is not None and outputs is not None: + raise ValueError("Both num_outputs and outputs cannot be provided simultaneously.") output_kwargs: dict[str, Any] if outputs is None: output_kwargs = dict(num_outputs=num_outputs) From 27ba3cf6a905e51a7286d83ef2a3116b0ff1d6fb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 25 Apr 2025 09:46:10 -0700 Subject: [PATCH 3/4] Rename op_multi_out --- onnxscript/ir/_tape.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index d19e13d214..74fd0186a6 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -114,7 +114,7 @@ def op( return node.outputs[0] - def op_multi_output( + def op_multi_out( self, op_type: str, inputs: Sequence[ir.Value | None], @@ -198,7 +198,7 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, if isinstance(outputs, Sequence): value.name = outputs[0] return value - values = super().op_multi_output( + values = super().op_multi_out( op_type, inputs=inputs, attributes=kwargs, From ae44748dd81a887cffc544e489b8742dd5567fbe Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 25 Apr 2025 16:17:35 -0700 Subject: [PATCH 4/4] op_multi_out --- onnxscript/ir/_tape_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py index 922c6d7eaa..46cbcc23fe 100644 --- a/onnxscript/ir/_tape_test.py +++ b/onnxscript/ir/_tape_test.py @@ -66,7 +66,7 @@ def test_op_multi_out(self): tape = ir.tape.Tape() - out1, out2, out3 = tape.op_multi_output("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking + out1, out2, out3 = tape.op_multi_out("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking _ = tape.op("SomeOtherOp", inputs=[out1, out2, out3]) self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"])