|
2 | 2 | from torch import Tensor |
3 | 3 |
|
4 | 4 |
|
5 | | -def _compute_gramian(matrix: Tensor) -> Tensor: |
| 5 | +def compute_gramian(matrix: Tensor) -> Tensor: |
6 | 6 | """ |
7 | 7 | Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix. |
8 | 8 | """ |
9 | 9 |
|
10 | 10 | return matrix @ matrix.T |
11 | 11 |
|
12 | 12 |
|
13 | | -def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg_eps: float): |
14 | | - normalized_gramian = _compute_normalized_gramian(matrix, norm_eps) |
15 | | - return _regularize(normalized_gramian, reg_eps) |
16 | | - |
17 | | - |
18 | | -def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor: |
19 | | - r""" |
20 | | - Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where |
21 | | - :math:`{\sigma_\max^2}` is :math:`J`'s largest singular value. |
22 | | - .. hint:: |
23 | | - :math:`J J^T` is the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of |
24 | | - :math:`J` |
25 | | - For a given matrix :math:`J` with SVD: :math:`J = U S V^T`, we can see that: |
26 | | - .. math:: |
27 | | - \frac{1}{\sigma_\max^2} J J^T = \frac{1}{\sigma_\max^2} U S V^T V S^T U^T = U |
28 | | - \left( \frac{S}{\sigma_\max} \right)^2 U^T |
29 | | - This is the quantity we compute. |
30 | | - .. note:: |
31 | | - If the provided matrix has dimension :math:`m \times n`, the computation only depends on |
32 | | - :math:`n` through the SVD algorithm which is efficient, therefore this is rather fast. |
| 13 | +def normalize(gramian: Tensor, eps: float) -> Tensor: |
33 | 14 | """ |
| 15 | + Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`. |
34 | 16 |
|
35 | | - left_unitary_matrix, singular_values, _ = torch.linalg.svd(matrix, full_matrices=False) |
36 | | - max_singular_value = torch.max(singular_values) |
37 | | - if max_singular_value < eps: |
38 | | - scaled_singular_values = torch.zeros_like(singular_values) |
| 17 | + If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the |
| 18 | + sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is |
| 19 | + therefore `G` divided by the sum of its diagonal elements. |
| 20 | + """ |
| 21 | + squared_frobenius_norm = gramian.diagonal().sum() |
| 22 | + if squared_frobenius_norm < eps: |
| 23 | + return torch.zeros_like(gramian) |
39 | 24 | else: |
40 | | - scaled_singular_values = singular_values / max_singular_value |
41 | | - normalized_gramian = ( |
42 | | - left_unitary_matrix @ torch.diag(scaled_singular_values**2) @ left_unitary_matrix.T |
43 | | - ) |
44 | | - return normalized_gramian |
| 25 | + return gramian / squared_frobenius_norm |
45 | 26 |
|
46 | 27 |
|
47 | | -def _regularize(gramian: Tensor, eps: float) -> Tensor: |
| 28 | +def regularize(gramian: Tensor, eps: float) -> Tensor: |
48 | 29 | """ |
49 | 30 | Adds a regularization term to the gramian to enforce positive definiteness. |
50 | 31 |
|
|
0 commit comments