Skip to content

Commit 4aa3d6b

Browse files
authored
feat(aggregation): Add ConFIG (#223)
* Add ConFIG * Add ConFIG unit tests * Add ConFIG to the list of supported Aggregators in README.md * Add ConFIG to the interactive plotter * Add documentation entry for ConFIG * Add doc test for ConFIG * Add changelog entry
1 parent bf6a38a commit 4aa3d6b

File tree

9 files changed

+148
-0
lines changed

9 files changed

+148
-0
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ changes that do not affect the user.
1212

1313
- Added Python 3.13 classifier in pyproject.toml (we now also run tests on Python 3.13 in the CI).
1414

15+
### Added
16+
17+
- New aggregator `ConFIG` from [ConFIG: Towards Conflict-free Training of Physics
18+
Informed Neural Networks](https://arxiv.org/pdf/2408.11104).
19+
1520
## [0.4.1] - 2025-01-02
1621

1722
### Fixed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo
113113
| [UPGrad](https://torchjd.org/docs/aggregation/upgrad/) (recommended) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
114114
| [AlignedMTL](https://torchjd.org/docs/aggregation/aligned_mtl/) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) |
115115
| [CAGrad](https://torchjd.org/docs/aggregation/cagrad/) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) |
116+
| [ConFIG](https://torchjd.org/docs/aggregation/config/) | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) |
116117
| [Constant](https://torchjd.org/docs/aggregation/constant/) | - |
117118
| [DualProj](https://torchjd.org/docs/aggregation/dualproj/) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
118119
| [GradDrop](https://torchjd.org/docs/aggregation/graddrop/) | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
:hide-toc:
2+
3+
ConFIG
4+
======
5+
6+
.. automodule:: torchjd.aggregation.config
7+
:members:
8+
:undoc-members:
9+
:exclude-members: forward

docs/source/docs/aggregation/index.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ In TorchJD, an aggregator is a class that inherits from the abstract class
3838
- |no|
3939
- |no|
4040
- |yes|
41+
* - :doc:`ConFIG <config>`
42+
- |no|
43+
- |yes|
44+
- |yes|
4145
* - :doc:`Constant <constant>`
4246
- |no|
4347
- |yes|
@@ -140,6 +144,7 @@ In TorchJD, an aggregator is a class that inherits from the abstract class
140144
upgrad.rst
141145
aligned_mtl.rst
142146
cagrad.rst
147+
config.rst
143148
constant.rst
144149
dualproj.rst
145150
graddrop.rst

src/torchjd/aggregation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .aligned_mtl import AlignedMTL
22
from .bases import Aggregator
33
from .cagrad import CAGrad
4+
from .config import ConFIG
45
from .constant import Constant
56
from .dualproj import DualProj
67
from .graddrop import GradDrop

src/torchjd/aggregation/config.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# The code of this file was partly adapted from
2+
# https://github.com/tum-pbs/ConFIG/tree/main/conflictfree.
3+
# It is therefore also subject to the following license.
4+
#
5+
# MIT License
6+
#
7+
# Copyright (c) 2024 TUM Physics-based Simulation
8+
#
9+
# Permission is hereby granted, free of charge, to any person obtaining a copy
10+
# of this software and associated documentation files (the "Software"), to deal
11+
# in the Software without restriction, including without limitation the rights
12+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13+
# copies of the Software, and to permit persons to whom the Software is
14+
# furnished to do so, subject to the following conditions:
15+
#
16+
# The above copyright notice and this permission notice shall be included in all
17+
# copies or substantial portions of the Software.
18+
#
19+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25+
# SOFTWARE.
26+
27+
28+
import torch
29+
from torch import Tensor
30+
31+
from torchjd.aggregation._pref_vector_utils import (
32+
_check_pref_vector,
33+
_pref_vector_to_str_suffix,
34+
_pref_vector_to_weighting,
35+
)
36+
from torchjd.aggregation.bases import Aggregator
37+
from torchjd.aggregation.sum import _SumWeighting
38+
39+
40+
class ConFIG(Aggregator):
41+
"""
42+
:class:`~torchjd.aggregation.bases.Aggregator` as defined in Equation 2 of `ConFIG: Towards
43+
Conflict-free Training of Physics Informed Neural Networks <https://arxiv.org/pdf/2408.11104>`_.
44+
45+
:param pref_vector: The preference vector used to weight the rows. If not provided, defaults to
46+
equal weights of 1.
47+
48+
.. admonition::
49+
Example
50+
51+
Use ConFIG to aggregate a matrix.
52+
53+
>>> from torch import tensor
54+
>>> from torchjd.aggregation import ConFIG
55+
>>>
56+
>>> A = ConFIG()
57+
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
58+
>>>
59+
>>> A(J)
60+
tensor([0.1588, 2.0706, 2.0706])
61+
62+
.. note::
63+
This implementation was adapted from the `official implementation
64+
<https://github.com/tum-pbs/ConFIG/tree/main/conflictfree>`_.
65+
"""
66+
67+
def __init__(self, pref_vector: Tensor | None = None):
68+
super().__init__()
69+
_check_pref_vector(pref_vector)
70+
self.weighting = _pref_vector_to_weighting(pref_vector, default=_SumWeighting())
71+
self._pref_vector = pref_vector
72+
73+
def forward(self, matrix: Tensor) -> Tensor:
74+
weights = self.weighting(matrix)
75+
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
76+
best_direction = torch.linalg.pinv(units) @ weights
77+
78+
if best_direction.norm() == 0:
79+
unit_target_vector = torch.zeros_like(best_direction)
80+
else:
81+
unit_target_vector = best_direction / best_direction.norm()
82+
83+
length = torch.sum(torch.stack([torch.dot(grad, unit_target_vector) for grad in matrix]))
84+
85+
return length * unit_target_vector
86+
87+
def __repr__(self) -> str:
88+
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
89+
90+
def __str__(self) -> str:
91+
return f"ConFIG{_pref_vector_to_str_suffix(self._pref_vector)}"

tests/doc/test_aggregation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ def test_cagrad():
3737
assert_close(A(J), tensor([0.1835, 1.2041, 1.2041]), rtol=0, atol=1e-4)
3838

3939

40+
def test_config():
41+
from torch import tensor
42+
43+
from torchjd.aggregation import ConFIG
44+
45+
A = ConFIG()
46+
J = tensor([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
47+
48+
assert_close(A(J), tensor([0.1588, 2.0706, 2.0706]), rtol=0, atol=1e-4)
49+
50+
4051
def test_constant():
4152
from torch import tensor
4253

tests/plots/interactive_plotter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
MGDA,
1515
AlignedMTL,
1616
CAGrad,
17+
ConFIG,
1718
DualProj,
1819
GradDrop,
1920
Mean,
@@ -44,6 +45,7 @@ def main() -> None:
4445
aggregators = [
4546
AlignedMTL(),
4647
CAGrad(c=0.5),
48+
ConFIG(),
4749
DualProj(),
4850
GradDrop(),
4951
IMTLG(),
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
from pytest import mark
3+
4+
from torchjd.aggregation import ConFIG
5+
6+
from ._property_testers import ExpectedStructureProperty
7+
8+
9+
# For some reason, some permutation-invariance property tests fail with the pinv-based
10+
# implementation.
11+
@mark.parametrize("aggregator", [ConFIG()])
12+
class TestConfig(ExpectedStructureProperty):
13+
pass
14+
15+
16+
def test_representations():
17+
A = ConFIG()
18+
assert repr(A) == "ConFIG(pref_vector=None)"
19+
assert str(A) == "ConFIG"
20+
21+
A = ConFIG(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"))
22+
assert repr(A) == "ConFIG(pref_vector=tensor([1., 2., 3.]))"
23+
assert str(A) == "ConFIG([1., 2., 3.])"

0 commit comments

Comments
 (0)