Skip to content

Commit eb6e11f

Browse files
refactor(aggregation): Refactor weighting structure (#347)
* Move weighting base classes to _weighting_bases.py and aggregator base classes to aggregator_bases.py * Add Matrix and PSDMatrix annotated types (not fully used yet) * Make Weighting generic on the type of input stat it takes * Add Composition * Add _GramianWeightedAggregator as an alias for a _WeightedAggregator that composes its (gramian-based) weighting with the compute_gramian function * Change gramian-based Weightings to not compute the gramian themselves * Change gramian-based Aggregators to be _GramianWeightedAggregator * Make some aggregators store some of their weighting's parameters to have access to them in __str__ and __repr__ * Adapt some weighting tests --------- Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
1 parent e760f22 commit eb6e11f

File tree

25 files changed

+242
-204
lines changed

25 files changed

+242
-204
lines changed

docs/source/docs/aggregation/bases.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Aggregator (abstract)
44
=====================
55

6-
.. automodule:: torchjd.aggregation.bases
6+
.. automodule:: torchjd.aggregation.aggregator_bases
77
:members:
88
:undoc-members:
99
:show-inheritance:

src/torchjd/aggregation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from ._utils.check_dependencies import (
22
OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError,
33
)
4+
from .aggregator_bases import Aggregator
45
from .aligned_mtl import AlignedMTL
5-
from .bases import Aggregator
66
from .config import ConFIG
77
from .constant import Constant
88
from .dualproj import DualProj

src/torchjd/aggregation/_utils/pref_vector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from torch import Tensor
22

3-
from torchjd.aggregation.bases import _Weighting
3+
from torchjd.aggregation._weighting_bases import Matrix, Weighting
44
from torchjd.aggregation.constant import _ConstantWeighting
55

66
from .str import vector_to_str
77

88

9-
def pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
9+
def pref_vector_to_weighting(
10+
pref_vector: Tensor | None, default: Weighting[Matrix]
11+
) -> Weighting[Matrix]:
1012
"""
1113
Returns the weighting associated to a given preference vector, with a fallback to a default
1214
weighting if the preference vector is None.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Annotated, Callable, Generic, TypeVar
5+
6+
from torch import Tensor, nn
7+
8+
_T = TypeVar("_T", contravariant=True)
9+
_FnInputT = TypeVar("_FnInputT")
10+
_FnOutputT = TypeVar("_FnOutputT")
11+
Matrix = Annotated[Tensor, "ndim=2"]
12+
PSDMatrix = Annotated[Matrix, "Positive semi-definite"]
13+
14+
15+
class Weighting(Generic[_T], nn.Module, ABC):
16+
r"""
17+
Abstract base class for all weighting methods. It has the role of extracting a vector of weights
18+
of dimension :math:`m` from some statistic of a matrix of dimension :math:`m \times n`.
19+
"""
20+
21+
@abstractmethod
22+
def forward(self, stat: _T) -> Tensor:
23+
"""Computes the vector of weights from the input stat."""
24+
25+
# Override to make type hints and documentation more specific
26+
def __call__(self, stat: _T) -> Tensor:
27+
"""Computes the vector of weights from the input stat and applies all registered hooks."""
28+
29+
return super().__call__(stat)
30+
31+
def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]:
32+
return _Composition(self, fn)
33+
34+
__lshift__ = _compose
35+
36+
37+
class _Composition(Weighting[_T]):
38+
"""
39+
Weighting that composes a Weighting with a function, so that the Weighting is applied to the
40+
output of the function.
41+
"""
42+
43+
def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutputT]):
44+
super().__init__()
45+
self.fn = fn
46+
self.weighting = weighting
47+
48+
def forward(self, stat: _T) -> Tensor:
49+
return self.weighting(self.fn(stat))

src/torchjd/aggregation/bases.py renamed to src/torchjd/aggregation/aggregator_bases.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
from torch import Tensor, nn
44

5+
from ._utils.gramian import compute_gramian
6+
from ._weighting_bases import Matrix, PSDMatrix, Weighting
7+
58

69
class Aggregator(nn.Module, ABC):
710
r"""
@@ -42,35 +45,15 @@ def __str__(self) -> str:
4245
return f"{self.__class__.__name__}"
4346

4447

45-
class _Weighting(nn.Module, ABC):
46-
r"""
47-
Abstract base class for all weighting methods. It has the role of extracting a vector of weights
48-
of dimension :math:`m` from a matrix of dimension :math:`m \times n`.
49-
"""
50-
51-
def __init__(self):
52-
super().__init__()
53-
54-
@abstractmethod
55-
def forward(self, matrix: Tensor) -> Tensor:
56-
"""Computes the vector of weights from the input matrix."""
57-
58-
# Override to make type hints and documentation more specific
59-
def __call__(self, matrix: Tensor) -> Tensor:
60-
"""Computes the vector of weights from the input matrix and applies all registered hooks."""
61-
62-
return super().__call__(matrix)
63-
64-
6548
class _WeightedAggregator(Aggregator):
6649
"""
67-
:class:`~torchjd.aggregation.bases.Aggregator` that combines the rows of the input matrix with
68-
weights given by applying a :class:`~torchjd.aggregation.bases._Weighting` to the matrix.
50+
Aggregator that combines the rows of the input jacobian matrix with weights given by applying a
51+
Weighting to it.
6952
7053
:param weighting: The object responsible for extracting the vector of weights from the matrix.
7154
"""
7255

73-
def __init__(self, weighting: _Weighting):
56+
def __init__(self, weighting: Weighting[Matrix]):
7457
super().__init__()
7558
self.weighting = weighting
7659

@@ -91,3 +74,15 @@ def forward(self, matrix: Tensor) -> Tensor:
9174
weights = self.weighting(matrix)
9275
vector = self.combine(matrix, weights)
9376
return vector
77+
78+
79+
class _GramianWeightedAggregator(_WeightedAggregator):
80+
"""
81+
WeightedAggregator that computes the gramian of the input jacobian matrix before applying a
82+
Weighting to it.
83+
84+
:param weighting: The object responsible for extracting the vector of weights from the gramian.
85+
"""
86+
87+
def __init__(self, weighting: Weighting[PSDMatrix]):
88+
super().__init__(weighting << compute_gramian)

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@
2828
import torch
2929
from torch import Tensor
3030

31-
from ._utils.gramian import compute_gramian
3231
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
33-
from .bases import _WeightedAggregator, _Weighting
32+
from ._weighting_bases import PSDMatrix, Weighting
33+
from .aggregator_bases import _GramianWeightedAggregator
3434
from .mean import _MeanWeighting
3535

3636

37-
class AlignedMTL(_WeightedAggregator):
37+
class AlignedMTL(_GramianWeightedAggregator):
3838
"""
39-
:class:`~torchjd.aggregation.bases.Aggregator` as defined in Algorithm 1 of
39+
:class:`~torchjd.aggregation.aggregator_bases.Aggregator` as defined in Algorithm 1 of
4040
`Independent Component Alignment for Multi-Task Learning
4141
<https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf>`_.
4242
@@ -65,7 +65,7 @@ def __init__(self, pref_vector: Tensor | None = None):
6565
weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
6666
self._pref_vector = pref_vector
6767

68-
super().__init__(weighting=_AlignedMTLWrapper(weighting))
68+
super().__init__(_AlignedMTLWrapper(weighting))
6969

7070
def __repr__(self) -> str:
7171
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
@@ -74,26 +74,24 @@ def __str__(self) -> str:
7474
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
7575

7676

77-
class _AlignedMTLWrapper(_Weighting):
77+
class _AlignedMTLWrapper(Weighting[PSDMatrix]):
7878
"""
79-
Wrapper of :class:`~torchjd.aggregation.bases._Weighting` that corrects the extracted
79+
Wrapper of :class:`~torchjd.aggregation._weighting_bases.Weighting` that corrects the extracted
8080
weights with the balance transformation defined in Algorithm 1 of `Independent Component
8181
Alignment for Multi-Task Learning
8282
<https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf>`_.
8383
84-
:param weighting: The wrapped :class:`~torchjd.aggregation.bases._Weighting`
84+
:param weighting: The wrapped :class:`~torchjd.aggregation._weighting_bases.Weighting`
8585
responsible for extracting weight vectors from the input matrices.
8686
"""
8787

88-
def __init__(self, weighting: _Weighting):
88+
def __init__(self, weighting: Weighting[PSDMatrix]):
8989
super().__init__()
9090
self.weighting = weighting
9191

92-
def forward(self, matrix: Tensor) -> Tensor:
93-
w = self.weighting(matrix)
94-
95-
M = compute_gramian(matrix)
96-
B = self._compute_balance_transformation(M)
92+
def forward(self, gramian: Tensor) -> Tensor:
93+
w = self.weighting(gramian)
94+
B = self._compute_balance_transformation(gramian)
9795
alpha = B @ w
9896

9997
return alpha

src/torchjd/aggregation/cagrad.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._utils.check_dependencies import check_dependencies_are_installed
2+
from ._weighting_bases import PSDMatrix, Weighting
23

34
check_dependencies_are_installed(["cvxpy", "clarabel"])
45

@@ -7,14 +8,14 @@
78
import torch
89
from torch import Tensor
910

10-
from ._utils.gramian import compute_gramian, normalize
11+
from ._utils.gramian import normalize
1112
from ._utils.non_differentiable import raise_non_differentiable_error
12-
from .bases import _WeightedAggregator, _Weighting
13+
from .aggregator_bases import _GramianWeightedAggregator
1314

1415

15-
class CAGrad(_WeightedAggregator):
16+
class CAGrad(_GramianWeightedAggregator):
1617
"""
17-
:class:`~torchjd.aggregation.bases.Aggregator` as defined in Algorithm 1 of
18+
:class:`~torchjd.aggregation.aggregator_bases.Aggregator` as defined in Algorithm 1 of
1819
`Conflict-Averse Gradient Descent for Multi-task Learning
1920
<https://arxiv.org/pdf/2110.14048.pdf>`_.
2021
@@ -43,24 +44,24 @@ class CAGrad(_WeightedAggregator):
4344
"""
4445

4546
def __init__(self, c: float, norm_eps: float = 0.0001):
46-
super().__init__(weighting=_CAGradWeighting(c=c, norm_eps=norm_eps))
47+
super().__init__(_CAGradWeighting(c=c, norm_eps=norm_eps))
48+
self._c = c
49+
self._norm_eps = norm_eps
4750

4851
# This prevents considering the computed weights as constant w.r.t. the matrix.
4952
self.register_full_backward_pre_hook(raise_non_differentiable_error)
5053

5154
def __repr__(self) -> str:
52-
return (
53-
f"{self.__class__.__name__}(c={self.weighting.c}, norm_eps={self.weighting.norm_eps})"
54-
)
55+
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"
5556

5657
def __str__(self) -> str:
57-
c_str = str(self.weighting.c).rstrip("0")
58+
c_str = str(self._c).rstrip("0")
5859
return f"CAGrad{c_str}"
5960

6061

61-
class _CAGradWeighting(_Weighting):
62+
class _CAGradWeighting(Weighting[PSDMatrix]):
6263
"""
63-
:class:`~torchjd.aggregation.bases._Weighting` that extracts weights using the CAGrad
64+
:class:`~torchjd.aggregation._weighting_bases.Weighting` that extracts weights using the CAGrad
6465
algorithm, as defined in algorithm 1 of `Conflict-Averse Gradient Descent for Multi-task
6566
Learning <https://arxiv.org/pdf/2110.14048.pdf>`_.
6667
@@ -85,11 +86,7 @@ def __init__(self, c: float, norm_eps: float):
8586
self.c = c
8687
self.norm_eps = norm_eps
8788

88-
def forward(self, matrix: Tensor) -> Tensor:
89-
gramian = compute_gramian(matrix)
90-
return self._compute_from_gramian(gramian)
91-
92-
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
89+
def forward(self, gramian: Tensor) -> Tensor:
9390
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))
9491

