@@ -16,32 +16,23 @@ def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg
1616
1717
1818def _compute_normalized_gramian (matrix : Tensor , eps : float ) -> Tensor :
19- r"""
20- Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where
21- :math:`{\sigma_\max^2}` is :math:`J`'s largest singular value.
22- .. hint::
23- :math:`J J^T` is the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of
24- :math:`J`
25- For a given matrix :math:`J` with SVD: :math:`J = U S V^T`, we can see that:
26- .. math::
27- \frac{1}{\sigma_\max^2} J J^T = \frac{1}{\sigma_\max^2} U S V^T V S^T U^T = U
28- \left( \frac{S}{\sigma_\max} \right)^2 U^T
29- This is the quantity we compute.
30- .. note::
31- If the provided matrix has dimension :math:`m \times n`, the computation only depends on
32- :math:`n` through the SVD algorithm which is efficient, therefore this is rather fast.
19+ gramian = _compute_gramian (matrix )
20+ return _normalize (gramian , eps )
21+
22+
23+ def _normalize (gramian : Tensor , eps : float ) -> Tensor :
3324 """
25+ Normalizes the gramian with respect to the Frobenius norm.
3426
35- left_unitary_matrix , singular_values , _ = torch .linalg .svd (matrix , full_matrices = False )
36- max_singular_value = torch .max (singular_values )
37- if max_singular_value < eps :
38- scaled_singular_values = torch .zeros_like (singular_values )
27+ If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the
28+ sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is
29+ therefore `G` divided by the sum of its diagonal elements.
30+ """
31+ squared_frobenius_norm = gramian .diagonal ().sum ()
32+ if squared_frobenius_norm < eps :
33+ return torch .zeros_like (gramian )
3934 else :
40- scaled_singular_values = singular_values / max_singular_value
41- normalized_gramian = (
42- left_unitary_matrix @ torch .diag (scaled_singular_values ** 2 ) @ left_unitary_matrix .T
43- )
44- return normalized_gramian
35+ return gramian / squared_frobenius_norm
4536
4637
4738def _regularize (gramian : Tensor , eps : float ) -> Tensor :
0 commit comments