1+ # the code is torch_scatter: https://github.com/rusty1s/pytorch_scatter/blob/1.3.0/torch_scatter/add.py
2+ from itertools import repeat
3+
4+ def maybe_dim_size (index , dim_size = None ):
5+ if dim_size is not None :
6+ return dim_size
7+ return index .max ().item () + 1 if index .numel () > 0 else 0
8+
9+
10+ def gen (src , index , dim = - 1 , out = None , dim_size = None , fill_value = 0 ):
11+ dim = range (src .dim ())[dim ] # Get real dim value.
12+
13+ # Automatically expand index tensor to the right dimensions.
14+ if index .dim () == 1 :
15+ index_size = list (repeat (1 , src .dim ()))
16+ index_size [dim ] = src .size (dim )
17+ index = index .view (index_size ).expand_as (src )
18+
19+ # Generate output tensor if not given.
20+ if out is None :
21+ out_size = list (src .size ())
22+ dim_size = maybe_dim_size (index , dim_size )
23+ out_size [dim ] = dim_size
24+ out = src .new_full (out_size , fill_value )
25+
26+ return src , out , index , dim
27+
28+ def scatter_add (src , index , dim = - 1 , out = None , dim_size = None , fill_value = 0 ):
29+ r"""
30+ |
31+
32+ .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
33+ master/docs/source/_figures/add.svg?sanitize=true
34+ :align: center
35+ :width: 400px
36+
37+ |
38+
39+ Sums all values from the :attr:`src` tensor into :attr:`out` at the indices
40+ specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
41+ each value in :attr:`src`, its output index is specified by its index in
42+ :attr:`input` for dimensions outside of :attr:`dim` and by the
43+ corresponding value in :attr:`index` for dimension :attr:`dim`. If
44+ multiple indices reference the same location, their **contributions add**.
45+
46+ Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with
47+ size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
48+ :attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
49+ size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
50+ values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
51+
52+ For one-dimensional tensors, the operation computes
53+
54+ .. math::
55+ \mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j
56+
57+ where :math:`\sum_j` is over :math:`j` such that
58+ :math:`\mathrm{index}_j = i`.
59+
60+ Args:
61+ src (Tensor): The source tensor.
62+ index (LongTensor): The indices of elements to scatter.
63+ dim (int, optional): The axis along which to index.
64+ (default: :obj:`-1`)
65+ out (Tensor, optional): The destination tensor. (default: :obj:`None`)
66+ dim_size (int, optional): If :attr:`out` is not given, automatically
67+ create output with size :attr:`dim_size` at dimension :attr:`dim`.
68+ If :attr:`dim_size` is not given, a minimal sized output tensor is
69+ returned. (default: :obj:`None`)
70+ fill_value (int, optional): If :attr:`out` is not given, automatically
71+ fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
72+
73+ :rtype: :class:`Tensor`
74+
75+ .. testsetup::
76+
77+ import torch
78+
79+ .. testcode::
80+
81+ from torch_scatter import scatter_add
82+
83+ src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
84+ index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
85+ out = src.new_zeros((2, 6))
86+
87+ out = scatter_add(src, index, out=out)
88+
89+ print(out)
90+
91+ .. testoutput::
92+
93+ tensor([[0., 0., 4., 3., 3., 0.],
94+ [2., 4., 4., 0., 0., 0.]])
95+ """
96+ src , out , index , dim = gen (src , index , dim , out , dim_size , fill_value )
97+ return out .scatter_add_ (dim , index , src )
0 commit comments