diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index e3ed995d..c84366ba 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -128,7 +128,61 @@ def forward(self, x): print(pretty_onnx(onx)) - # And with :func:`torch.onnx.export`: + We do the same with :func:`torch.onnx.export`: + + .. runpython:: + :showcode: + + import onnx.helper as oh + import torch + from onnx_diagnostic.helpers.onnx_helper import pretty_onnx + from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx + from onnx_diagnostic.export.api import to_onnx + from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str + + + def demo_customsub(x, y): + return x - y + + + def demo_customsub_shape(x, y): + return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype) + + + def make_function_proto(): + return oh.make_function( + "onnx_plug", + "demo_customsub", + ["x", "y"], + ["z"], + [oh.make_node("Sub", ["x", "y"], ["z"])], + opset_imports=[oh.make_opsetid("", 22)], + ) + + + class Model(torch.nn.Module): + def forward(self, x): + y = x.sum(axis=1, keepdim=True) + d = torch.ops.onnx_plug.demo_customsub(x, y) + return torch.abs(d) + + + replacements = [ + EagerDirectReplacementWithOnnx( + demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1 + ) + ] + + x = torch.randn((3, 4), dtype=torch.float32) + model = Model() + ds = ({0: "d1", 1: "d2"},) + + # The exported program shows a custom op. + ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds)) + print("ep") + + # As the exporter knows how the replace this custom op. + # Let's export. onx = to_onnx( model,