|
14 | 14 | import math |
15 | 15 | from typing import Any, Optional, Sequence, Tuple, Union |
16 | 16 |
|
| 17 | +import numpy as np |
| 18 | +import torch |
| 19 | + |
17 | 20 | from onnxscript import ( |
18 | 21 | BFLOAT16, |
19 | 22 | BOOL, |
@@ -7599,13 +7602,62 @@ def aten_scatter_reduce( |
7599 | 7602 | "amax": "max", |
7600 | 7603 | } |
7601 | 7604 | onnx_reduce = reduce_mode[reduce] |
| 7605 | + dtype = src.dtype or self.dtype |
| 7606 | + assert dtype is not None, "dtype should be not None" |
| 7607 | + |
7602 | 7608 | self_is_scalar = len(self.shape) == 0 |
7603 | 7609 | if self_is_scalar: # assert (index_rank == 0 and rank_src == 0) |
7604 | 7610 | neg_1 = op.Constant(value_ints=[-1]) |
7605 | 7611 | self = op.Reshape(self, neg_1) |
7606 | 7612 | index = op.Reshape(index, neg_1) |
7607 | 7613 | src = op.Reshape(src, neg_1) |
| 7614 | + |
| 7615 | + if not include_self: |
| 7616 | + # onnx standard always assume the value from self is part of the reduction. |
| 7617 | + # A first step is added to replace the impacted value by another one |
| 7618 | + # chosen in a way that the results of the reduction is not changed |
| 7619 | + # whether or not it takes part in it. |
| 7620 | + # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. |
| 7621 | + # mean is not supported. |
| 7622 | + if onnx_reduce == "max": |
| 7623 | + if dtype in { |
| 7624 | + ir.DataType.FLOAT16, |
| 7625 | + ir.DataType.FLOAT, |
| 7626 | + ir.DataType.DOUBLE, |
| 7627 | + }: |
| 7628 | + value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) |
| 7629 | + elif dtype == ir.DataType.BFLOAT16: |
| 7630 | + value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) |
| 7631 | + else: |
| 7632 | + value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) |
| 7633 | + reduction_init = "min" |
| 7634 | + elif onnx_reduce == "min": |
| 7635 | + if dtype in { |
| 7636 | + ir.DataType.FLOAT16, |
| 7637 | + ir.DataType.FLOAT, |
| 7638 | + ir.DataType.DOUBLE, |
| 7639 | + }: |
| 7640 | + value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) |
| 7641 | + elif dtype == ir.DataType.BFLOAT16: |
| 7642 | + value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) |
| 7643 | + else: |
| 7644 | + value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) |
| 7645 | + reduction_init = "max" |
| 7646 | + elif onnx_reduce == "add": |
| 7647 | + value = ir.tensor([0], dtype=dtype) |
| 7648 | + reduction_init = "none" |
| 7649 | + elif onnx_reduce == "mul": |
| 7650 | + value = ir.tensor([1], dtype=dtype) |
| 7651 | + reduction_init = "none" |
| 7652 | + else: |
| 7653 | + value = 0 |
| 7654 | + reduction_init = "none" |
| 7655 | + |
| 7656 | + cst = op.ConstantOfShape(op.Shape(src), value=value) |
| 7657 | + self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init) |
| 7658 | + |
7608 | 7659 | result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce) |
| 7660 | + |
7609 | 7661 | if self_is_scalar: |
7610 | 7662 | result = op.Squeeze(result) |
7611 | 7663 | return result |
|
0 commit comments