Skip to content

Commit fefc53a

Browse files
committed
Improve grouping example
* Add link to the paper * Simplify some formulations * Rename strategies whole model => together; encoder-decoder => per network; all layers => per layer; all matrices => per tensor * Place a bit less emphasis on GradVac: rename gradvac to aggregator when possible * Create losses in separate lines in the code examples * Remove a few redundant sentences from a note
1 parent 91c91d3 commit fefc53a

File tree

1 file changed

+62
-51
lines changed

1 file changed

+62
-51
lines changed

docs/source/examples/grouping.rst

Lines changed: 62 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,45 @@ Grouping
33

44
When applying a conflict-resolving aggregator such as :class:`~torchjd.aggregation.GradVac` in
55
multi-task learning, the cosine similarities between task gradients can be computed at different
6-
granularities. The GradVac paper introduces four strategies, each partitioning the shared
7-
parameter vector differently:
6+
granularities. The [Gradient Vaccine paper](https://arxiv.org/pdf/2010.05874) introduces four
7+
strategies, each partitioning the shared parameter vector differently:
88

9-
1. **Whole Model** (default) — one group covering all shared parameters.
10-
2. **Encoder-Decoder** — one group per top-level sub-network (e.g. encoder and decoder separately).
11-
3. **All Layers** — one group per leaf module of the encoder.
12-
4. **All Matrices** — one group per individual parameter tensor.
9+
1. **Together** (baseline): one group covering all shared parameters. Corresponds to the
10+
`whole_model` stategy in the paper.
11+
12+
2. **Per network**: one group per top-level sub-network (e.g. encoder and decoder separately).
13+
Corresponds to the `enc_dec` stategy in the paper.
14+
15+
3. **Per layer**: one group per leaf module of the encoder. Corresponds to the `all_layer` stategy
16+
in the paper.
17+
18+
4. **Per tensor**: one group per individual parameter tensor. Corresponds to the `all_matrix`
19+
stategy in the paper.
1320

1421
In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group
1522
after :func:`~torchjd.autojac.mtl_backward`, with a dedicated aggregator instance per group.
16-
For stateful aggregators such as :class:`~torchjd.aggregation.GradVac`, each instance
17-
independently maintains its own EMA state :math:`\hat{\phi}`, matching the per-block targets from
18-
the original paper.
23+
For :class:`~torchjd.aggregation.Stateful` aggregators, each instance independently maintains its
24+
own state (e.g. the EMA :math:`\hat{\phi}` state in :class:`~torchjd.aggregation.GradVac`), matching
25+
the per-block targets from the original paper.
1926

2027
.. note::
2128
The grouping is orthogonal to the choice of
2229
:func:`~torchjd.autojac.backward` vs :func:`~torchjd.autojac.mtl_backward`. Those functions
2330
determine *which* parameters receive Jacobians; grouping then determines *how* those Jacobians
24-
are partitioned for aggregation. Calling :func:`~torchjd.autojac.jac_to_grad` once on all shared
25-
parameters corresponds to the Whole Model strategy. Splitting those parameters into
26-
sub-networks and calling :func:`~torchjd.autojac.jac_to_grad` separately on each — with a
27-
dedicated aggregator per sub-network — gives an arbitrary custom grouping, such as the
28-
Encoder-Decoder strategy described in the GradVac paper for encoder-decoder architectures.
31+
are partitioned for aggregation.
2932

3033
.. note::
3134
The examples below use :class:`~torchjd.aggregation.GradVac`, but the same pattern applies to
32-
any aggregator.
35+
any :class:`~torchjd.aggregation.Aggregator`.
3336

34-
1. Whole Model
35-
--------------
37+
1. Together
38+
-----------
3639

37-
A single :class:`~torchjd.aggregation.GradVac` instance aggregates all shared parameters
40+
A single :class:`~torchjd.aggregation.Aggregator` instance aggregates all shared parameters
3841
together. Cosine similarities are computed between the full task gradient vectors.
3942

4043
.. testcode::
41-
:emphasize-lines: 14, 19
44+
:emphasize-lines: 14, 21
4245

4346
import torch
4447
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -53,25 +56,27 @@ together. Cosine similarities are computed between the full task gradient vector
5356
loss_fn = MSELoss()
5457
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
5558
56-
gradvac = GradVac()
59+
aggregator = GradVac()
5760

5861
for x, y1, y2 in zip(inputs, t1, t2):
5962
features = encoder(x)
60-
mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
61-
jac_to_grad(encoder.parameters(), gradvac)
63+
loss1 = loss_fn(task1_head(features), y1)
64+
loss2 = loss_fn(task2_head(features), y2)
65+
mtl_backward([loss1, loss2], features=features)
66+
jac_to_grad(encoder.parameters(), aggregator)
6267
optimizer.step()
6368
optimizer.zero_grad()
6469

65-
2. Encoder-Decoder
66-
------------------
70+
2. Per network
71+
--------------
6772

68-
One :class:`~torchjd.aggregation.GradVac` instance per top-level sub-network. Here the model
73+
One :class:`~torchjd.aggregation.Aggregator` instance per top-level sub-network. Here the model
6974
is split into an encoder and a decoder; cosine similarities are computed separately within each.
7075
Passing ``features=dec_out`` to :func:`~torchjd.autojac.mtl_backward` causes both sub-networks
7176
to receive Jacobians, which are then aggregated independently.
7277

7378
.. testcode::
74-
:emphasize-lines: 8-9, 15-16, 22-23
79+
:emphasize-lines: 8-9, 15-16, 24-25
7580

7681
import torch
7782
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -87,26 +92,28 @@ to receive Jacobians, which are then aggregated independently.
8792
loss_fn = MSELoss()
8893
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
8994
90-
encoder_gradvac = GradVac()
91-
decoder_gradvac = GradVac()
95+
encoder_aggregator = GradVac()
96+
decoder_aggregator = GradVac()
9297

9398
for x, y1, y2 in zip(inputs, t1, t2):
9499
enc_out = encoder(x)
95100
dec_out = decoder(enc_out)
96-
mtl_backward([loss_fn(task1_head(dec_out), y1), loss_fn(task2_head(dec_out), y2)], features=dec_out)
97-
jac_to_grad(encoder.parameters(), encoder_gradvac)
98-
jac_to_grad(decoder.parameters(), decoder_gradvac)
101+
loss1 = loss_fn(task1_head(dec_out), y1)
102+
loss2 = loss_fn(task2_head(dec_out), y2)
103+
mtl_backward([loss1, loss2], features=dec_out)
104+
jac_to_grad(encoder.parameters(), encoder_aggregator)
105+
jac_to_grad(decoder.parameters(), decoder_aggregator)
99106
optimizer.step()
100107
optimizer.zero_grad()
101108

102-
3. All Layers
103-
-------------
109+
3. Per layer
110+
------------
104111

105-
One :class:`~torchjd.aggregation.GradVac` instance per leaf module. Cosine similarities are
106-
computed between the per-layer blocks of the task gradients.
112+
One :class:`~torchjd.aggregation.Aggregator` instance per leaf module. Cosine similarities are
113+
computed per-layer between the task gradients.
107114

108115
.. testcode::
109-
:emphasize-lines: 14-15, 20-21
116+
:emphasize-lines: 14-15, 22-23
110117

111118
import torch
112119
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -121,26 +128,28 @@ computed between the per-layer blocks of the task gradients.
121128
loss_fn = MSELoss()
122129
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
123130
124-
leaf_layers = [m for m in encoder.modules() if not list(m.children()) and list(m.parameters())]
125-
gradvacs = [GradVac() for _ in leaf_layers]
131+
leaf_layers = [m for m in encoder.modules() if list(m.parameters()) and not list(m.children())]
132+
aggregators = [GradVac() for _ in leaf_layers]
126133

127134
for x, y1, y2 in zip(inputs, t1, t2):
128135
features = encoder(x)
129-
mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
130-
for layer, gradvac in zip(leaf_layers, gradvacs):
131-
jac_to_grad(layer.parameters(), gradvac)
136+
loss1 = loss_fn(task1_head(features), y1)
137+
loss2 = loss_fn(task2_head(features), y2)
138+
mtl_backward([loss1, loss2], features=features)
139+
for layer, aggregator in zip(leaf_layers, aggregators):
140+
jac_to_grad(layer.parameters(), aggregator)
132141
optimizer.step()
133142
optimizer.zero_grad()
134143

135-
4. All Matrices
136-
---------------
144+
4. Per parameter
145+
----------------
137146

138-
One :class:`~torchjd.aggregation.GradVac` instance per individual parameter tensor. Cosine
139-
similarities are computed between the per-tensor blocks of the task gradients (e.g. weights and
140-
biases of each layer are treated as separate groups).
147+
One :class:`~torchjd.aggregation.Aggregator` instance per individual parameter tensor. Cosine
148+
similarities are computed per-tensor between the task gradients (e.g. weights and biases of each
149+
layer are treated as separate groups).
141150

142151
.. testcode::
143-
:emphasize-lines: 14-15, 20-21
152+
:emphasize-lines: 14-15, 22-23
144153

145154
import torch
146155
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -156,12 +165,14 @@ biases of each layer are treated as separate groups).
156165
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
157166
158167
shared_params = list(encoder.parameters())
159-
gradvacs = [GradVac() for _ in shared_params]
168+
aggregators = [GradVac() for _ in shared_params]
160169

161170
for x, y1, y2 in zip(inputs, t1, t2):
162171
features = encoder(x)
163-
mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
164-
for param, gradvac in zip(shared_params, gradvacs):
165-
jac_to_grad([param], gradvac)
172+
loss1 = loss_fn(task1_head(features), y1)
173+
loss2 = loss_fn(task2_head(features), y2)
174+
mtl_backward([loss1, loss2], features=features)
175+
for param, aggregator in zip(shared_params, aggregators):
176+
jac_to_grad([param], aggregator)
166177
optimizer.step()
167178
optimizer.zero_grad()

0 commit comments

Comments
 (0)