Skip to content

Commit d33cdf5

Browse files
authored
typing(aggregation): Add type hint to (gramian_)weighting fields (#651)
1 parent d1d12c9 commit d33cdf5

File tree

14 files changed

+424
-396
lines changed

14 files changed

+424
-396
lines changed

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,6 @@
1717
SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"]
1818

1919

20-
class AlignedMTL(GramianWeightedAggregator):
21-
r"""
22-
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
23-
`Independent Component Alignment for Multi-Task Learning
24-
<https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf>`_.
25-
26-
:param pref_vector: The preference vector to use. If not provided, defaults to
27-
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
28-
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
29-
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
30-
uses the mean eigenvalue (as in the original implementation).
31-
32-
.. note::
33-
This implementation was adapted from the official implementation of SamsungLabs/MTL,
34-
which is not available anymore at the time of writing.
35-
"""
36-
37-
def __init__(
38-
self,
39-
pref_vector: Tensor | None = None,
40-
scale_mode: SUPPORTED_SCALE_MODE = "min",
41-
) -> None:
42-
self._pref_vector = pref_vector
43-
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
44-
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))
45-
46-
def __repr__(self) -> str:
47-
return (
48-
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
49-
f"scale_mode={repr(self._scale_mode)})"
50-
)
51-
52-
def __str__(self) -> str:
53-
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
54-
55-
5620
class AlignedMTLWeighting(Weighting[PSDMatrix]):
5721
r"""
5822
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
@@ -113,3 +77,41 @@ def _compute_balance_transformation(
11377

11478
B = scale.sqrt() * V @ sigma_inv @ V.T
11579
return B
80+
81+
82+
class AlignedMTL(GramianWeightedAggregator):
83+
r"""
84+
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
85+
`Independent Component Alignment for Multi-Task Learning
86+
<https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf>`_.
87+
88+
:param pref_vector: The preference vector to use. If not provided, defaults to
89+
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
90+
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
91+
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
92+
uses the mean eigenvalue (as in the original implementation).
93+
94+
.. note::
95+
This implementation was adapted from the official implementation of SamsungLabs/MTL,
96+
which is not available anymore at the time of writing.
97+
"""
98+
99+
gramian_weighting: AlignedMTLWeighting
100+
101+
def __init__(
102+
self,
103+
pref_vector: Tensor | None = None,
104+
scale_mode: SUPPORTED_SCALE_MODE = "min",
105+
) -> None:
106+
self._pref_vector = pref_vector
107+
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
108+
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))
109+
110+
def __repr__(self) -> str:
111+
return (
112+
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
113+
f"scale_mode={repr(self._scale_mode)})"
114+
)
115+
116+
def __str__(self) -> str:
117+
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"

src/torchjd/aggregation/_cagrad.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,6 @@
1818
from ._utils.non_differentiable import raise_non_differentiable_error
1919

2020

21-
class CAGrad(GramianWeightedAggregator):
22-
"""
23-
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
24-
`Conflict-Averse Gradient Descent for Multi-task Learning
25-
<https://arxiv.org/pdf/2110.14048.pdf>`_.
26-
27-
:param c: The scale of the radius of the ball constraint.
28-
:param norm_eps: A small value to avoid division by zero when normalizing.
29-
30-
.. note::
31-
This aggregator is not installed by default. When not installed, trying to import it should
32-
result in the following error:
33-
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
34-
To install it, use ``pip install "torchjd[cagrad]"``.
35-
"""
36-
37-
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
38-
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
39-
self._c = c
40-
self._norm_eps = norm_eps
41-
42-
# This prevents considering the computed weights as constant w.r.t. the matrix.
43-
self.register_full_backward_pre_hook(raise_non_differentiable_error)
44-
45-
def __repr__(self) -> str:
46-
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"
47-
48-
def __str__(self) -> str:
49-
c_str = str(self._c).rstrip("0")
50-
return f"CAGrad{c_str}"
51-
52-
5321
class CAGradWeighting(Weighting[PSDMatrix]):
5422
"""
5523
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
@@ -104,3 +72,37 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
10472
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)
10573

10674
return weights
75+
76+
77+
class CAGrad(GramianWeightedAggregator):
78+
"""
79+
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
80+
`Conflict-Averse Gradient Descent for Multi-task Learning
81+
<https://arxiv.org/pdf/2110.14048.pdf>`_.
82+
83+
:param c: The scale of the radius of the ball constraint.
84+
:param norm_eps: A small value to avoid division by zero when normalizing.
85+
86+
.. note::
87+
This aggregator is not installed by default. When not installed, trying to import it should
88+
result in the following error:
89+
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
90+
To install it, use ``pip install "torchjd[cagrad]"``.
91+
"""
92+
93+
gramian_weighting: CAGradWeighting
94+
95+
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
96+
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
97+
self._c = c
98+
self._norm_eps = norm_eps
99+
100+
# This prevents considering the computed weights as constant w.r.t. the matrix.
101+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
102+
103+
def __repr__(self) -> str:
104+
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"
105+
106+
def __str__(self) -> str:
107+
c_str = str(self._c).rstrip("0")
108+
return f"CAGrad{c_str}"

