|
14 | 14 | import math |
15 | 15 | from typing import Any, Optional, Sequence, Tuple, Union |
16 | 16 |
|
| 17 | +import numpy as np |
| 18 | +import torch |
| 19 | + |
17 | 20 | from onnxscript import ( |
18 | 21 | BFLOAT16, |
19 | 22 | BOOL, |
@@ -7512,13 +7515,62 @@ def aten_scatter_reduce( |
7512 | 7515 | "amax": "max", |
7513 | 7516 | } |
7514 | 7517 | onnx_reduce = reduce_mode[reduce] |
| 7518 | + dtype = src.dtype or self.dtype |
| 7519 | + assert dtype is not None, "dtype should be not None" |
| 7520 | + |
7515 | 7521 | self_is_scalar = len(self.shape) == 0 |
7516 | 7522 | if self_is_scalar: # assert (index_rank == 0 and rank_src == 0) |
7517 | 7523 | neg_1 = op.Constant(value_ints=[-1]) |
7518 | 7524 | self = op.Reshape(self, neg_1) |
7519 | 7525 | index = op.Reshape(index, neg_1) |
7520 | 7526 | src = op.Reshape(src, neg_1) |
| 7527 | + |
| 7528 | + if not include_self: |
| 7529 | + # onnx standard always assume the value from self is part of the reduction. |
| 7530 | + # A first step is added to replace the impacted value by another one |
| 7531 | + # chosen in a way that the results of the reduction is not changed |
| 7532 | + # whether or not it takes part in it. |
| 7533 | + # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. |
| 7534 | + # mean is not supported. |
| 7535 | + if onnx_reduce == "max": |
| 7536 | + if dtype in { |
| 7537 | + ir.DataType.FLOAT16, |
| 7538 | + ir.DataType.FLOAT, |
| 7539 | + ir.DataType.DOUBLE, |
| 7540 | + }: |
| 7541 | + value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) |
| 7542 | + elif dtype == ir.DataType.BFLOAT16: |
| 7543 | + value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) |
| 7544 | + else: |
| 7545 | + value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) |
| 7546 | + reduction_init = "min" |
| 7547 | + elif onnx_reduce == "min": |
| 7548 | + if dtype in { |
| 7549 | + ir.DataType.FLOAT16, |
| 7550 | + ir.DataType.FLOAT, |
| 7551 | + ir.DataType.DOUBLE, |
| 7552 | + }: |
| 7553 | + value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) |
| 7554 | + elif dtype == ir.DataType.BFLOAT16: |
| 7555 | + value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) |
| 7556 | + else: |
| 7557 | + value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) |
| 7558 | + reduction_init = "max" |
| 7559 | + elif onnx_reduce == "add": |
| 7560 | + value = ir.tensor([0], dtype=dtype) |
| 7561 | + reduction_init = "none" |
| 7562 | + elif onnx_reduce == "mul": |
| 7563 | + value = ir.tensor([1], dtype=dtype) |
| 7564 | + reduction_init = "none" |
| 7565 | + else: |
| 7566 | + value = 0 |
| 7567 | + reduction_init = "none" |
| 7568 | + |
| 7569 | + cst = op.ConstantOfShape(op.Shape(src), value=value) |
| 7570 | + self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init) |
| 7571 | + |
7521 | 7572 | result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce) |
| 7573 | + |
7522 | 7574 | if self_is_scalar: |
7523 | 7575 | result = op.Squeeze(result) |
7524 | 7576 | return result |
|
0 commit comments