Skip to content

Commit 79f5501

Browse files
committed
1 parent 7e93f30 commit 79f5501

67 files changed

Lines changed: 2626 additions & 923 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

latest/_sources/docs/aggregation/aligned_mtl.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ Aligned-MTL
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.AlignedMTLWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

latest/_sources/docs/aggregation/bases.rst.txt

Lines changed: 0 additions & 10 deletions
This file was deleted.

latest/_sources/docs/aggregation/cagrad.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ CAGrad
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.CAGradWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

latest/_sources/docs/aggregation/constant.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ Constant
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.ConstantWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

latest/_sources/docs/aggregation/dualproj.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ DualProj
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.DualProjWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

latest/_sources/docs/aggregation/imtl_g.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ IMTL-G
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.IMTLGWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

latest/_sources/docs/aggregation/index.rst.txt

Lines changed: 13 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,146 +1,27 @@
1-
Aggregation
1+
aggregation
22
===========
33

4-
A mapping :math:`\mathcal A: \mathbb R^{m\times n} \to \mathbb R^n` reducing any matrix
5-
:math:`J \in \mathbb R^{m\times n}` into its aggregation :math:`\mathcal A(J) \in \mathbb R^n` is
6-
called an aggregator.
4+
.. automodule:: torchjd.aggregation
5+
:no-members:
76

8-
In the context of JD, the matrix to aggregate is a Jacobian whose rows are the gradients of the
9-
individual objectives. The aggregator is used to reduce this matrix into an update vector for the
10-
parameters of the model
7+
Abstract base classes
8+
---------------------
119

12-
In TorchJD, an aggregator is a class that inherits from the abstract class
13-
:doc:`Aggregator <bases>`. We provide the following list of aggregators from the literature:
14-
15-
.. role:: raw-html(raw)
16-
:format: html
17-
18-
.. |yes| replace:: :raw-html:`<center><font color="#28b528">✔</font></center>`
19-
.. |no| replace:: :raw-html:`<center><font color="#e63232">✘</font></center>`
20-
21-
.. list-table::
22-
:widths: 25 15 15 15
23-
:header-rows: 1
24-
25-
* - :doc:`Aggregator <bases>`
26-
- :ref:`Non-conflicting <Non-conflicting>`
27-
- :ref:`Linear under scaling <Linear under scaling>`
28-
- :ref:`Weighted <Weighted>`
29-
* - :doc:`UPGrad <upgrad>` (recommended)
30-
- |yes|
31-
- |yes|
32-
- |yes|
33-
* - :doc:`Aligned-MTL <aligned_mtl>`
34-
- |no|
35-
- |no|
36-
- |yes|
37-
* - :doc:`CAGrad <cagrad>`
38-
- |no|
39-
- |no|
40-
- |yes|
41-
* - :doc:`ConFIG <config>`
42-
- |no|
43-
- |yes|
44-
- |yes|
45-
* - :doc:`Constant <constant>`
46-
- |no|
47-
- |yes|
48-
- |yes|
49-
* - :doc:`DualProj <dualproj>`
50-
- |yes|
51-
- |no|
52-
- |yes|
53-
* - :doc:`GradDrop <graddrop>`
54-
- |no|
55-
- |no|
56-
- |no|
57-
* - :doc:`IMTL-G <imtl_g>`
58-
- |no|
59-
- |no|
60-
- |yes|
61-
* - :doc:`Krum <krum>`
62-
- |no|
63-
- |no|
64-
- |yes|
65-
* - :doc:`Mean <mean>`
66-
- |no|
67-
- |yes|
68-
- |yes|
69-
* - :doc:`MGDA <mgda>`
70-
- |yes|
71-
- |no|
72-
- |yes|
73-
* - :doc:`Nash-MTL <nash_mtl>`
74-
- |yes|
75-
- |no|
76-
- |yes|
77-
* - :doc:`PCGrad <pcgrad>`
78-
- |no|
79-
- |yes|
80-
- |yes|
81-
* - :doc:`Random <random>`
82-
- |no|
83-
- |yes|
84-
- |yes|
85-
* - :doc:`Sum <sum>`
86-
- |no|
87-
- |yes|
88-
- |yes|
89-
* - :doc:`Trimmed Mean <trimmed_mean>`
90-
- |no|
91-
- |no|
92-
- |no|
93-
94-
.. hint::
95-
This table is an adaptation of the one available in `Jacobian Descent For Multi-Objective
96-
Optimization <https://arxiv.org/pdf/2406.16232>`_. The paper provides precise justification of
97-
the properties in Section 2.2 as well as proofs in Appendix B.
98-
99-
.. _Non-conflicting:
100-
.. admonition::
101-
Non-conflicting
102-
103-
An aggregator :math:`\mathcal A: \mathbb R^{m\times n} \to \mathbb R^n` is said to be
104-
*non-conflicting* if for any :math:`J\in\mathbb R^{m\times n}`, :math:`J\cdot\mathcal A(J)` is a
105-
vector with only non-negative elements.
106-
107-
In other words, :math:`\mathcal A` is non-conflicting whenever the aggregation of any matrix has
108-
non-negative inner product with all rows of that matrix. In the context of JD, this ensures that
109-
no objective locally increases.
110-
111-
.. _Linear under scaling:
112-
.. admonition::
113-
Linear under scaling
114-
115-
An aggregator :math:`\mathcal A: \mathbb R^{m\times n} \to \mathbb R^n` is said to be
116-
*linear under scaling* if for any :math:`J\in\mathbb R^{m\times n}`, the mapping from any
117-
positive :math:`c\in\mathbb R^{n}` to :math:`\mathcal A(\operatorname{diag}(c)\cdot J)` is
118-
linear in :math:`c`.
119-
120-
In other words, :math:`\mathcal A` is linear under scaling whenever scaling a row of the matrix
121-
to aggregate scales its influence proportionally. In the context of JD, this ensures that even
122-
when the gradient norms are imbalanced, each gradient will contribute to the update
123-
proportionally to its norm.
124-
125-
.. _Weighted:
126-
.. admonition::
127-
Weighted
128-
129-
An aggregator :math:`\mathcal A: \mathbb R^{m\times n} \to \mathbb R^n` is said to be *weighted*
130-
if for any :math:`J\in\mathbb R^{m\times n}`, there exists a weight vector
131-
:math:`w\in\mathbb R^m` such that :math:`\mathcal A(J)=J^\top w`.
132-
133-
In other words, :math:`\mathcal A` is weighted whenever the aggregation of any matrix is always
134-
in the span of the rows of that matrix. This ensures a higher precision of the Taylor
135-
approximation that JD relies on.
10+
.. autoclass:: torchjd.aggregation.Aggregator
11+
:members:
12+
:undoc-members:
13+
:exclude-members: forward
13614

15+
.. autoclass:: torchjd.aggregation.Weighting
16+
:members:
17+
:undoc-members:
18+
:exclude-members: forward
13719

13820

13921
.. toctree::
14022
:hidden:
14123
:maxdepth: 1
14224

143-
bases.rst
14425
upgrad.rst
14526
aligned_mtl.rst
14627
cagrad.rst

latest/_sources/docs/aggregation/krum.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ Krum
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.KrumWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

latest/_sources/docs/aggregation/mean.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ Mean
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.MeanWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

latest/_sources/docs/aggregation/mgda.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ MGDA
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.MGDAWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

0 commit comments

Comments
 (0)