Skip to content

Commit 91de1f3

Browse files
committed
documentation
1 parent b2b284d commit 91de1f3

1 file changed

Lines changed: 55 additions & 1 deletion

File tree

onnx_diagnostic/export/onnx_plug.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,61 @@ def forward(self, x):
128128
129129
print(pretty_onnx(onx))
130130
131-
# And with :func:`torch.onnx.export`:
131+
We do the same with :func:`torch.onnx.export`:
132+
133+
.. runpython::
134+
:showcode:
135+
136+
import onnx.helper as oh
137+
import torch
138+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
139+
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
140+
from onnx_diagnostic.export.api import to_onnx
141+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
142+
143+
144+
def demo_customsub(x, y):
145+
return x - y
146+
147+
148+
def demo_customsub_shape(x, y):
149+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
150+
151+
152+
def make_function_proto():
153+
return oh.make_function(
154+
"onnx_plug",
155+
"demo_customsub",
156+
["x", "y"],
157+
["z"],
158+
[oh.make_node("Sub", ["x", "y"], ["z"])],
159+
opset_imports=[oh.make_opsetid("", 22)],
160+
)
161+
162+
163+
class Model(torch.nn.Module):
164+
def forward(self, x):
165+
y = x.sum(axis=1, keepdim=True)
166+
d = torch.ops.onnx_plug.demo_customsub(x, y)
167+
return torch.abs(d)
168+
169+
170+
replacements = [
171+
EagerDirectReplacementWithOnnx(
172+
demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
173+
)
174+
]
175+
176+
x = torch.randn((3, 4), dtype=torch.float32)
177+
model = Model()
178+
ds = ({0: "d1", 1: "d2"},)
179+
180+
# The exported program shows a custom op.
181+
ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
182+
print("ep")
183+
184+
# As the exporter knows how the replace this custom op.
185+
# Let's export.
132186
133187
onx = to_onnx(
134188
model,

0 commit comments

Comments
 (0)