Skip to content

Commit 02acff3

Browse files
committed
1 parent 5820690 commit 02acff3

58 files changed

Lines changed: 1525 additions & 163 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
:hide-toc:
2+
3+
GradVac
4+
=======
5+
6+
.. autoclass:: torchjd.aggregation.GradVac
7+
:members:
8+
:undoc-members:
9+
:exclude-members: forward, eps, beta
10+
11+
.. autoclass:: torchjd.aggregation.GradVacWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward, eps, beta

stable/_sources/docs/aggregation/index.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ Abstract base classes
2222
:undoc-members:
2323
:exclude-members: forward
2424

25+
.. autoclass:: torchjd.aggregation.Stateful
26+
:members:
27+
:undoc-members:
28+
2529

2630
.. toctree::
2731
:hidden:
@@ -35,6 +39,7 @@ Abstract base classes
3539
dualproj.rst
3640
flattening.rst
3741
graddrop.rst
42+
gradvac.rst
3843
imtl_g.rst
3944
krum.rst
4045
mean.rst

stable/_sources/examples/amp.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ case, the losses) should preferably be scaled with a `GradScaler
1111
<https://pytorch.org/docs/stable/amp.html#gradient-scaling>`_ to avoid gradient underflow. The
1212
following example shows the resulting code for a multi-task learning use-case.
1313

14-
.. code-block:: python
14+
.. testcode::
1515
:emphasize-lines: 2, 17, 27, 34-35, 37-38
1616

1717
import torch

stable/_sources/examples/basic_usage.rst.txt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ the parameters are updated using the resulting aggregation.
1212

1313
Import several classes from ``torch`` and ``torchjd``:
1414

15-
.. code-block:: python
15+
.. testcode::
1616

1717
import torch
1818
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -24,14 +24,14 @@ Import several classes from ``torch`` and ``torchjd``:
2424

2525
Define the model and the optimizer, as usual:
2626

27-
.. code-block:: python
27+
.. testcode::
2828

2929
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2))
3030
optimizer = SGD(model.parameters(), lr=0.1)
3131

3232
Define the aggregator that will be used to combine the Jacobian matrix:
3333

34-
.. code-block:: python
34+
.. testcode::
3535

3636
aggregator = UPGrad()
3737

@@ -41,7 +41,7 @@ negatively affected by the update.
4141

4242
Now that everything is defined, we can train the model. Define the input and the associated target:
4343

44-
.. code-block:: python
44+
.. testcode::
4545

4646
input = torch.randn(16, 10) # Batch of 16 random input vectors of length 10
4747
target1 = torch.randn(16) # First batch of 16 targets
@@ -51,7 +51,7 @@ Here, we generate fake inputs and labels for the sake of the example.
5151

5252
We can now compute the losses associated to each element of the batch.
5353

54-
.. code-block:: python
54+
.. testcode::
5555

5656
loss_fn = MSELoss()
5757
output = model(input)
@@ -62,7 +62,7 @@ The last steps are similar to gradient descent-based optimization, but using the
6262

6363
Perform the Jacobian descent backward pass:
6464

65-
.. code-block:: python
65+
.. testcode::
6666

6767
autojac.backward([loss1, loss2])
6868
jac_to_grad(model.parameters(), aggregator)
@@ -73,14 +73,14 @@ field of the parameters. It also deletes the ``.jac`` fields save some memory.
7373

7474
Update each parameter based on its ``.grad`` field, using the ``optimizer``:
7575

76-
.. code-block:: python
76+
.. testcode::
7777

7878
optimizer.step()
7979

8080
The model's parameters have been updated!
8181

8282
As usual, you should now reset the ``.grad`` field of each model parameter:
8383

84-
.. code-block:: python
84+
.. testcode::
8585

