-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_flattening.py
More file actions
37 lines (29 loc) · 1.5 KB
/
_flattening.py
File metadata and controls
37 lines (29 loc) · 1.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from math import prod
from torch import Tensor
from torchjd._linalg import PSDMatrix, PSDQuadraticForm, is_psd_matrix
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting
from torchjd.autogram._gramian_utils import reshape_gramian
class Flattening(GeneralizedWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting` flattening the generalized
Gramian into a square matrix, extracting a vector of weights from it using a
:class:`~torchjd.aggregation._weighting_bases.Weighting`, and returning the reshaped tensor of
weights.
For instance, when applied to a generalized Gramian of shape ``[2, 3, 3, 2]``, it would flatten
it into a square Gramian matrix of shape ``[6, 6]``, apply the weighting on it to get a vector
of weights of shape ``[6]``, and then return this vector reshaped into a matrix of shape
``[2, 3]``.
:param weighting: The weighting to apply to the Gramian matrix.
"""
def __init__(self, weighting: Weighting[PSDMatrix]):
super().__init__()
self.weighting = weighting
def forward(self, generalized_gramian: PSDQuadraticForm) -> Tensor:
k = generalized_gramian.ndim // 2
shape = generalized_gramian.shape[:k]
m = prod(shape)
square_gramian = reshape_gramian(generalized_gramian, [m])
assert is_psd_matrix(square_gramian)
weights_vector = self.weighting(square_gramian)
weights = weights_vector.reshape(shape)
return weights