From c451d4dd2c5046d693146aa7a8e6da356ba2bf5a Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 7 Mar 2025 15:13:32 +0100 Subject: [PATCH 01/26] Fix include_self for scatter_reduce --- .../function_libs/torch_lib/ops/core.py | 24 +++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 5 ---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 249569fbca..ea2d1ef330 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -14,6 +14,8 @@ import math from typing import Any, Optional, Sequence, Tuple, Union +import numpy as np + from onnxscript import ( BFLOAT16, BOOL, @@ -7584,7 +7586,29 @@ def aten_scatter_reduce( self = op.Reshape(self, neg_1) index = op.Reshape(index, neg_1) src = op.Reshape(src, neg_1) + + if not include_self: + if onnx_reduce == "max": + value = np.finfo(src.dtype.numpy()).min + reduction_init = "min" + elif onnx_reduce == "min": + value = np.finfo(src.dtype.numpy()).max + reduction_init = "max" + elif onnx_reduce == "add": + value = 0 + reduction_init = "none" + elif onnx_reduce == "mul": + value = 1 + reduction_init = "none" + else: + value = 0 + reduction_init = "none" + + cst = op.ConstantOfShape(op.Shape(src), value=value) + self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init) + result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce) + if self_is_scalar: result = op.Squeeze(result) return result diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index c1d380f9f5..5958f43c2a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2025,11 +2025,6 @@ def _where_input_wrangler( variant_name="mean", reason="ONNX doesn't support reduce='mean' option", ) - .skip( - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ) .xfail( variant_name="amax", reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'", From 87c4085f08a9d6e8b8bafb756db29578408c8bbe Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 7 Mar 2025 15:28:09 +0100 Subject: [PATCH 02/26] fix values --- onnxscript/function_libs/torch_lib/ops/core.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ea2d1ef330..1fcc2d5dbf 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -15,6 +15,7 @@ from typing import Any, Optional, Sequence, Tuple, Union import numpy as np +import onnx.numpy_helper as onh from onnxscript import ( BFLOAT16, @@ -7589,16 +7590,16 @@ def aten_scatter_reduce( if not include_self: if onnx_reduce == "max": - value = np.finfo(src.dtype.numpy()).min + value = onh.from_array(np.array([np.finfo(src.dtype.numpy()).min], dtype=src.dtype.numpy())) reduction_init = "min" elif onnx_reduce == "min": - value = np.finfo(src.dtype.numpy()).max + value = onh.from_array(np.array([np.finfo(src.dtype.numpy()).max], dtype=src.dtype.numpy())) reduction_init = "max" elif onnx_reduce == "add": - value = 0 + value = onh.from_array(np.array([0], dtype=src.dtype.numpy())) reduction_init = "none" elif onnx_reduce == "mul": - value = 1 + value = onh.from_array(np.array([1], dtype=src.dtype.numpy())) reduction_init = "none" else: value = 0 From 0a6c1819aaf25751188000c2a7dc81e116a0f8b8 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 7 Mar 2025 15:51:54 +0100 Subject: [PATCH 03/26] lint --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1fcc2d5dbf..97242ed36e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7590,10 +7590,14 @@ def aten_scatter_reduce( if not include_self: if onnx_reduce == "max": - value = onh.from_array(np.array([np.finfo(src.dtype.numpy()).min], dtype=src.dtype.numpy())) + value = onh.from_array( + np.array([np.finfo(src.dtype.numpy()).min], dtype=src.dtype.numpy()) + ) reduction_init = "min" elif onnx_reduce == "min": - value = onh.from_array(np.array([np.finfo(src.dtype.numpy()).max], dtype=src.dtype.numpy())) + value = onh.from_array( + np.array([np.finfo(src.dtype.numpy()).max], dtype=src.dtype.numpy()) + ) reduction_init = "max" elif onnx_reduce == "add": value = onh.from_array(np.array([0], dtype=src.dtype.numpy())) From 9aaf14221c6f8b100236f4093a8b6bd14d302bcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 10 Mar 2025 16:07:25 +0100 Subject: [PATCH 04/26] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 97242ed36e..02d8a28717 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7590,9 +7590,7 @@ def aten_scatter_reduce( if not include_self: if onnx_reduce == "max": - value = onh.from_array( - np.array([np.finfo(src.dtype.numpy()).min], dtype=src.dtype.numpy()) - ) + value = ir.tensor([np.finfo(src.dtype.numpy()).min], dtype=src.dtype) reduction_init = "min" elif onnx_reduce == "min": value = onh.from_array( From 27d0ff7fc3b79e4fde08361add37328168f11b75 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 10 Mar 2025 16:16:26 +0100 Subject: [PATCH 05/26] more comments --- onnxscript/function_libs/torch_lib/ops/core.py | 16 ++++++++++------ tests/function_libs/torch_lib/ops_test_data.py | 16 ---------------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a809594cdc..537c7c138e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -15,7 +15,6 @@ from typing import Any, Optional, Sequence, Tuple, Union import numpy as np -import onnx.numpy_helper as onh from onnxscript import ( BFLOAT16, @@ -34,6 +33,7 @@ UINT32, UINT64, graph, + ir, ) from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op @@ -7589,19 +7589,23 @@ def aten_scatter_reduce( src = op.Reshape(src, neg_1) if not include_self: + # onnx standard always assume the value from self is part of the reduction. + # A first step is added to replace the impacted value by another one + # chosen in a way that the results of the reduction is not changed + # whether or not it takes part in it. + # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. + # mean is not supported. if onnx_reduce == "max": value = ir.tensor([np.finfo(src.dtype.numpy()).min], dtype=src.dtype) reduction_init = "min" elif onnx_reduce == "min": - value = onh.from_array( - np.array([np.finfo(src.dtype.numpy()).max], dtype=src.dtype.numpy()) - ) + value = ir.tensor([np.finfo(src.dtype.numpy()).max], dtype=src.dtype) reduction_init = "max" elif onnx_reduce == "add": - value = onh.from_array(np.array([0], dtype=src.dtype.numpy())) + value = ir.tensor([0], dtype=src.dtype) reduction_init = "none" elif onnx_reduce == "mul": - value = onh.from_array(np.array([1], dtype=src.dtype.numpy())) + value = ir.tensor([1], dtype=src.dtype) reduction_init = "none" else: value = 0 diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 8bd9ad5b9e..5e7ccc97e6 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2024,22 +2024,6 @@ def _where_input_wrangler( .xfail( variant_name="mean", reason="ONNX doesn't support reduce='mean' option", - ) - .xfail( - variant_name="amax", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'", - ) - .xfail( - variant_name="amin", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'", - ) - .xfail( - variant_name="prod", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'prod'", - ) - .xfail( - variant_name="sum", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), TorchLibOpInfo("slice", core_ops.aten_slice), From 2ed4f9bf11be630543fd138c33f32fb4f1c819c3 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 10 Mar 2025 18:39:02 +0100 Subject: [PATCH 06/26] dtype --- onnxscript/function_libs/torch_lib/ops/core.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 537c7c138e..4efc5bc30c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7595,17 +7595,19 @@ def aten_scatter_reduce( # whether or not it takes part in it. # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. # mean is not supported. + dtype = src.dtype or cst.dtype + # dtype should be not None. if onnx_reduce == "max": - value = ir.tensor([np.finfo(src.dtype.numpy()).min], dtype=src.dtype) + value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) reduction_init = "min" elif onnx_reduce == "min": - value = ir.tensor([np.finfo(src.dtype.numpy()).max], dtype=src.dtype) + value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) reduction_init = "max" elif onnx_reduce == "add": - value = ir.tensor([0], dtype=src.dtype) + value = ir.tensor([0], dtype=dtype) reduction_init = "none" elif onnx_reduce == "mul": - value = ir.tensor([1], dtype=src.dtype) + value = ir.tensor([1], dtype=dtype) reduction_init = "none" else: value = 0 From f00fe37fcdd6a09cb67f8cd51cd9d8c91a151d87 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 14 Mar 2025 16:24:57 +0100 Subject: [PATCH 07/26] use inf --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5f0fedff0e..9175c98be5 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7585,13 +7585,13 @@ def aten_scatter_reduce( # whether or not it takes part in it. # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. # mean is not supported. - dtype = src.dtype or cst.dtype + dtype = src.dtype or self.dtype # dtype should be not None. if onnx_reduce == "max": - value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) + value = ir.tensor([-np.inf], dtype=dtype) reduction_init = "min" elif onnx_reduce == "min": - value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) + value = ir.tensor([np.inf], dtype=dtype) reduction_init = "max" elif onnx_reduce == "add": value = ir.tensor([0], dtype=dtype) From e662f81252dbc43c3020b6c69bfec3a0cbafa7df Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 14 Mar 2025 17:15:39 +0100 Subject: [PATCH 08/26] fix one bug --- tests/function_libs/torch_lib/ops_test_data.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 5e7ccc97e6..d793a6f086 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2024,6 +2024,22 @@ def _where_input_wrangler( .xfail( variant_name="mean", reason="ONNX doesn't support reduce='mean' option", + ) + .xfail( + variant_name="amax", + reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'", + ) + .xfail( + variant_name="amin", + reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'", + ) + .xfail( + variant_name="prod", + reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'", + ) + .xfail( + variant_name="sum", + reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), TorchLibOpInfo("slice", core_ops.aten_slice), From 914283fd0fece39f013e7f0d8c22f2f612523d8b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Mar 2025 09:26:03 -0700 Subject: [PATCH 09/26] Update tests/function_libs/torch_lib/ops_test_data.py --- tests/function_libs/torch_lib/ops_test_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index d793a6f086..4cca2a58f8 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2027,6 +2027,7 @@ def _where_input_wrangler( ) .xfail( variant_name="amax", + dtypes=(torch.float16), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'", ) .xfail( From ba2a563bf701179582d22dff57f7e9c21e520031 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Mar 2025 09:26:39 -0700 Subject: [PATCH 10/26] Add float16 dtype to xfail tests --- tests/function_libs/torch_lib/ops_test_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4cca2a58f8..26ca174609 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2032,14 +2032,17 @@ def _where_input_wrangler( ) .xfail( variant_name="amin", + dtypes=(torch.float16), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'", ) .xfail( variant_name="prod", + dtypes=(torch.float16), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'", ) .xfail( variant_name="sum", + dtypes=(torch.float16), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), From 6e9e756d61df2ad2dfab86ad01fbe87653accbb7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Mar 2025 09:33:12 -0700 Subject: [PATCH 11/26] Fix tuple syntax for torch.float16 dtypes --- tests/function_libs/torch_lib/ops_test_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 26ca174609..b7170e76ed 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2027,22 +2027,22 @@ def _where_input_wrangler( ) .xfail( variant_name="amax", - dtypes=(torch.float16), + dtypes=(torch.float16,), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'", ) .xfail( variant_name="amin", - dtypes=(torch.float16), + dtypes=(torch.float16,), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'", ) .xfail( variant_name="prod", - dtypes=(torch.float16), + dtypes=(torch.float16,), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'", ) .xfail( variant_name="sum", - dtypes=(torch.float16), + dtypes=(torch.float16,), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), From d7ab3b83ec0c43145972490fa2bfde0ec0b2aea8 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 14 Mar 2025 19:47:37 +0100 Subject: [PATCH 12/26] fix dtype --- .../function_libs/torch_lib/ops/core.py | 21 +++++++++++++++++-- .../function_libs/torch_lib/ops_test_data.py | 8 +++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 9175c98be5..ab0ba1f51f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7586,12 +7586,29 @@ def aten_scatter_reduce( # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. # mean is not supported. dtype = src.dtype or self.dtype + assert dtype is not None, f"dtype is None, src={src}, self={self}" # dtype should be not None. if onnx_reduce == "max": - value = ir.tensor([-np.inf], dtype=dtype) + if dtype in { + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + }: + value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) + else: + value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) reduction_init = "min" elif onnx_reduce == "min": - value = ir.tensor([np.inf], dtype=dtype) + if dtype in { + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + }: + value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) + else: + value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) reduction_init = "max" elif onnx_reduce == "add": value = ir.tensor([0], dtype=dtype) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b7170e76ed..ae03ba4f0c 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2027,22 +2027,22 @@ def _where_input_wrangler( ) .xfail( variant_name="amax", - dtypes=(torch.float16,), + dtypes=(torch.float16, torch.float64), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'", ) .xfail( variant_name="amin", - dtypes=(torch.float16,), + dtypes=(torch.float16, torch.float64), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'", ) .xfail( variant_name="prod", - dtypes=(torch.float16,), + dtypes=(torch.float16, torch.float64), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'", ) .xfail( variant_name="sum", - dtypes=(torch.float16,), + dtypes=(torch.float16, torch.float64), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), From 8af535392a13891cd6680b7ff351f61382df3df6 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 17 Mar 2025 12:43:15 +0100 Subject: [PATCH 13/26] simple try --- .../function_libs/torch_lib/ops/core.py | 8 +- .../function_libs/torch_lib/e2e_ops_tests.py | 73 +++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 14 +--- 3 files changed, 83 insertions(+), 12 deletions(-) create mode 100644 tests/function_libs/torch_lib/e2e_ops_tests.py diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ab0ba1f51f..2d209bc86d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7591,22 +7591,26 @@ def aten_scatter_reduce( if onnx_reduce == "max": if dtype in { ir.DataType.FLOAT16, - ir.DataType.BFLOAT16, ir.DataType.FLOAT, ir.DataType.DOUBLE, }: value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) + elif dtype in {ir.DataType.BFLOAT16}: + import ml_dtypes + value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).min], dtype=dtype) else: value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) reduction_init = "min" elif onnx_reduce == "min": if dtype in { ir.DataType.FLOAT16, - ir.DataType.BFLOAT16, ir.DataType.FLOAT, ir.DataType.DOUBLE, }: value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) + elif dtype in {ir.DataType.BFLOAT16}: + import ml_dtypes + value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).max], dtype=dtype) else: value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) reduction_init = "max" diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py new file mode 100644 index 0000000000..e3bebd1f02 --- /dev/null +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import itertools +import unittest + +import onnxruntime +import torch + +from tests.common import testutils + + +class TorchLibe2eTest(testutils.TestBase): + + def test_aten_scatter_reduce_include_self(self): + # known failing configurations because of onnxruntime + skip_ort = { + ("sum", False, "float16"), + ("sum", True, "float16"), + ("prod", False, "float16"), + ("prod", True, "float16"), + } + + for red, include, stype in itertools.product( + ["amin", "amax", "sum", "prod"], + [False, True], + ["bfloat16", "float32", "float16", "int32", "int64", "float64"], + ): + with self.subTest(reduce=red, include=include, type=stype): + key = red, include, stype + if key in skip_ort: + continue + dtype = getattr(torch, stype) + + class Model(torch.nn.Module): + def __init__(self, include, red): + super().__init__() + self.include = include + self.red = red + + def forward(self, x, indices, updates): + x = x.clone() + return x.scatter_reduce( + 0, indices, updates, self.red, include_self=self.include + ) + + model = Model(include, red) + xs = ( + torch.tensor([[-2, 0, 2], [2, -2, 0]], dtype=dtype), + torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64), + torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype), + ) + expected = model(*xs) + model_path = ( + f"test_aten_scatter_{red}_" + f"{'include' if include else 'exclude'}_{stype}.onnx" + ) + torch.onnx.export(model, xs, model_path, dynamo=True) + if stype == "bfloat16": + # not supported yet by onnxruntime + continue + feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs])) + + sess_options = onnxruntime.SessionOptions() + sess = onnxruntime.InferenceSession( + model_path, sess_options=sess_options, providers=["CPUExecutionProvider"] + ) + got = sess.run(None, feeds)[0] + torch.testing.assert_close(expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index ae03ba4f0c..4d3a42d840 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2025,16 +2025,6 @@ def _where_input_wrangler( variant_name="mean", reason="ONNX doesn't support reduce='mean' option", ) - .xfail( - variant_name="amax", - dtypes=(torch.float16, torch.float64), - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'", - ) - .xfail( - variant_name="amin", - dtypes=(torch.float16, torch.float64), - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'", - ) .xfail( variant_name="prod", dtypes=(torch.float16, torch.float64), @@ -2044,6 +2034,10 @@ def _where_input_wrangler( variant_name="sum", dtypes=(torch.float16, torch.float64), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", + ) + .xfail( + dtypes=(torch.bfloat16), + reason="onnxruntime does not support ml_dtypes.bfloat16", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), TorchLibOpInfo("slice", core_ops.aten_slice), From 559888d6a157677646fcf10cdd2aad650034ebe5 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 17 Mar 2025 13:08:58 +0100 Subject: [PATCH 14/26] variant --- tests/function_libs/torch_lib/ops_test_data.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4d3a42d840..d8f2d642d7 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2036,7 +2036,18 @@ def _where_input_wrangler( reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", ) .xfail( - dtypes=(torch.bfloat16), + variant_name="mean", + dtypes=(torch.bfloat16,), + reason="onnxruntime does not support ml_dtypes.bfloat16", + ) + .xfail( + variant_name="prod", + dtypes=(torch.bfloat16,), + reason="onnxruntime does not support ml_dtypes.bfloat16", + ) + .xfail( + variant_name="sum", + dtypes=(torch.bfloat16,), reason="onnxruntime does not support ml_dtypes.bfloat16", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), From 3976cd6964fa8a2290cbdcdf08a2166ef6868006 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 17 Mar 2025 14:42:12 +0100 Subject: [PATCH 15/26] lint --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- tests/function_libs/torch_lib/e2e_ops_tests.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2d209bc86d..ae14890e39 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -15,6 +15,7 @@ from typing import Any, Optional, Sequence, Tuple, Union import numpy as np +import ml_dtypes from onnxscript import ( BFLOAT16, @@ -7596,7 +7597,6 @@ def aten_scatter_reduce( }: value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) elif dtype in {ir.DataType.BFLOAT16}: - import ml_dtypes value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).min], dtype=dtype) else: value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index e3bebd1f02..4b594a8cc0 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -66,7 +66,9 @@ def forward(self, x, indices, updates): model_path, sess_options=sess_options, providers=["CPUExecutionProvider"] ) got = sess.run(None, feeds)[0] - torch.testing.assert_close(expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5) + torch.testing.assert_close( + expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5 + ) if __name__ == "__main__": From 020ec1c61d165550f4cd215cd447d17d0f14e0d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 17 Mar 2025 19:04:39 +0100 Subject: [PATCH 16/26] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ae14890e39..06b6fa0b94 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7608,7 +7608,7 @@ def aten_scatter_reduce( ir.DataType.DOUBLE, }: value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) - elif dtype in {ir.DataType.BFLOAT16}: + elif dtype == ir.DataType.BFLOAT16: import ml_dtypes value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).max], dtype=dtype) else: From 02fdc558a3d38dad52c5813e0ff4143bad1f3af4 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 17 Mar 2025 19:05:15 +0100 Subject: [PATCH 17/26] fix missing type --- onnxscript/function_libs/torch_lib/ops/core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ae14890e39..91f88dbc87 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7587,7 +7587,11 @@ def aten_scatter_reduce( # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. # mean is not supported. dtype = src.dtype or self.dtype - assert dtype is not None, f"dtype is None, src={src}, self={self}" + if dtype is None: + dtype = ir.DataType.FLOAT + cast_like = True + else: + cast_like = False # dtype should be not None. if onnx_reduce == "max": if dtype in { @@ -7609,7 +7613,6 @@ def aten_scatter_reduce( }: value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) elif dtype in {ir.DataType.BFLOAT16}: - import ml_dtypes value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).max], dtype=dtype) else: value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) @@ -7625,6 +7628,8 @@ def aten_scatter_reduce( reduction_init = "none" cst = op.ConstantOfShape(op.Shape(src), value=value) + if cast_like: + cst = op.CastLike(cst, self) self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init) result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce) From b6b57f7da1c2ee721720269e82b1f14034cd42ae Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 17 Mar 2025 19:50:32 +0100 Subject: [PATCH 18/26] disable two tests --- tests/function_libs/torch_lib/ops_test_data.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index d8f2d642d7..9b6d98ab5d 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2030,6 +2030,11 @@ def _where_input_wrangler( dtypes=(torch.float16, torch.float64), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'", ) + .xfail( + variant_name="amin", + dtypes=(torch.int32, torch.int64), + reason="fixme: discrepancies, that should be investigated", + ) .xfail( variant_name="sum", dtypes=(torch.float16, torch.float64), From 5c063f9ec3fd2585015193dc56953b5587c0b9fc Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Mar 2025 11:33:29 +0100 Subject: [PATCH 19/26] fix remaining test --- .../function_libs/torch_lib/ops/core.py | 5 + onnxscript/ir/_enums.py | 2 + onnxscript/optimizer/_constant_folding.py | 6 +- .../function_libs/torch_lib/e2e_ops_tests.py | 93 ++++++++----------- .../function_libs/torch_lib/ops_test_data.py | 5 - 5 files changed, 49 insertions(+), 62 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 996c65fe33..2580ffb97e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7601,6 +7601,7 @@ def aten_scatter_reduce( # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. # mean is not supported. dtype = src.dtype or self.dtype + post_process_after_cast_like = False if dtype is None: dtype = ir.DataType.FLOAT cast_like = True @@ -7620,6 +7621,7 @@ def aten_scatter_reduce( value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) reduction_init = "min" elif onnx_reduce == "min": + post_process_after_cast_like = cast_like if dtype in { ir.DataType.FLOAT16, ir.DataType.FLOAT, @@ -7629,6 +7631,7 @@ def aten_scatter_reduce( elif dtype == ir.DataType.BFLOAT16: value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).max], dtype=dtype) else: + # Cast 1e20 into int32 returns the min value -2147483648 value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) reduction_init = "max" elif onnx_reduce == "add": @@ -7644,6 +7647,8 @@ def aten_scatter_reduce( cst = op.ConstantOfShape(op.Shape(src), value=value) if cast_like: cst = op.CastLike(cst, self) + if post_process_after_cast_like: + cst = op.Max(cst, op.Neg(op.Add(cst, op.CastLike(1, cst)))) self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init) result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce) diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py index d0d8c19270..31a119eef3 100644 --- a/onnxscript/ir/_enums.py +++ b/onnxscript/ir/_enums.py @@ -74,6 +74,8 @@ def from_numpy(cls, dtype: np.dtype) -> DataType: TypeError: If the data type is not supported by ONNX. """ if dtype not in _NP_TYPE_TO_DATA_TYPE: + if dtype == ml_dtypes.bfloat16 or (hasattr(dtype, "names") and dtype.names == ("bfloat16",)): + return DataType.BFLOAT16 raise TypeError(f"Unsupported numpy data type: {dtype}") return cls(_NP_TYPE_TO_DATA_TYPE[dtype]) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 3b91e378d2..b30b3ec3c4 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -827,7 +827,11 @@ def _do_inference(self, node: ir.Node) -> None: def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: value = _get_numpy_value(x) if isinstance(value, np.ndarray) and value.size < 20: - return onnx.numpy_helper.from_array(value, x.name) + try: + return onnx.numpy_helper.from_array(value, x.name) + except ValueError: + # This happens for bfloat16 and old versions of ONNX. + return None return None def get_type(value: ir.Value) -> onnx.TypeProto | None: diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 4b594a8cc0..8327b9a135 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -11,64 +11,45 @@ class TorchLibe2eTest(testutils.TestBase): - - def test_aten_scatter_reduce_include_self(self): - # known failing configurations because of onnxruntime - skip_ort = { - ("sum", False, "float16"), - ("sum", True, "float16"), - ("prod", False, "float16"), - ("prod", True, "float16"), - } - - for red, include, stype in itertools.product( - ["amin", "amax", "sum", "prod"], - [False, True], - ["bfloat16", "float32", "float16", "int32", "int64", "float64"], - ): - with self.subTest(reduce=red, include=include, type=stype): - key = red, include, stype - if key in skip_ort: - continue - dtype = getattr(torch, stype) - - class Model(torch.nn.Module): - def __init__(self, include, red): - super().__init__() - self.include = include - self.red = red - - def forward(self, x, indices, updates): - x = x.clone() - return x.scatter_reduce( - 0, indices, updates, self.red, include_self=self.include - ) - - model = Model(include, red) - xs = ( - torch.tensor([[-2, 0, 2], [2, -2, 0]], dtype=dtype), - torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64), - torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype), - ) - expected = model(*xs) - model_path = ( - f"test_aten_scatter_{red}_" - f"{'include' if include else 'exclude'}_{stype}.onnx" + def test_investigate_one_particular_model(self): + """This test can be used to investigate a particular issue.""" + red, include, stype = "amin", False, "int32" + dtype = getattr(torch, stype) + + class Model(torch.nn.Module): + def __init__(self, include, red): + super().__init__() + self.include = include + self.red = red + + def forward(self, x, indices, updates): + x = x.clone() + return x.scatter_reduce( + 0, indices, updates, self.red, include_self=self.include ) - torch.onnx.export(model, xs, model_path, dynamo=True) - if stype == "bfloat16": - # not supported yet by onnxruntime - continue - feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs])) - sess_options = onnxruntime.SessionOptions() - sess = onnxruntime.InferenceSession( - model_path, sess_options=sess_options, providers=["CPUExecutionProvider"] - ) - got = sess.run(None, feeds)[0] - torch.testing.assert_close( - expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5 - ) + model = Model(include, red) + xs = ( + torch.tensor([[-2, 0, 2], [2, -2, 0]], dtype=dtype), + torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64), + torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype), + ) + expected = model(*xs) + model_path = ( + f"test_aten_scatter_{red}_" + f"{'include' if include else 'exclude'}_{stype}.onnx" + ) + torch.onnx.export(model, xs, model_path, dynamo=True) + feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs])) + + sess_options = onnxruntime.SessionOptions() + sess = onnxruntime.InferenceSession( + model_path, sess_options=sess_options, providers=["CPUExecutionProvider"] + ) + got = sess.run(None, feeds)[0] + torch.testing.assert_close( + expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5 + ) if __name__ == "__main__": diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a2fa5ed2cf..c759bb3da2 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2030,11 +2030,6 @@ def _where_input_wrangler( dtypes=(torch.float16, torch.float64), reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'", ) - .xfail( - variant_name="amin", - dtypes=(torch.int32, torch.int64), - reason="fixme: discrepancies, that should be investigated", - ) .xfail( variant_name="sum", dtypes=(torch.float16, torch.float64), From d2a40e5df7f562d0a0a2fd8f300d92f00bc6b140 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Mar 2025 12:38:37 +0100 Subject: [PATCH 20/26] Update tests/function_libs/torch_lib/e2e_ops_tests.py Co-authored-by: Justin Chu --- tests/function_libs/torch_lib/e2e_ops_tests.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 8327b9a135..29615d12cd 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo + import itertools import unittest From f86e2f6f73ae1779cb428def5022e454ea535ce0 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Mar 2025 13:19:30 +0100 Subject: [PATCH 21/26] lint --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- onnxscript/ir/_enums.py | 4 +++- tests/function_libs/torch_lib/e2e_ops_tests.py | 8 ++------ 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2580ffb97e..acc530a1d2 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -14,8 +14,8 @@ import math from typing import Any, Optional, Sequence, Tuple, Union -import numpy as np import ml_dtypes +import numpy as np from onnxscript import ( BFLOAT16, diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py index 31a119eef3..5c9d3e32b4 100644 --- a/onnxscript/ir/_enums.py +++ b/onnxscript/ir/_enums.py @@ -74,7 +74,9 @@ def from_numpy(cls, dtype: np.dtype) -> DataType: TypeError: If the data type is not supported by ONNX. """ if dtype not in _NP_TYPE_TO_DATA_TYPE: - if dtype == ml_dtypes.bfloat16 or (hasattr(dtype, "names") and dtype.names == ("bfloat16",)): + if dtype == ml_dtypes.bfloat16 or ( + hasattr(dtype, "names") and dtype.names == ("bfloat16",) + ): return DataType.BFLOAT16 raise TypeError(f"Unsupported numpy data type: {dtype}") return cls(_NP_TYPE_TO_DATA_TYPE[dtype]) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 8327b9a135..c58c162991 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import itertools import unittest import onnxruntime @@ -36,8 +35,7 @@ def forward(self, x, indices, updates): ) expected = model(*xs) model_path = ( - f"test_aten_scatter_{red}_" - f"{'include' if include else 'exclude'}_{stype}.onnx" + f"test_aten_scatter_{red}_{'include' if include else 'exclude'}_{stype}.onnx" ) torch.onnx.export(model, xs, model_path, dynamo=True) feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs])) @@ -47,9 +45,7 @@ def forward(self, x, indices, updates): model_path, sess_options=sess_options, providers=["CPUExecutionProvider"] ) got = sess.run(None, feeds)[0] - torch.testing.assert_close( - expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5 - ) + torch.testing.assert_close(expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5) if __name__ == "__main__": From a4d17ff212fb0f5d5a209d544e5a7ad4c00214d0 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Mar 2025 17:17:30 +0100 Subject: [PATCH 22/26] comment --- onnxscript/function_libs/torch_lib/ops/core.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index acc530a1d2..5dc2bb1f15 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7648,6 +7648,12 @@ def aten_scatter_reduce( if cast_like: cst = op.CastLike(cst, self) if post_process_after_cast_like: + # torch.tensor(1e20, dtype=torch.float32).to(torch.int32) -> + # -2147483648 and we need 2147483647. These extra operators + # compute that value. It could be a constant but this + # works for int16, int32, int64. + # This is not added where one of the input type (src or self) + # is known. Constant folding should fold them anyway. cst = op.Max(cst, op.Neg(op.Add(cst, op.CastLike(1, cst)))) self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init) From 33eac3ea8002e38e2525685d01e9210ce7285cc0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 26 Mar 2025 09:05:22 -0700 Subject: [PATCH 23/26] Refactor get_constant_value function --- onnxscript/optimizer/_constant_folding.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 54c15bdded..04ae756bb0 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -829,13 +829,10 @@ def _do_inference(self, node: ir.Node) -> None: # TODO: handle optional inputs def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: - value = _get_numpy_value(x) - if isinstance(value, np.ndarray) and value.size < 20: - try: - return onnx.numpy_helper.from_array(value, x.name) - except ValueError: - # This happens for bfloat16 and old versions of ONNX. - return None + value = _get_numpy_value(x, size_limit=20) + if value is not None: + assert x.const_value is not None + return ir.serde.serialize_tensor(x.const_value) return None def get_type(value: ir.Value) -> onnx.TypeProto | None: From fe0c0edc2cc6d960185ec81a734e0944b8a7ac10 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 26 Mar 2025 11:38:42 -0700 Subject: [PATCH 24/26] Apply suggestions from code review --- onnxscript/function_libs/torch_lib/ops/core.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d7afd73595..1a94b73801 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7617,12 +7617,6 @@ def aten_scatter_reduce( # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. # mean is not supported. dtype = src.dtype or self.dtype - post_process_after_cast_like = False - if dtype is None: - dtype = ir.DataType.FLOAT - cast_like = True - else: - cast_like = False # dtype should be not None. if onnx_reduce == "max": if dtype in { @@ -7637,7 +7631,6 @@ def aten_scatter_reduce( value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) reduction_init = "min" elif onnx_reduce == "min": - post_process_after_cast_like = cast_like if dtype in { ir.DataType.FLOAT16, ir.DataType.FLOAT, @@ -7661,16 +7654,6 @@ def aten_scatter_reduce( reduction_init = "none" cst = op.ConstantOfShape(op.Shape(src), value=value) - if cast_like: - cst = op.CastLike(cst, self) - if post_process_after_cast_like: - # torch.tensor(1e20, dtype=torch.float32).to(torch.int32) -> - # -2147483648 and we need 2147483647. These extra operators - # compute that value. It could be a constant but this - # works for int16, int32, int64. - # This is not added where one of the input type (src or self) - # is known. Constant folding should fold them anyway. - cst = op.Max(cst, op.Neg(op.Add(cst, op.CastLike(1, cst)))) self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init) result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce) From 0f1e996c179150b04836d672fafdb27352cb71a6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 31 Mar 2025 12:42:53 -0700 Subject: [PATCH 25/26] Add dtype assertion and fix dtype checks --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1a94b73801..97bb984bae 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7602,6 +7602,9 @@ def aten_scatter_reduce( "amax": "max", } onnx_reduce = reduce_mode[reduce] + dtype = src.dtype or self.dtype + assert dtype is not None, "dtype should be not None" + self_is_scalar = len(self.shape) == 0 if self_is_scalar: # assert (index_rank == 0 and rank_src == 0) neg_1 = op.Constant(value_ints=[-1]) @@ -7616,8 +7619,6 @@ def aten_scatter_reduce( # whether or not it takes part in it. # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. # mean is not supported. - dtype = src.dtype or self.dtype - # dtype should be not None. if onnx_reduce == "max": if dtype in { ir.DataType.FLOAT16, @@ -7625,7 +7626,7 @@ def aten_scatter_reduce( ir.DataType.DOUBLE, }: value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) - elif dtype in {ir.DataType.BFLOAT16}: + elif dtype == ir.DataType.BFLOAT16: value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).min], dtype=dtype) else: value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) @@ -7640,7 +7641,6 @@ def aten_scatter_reduce( elif dtype == ir.DataType.BFLOAT16: value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).max], dtype=dtype) else: - # Cast 1e20 into int32 returns the min value -2147483648 value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) reduction_init = "max" elif onnx_reduce == "add": From dfd613b85364862ef7cff111c41fa819e7ad0620 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 31 Mar 2025 12:59:04 -0700 Subject: [PATCH 26/26] Replace ml_dtypes with torch for BFLOAT16 --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 97bb984bae..ba3b9bfb3f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -14,8 +14,8 @@ import math from typing import Any, Optional, Sequence, Tuple, Union -import ml_dtypes import numpy as np +import torch from onnxscript import ( BFLOAT16, @@ -7627,7 +7627,7 @@ def aten_scatter_reduce( }: value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) elif dtype == ir.DataType.BFLOAT16: - value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).min], dtype=dtype) + value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) else: value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) reduction_init = "min" @@ -7639,7 +7639,7 @@ def aten_scatter_reduce( }: value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) elif dtype == ir.DataType.BFLOAT16: - value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).max], dtype=dtype) + value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) else: value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) reduction_init = "max"