8686
optimizer.zero_grad()
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
Grouping
2+
========
3+
4+
The aggregation can be made independently on groups of parameters, at different granularities. The
5+
`Gradient Vaccine paper <https://arxiv.org/pdf/2010.05874>`_ introduces four strategies to partition
6+
the parameters:
7+
8+
1. **Together** (baseline): one group covering all parameters. Corresponds to the `whole_model`
9+
stategy in the paper.
10+
11+
2. **Per network**: one group per top-level sub-network (e.g. encoder and decoder separately).
12+
Corresponds to the `enc_dec` stategy in the paper.
13+
14+
3. **Per layer**: one group per leaf module of the network. Corresponds to the `all_layer` stategy
15+
in the paper.
16+
17+
4. **Per tensor**: one group per individual parameter tensor. Corresponds to the `all_matrix`
18+
stategy in the paper.
19+
20+
In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group
21+
after :func:`~torchjd.autojac.backward` or :func:`~torchjd.autojac.mtl_backward`, with a dedicated
22+
aggregator instance per group. For :class:`~torchjd.aggregation.Stateful` aggregators, each instance
23+
should independently maintains its own state (e.g. the EMA :math:`\hat{\phi}` state in
24+
:class:`~torchjd.aggregation.GradVac`, matching the per-block targets from the original paper).
25+
26+
.. note::
27+
The grouping is orthogonal to the choice between
28+
:func:`~torchjd.autojac.backward` vs :func:`~torchjd.autojac.mtl_backward`. Those functions
29+
determine *which* parameters receive Jacobians; grouping then determines *how* those Jacobians
30+
are partitioned for aggregation.
31+
32+
.. note::
33+
The examples below use :class:`~torchjd.aggregation.GradVac`, but the same pattern applies to
34+
any :class:`~torchjd.aggregation.Aggregator`.
35+
36+
1. Together
37+
-----------
38+
39+
A single :class:`~torchjd.aggregation.Aggregator` instance aggregates all shared parameters
40+
together. Cosine similarities are computed between the full task gradient vectors.
41+
42+
.. testcode::
43+
:emphasize-lines: 14, 21
44+
45+
import torch
46+
from torch.nn import Linear, MSELoss, ReLU, Sequential
47+
from torch.optim import SGD
48+
49+
from torchjd.aggregation import GradVac
50+
from torchjd.autojac import jac_to_grad, mtl_backward
51+
52+
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
53+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
54+
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
55+
loss_fn = MSELoss()
56+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
57+
58+
aggregator = GradVac()
59+
60+
for x, y1, y2 in zip(inputs, t1, t2):
61+
features = encoder(x)
62+
loss1 = loss_fn(task1_head(features), y1)
63+
loss2 = loss_fn(task2_head(features), y2)
64+
mtl_backward([loss1, loss2], features=features)
65+
jac_to_grad(encoder.parameters(), aggregator)
66+
optimizer.step()
67+
optimizer.zero_grad()
68+
69+
2. Per network
70+
--------------
71+
72+
One :class:`~torchjd.aggregation.Aggregator` instance per top-level sub-network. Here the model
73+
is split into an encoder and a decoder; cosine similarities are computed separately within each.
74+
Passing ``features=dec_out`` to :func:`~torchjd.autojac.mtl_backward` causes both sub-networks
75+
to receive Jacobians, which are then aggregated independently.
76+
77+
.. testcode::
78+
:emphasize-lines: 8-9, 15-16, 24-25
79+
80+
import torch
81+
from torch.nn import Linear, MSELoss, ReLU, Sequential
82+
from torch.optim import SGD
83+
84+
from torchjd.aggregation import GradVac
85+
from torchjd.autojac import jac_to_grad, mtl_backward
86+
87+
encoder = Sequential(Linear(10, 5), ReLU())
88+
decoder = Sequential(Linear(5, 3), ReLU())
89+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
90+
optimizer = SGD([*encoder.parameters(), *decoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
91+
loss_fn = MSELoss()
92+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
93+
94+
encoder_aggregator = GradVac()
95+
decoder_aggregator = GradVac()
96+
97+
for x, y1, y2 in zip(inputs, t1, t2):
98+
enc_out = encoder(x)
99+
dec_out = decoder(enc_out)
100+
loss1 = loss_fn(task1_head(dec_out), y1)
101+
loss2 = loss_fn(task2_head(dec_out), y2)
102+
mtl_backward([loss1, loss2], features=dec_out)
103+
jac_to_grad(encoder.parameters(), encoder_aggregator)
104+
jac_to_grad(decoder.parameters(), decoder_aggregator)
105+
optimizer.step()
106+
optimizer.zero_grad()
107+
108+
3. Per layer
109+
------------
110+
111+
One :class:`~torchjd.aggregation.Aggregator` instance per leaf module. Cosine similarities are
112+
computed per-layer between the task gradients.
113+
114+
.. testcode::
115+
:emphasize-lines: 14-15, 22-23
116+
117+
import torch
118+
from torch.nn import Linear, MSELoss, ReLU, Sequential
119+
from torch.optim import SGD
120+
121+
from torchjd.aggregation import GradVac
122+
from torchjd.autojac import jac_to_grad, mtl_backward
123+
124+
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
125+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
126+
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
127+
loss_fn = MSELoss()
128+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
129+
130+
leaf_layers = [m for m in encoder.modules() if list(m.parameters()) and not list(m.children())]
131+
aggregators = [GradVac() for _ in leaf_layers]
132+
133+
for x, y1, y2 in zip(inputs, t1, t2):
134+
features = encoder(x)
135+
loss1 = loss_fn(task1_head(features), y1)
136+
loss2 = loss_fn(task2_head(features), y2)
137+
mtl_backward([loss1, loss2], features=features)
138+
for layer, aggregator in zip(leaf_layers, aggregators):
139+
jac_to_grad(layer.parameters(), aggregator)
140+
optimizer.step()
141+
optimizer.zero_grad()
142+
143+
4. Per parameter
144+
----------------
145+
146+
One :class:`~torchjd.aggregation.Aggregator` instance per individual parameter tensor. Cosine
147+
similarities are computed per-tensor between the task gradients (e.g. weights and biases of each
148+
layer are treated as separate groups).
149+
150+
.. testcode::
151+
:emphasize-lines: 14-15, 22-23
152+
153+
import torch
154+
from torch.nn import Linear, MSELoss, ReLU, Sequential
155+
from torch.optim import SGD
156+
157+
from torchjd.aggregation import GradVac
158+
from torchjd.autojac import jac_to_grad, mtl_backward
159+
160+
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
161+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
162+
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
163+
loss_fn = MSELoss()
164+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
165+
166+
shared_params = list(encoder.parameters())
167+
aggregators = [GradVac() for _ in shared_params]
168+
169+
for x, y1, y2 in zip(inputs, t1, t2):
170+
features = encoder(x)
171+
loss1 = loss_fn(task1_head(features), y1)
172+
loss2 = loss_fn(task2_head(features), y2)
173+
mtl_backward([loss1, loss2], features=features)
174+
for param, aggregator in zip(shared_params, aggregators):
175+
jac_to_grad([param], aggregator)
176+
optimizer.step()
177+
optimizer.zero_grad()

stable/_sources/examples/index.rst.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ This section contains some usage examples for TorchJD.
2929
- :doc:`PyTorch Lightning Integration <lightning_integration>` showcases how to combine
3030
TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
3131
``LightningModule`` optimized by Jacobian descent.
32+
- :doc:`Grouping <grouping>` shows how to apply an aggregator independently per parameter group
33+
(e.g. per layer), so that conflict resolution happens at a finer granularity than the full
34+
parameter vector.
3235
- :doc:`Automatic Mixed Precision <amp>` shows how to combine mixed precision training with TorchJD.
3336

3437
.. toctree::
@@ -43,3 +46,4 @@ This section contains some usage examples for TorchJD.
4346
monitoring.rst
4447
lightning_integration.rst
4548
amp.rst
49+
grouping.rst

stable/_sources/examples/iwmtl.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ this Gramian to reweight the gradients and resolve conflict entirely.
99

1010
The following example shows how to do that.
1111

12-
.. code-block:: python
12+
.. testcode::
1313
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 40-41
1414

1515
import torch

stable/_sources/examples/iwrm.rst.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
4141
.. tab-set::
4242
.. tab-item:: autograd (baseline)
4343

44-
.. code-block:: python
44+
.. testcode::
4545

4646
import torch
4747
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -75,7 +75,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
7575

7676
.. tab-item:: autojac
7777

78-
.. code-block:: python
78+
.. testcode::
7979
:emphasize-lines: 5-6, 12, 16, 21-23
8080

8181
import torch
@@ -110,7 +110,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
110110

111111
.. tab-item:: autogram (recommended)
112112

113-
.. code-block:: python
113+
.. testcode::
114114
:emphasize-lines: 5-6, 12, 16-17, 21-24
115115

116116
import torch

stable/_sources/examples/lightning_integration.rst.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,18 @@ The following code example demonstrates a basic multi-task learning setup using
1010
:class:`~lightning.pytorch.core.LightningModule` that will call :doc:`mtl_backward
1111
<../docs/autojac/mtl_backward>` at each training iteration.
1212

13-
.. code-block:: python
13+
.. testsetup::
14+
15+
import warnings
16+
import logging
17+
from lightning.fabric.utilities.warnings import PossibleUserWarning
18+
19+
logging.disable(logging.INFO)
20+
warnings.filterwarnings("ignore", category=DeprecationWarning)
21+
warnings.filterwarnings("ignore", category=FutureWarning)
22+
warnings.filterwarnings("ignore", category=PossibleUserWarning)
23+
24+
.. testcode::
1425
:emphasize-lines: 9-10, 18, 31-32
1526

1627
import torch

stable/_sources/examples/monitoring.rst.txt

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ Jacobian descent is doing something different than gradient descent. With
1414
:doc:`UPGrad <../docs/aggregation/upgrad>`, this happens when the original gradients conflict (i.e.
1515
they have a negative inner product).
1616

17-
.. code-block:: python
17+
.. testsetup::
18+
19+
import torch
20+
torch.manual_seed(0)
21+
22+
.. testcode::
1823
:emphasize-lines: 9-11, 13-18, 33-34
1924

2025
import torch
@@ -67,3 +72,22 @@ they have a negative inner product).
6772
jac_to_grad(shared_module.parameters(), aggregator)
6873
optimizer.step()
6974
optimizer.zero_grad()
75+
76+
.. testoutput::
77+
78+
Weights: tensor([0.5000, 0.5000])
79+
Cosine similarity: 1.0000
80+
Weights: tensor([0.5000, 0.5000])
81+
Cosine similarity: 1.0000
82+
Weights: tensor([0.5000, 0.5000])
83+
Cosine similarity: 1.0000
84+
Weights: tensor([0.6618, 1.0554])
85+
Cosine similarity: 0.9249
86+
Weights: tensor([0.6569, 1.2146])
87+
Cosine similarity: 0.8661
88+
Weights: tensor([0.5004, 0.5060])
89+
Cosine similarity: 1.0000
90+
Weights: tensor([0.5000, 0.5000])
91+
Cosine similarity: 1.0000
92+
Weights: tensor([0.5746, 1.1607])
93+
Cosine similarity: 0.9301

0 commit comments

Comments
 (0)