diff --git a/CHANGELOG.md b/CHANGELOG.md index ec07f4c6..89c7d7f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,8 @@ changes that do not affect the user. ### Changed - Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to - project onto the dual cone. This may minimally affect the output of these aggregators. + project onto the dual cone. This should slightly improve the performance and precision of these + aggregators. - Refactored internal verifications in the autojac engine so that they do not run at runtime anymore. This should minimally improve the performance and reduce the memory usage of `backward` and `mtl_backward`. @@ -23,6 +24,8 @@ changes that do not affect the user. simplified. This should slightly improve the performance of `backward` and `mtl_backward`. - Improved the implementation of `ConFIG` to be simpler and safer when normalizing vectors. It should slightly improve the performance of `ConFIG` and minimally affect its behavior. +- Simplified the normalization of the Gramian in `UPGrad`, `DualProj` and `CAGrad`. This should + slightly improve their performance and precision. ### Fixed diff --git a/src/torchjd/aggregation/_gramian_utils.py b/src/torchjd/aggregation/_gramian_utils.py index 11c31fd8..dfe16987 100644 --- a/src/torchjd/aggregation/_gramian_utils.py +++ b/src/torchjd/aggregation/_gramian_utils.py @@ -2,7 +2,7 @@ from torch import Tensor -def _compute_gramian(matrix: Tensor) -> Tensor: +def compute_gramian(matrix: Tensor) -> Tensor: """ Computes the `Gramian matrix `_ of a given matrix. """ @@ -10,41 +10,22 @@ def _compute_gramian(matrix: Tensor) -> Tensor: return matrix @ matrix.T -def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg_eps: float): - normalized_gramian = _compute_normalized_gramian(matrix, norm_eps) - return _regularize(normalized_gramian, reg_eps) - - -def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor: - r""" - Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where - :math:`{\sigma_\max^2}` is :math:`J`'s largest singular value. - .. hint:: - :math:`J J^T` is the `Gramian matrix `_ of - :math:`J` - For a given matrix :math:`J` with SVD: :math:`J = U S V^T`, we can see that: - .. math:: - \frac{1}{\sigma_\max^2} J J^T = \frac{1}{\sigma_\max^2} U S V^T V S^T U^T = U - \left( \frac{S}{\sigma_\max} \right)^2 U^T - This is the quantity we compute. - .. note:: - If the provided matrix has dimension :math:`m \times n`, the computation only depends on - :math:`n` through the SVD algorithm which is efficient, therefore this is rather fast. +def normalize(gramian: Tensor, eps: float) -> Tensor: """ + Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`. - left_unitary_matrix, singular_values, _ = torch.linalg.svd(matrix, full_matrices=False) - max_singular_value = torch.max(singular_values) - if max_singular_value < eps: - scaled_singular_values = torch.zeros_like(singular_values) + If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the + sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is + therefore `G` divided by the sum of its diagonal elements. + """ + squared_frobenius_norm = gramian.diagonal().sum() + if squared_frobenius_norm < eps: + return torch.zeros_like(gramian) else: - scaled_singular_values = singular_values / max_singular_value - normalized_gramian = ( - left_unitary_matrix @ torch.diag(scaled_singular_values**2) @ left_unitary_matrix.T - ) - return normalized_gramian + return gramian / squared_frobenius_norm -def _regularize(gramian: Tensor, eps: float) -> Tensor: +def regularize(gramian: Tensor, eps: float) -> Tensor: """ Adds a regularization term to the gramian to enforce positive definiteness. diff --git a/src/torchjd/aggregation/cagrad.py b/src/torchjd/aggregation/cagrad.py index 20fef34b..c6c66788 100644 --- a/src/torchjd/aggregation/cagrad.py +++ b/src/torchjd/aggregation/cagrad.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from ._gramian_utils import _compute_normalized_gramian +from ._gramian_utils import compute_gramian, normalize from .bases import _WeightedAggregator, _Weighting @@ -72,7 +72,7 @@ def __init__(self, c: float, norm_eps: float): self.norm_eps = norm_eps def forward(self, matrix: Tensor) -> Tensor: - gramian = _compute_normalized_gramian(matrix, self.norm_eps) + gramian = normalize(compute_gramian(matrix), self.norm_eps) U, S, _ = torch.svd(gramian) reduced_matrix = U @ S.sqrt().diag() diff --git a/src/torchjd/aggregation/dualproj.py b/src/torchjd/aggregation/dualproj.py index d0e20d75..970d4d21 100644 --- a/src/torchjd/aggregation/dualproj.py +++ b/src/torchjd/aggregation/dualproj.py @@ -3,7 +3,7 @@ from torch import Tensor from ._dual_cone_utils import project_weights -from ._gramian_utils import _compute_regularized_normalized_gramian +from ._gramian_utils import compute_gramian, normalize, regularize from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -100,6 +100,6 @@ def __init__( def forward(self, matrix: Tensor) -> Tensor: u = self.weighting(matrix) - G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) + G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps) w = project_weights(u, G, self.solver) return w diff --git a/src/torchjd/aggregation/mgda.py b/src/torchjd/aggregation/mgda.py index 99e7a111..f2548ce2 100644 --- a/src/torchjd/aggregation/mgda.py +++ b/src/torchjd/aggregation/mgda.py @@ -1,7 +1,7 @@ import torch from torch import Tensor -from ._gramian_utils import _compute_gramian +from ._gramian_utils import compute_gramian from .bases import _WeightedAggregator, _Weighting @@ -57,7 +57,7 @@ def __init__(self, epsilon: float, max_iters: int): self.max_iters = max_iters def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor: - gramian = _compute_gramian(matrix) + gramian = compute_gramian(matrix) device = matrix.device dtype = matrix.dtype diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/upgrad.py index fe258402..a53e8903 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/upgrad.py @@ -4,7 +4,7 @@ from torch import Tensor from ._dual_cone_utils import project_weights -from ._gramian_utils import _compute_regularized_normalized_gramian +from ._gramian_utils import compute_gramian, normalize, regularize from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -96,6 +96,6 @@ def __init__( def forward(self, matrix: Tensor) -> Tensor: U = torch.diag(self.weighting(matrix)) - G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) + G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps) W = project_weights(U, G, self.solver) return torch.sum(W, dim=0) diff --git a/tests/unit/aggregation/test_dual_cone_utils.py b/tests/unit/aggregation/test_dual_cone_utils.py index bf192323..d7976d1c 100644 --- a/tests/unit/aggregation/test_dual_cone_utils.py +++ b/tests/unit/aggregation/test_dual_cone_utils.py @@ -51,6 +51,23 @@ def test_solution_weights(shape: tuple[int, int]): assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) +@mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) +@mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) +def test_scale_invariant(shape: tuple[int, int], scaling: float): + """ + Tests that `_project_weights` is invariant under scaling. + """ + + J = torch.randn(shape) + G = J @ J.T + u = torch.rand(shape[0]) + + w = project_weights(u, G, "quadprog") + w_scaled = project_weights(u, scaling * G, "quadprog") + + assert_close(w_scaled, w) + + @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) def test_tensorization_shape(shape: tuple[int, ...]): """