Skip to content
Merged
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ changes that do not affect the user.
### Changed

- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to
project onto the dual cone. This may minimally affect the output of these aggregators.
project onto the dual cone. This should slightly improve the performance and precision of these
aggregators.
- Refactored internal verifications in the autojac engine so that they do not run at runtime
anymore. This should minimally improve the performance and reduce the memory usage of `backward`
and `mtl_backward`.
- Refactored internal typing in the autojac engine so that fewer casts are made and so that code is
simplified. This should slightly improve the performance of `backward` and `mtl_backward`.
- Improved the implementation of `ConFIG` to be simpler and safer when normalizing vectors. It
should slightly improve the performance of `ConFIG` and minimally affect its behavior.
- Simplified the normalization of the Gramian in `UPGrad`, `DualProj` and `CAGrad`. This should
slightly improve their performance and precision.

### Fixed

Expand Down
43 changes: 12 additions & 31 deletions src/torchjd/aggregation/_gramian_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,30 @@
from torch import Tensor


def _compute_gramian(matrix: Tensor) -> Tensor:
def compute_gramian(matrix: Tensor) -> Tensor:
"""
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
"""

return matrix @ matrix.T


def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg_eps: float):
normalized_gramian = _compute_normalized_gramian(matrix, norm_eps)
return _regularize(normalized_gramian, reg_eps)


def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor:
r"""
Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where
:math:`{\sigma_\max^2}` is :math:`J`'s largest singular value.
.. hint::
:math:`J J^T` is the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of
:math:`J`
For a given matrix :math:`J` with SVD: :math:`J = U S V^T`, we can see that:
.. math::
\frac{1}{\sigma_\max^2} J J^T = \frac{1}{\sigma_\max^2} U S V^T V S^T U^T = U
\left( \frac{S}{\sigma_\max} \right)^2 U^T
This is the quantity we compute.
.. note::
If the provided matrix has dimension :math:`m \times n`, the computation only depends on
:math:`n` through the SVD algorithm which is efficient, therefore this is rather fast.
def normalize(gramian: Tensor, eps: float) -> Tensor:
"""
Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`.

left_unitary_matrix, singular_values, _ = torch.linalg.svd(matrix, full_matrices=False)
max_singular_value = torch.max(singular_values)
if max_singular_value < eps:
scaled_singular_values = torch.zeros_like(singular_values)
If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the
sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is
therefore `G` divided by the sum of its diagonal elements.
"""
squared_frobenius_norm = gramian.diagonal().sum()
if squared_frobenius_norm < eps:
return torch.zeros_like(gramian)
else:
scaled_singular_values = singular_values / max_singular_value
normalized_gramian = (
left_unitary_matrix @ torch.diag(scaled_singular_values**2) @ left_unitary_matrix.T
)
return normalized_gramian
return gramian / squared_frobenius_norm


def _regularize(gramian: Tensor, eps: float) -> Tensor:
def regularize(gramian: Tensor, eps: float) -> Tensor:
"""
Adds a regularization term to the gramian to enforce positive definiteness.

Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import Tensor

from ._gramian_utils import _compute_normalized_gramian
from ._gramian_utils import compute_gramian, normalize
from .bases import _WeightedAggregator, _Weighting


Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self, c: float, norm_eps: float):
self.norm_eps = norm_eps

def forward(self, matrix: Tensor) -> Tensor:
gramian = _compute_normalized_gramian(matrix, self.norm_eps)
gramian = normalize(compute_gramian(matrix), self.norm_eps)
U, S, _ = torch.svd(gramian)

reduced_matrix = U @ S.sqrt().diag()
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch import Tensor

from ._dual_cone_utils import project_weights
from ._gramian_utils import _compute_regularized_normalized_gramian
from ._gramian_utils import compute_gramian, normalize, regularize
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
from .bases import _WeightedAggregator, _Weighting
from .mean import _MeanWeighting
Expand Down Expand Up @@ -100,6 +100,6 @@ def __init__(

def forward(self, matrix: Tensor) -> Tensor:
u = self.weighting(matrix)
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps)
w = project_weights(u, G, self.solver)
return w
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/mgda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import Tensor

from ._gramian_utils import _compute_gramian
from ._gramian_utils import compute_gramian
from .bases import _WeightedAggregator, _Weighting


Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(self, epsilon: float, max_iters: int):
self.max_iters = max_iters

def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
gramian = _compute_gramian(matrix)
gramian = compute_gramian(matrix)
device = matrix.device
dtype = matrix.dtype

Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import Tensor

from ._dual_cone_utils import project_weights
from ._gramian_utils import _compute_regularized_normalized_gramian
from ._gramian_utils import compute_gramian, normalize, regularize
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
from .bases import _WeightedAggregator, _Weighting
from .mean import _MeanWeighting
Expand Down Expand Up @@ -96,6 +96,6 @@ def __init__(

def forward(self, matrix: Tensor) -> Tensor:
U = torch.diag(self.weighting(matrix))
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps)
W = project_weights(U, G, self.solver)
return torch.sum(W, dim=0)
17 changes: 17 additions & 0 deletions tests/unit/aggregation/test_dual_cone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ def test_solution_weights(shape: tuple[int, int]):
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)


@mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)])
@mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4])
def test_scale_invariant(shape: tuple[int, int], scaling: float):
"""
Tests that `_project_weights` is invariant under scaling.
"""

J = torch.randn(shape)
G = J @ J.T
u = torch.rand(shape[0])

w = project_weights(u, G, "quadprog")
w_scaled = project_weights(u, scaling * G, "quadprog")

assert_close(w_scaled, w)


@mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)])
def test_tensorization_shape(shape: tuple[int, ...]):
"""
Expand Down