Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c451d4d
Fix include_self for scatter_reduce
xadupre Mar 7, 2025
87c4085
fix values
xadupre Mar 7, 2025
0a6c181
lint
xadupre Mar 7, 2025
a11b84a
Merge branch 'main' of https://github.com/microsoft/onnxscript into s…
xadupre Mar 9, 2025
9aaf142
Update onnxscript/function_libs/torch_lib/ops/core.py
xadupre Mar 10, 2025
b89b653
Merge branch 'scatter' of https://github.com/xadupre/onnxscript into …
xadupre Mar 10, 2025
27d0ff7
more comments
xadupre Mar 10, 2025
2ed4f9b
dtype
xadupre Mar 10, 2025
8813d44
Merge branch 'main' into scatter
xadupre Mar 14, 2025
f00fe37
use inf
xadupre Mar 14, 2025
e662f81
fix one bug
xadupre Mar 14, 2025
914283f
Update tests/function_libs/torch_lib/ops_test_data.py
justinchuby Mar 14, 2025
ba2a563
Add float16 dtype to xfail tests
justinchuby Mar 14, 2025
6e9e756
Fix tuple syntax for torch.float16 dtypes
justinchuby Mar 14, 2025
d7ab3b8
fix dtype
xadupre Mar 14, 2025
dd5e5d3
Merge branch 'main' of https://github.com/microsoft/onnxscript into s…
xadupre Mar 17, 2025
8af5353
simple try
xadupre Mar 17, 2025
559888d
variant
xadupre Mar 17, 2025
3976cd6
lint
xadupre Mar 17, 2025
020ec1c
Update onnxscript/function_libs/torch_lib/ops/core.py
xadupre Mar 17, 2025
02fdc55
fix missing type
xadupre Mar 17, 2025
f8071df
merge
xadupre Mar 17, 2025
b6b57f7
disable two tests
xadupre Mar 17, 2025
8ce51a5
Merge branch 'main' of https://github.com/microsoft/onnxscript into s…
xadupre Mar 18, 2025
5c063f9
fix remaining test
xadupre Mar 18, 2025
d2a40e5
Update tests/function_libs/torch_lib/e2e_ops_tests.py
xadupre Mar 18, 2025
f86e2f6
lint
xadupre Mar 18, 2025
673a32d
fix merhe
xadupre Mar 18, 2025
a4d17ff
comment
xadupre Mar 18, 2025
7fa5127
Merge branch 'main' into scatter
justinchuby Mar 25, 2025
33eac3e
Refactor get_constant_value function
justinchuby Mar 26, 2025
fe0c0ed
Apply suggestions from code review
justinchuby Mar 26, 2025
2743b8e
Merge branch 'main' into scatter
titaiwangms Mar 26, 2025
7fc0c27
Merge branch 'main' into scatter
justinchuby Mar 31, 2025
0f1e996
Add dtype assertion and fix dtype checks
justinchuby Mar 31, 2025
dfd613b
Replace ml_dtypes with torch for BFLOAT16
justinchuby Mar 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import math
from typing import Any, Optional, Sequence, Tuple, Union

import numpy as np
import ml_dtypes
Comment thread Fixed
Comment thread Fixed

from onnxscript import (
BFLOAT16,
BOOL,
Expand Down Expand Up @@ -7589,7 +7592,67 @@ def aten_scatter_reduce(
self = op.Reshape(self, neg_1)
index = op.Reshape(index, neg_1)
src = op.Reshape(src, neg_1)

if not include_self:
Comment thread
xadupre marked this conversation as resolved.
# onnx standard always assume the value from self is part of the reduction.
# A first step is added to replace the impacted value by another one
# chosen in a way that the results of the reduction is not changed
# whether or not it takes part in it.
# It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul.
# mean is not supported.
dtype = src.dtype or self.dtype
post_process_after_cast_like = False
if dtype is None:
dtype = ir.DataType.FLOAT
cast_like = True
else:
cast_like = False
Comment thread
justinchuby marked this conversation as resolved.
Outdated
# dtype should be not None.
if onnx_reduce == "max":
if dtype in {
ir.DataType.FLOAT16,
ir.DataType.FLOAT,
ir.DataType.DOUBLE,
}:
value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype)
elif dtype in {ir.DataType.BFLOAT16}:
value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).min], dtype=dtype)
Comment thread Fixed
Comment thread Fixed
else:
value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype)
reduction_init = "min"
elif onnx_reduce == "min":
post_process_after_cast_like = cast_like
Comment thread
justinchuby marked this conversation as resolved.
Outdated
if dtype in {
ir.DataType.FLOAT16,
ir.DataType.FLOAT,
ir.DataType.DOUBLE,
}:
value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype)
elif dtype == ir.DataType.BFLOAT16:
value = ir.tensor([ml_dtypes.finfo(dtype.numpy()).max], dtype=dtype)
else:
# Cast 1e20 into int32 returns the min value -2147483648
value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype)
reduction_init = "max"
elif onnx_reduce == "add":
value = ir.tensor([0], dtype=dtype)
reduction_init = "none"
elif onnx_reduce == "mul":
value = ir.tensor([1], dtype=dtype)
reduction_init = "none"
else:
value = 0
Comment thread
xadupre marked this conversation as resolved.
reduction_init = "none"