src/torchjd/aggregation/_constant.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,6 @@
77
from ._weighting_bases import Weighting
88

99

10-
class Constant(WeightedAggregator):
11-
"""
12-
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of
13-
the rows of the provided matrix, with constant, pre-determined weights.
14-
15-
:param weights: The weights associated to the rows of the input matrices.
16-
"""
17-
18-
def __init__(self, weights: Tensor) -> None:
19-
super().__init__(weighting=ConstantWeighting(weights=weights))
20-
self._weights = weights
21-
22-
def __repr__(self) -> str:
23-
return f"{self.__class__.__name__}(weights={repr(self._weights)})"
24-
25-
def __str__(self) -> str:
26-
weights_str = vector_to_str(self._weights)
27-
return f"{self.__class__.__name__}([{weights_str}])"
28-
29-
3010
class ConstantWeighting(Weighting[Matrix]):
3111
"""
3212
:class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
@@ -55,3 +35,25 @@ def _check_matrix_shape(self, matrix: Tensor) -> None:
5535
f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified "
5636
f"weights). Found `matrix` with {matrix.shape[0]} rows.",
5737
)
38+
39+
40+
class Constant(WeightedAggregator):
41+
"""
42+
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of
43+
the rows of the provided matrix, with constant, pre-determined weights.
44+
45+
:param weights: The weights associated to the rows of the input matrices.
46+
"""
47+
48+
weighting: ConstantWeighting
49+
50+
def __init__(self, weights: Tensor) -> None:
51+
super().__init__(weighting=ConstantWeighting(weights=weights))
52+
self._weights = weights
53+
54+
def __repr__(self) -> str:
55+
return f"{self.__class__.__name__}(weights={repr(self._weights)})"
56+
57+
def __str__(self) -> str:
58+
weights_str = vector_to_str(self._weights)
59+
return f"{self.__class__.__name__}([{weights_str}])"

src/torchjd/aggregation/_dualproj.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,42 @@
1010
from ._weighting_bases import Weighting
1111

1212

13+
class DualProjWeighting(Weighting[PSDMatrix]):
14+
r"""
15+
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
16+
:class:`~torchjd.aggregation.DualProj`.
17+
18+
:param pref_vector: The preference vector to use. If not provided, defaults to
19+
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
20+
:param norm_eps: A small value to avoid division by zero when normalizing.
21+
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
22+
numerical errors when computing the gramian, it might not exactly be positive definite.
23+
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
24+
ensures that it is positive definite.
25+
:param solver: The solver used to optimize the underlying optimization problem.
26+
"""
27+
28+
def __init__(
29+
self,
30+
pref_vector: Tensor | None = None,
31+
norm_eps: float = 0.0001,
32+
reg_eps: float = 0.0001,
33+
solver: SUPPORTED_SOLVER = "quadprog",
34+
) -> None:
35+
super().__init__()
36+
self._pref_vector = pref_vector
37+
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
38+
self.norm_eps = norm_eps
39+
self.reg_eps = reg_eps
40+
self.solver: SUPPORTED_SOLVER = solver
41+
42+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
43+
u = self.weighting(gramian)
44+
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
45+
w = project_weights(u, G, self.solver)
46+
return w
47+
48+
1349
class DualProj(GramianWeightedAggregator):
1450
r"""
1551
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input
@@ -27,6 +63,8 @@ class DualProj(GramianWeightedAggregator):
2763
:param solver: The solver used to optimize the underlying optimization problem.
2864
"""
2965

66+
gramian_weighting: DualProjWeighting
67+
3068
def __init__(
3169
self,
3270
pref_vector: Tensor | None = None,
@@ -54,39 +92,3 @@ def __repr__(self) -> str:
5492

5593
def __str__(self) -> str:
5694
return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}"
57-
58-
59-
class DualProjWeighting(Weighting[PSDMatrix]):
60-
r"""
61-
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
62-
:class:`~torchjd.aggregation.DualProj`.
63-
64-
:param pref_vector: The preference vector to use. If not provided, defaults to
65-
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
66-
:param norm_eps: A small value to avoid division by zero when normalizing.
67-
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
68-
numerical errors when computing the gramian, it might not exactly be positive definite.
69-
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
70-
ensures that it is positive definite.
71-
:param solver: The solver used to optimize the underlying optimization problem.
72-
"""
73-
74-
def __init__(
75-
self,
76-
pref_vector: Tensor | None = None,
77-
norm_eps: float = 0.0001,
78-
reg_eps: float = 0.0001,
79-
solver: SUPPORTED_SOLVER = "quadprog",
80-
) -> None:
81-
super().__init__()
82-
self._pref_vector = pref_vector
83-
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
84-
self.norm_eps = norm_eps
85-
self.reg_eps = reg_eps
86-
self.solver: SUPPORTED_SOLVER = solver
87-
88-
def forward(self, gramian: PSDMatrix, /) -> Tensor:
89-
u = self.weighting(gramian)
90-
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
91-
w = project_weights(u, G, self.solver)
92-
return w

0 commit comments

Comments
 (0)