diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d2648d94a4..ba3b9bfb3f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -14,6 +14,9 @@ import math from typing import Any, Optional, Sequence, Tuple, Union +import numpy as np +import torch + from onnxscript import ( BFLOAT16, BOOL, @@ -7599,13 +7602,62 @@ def aten_scatter_reduce( "amax": "max", } onnx_reduce = reduce_mode[reduce] + dtype = src.dtype or self.dtype + assert dtype is not None, "dtype should be not None" + self_is_scalar = len(self.shape) == 0 if self_is_scalar: # assert (index_rank == 0 and rank_src == 0) neg_1 = op.Constant(value_ints=[-1]) self = op.Reshape(self, neg_1) index = op.Reshape(index, neg_1) src = op.Reshape(src, neg_1) + + if not include_self: + # onnx standard always assume the value from self is part of the reduction. + # A first step is added to replace the impacted value by another one + # chosen in a way that the results of the reduction is not changed + # whether or not it takes part in it. + # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. + # mean is not supported. + if onnx_reduce == "max": + if dtype in { + ir.DataType.FLOAT16, + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + }: + value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) + elif dtype == ir.DataType.BFLOAT16: + value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) + else: + value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) + reduction_init = "min" + elif onnx_reduce == "min": + if dtype in { + ir.DataType.FLOAT16, + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + }: + value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) + elif dtype == ir.DataType.BFLOAT16: + value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) + else: + value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) + reduction_init = "max" + elif onnx_reduce == "add": + value = ir.tensor([0], dtype=dtype) + reduction_init = "none" + elif onnx_reduce == "mul": + value = ir.tensor([1], dtype=dtype) + reduction_init = "none" + else: + value = 0 + reduction_init = "none" + + cst = op.ConstantOfShape(op.Shape(src), value=value) + self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init) + result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce) + if self_is_scalar: result = op.Squeeze(result) return result diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 034724a3a8..5fa7848626 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -867,9 +867,10 @@ def _do_inference(self, node: ir.Node) -> None: # TODO: handle optional inputs def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: - value = _get_numpy_value(x) - if isinstance(value, np.ndarray) and value.size < 20: - return onnx.numpy_helper.from_array(value, x.name) + value = _get_numpy_value(x, size_limit=20) + if value is not None: + assert x.const_value is not None + return ir.serde.serialize_tensor(x.const_value) return None def get_type(value: ir.Value) -> onnx.TypeProto | None: diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py new file mode 100644 index 0000000000..e933ab8d8b --- /dev/null +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo + +import unittest + +import onnxruntime +import torch + +from tests.common import testutils + + +class TorchLibe2eTest(testutils.TestBase): + def test_investigate_one_particular_model(self): + """This test can be used to investigate a particular issue.""" + red, include, stype = "amin", False, "int32" + dtype = getattr(torch, stype) + + class Model(torch.nn.Module): + def __init__(self, include, red): + super().__init__() + self.include = include + self.red = red + + def forward(self, x, indices, updates): + x = x.clone() + return x.scatter_reduce( + 0, indices, updates, self.red, include_self=self.include + ) + + model = Model(include, red) + xs = ( + torch.tensor([[-2, 0, 2], [2, -2, 0]], dtype=dtype), + torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64), + torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype), + ) + expected = model(*xs) + model_path = ( + f"test_aten_scatter_{red}_{'include' if include else 'exclude'}_{stype}.onnx" + ) + torch.onnx.export(model, xs, model_path, dynamo=True) + feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs])) + + sess_options = onnxruntime.SessionOptions() + sess = onnxruntime.InferenceSession( + model_path, sess_options=sess_options, providers=["CPUExecutionProvider"] + ) + got = sess.run(None, feeds)[0] + torch.testing.assert_close(expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index e8ccc87aea..f1c0918cda 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2026,26 +2026,30 @@ def _where_input_wrangler( variant_name="mean", reason="ONNX doesn't support reduce='mean' option", ) - .skip( - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", + .xfail( + variant_name="prod", + dtypes=(torch.float16, torch.float64), + reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'", ) .xfail( - variant_name="amax", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'", + variant_name="sum", + dtypes=(torch.float16, torch.float64), + reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", ) .xfail( - variant_name="amin", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'", + variant_name="mean", + dtypes=(torch.bfloat16,), + reason="onnxruntime does not support ml_dtypes.bfloat16", ) .xfail( variant_name="prod", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'prod'", + dtypes=(torch.bfloat16,), + reason="onnxruntime does not support ml_dtypes.bfloat16", ) .xfail( variant_name="sum", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", + dtypes=(torch.bfloat16,), + reason="onnxruntime does not support ml_dtypes.bfloat16", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), TorchLibOpInfo("slice", core_ops.aten_slice),