-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_mean.py
More file actions
31 lines (23 loc) · 871 Bytes
/
_mean.py
File metadata and controls
31 lines (23 loc) · 871 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch import Tensor
from torchjd._linalg import Matrix
from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Weighting
class Mean(WeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input
matrices.
"""
def __init__(self):
super().__init__(weighting=MeanWeighting())
class MeanWeighting(Weighting[Matrix]):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
\mathbb{R}^m`.
"""
def forward(self, matrix: Tensor, /) -> Tensor:
device = matrix.device
dtype = matrix.dtype
m = matrix.shape[0]
return torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype)