Skip to content

Commit 02cf905

Browse files
justinchubyCopilot
andauthored
[IR] Support specifying output value in Tape (#2225)
When a user wants to specify names for output values, they can initialize the value first then supply them to tape op() call. Also renamed symbolic_multi_output to symblic_multi_out to match https://pytorch.org/docs/main/onnx_ops.html#torch.onnx.ops.symbolic_multi_out --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent c6f535f commit 02cf905

2 files changed

Lines changed: 22 additions & 6 deletions

File tree

onnxscript/ir/_tape.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,23 @@ def op(
8686
name: str | None = None,
8787
doc_string: str | None = None,
8888
metadata_props: dict[str, str] | None = None,
89+
output: ir.Value | None = None,
8990
) -> ir.Value:
9091
if attributes is None:
9192
attrs: Sequence[ir.Attr | ir.RefAttr] = ()
9293
else:
9394
attrs = _convenience.convert_attributes(attributes)
95+
output_kwargs: dict[str, Any]
96+
if output is None:
97+
output_kwargs = dict(num_outputs=1)
98+
else:
99+
output_kwargs = dict(outputs=[output])
94100
node = ir.Node(
95101
domain,
96102
op_type,
97103
inputs,
98104
attributes=attrs,
99-
num_outputs=1,
105+
**output_kwargs,
100106
overload=overload,
101107
version=version,
102108
graph=graph or self.graph_like,
@@ -109,13 +115,14 @@ def op(
109115

110116
return node.outputs[0]
111117

112-
def op_multi_output(
118+
def op_multi_out(
113119
self,
114120
op_type: str,
115121
inputs: Sequence[ir.Value | None],
116122
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
117123
*,
118-
num_outputs: int,
124+
num_outputs: int | None = None,
125+
outputs: Sequence[ir.Value] | None = None,
119126
domain: str = "",
120127
overload: str = "",
121128
version: int | None = None,
@@ -124,6 +131,15 @@ def op_multi_output(
124131
doc_string: str | None = None,
125132
metadata_props: dict[str, str] | None = None,
126133
) -> Sequence[ir.Value]:
134+
if num_outputs is None and outputs is None:
135+
raise ValueError("Either num_outputs or outputs must be provided.")
136+
if num_outputs is not None and outputs is not None:
137+
raise ValueError("Both num_outputs and outputs cannot be provided simultaneously.")
138+
output_kwargs: dict[str, Any]
139+
if outputs is None:
140+
output_kwargs = dict(num_outputs=num_outputs)
141+
else:
142+
output_kwargs = dict(outputs=outputs)
127143
if attributes is None:
128144
attrs: Sequence[ir.Attr | ir.RefAttr] = ()
129145
else:
@@ -133,7 +149,7 @@ def op_multi_output(
133149
op_type,
134150
inputs,
135151
attributes=attrs,
136-
num_outputs=num_outputs,
152+
**output_kwargs,
137153
overload=overload,
138154
version=version,
139155
graph=graph or self.graph_like,
@@ -183,7 +199,7 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str,
183199
if isinstance(outputs, Sequence):
184200
value.name = outputs[0]
185201
return value
186-
values = super().op_multi_output(
202+
values = super().op_multi_out(
187203
op_type,
188204
inputs=inputs,
189205
attributes=kwargs,

onnxscript/ir/_tape_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_op_multi_out(self):
6666

6767
tape = ir.tape.Tape()
6868

69-
out1, out2, out3 = tape.op_multi_output("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking
69+
out1, out2, out3 = tape.op_multi_out("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking
7070
_ = tape.op("SomeOtherOp", inputs=[out1, out2, out3])
7171

7272
self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"])

0 commit comments

Comments
 (0)