@@ -128,7 +128,61 @@ def forward(self, x):
128128
129129 print(pretty_onnx(onx))
130130
131- # And with :func:`torch.onnx.export`:
131+ We do the same with :func:`torch.onnx.export`:
132+
133+ .. runpython::
134+ :showcode:
135+
136+ import onnx.helper as oh
137+ import torch
138+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
139+ from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
140+ from onnx_diagnostic.export.api import to_onnx
141+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
142+
143+
144+ def demo_customsub(x, y):
145+ return x - y
146+
147+
148+ def demo_customsub_shape(x, y):
149+ return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
150+
151+
152+ def make_function_proto():
153+ return oh.make_function(
154+ "onnx_plug",
155+ "demo_customsub",
156+ ["x", "y"],
157+ ["z"],
158+ [oh.make_node("Sub", ["x", "y"], ["z"])],
159+ opset_imports=[oh.make_opsetid("", 22)],
160+ )
161+
162+
163+ class Model(torch.nn.Module):
164+ def forward(self, x):
165+ y = x.sum(axis=1, keepdim=True)
166+ d = torch.ops.onnx_plug.demo_customsub(x, y)
167+ return torch.abs(d)
168+
169+
170+ replacements = [
171+ EagerDirectReplacementWithOnnx(
172+ demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
173+ )
174+ ]
175+
176+ x = torch.randn((3, 4), dtype=torch.float32)
177+ model = Model()
178+ ds = ({0: "d1", 1: "d2"},)
179+
180+ # The exported program shows a custom op.
181+ ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
182+ print("ep")
183+
184+ # As the exporter knows how the replace this custom op.
185+ # Let's export.
132186
133187 onx = to_onnx(
134188 model,
0 commit comments