Skip to content

Commit 37d0813

Browse files
committed
Change the normalization of Gramian to use the Frobenius norm instead of the spectral norm
1 parent 0017a56 commit 37d0813

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

src/torchjd/aggregation/_gramian_utils.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,23 @@ def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg
1616

1717

1818
def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor:
19-
r"""
20-
Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where
21-
:math:`{\sigma_\max^2}` is :math:`J`'s largest singular value.
22-
.. hint::
23-
:math:`J J^T` is the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of
24-
:math:`J`
25-
For a given matrix :math:`J` with SVD: :math:`J = U S V^T`, we can see that:
26-
.. math::
27-
\frac{1}{\sigma_\max^2} J J^T = \frac{1}{\sigma_\max^2} U S V^T V S^T U^T = U
28-
\left( \frac{S}{\sigma_\max} \right)^2 U^T
29-
This is the quantity we compute.
30-
.. note::
31-
If the provided matrix has dimension :math:`m \times n`, the computation only depends on
32-
:math:`n` through the SVD algorithm which is efficient, therefore this is rather fast.
19+
gramian = _compute_gramian(matrix)
20+
return _normalize(gramian, eps)
21+
22+
23+
def _normalize(gramian: Tensor, eps: float) -> Tensor:
3324
"""
25+
Normalizes the gramian with respect to the Frobenius norm.
3426
35-
left_unitary_matrix, singular_values, _ = torch.linalg.svd(matrix, full_matrices=False)
36-
max_singular_value = torch.max(singular_values)
37-
if max_singular_value < eps:
38-
scaled_singular_values = torch.zeros_like(singular_values)
27+
If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the
28+
sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is
29+
therefore `G` divided by the sum of its diagonal elements.
30+
"""
31+
squared_frobenius_norm = gramian.diagonal().sum()
32+
if squared_frobenius_norm < eps:
33+
return torch.zeros_like(gramian)
3934
else:
40-
scaled_singular_values = singular_values / max_singular_value
41-
normalized_gramian = (
42-
left_unitary_matrix @ torch.diag(scaled_singular_values**2) @ left_unitary_matrix.T
43-
)
44-
return normalized_gramian
35+
return gramian / squared_frobenius_norm
4536

4637

4738
def _regularize(gramian: Tensor, eps: float) -> Tensor:

0 commit comments

Comments
 (0)