From 5feef8144550ab5ed12407e20de194902f9da4a7 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sun, 6 Apr 2025 10:59:09 +0200 Subject: [PATCH 1/9] Change the normalization of Gramian to use the Frobenius norm instead of the spectral norm --- src/torchjd/aggregation/_gramian_utils.py | 37 +++++++++-------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/src/torchjd/aggregation/_gramian_utils.py b/src/torchjd/aggregation/_gramian_utils.py index 11c31fd8..a0ad9c14 100644 --- a/src/torchjd/aggregation/_gramian_utils.py +++ b/src/torchjd/aggregation/_gramian_utils.py @@ -16,32 +16,23 @@ def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor: - r""" - Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where - :math:`{\sigma_\max^2}` is :math:`J`'s largest singular value. - .. hint:: - :math:`J J^T` is the `Gramian matrix `_ of - :math:`J` - For a given matrix :math:`J` with SVD: :math:`J = U S V^T`, we can see that: - .. math:: - \frac{1}{\sigma_\max^2} J J^T = \frac{1}{\sigma_\max^2} U S V^T V S^T U^T = U - \left( \frac{S}{\sigma_\max} \right)^2 U^T - This is the quantity we compute. - .. note:: - If the provided matrix has dimension :math:`m \times n`, the computation only depends on - :math:`n` through the SVD algorithm which is efficient, therefore this is rather fast. + gramian = _compute_gramian(matrix) + return _normalize(gramian, eps) + + +def _normalize(gramian: Tensor, eps: float) -> Tensor: """ + Normalizes the gramian with respect to the Frobenius norm. - left_unitary_matrix, singular_values, _ = torch.linalg.svd(matrix, full_matrices=False) - max_singular_value = torch.max(singular_values) - if max_singular_value < eps: - scaled_singular_values = torch.zeros_like(singular_values) + If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the + sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is + therefore `G` divided by the sum of its diagonal elements. + """ + squared_frobenius_norm = gramian.diagonal().sum() + if squared_frobenius_norm < eps: + return torch.zeros_like(gramian) else: - scaled_singular_values = singular_values / max_singular_value - normalized_gramian = ( - left_unitary_matrix @ torch.diag(scaled_singular_values**2) @ left_unitary_matrix.T - ) - return normalized_gramian + return gramian / squared_frobenius_norm def _regularize(gramian: Tensor, eps: float) -> Tensor: From 921659788459450ca4e91e91b9b0fc23a64be3fa Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 7 Apr 2025 15:25:15 +0200 Subject: [PATCH 2/9] Add tests for scaling invariance of `_project_weights` --- tests/unit/aggregation/test_dual_cone_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/unit/aggregation/test_dual_cone_utils.py b/tests/unit/aggregation/test_dual_cone_utils.py index bf192323..126bd2ea 100644 --- a/tests/unit/aggregation/test_dual_cone_utils.py +++ b/tests/unit/aggregation/test_dual_cone_utils.py @@ -51,6 +51,23 @@ def test_solution_weights(shape: tuple[int, int]): assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) +@mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) +@mark.parametrize("scaling", [0.25, 0.5, 4.0, 16.0]) +def test_scale_invariant(shape: tuple[int, int], scaling: float): + """ + Tests that `_project_weights` is invariant under scaling. + """ + + J = torch.randn(shape) + G = J @ J.T + u = torch.rand(shape[0]) + + w = _project_weights(u, G, "quadprog") + w_scaled = _project_weights(u, scaling * G, "quadprog") + + assert_close(w_scaled, w) + + @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) def test_tensorization_shape(shape: tuple[int, ...]): """ From 5fe88e2209825b725f8a8f0a6359a137a3594df6 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 8 Apr 2025 08:37:50 +0200 Subject: [PATCH 3/9] Remove `_compute_normalized_regularized_gramian` and `_compute_normalized_gramian` as they are now just combinations of `_compute_gramian`, `_regularize` and `_normalize` --- src/torchjd/aggregation/_gramian_utils.py | 5 ----- src/torchjd/aggregation/cagrad.py | 4 ++-- src/torchjd/aggregation/dualproj.py | 4 ++-- src/torchjd/aggregation/upgrad.py | 4 ++-- 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/torchjd/aggregation/_gramian_utils.py b/src/torchjd/aggregation/_gramian_utils.py index a0ad9c14..dd1de36d 100644 --- a/src/torchjd/aggregation/_gramian_utils.py +++ b/src/torchjd/aggregation/_gramian_utils.py @@ -10,11 +10,6 @@ def _compute_gramian(matrix: Tensor) -> Tensor: return matrix @ matrix.T -def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg_eps: float): - normalized_gramian = _compute_normalized_gramian(matrix, norm_eps) - return _regularize(normalized_gramian, reg_eps) - - def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor: gramian = _compute_gramian(matrix) return _normalize(gramian, eps) diff --git a/src/torchjd/aggregation/cagrad.py b/src/torchjd/aggregation/cagrad.py index 20fef34b..da784694 100644 --- a/src/torchjd/aggregation/cagrad.py +++ b/src/torchjd/aggregation/cagrad.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from ._gramian_utils import _compute_normalized_gramian +from ._gramian_utils import _compute_gramian, _normalize from .bases import _WeightedAggregator, _Weighting @@ -72,7 +72,7 @@ def __init__(self, c: float, norm_eps: float): self.norm_eps = norm_eps def forward(self, matrix: Tensor) -> Tensor: - gramian = _compute_normalized_gramian(matrix, self.norm_eps) + gramian = _normalize(_compute_gramian(matrix), self.norm_eps) U, S, _ = torch.svd(gramian) reduced_matrix = U @ S.sqrt().diag() diff --git a/src/torchjd/aggregation/dualproj.py b/src/torchjd/aggregation/dualproj.py index d0e20d75..970d4d21 100644 --- a/src/torchjd/aggregation/dualproj.py +++ b/src/torchjd/aggregation/dualproj.py @@ -3,7 +3,7 @@ from torch import Tensor from ._dual_cone_utils import project_weights -from ._gramian_utils import _compute_regularized_normalized_gramian +from ._gramian_utils import compute_gramian, normalize, regularize from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -100,6 +100,6 @@ def __init__( def forward(self, matrix: Tensor) -> Tensor: u = self.weighting(matrix) - G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) + G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps) w = project_weights(u, G, self.solver) return w diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/upgrad.py index fe258402..a53e8903 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/upgrad.py @@ -4,7 +4,7 @@ from torch import Tensor from ._dual_cone_utils import project_weights -from ._gramian_utils import _compute_regularized_normalized_gramian +from ._gramian_utils import compute_gramian, normalize, regularize from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -96,6 +96,6 @@ def __init__( def forward(self, matrix: Tensor) -> Tensor: U = torch.diag(self.weighting(matrix)) - G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) + G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps) W = project_weights(U, G, self.solver) return torch.sum(W, dim=0) From 50bd22021d814b582092fc22ae42829545f9c32e Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 8 Apr 2025 20:36:20 +0200 Subject: [PATCH 4/9] Remove `_compute_normalized_gramian` --- src/torchjd/aggregation/_gramian_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/torchjd/aggregation/_gramian_utils.py b/src/torchjd/aggregation/_gramian_utils.py index dd1de36d..2535fc3e 100644 --- a/src/torchjd/aggregation/_gramian_utils.py +++ b/src/torchjd/aggregation/_gramian_utils.py @@ -10,11 +10,6 @@ def _compute_gramian(matrix: Tensor) -> Tensor: return matrix @ matrix.T -def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor: - gramian = _compute_gramian(matrix) - return _normalize(gramian, eps) - - def _normalize(gramian: Tensor, eps: float) -> Tensor: """ Normalizes the gramian with respect to the Frobenius norm. From 3211be8e8991cf6bc8cd75b6af052432ca09926d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 8 Apr 2025 20:37:13 +0200 Subject: [PATCH 5/9] Make functions in `_gramian_utils` public (for the same package) --- src/torchjd/aggregation/_gramian_utils.py | 6 +++--- src/torchjd/aggregation/cagrad.py | 4 ++-- src/torchjd/aggregation/mgda.py | 4 ++-- tests/unit/aggregation/test_dual_cone_utils.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/torchjd/aggregation/_gramian_utils.py b/src/torchjd/aggregation/_gramian_utils.py index 2535fc3e..272e034a 100644 --- a/src/torchjd/aggregation/_gramian_utils.py +++ b/src/torchjd/aggregation/_gramian_utils.py @@ -2,7 +2,7 @@ from torch import Tensor -def _compute_gramian(matrix: Tensor) -> Tensor: +def compute_gramian(matrix: Tensor) -> Tensor: """ Computes the `Gramian matrix `_ of a given matrix. """ @@ -10,7 +10,7 @@ def _compute_gramian(matrix: Tensor) -> Tensor: return matrix @ matrix.T -def _normalize(gramian: Tensor, eps: float) -> Tensor: +def normalize(gramian: Tensor, eps: float) -> Tensor: """ Normalizes the gramian with respect to the Frobenius norm. @@ -25,7 +25,7 @@ def _normalize(gramian: Tensor, eps: float) -> Tensor: return gramian / squared_frobenius_norm -def _regularize(gramian: Tensor, eps: float) -> Tensor: +def regularize(gramian: Tensor, eps: float) -> Tensor: """ Adds a regularization term to the gramian to enforce positive definiteness. diff --git a/src/torchjd/aggregation/cagrad.py b/src/torchjd/aggregation/cagrad.py index da784694..c6c66788 100644 --- a/src/torchjd/aggregation/cagrad.py +++ b/src/torchjd/aggregation/cagrad.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from ._gramian_utils import _compute_gramian, _normalize +from ._gramian_utils import compute_gramian, normalize from .bases import _WeightedAggregator, _Weighting @@ -72,7 +72,7 @@ def __init__(self, c: float, norm_eps: float): self.norm_eps = norm_eps def forward(self, matrix: Tensor) -> Tensor: - gramian = _normalize(_compute_gramian(matrix), self.norm_eps) + gramian = normalize(compute_gramian(matrix), self.norm_eps) U, S, _ = torch.svd(gramian) reduced_matrix = U @ S.sqrt().diag() diff --git a/src/torchjd/aggregation/mgda.py b/src/torchjd/aggregation/mgda.py index 99e7a111..f2548ce2 100644 --- a/src/torchjd/aggregation/mgda.py +++ b/src/torchjd/aggregation/mgda.py @@ -1,7 +1,7 @@ import torch from torch import Tensor -from ._gramian_utils import _compute_gramian +from ._gramian_utils import compute_gramian from .bases import _WeightedAggregator, _Weighting @@ -57,7 +57,7 @@ def __init__(self, epsilon: float, max_iters: int): self.max_iters = max_iters def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor: - gramian = _compute_gramian(matrix) + gramian = compute_gramian(matrix) device = matrix.device dtype = matrix.dtype diff --git a/tests/unit/aggregation/test_dual_cone_utils.py b/tests/unit/aggregation/test_dual_cone_utils.py index 126bd2ea..7f824c72 100644 --- a/tests/unit/aggregation/test_dual_cone_utils.py +++ b/tests/unit/aggregation/test_dual_cone_utils.py @@ -62,8 +62,8 @@ def test_scale_invariant(shape: tuple[int, int], scaling: float): G = J @ J.T u = torch.rand(shape[0]) - w = _project_weights(u, G, "quadprog") - w_scaled = _project_weights(u, scaling * G, "quadprog") + w = project_weights(u, G, "quadprog") + w_scaled = project_weights(u, scaling * G, "quadprog") assert_close(w_scaled, w) From 53e73ab1495e5499bc38310c74fc78bb45e2aa83 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 10 Apr 2025 08:35:18 +0200 Subject: [PATCH 6/9] Change the scaling to be between 1/16 to 16 --- tests/unit/aggregation/test_dual_cone_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/aggregation/test_dual_cone_utils.py b/tests/unit/aggregation/test_dual_cone_utils.py index 7f824c72..d7976d1c 100644 --- a/tests/unit/aggregation/test_dual_cone_utils.py +++ b/tests/unit/aggregation/test_dual_cone_utils.py @@ -52,7 +52,7 @@ def test_solution_weights(shape: tuple[int, int]): @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) -@mark.parametrize("scaling", [0.25, 0.5, 4.0, 16.0]) +@mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) def test_scale_invariant(shape: tuple[int, int], scaling: float): """ Tests that `_project_weights` is invariant under scaling. From ee066595fb6ba0a999ad957baadc055643c53edf Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 10 Apr 2025 08:36:10 +0200 Subject: [PATCH 7/9] Update the docstring of `normalize` --- src/torchjd/aggregation/_gramian_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_gramian_utils.py b/src/torchjd/aggregation/_gramian_utils.py index 272e034a..dfe16987 100644 --- a/src/torchjd/aggregation/_gramian_utils.py +++ b/src/torchjd/aggregation/_gramian_utils.py @@ -12,7 +12,7 @@ def compute_gramian(matrix: Tensor) -> Tensor: def normalize(gramian: Tensor, eps: float) -> Tensor: """ - Normalizes the gramian with respect to the Frobenius norm. + Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`. If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is From 40ae25094ee7630c8714114b12dbdd8abce41928 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 10 Apr 2025 08:52:48 +0200 Subject: [PATCH 8/9] Update changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec07f4c6..c0c87dfd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,8 @@ changes that do not affect the user. ### Changed - Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to - project onto the dual cone. This may minimally affect the output of these aggregators. + project onto the dual cone. This should slightly improve the performance and precision of these + aggregators. - Refactored internal verifications in the autojac engine so that they do not run at runtime anymore. This should minimally improve the performance and reduce the memory usage of `backward` and `mtl_backward`. From 4aea90e19475e5a9e4c374ddf9763df6f88aa957 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 10 Apr 2025 13:51:25 +0200 Subject: [PATCH 9/9] Add changelog entry --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0c87dfd..89c7d7f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ changes that do not affect the user. simplified. This should slightly improve the performance of `backward` and `mtl_backward`. - Improved the implementation of `ConFIG` to be simpler and safer when normalizing vectors. It should slightly improve the performance of `ConFIG` and minimally affect its behavior. +- Simplified the normalization of the Gramian in `UPGrad`, `DualProj` and `CAGrad`. This should + slightly improve their performance and precision. ### Fixed