9592
reduced_matrix = U @ S.sqrt().diag()

src/torchjd/aggregation/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@
3030

3131
from ._utils.non_differentiable import raise_non_differentiable_error
3232
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
33-
from .bases import Aggregator
33+
from .aggregator_bases import Aggregator
3434
from .sum import _SumWeighting
3535

3636

3737
class ConFIG(Aggregator):
3838
"""
39-
:class:`~torchjd.aggregation.bases.Aggregator` as defined in Equation 2 of `ConFIG: Towards
40-
Conflict-free Training of Physics Informed Neural Networks <https://arxiv.org/pdf/2408.11104>`_.
39+
:class:`~torchjd.aggregation.aggregator_bases.Aggregator` as defined in Equation 2 of `ConFIG:
40+
Towards Conflict-free Training of Physics Informed Neural Networks
41+
<https://arxiv.org/pdf/2408.11104>`_.
4142
4243
:param pref_vector: The preference vector used to weight the rows. If not provided, defaults to
4344
equal weights of 1.

src/torchjd/aggregation/constant.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from torch import Tensor
22

33
from ._utils.str import vector_to_str
4-
from .bases import _WeightedAggregator, _Weighting
4+
from ._weighting_bases import Matrix, Weighting
5+
from .aggregator_bases import _WeightedAggregator
56

67

78
class Constant(_WeightedAggregator):
89
"""
9-
:class:`~torchjd.aggregation.bases.Aggregator` that makes a linear combination of the rows of
10-
the provided matrix, with constant, pre-determined weights.
10+
:class:`~torchjd.aggregation.aggregator_bases.Aggregator` that makes a linear combination of the
11+
rows of the provided matrix, with constant, pre-determined weights.
1112
1213
:param weights: The weights associated to the rows of the input matrices.
1314
@@ -37,9 +38,9 @@ def __str__(self) -> str:
3738
return f"{self.__class__.__name__}([{weights_str}])"
3839

3940

40-
class _ConstantWeighting(_Weighting):
41+
class _ConstantWeighting(Weighting[Matrix]):
4142
"""
42-
:class:`~torchjd.aggregation.bases._Weighting` that returns constant, pre-determined
43+
:class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
4344
weights.
4445
4546
:param weights: The weights associated to the rows of the input matrices.

0 commit comments

Comments
 (0)