-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_upgrad.py
More file actions
95 lines (79 loc) · 3.99 KB
/
_upgrad.py
File metadata and controls
95 lines (79 loc) · 3.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from typing import Literal
import torch
from torch import Tensor
from torchjd._linalg import PSDMatrix, normalize, regularize
from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import project_weights
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
class UPGrad(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that projects each row of the input
matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed
in `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.
:param pref_vector: The preference vector used to combine the projected rows. If not provided,
defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
\mathbb{R}^m`.
:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
"""
def __init__(
self,
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: Literal["quadprog"] = "quadprog",
):
self._pref_vector = pref_vector
self._norm_eps = norm_eps
self._reg_eps = reg_eps
self._solver = solver
super().__init__(
UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver)
)
# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
)
def __str__(self) -> str:
return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}"
class UPGradWeighting(Weighting[PSDMatrix]):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.UPGrad`.
:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
"""
def __init__(
self,
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: Literal["quadprog"] = "quadprog",
):
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver = solver
def forward(self, gramian: PSDMatrix) -> Tensor:
U = torch.diag(self.weighting(gramian))
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
W = project_weights(U, G, self.solver)
return torch.sum(W, dim=0)