cst = op.ConstantOfShape(op.Shape(src), value=value)
if cast_like:
cst = op.CastLike(cst, self)
if post_process_after_cast_like:
cst = op.Max(cst, op.Neg(op.Add(cst, op.CastLike(1, cst))))
self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init)

result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce)

if self_is_scalar:
result = op.Squeeze(result)
return result
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/ir/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def from_numpy(cls, dtype: np.dtype) -> DataType:
TypeError: If the data type is not supported by ONNX.
"""
if dtype not in _NP_TYPE_TO_DATA_TYPE:
if dtype == ml_dtypes.bfloat16 or (hasattr(dtype, "names") and dtype.names == ("bfloat16",)):
return DataType.BFLOAT16
raise TypeError(f"Unsupported numpy data type: {dtype}")
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])

Expand Down
6 changes: 5 additions & 1 deletion onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,11 @@ def _do_inference(self, node: ir.Node) -> None:
def get_constant_value(x: ir.Value) -> onnx.TensorProto | None:
value = _get_numpy_value(x)
if isinstance(value, np.ndarray) and value.size < 20:
return onnx.numpy_helper.from_array(value, x.name)
try:
return onnx.numpy_helper.from_array(value, x.name)
except ValueError:
# This happens for bfloat16 and old versions of ONNX.
Comment thread
justinchuby marked this conversation as resolved.
Outdated
return None
return None

def get_type(value: ir.Value) -> onnx.TypeProto | None:
Expand Down
58 changes: 58 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
# Licensed under the MIT License.
Comment thread
xadupre marked this conversation as resolved.

# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo

import itertools
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
import unittest

import onnxruntime
import torch

from tests.common import testutils


class TorchLibe2eTest(testutils.TestBase):
Comment thread
xadupre marked this conversation as resolved.
def test_investigate_one_particular_model(self):
"""This test can be used to investigate a particular issue."""
red, include, stype = "amin", False, "int32"
dtype = getattr(torch, stype)

class Model(torch.nn.Module):
def __init__(self, include, red):
super().__init__()
self.include = include
self.red = red

def forward(self, x, indices, updates):
x = x.clone()
return x.scatter_reduce(
0, indices, updates, self.red, include_self=self.include
)

model = Model(include, red)
xs = (
torch.tensor([[-2, 0, 2], [2, -2, 0]], dtype=dtype),
torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64),
torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype),
)
expected = model(*xs)
model_path = (
f"test_aten_scatter_{red}_"
f"{'include' if include else 'exclude'}_{stype}.onnx"
)
torch.onnx.export(model, xs, model_path, dynamo=True)
feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs]))

sess_options = onnxruntime.SessionOptions()
sess = onnxruntime.InferenceSession(
model_path, sess_options=sess_options, providers=["CPUExecutionProvider"]
)
got = sess.run(None, feeds)[0]
torch.testing.assert_close(
expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
unittest.main()
24 changes: 14 additions & 10 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,26 +2025,30 @@ def _where_input_wrangler(
variant_name="mean",
reason="ONNX doesn't support reduce='mean' option",
)
.skip(
# ONNX has not include_self parameter and default is include_self=True mode
matcher=lambda sample: sample.kwargs.get("include_self") is False,
reason="ONNX does't support include_self=False option",
.xfail(
variant_name="prod",
dtypes=(torch.float16, torch.float64),
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'",
)
.xfail(
variant_name="amax",
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'",
variant_name="sum",
dtypes=(torch.float16, torch.float64),
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'",
)
.xfail(
variant_name="amin",
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'",
variant_name="mean",
dtypes=(torch.bfloat16,),
reason="onnxruntime does not support ml_dtypes.bfloat16",
)
.xfail(
variant_name="prod",
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'prod'",
dtypes=(torch.bfloat16,),
reason="onnxruntime does not support ml_dtypes.bfloat16",
)
.xfail(
variant_name="sum",
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'",
dtypes=(torch.bfloat16,),
Comment thread
xadupre marked this conversation as resolved.
reason="onnxruntime does not support ml_dtypes.bfloat16",
),
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter),
TorchLibOpInfo("slice", core_ops.aten_slice),
Expand Down