-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy path_sum.py
More file actions
30 lines (22 loc) · 842 Bytes
/
_sum.py
File metadata and controls
30 lines (22 loc) · 842 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch import Tensor
from torchjd._utils.compute_gramian import Matrix
from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Weighting
class Sum(WeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that sums of the rows of the input
matrices.
"""
def __init__(self):
super().__init__(weighting=SumWeighting())
class SumWeighting(Weighting[Matrix]):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
:math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`.
"""
def forward(self, matrix: Tensor) -> Tensor:
device = matrix.device
dtype = matrix.dtype
weights = torch.ones(matrix.shape[0], device=device, dtype=dtype)
return weights