diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 340142df3d..fbcfcb428a 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -86,17 +86,23 @@ def op( name: str | None = None, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, + output: ir.Value | None = None, ) -> ir.Value: if attributes is None: attrs: Sequence[ir.Attr | ir.RefAttr] = () else: attrs = _convenience.convert_attributes(attributes) + output_kwargs: dict[str, Any] + if output is None: + output_kwargs = dict(num_outputs=1) + else: + output_kwargs = dict(outputs=[output]) node = ir.Node( domain, op_type, inputs, attributes=attrs, - num_outputs=1, + **output_kwargs, overload=overload, version=version, graph=graph or self.graph_like, @@ -109,13 +115,14 @@ def op( return node.outputs[0] - def op_multi_output( + def op_multi_out( self, op_type: str, inputs: Sequence[ir.Value | None], attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, *, - num_outputs: int, + num_outputs: int | None = None, + outputs: Sequence[ir.Value] | None = None, domain: str = "", overload: str = "", version: int | None = None, @@ -124,6 +131,15 @@ def op_multi_output( doc_string: str | None = None, metadata_props: dict[str, str] | None = None, ) -> Sequence[ir.Value]: + if num_outputs is None and outputs is None: + raise ValueError("Either num_outputs or outputs must be provided.") + if num_outputs is not None and outputs is not None: + raise ValueError("Both num_outputs and outputs cannot be provided simultaneously.") + output_kwargs: dict[str, Any] + if outputs is None: + output_kwargs = dict(num_outputs=num_outputs) + else: + output_kwargs = dict(outputs=outputs) if attributes is None: attrs: Sequence[ir.Attr | ir.RefAttr] = () else: @@ -133,7 +149,7 @@ def op_multi_output( op_type, inputs, attributes=attrs, - num_outputs=num_outputs, + **output_kwargs, overload=overload, version=version, graph=graph or self.graph_like, @@ -183,7 +199,7 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, if isinstance(outputs, Sequence): value.name = outputs[0] return value - values = super().op_multi_output( + values = super().op_multi_out( op_type, inputs=inputs, attributes=kwargs, diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py index 922c6d7eaa..46cbcc23fe 100644 --- a/onnxscript/ir/_tape_test.py +++ b/onnxscript/ir/_tape_test.py @@ -66,7 +66,7 @@ def test_op_multi_out(self): tape = ir.tape.Tape() - out1, out2, out3 = tape.op_multi_output("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking + out1, out2, out3 = tape.op_multi_out("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking _ = tape.op("SomeOtherOp", inputs=[out1, out2, out3]) self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"])