Skip to content

Commit f0d3eac

Browse files
refactor(aggregation): Simplify Gramian normalization (#302)
* Use the Frobenius norm instead of the spectral norm * Replace compute_normalized_gramian by normalize since the normalization is now composable with the gramian computation * Remove compute_regularized_normalized_gramian (in favor of composition of simpler functions) * Add test_scale_invariant * Make gramian_utils functions public to their package * Update the changelog entry relative to the projection changes of UPGrad and DualProj * Add changelog entry relative to the normalization changes of UPGrad, DualProj and CAGrad
1 parent 9c29977 commit f0d3eac

File tree

7 files changed

+41
-40
lines changed

7 files changed

+41
-40
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@ changes that do not affect the user.
1515
### Changed
1616

1717
- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to
18-
project onto the dual cone. This may minimally affect the output of these aggregators.
18+
project onto the dual cone. This should slightly improve the performance and precision of these
19+
aggregators.
1920
- Refactored internal verifications in the autojac engine so that they do not run at runtime
2021
anymore. This should minimally improve the performance and reduce the memory usage of `backward`
2122
and `mtl_backward`.
2223
- Refactored internal typing in the autojac engine so that fewer casts are made and so that code is
2324
simplified. This should slightly improve the performance of `backward` and `mtl_backward`.
2425
- Improved the implementation of `ConFIG` to be simpler and safer when normalizing vectors. It
2526
should slightly improve the performance of `ConFIG` and minimally affect its behavior.
27+
- Simplified the normalization of the Gramian in `UPGrad`, `DualProj` and `CAGrad`. This should
28+
slightly improve their performance and precision.
2629

2730
### Fixed
2831

src/torchjd/aggregation/_gramian_utils.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,30 @@
22
from torch import Tensor
33

44

5-
def _compute_gramian(matrix: Tensor) -> Tensor:
5+
def compute_gramian(matrix: Tensor) -> Tensor:
66
"""
77
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
88
"""
99

1010
return matrix @ matrix.T
1111

1212

13-
def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg_eps: float):
14-
normalized_gramian = _compute_normalized_gramian(matrix, norm_eps)
15-
return _regularize(normalized_gramian, reg_eps)
16-
17-
18-
def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor:
19-
r"""
20-
Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where
21-
:math:`{\sigma_\max^2}` is :math:`J`'s largest singular value.
22-
.. hint::
23-
:math:`J J^T` is the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of
24-
:math:`J`
25-
For a given matrix :math:`J` with SVD: :math:`J = U S V^T`, we can see that:
26-
.. math::
27-
\frac{1}{\sigma_\max^2} J J^T = \frac{1}{\sigma_\max^2} U S V^T V S^T U^T = U
28-
\left( \frac{S}{\sigma_\max} \right)^2 U^T
29-
This is the quantity we compute.
30-
.. note::
31-
If the provided matrix has dimension :math:`m \times n`, the computation only depends on
32-
:math:`n` through the SVD algorithm which is efficient, therefore this is rather fast.
13+
def normalize(gramian: Tensor, eps: float) -> Tensor:
3314
"""
15+
Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`.
3416
35-
left_unitary_matrix, singular_values, _ = torch.linalg.svd(matrix, full_matrices=False)
36-
max_singular_value = torch.max(singular_values)
37-
if max_singular_value < eps:
38-
scaled_singular_values = torch.zeros_like(singular_values)
17+
If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the
18+
sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is
19+
therefore `G` divided by the sum of its diagonal elements.
20+
"""
21+
squared_frobenius_norm = gramian.diagonal().sum()
22+
if squared_frobenius_norm < eps:
23+
return torch.zeros_like(gramian)
3924
else:
40-
scaled_singular_values = singular_values / max_singular_value
41-
normalized_gramian = (
42-
left_unitary_matrix @ torch.diag(scaled_singular_values**2) @ left_unitary_matrix.T
43-
)
44-
return normalized_gramian
25+
return gramian / squared_frobenius_norm
4526

4627

47-
def _regularize(gramian: Tensor, eps: float) -> Tensor:
28+
def regularize(gramian: Tensor, eps: float) -> Tensor:
4829
"""
4930
Adds a regularization term to the gramian to enforce positive definiteness.
5031

src/torchjd/aggregation/cagrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55

6-
from ._gramian_utils import _compute_normalized_gramian
6+
from ._gramian_utils import compute_gramian, normalize
77
from .bases import _WeightedAggregator, _Weighting
88

99

@@ -72,7 +72,7 @@ def __init__(self, c: float, norm_eps: float):
7272
self.norm_eps = norm_eps
7373

7474
def forward(self, matrix: Tensor) -> Tensor:
75-
gramian = _compute_normalized_gramian(matrix, self.norm_eps)
75+
gramian = normalize(compute_gramian(matrix), self.norm_eps)
7676
U, S, _ = torch.svd(gramian)
7777

7878
reduced_matrix = U @ S.sqrt().diag()

src/torchjd/aggregation/dualproj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import Tensor
44

55
from ._dual_cone_utils import project_weights
6-
from ._gramian_utils import _compute_regularized_normalized_gramian
6+
from ._gramian_utils import compute_gramian, normalize, regularize
77
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
88
from .bases import _WeightedAggregator, _Weighting
99
from .mean import _MeanWeighting
@@ -100,6 +100,6 @@ def __init__(
100100

101101
def forward(self, matrix: Tensor) -> Tensor:
102102
u = self.weighting(matrix)
103-
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
103+
G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps)
104104
w = project_weights(u, G, self.solver)
105105
return w

src/torchjd/aggregation/mgda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import Tensor
33

4-
from ._gramian_utils import _compute_gramian
4+
from ._gramian_utils import compute_gramian
55
from .bases import _WeightedAggregator, _Weighting
66

77

@@ -57,7 +57,7 @@ def __init__(self, epsilon: float, max_iters: int):
5757
self.max_iters = max_iters
5858

5959
def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
60-
gramian = _compute_gramian(matrix)
60+
gramian = compute_gramian(matrix)
6161
device = matrix.device
6262
dtype = matrix.dtype
6363

src/torchjd/aggregation/upgrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55

66
from ._dual_cone_utils import project_weights
7-
from ._gramian_utils import _compute_regularized_normalized_gramian
7+
from ._gramian_utils import compute_gramian, normalize, regularize
88
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
99
from .bases import _WeightedAggregator, _Weighting
1010
from .mean import _MeanWeighting
@@ -96,6 +96,6 @@ def __init__(
9696

9797
def forward(self, matrix: Tensor) -> Tensor:
9898
U = torch.diag(self.weighting(matrix))
99-
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
99+
G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps)
100100
W = project_weights(U, G, self.solver)
101101
return torch.sum(W, dim=0)

tests/unit/aggregation/test_dual_cone_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ def test_solution_weights(shape: tuple[int, int]):
5151
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)
5252

5353

54+
@mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)])
55+
@mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4])
56+
def test_scale_invariant(shape: tuple[int, int], scaling: float):
57+
"""
58+
Tests that `_project_weights` is invariant under scaling.
59+
"""
60+
61+
J = torch.randn(shape)
62+
G = J @ J.T
63+
u = torch.rand(shape[0])
64+
65+
w = project_weights(u, G, "quadprog")
66+
w_scaled = project_weights(u, scaling * G, "quadprog")
67+
68+
assert_close(w_scaled, w)
69+
70+
5471
@mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)])
5572
def test_tensorization_shape(shape: tuple[int, ...]):
5673
"""

0 commit comments

Comments
 (0)