@@ -3,42 +3,45 @@ Grouping
33
44When applying a conflict-resolving aggregator such as :class: `~torchjd.aggregation.GradVac ` in
55multi-task learning, the cosine similarities between task gradients can be computed at different
6- granularities. The GradVac paper introduces four strategies, each partitioning the shared
7- parameter vector differently:
6+ granularities. The [Gradient Vaccine paper](https://arxiv.org/pdf/2010.05874) introduces four
7+ strategies, each partitioning the shared parameter vector differently:
88
9- 1. **Whole Model ** (default) — one group covering all shared parameters.
10- 2. **Encoder-Decoder ** — one group per top-level sub-network (e.g. encoder and decoder separately).
11- 3. **All Layers ** — one group per leaf module of the encoder.
12- 4. **All Matrices ** — one group per individual parameter tensor.
9+ 1. **Together ** (baseline): one group covering all shared parameters. Corresponds to the
10+ `whole_model ` stategy in the paper.
11+
12+ 2. **Per network **: one group per top-level sub-network (e.g. encoder and decoder separately).
13+ Corresponds to the `enc_dec ` stategy in the paper.
14+
15+ 3. **Per layer **: one group per leaf module of the encoder. Corresponds to the `all_layer ` stategy
16+ in the paper.
17+
18+ 4. **Per tensor **: one group per individual parameter tensor. Corresponds to the `all_matrix `
19+ stategy in the paper.
1320
1421In TorchJD, grouping is achieved by calling :func: `~torchjd.autojac.jac_to_grad ` once per group
1522after :func: `~torchjd.autojac.mtl_backward `, with a dedicated aggregator instance per group.
16- For stateful aggregators such as :class: `~torchjd.aggregation.GradVac ` , each instance
17- independently maintains its own EMA state :math: `\hat {\phi }`, matching the per-block targets from
18- the original paper.
23+ For :class: `~torchjd.aggregation.Stateful ` aggregators , each instance independently maintains its
24+ own state (e.g. the EMA :math: `\hat {\phi }` state in :class: ` ~torchjd.aggregation.GradVac `), matching
25+ the per-block targets from the original paper.
1926
2027.. note ::
2128 The grouping is orthogonal to the choice of
2229 :func: `~torchjd.autojac.backward ` vs :func: `~torchjd.autojac.mtl_backward `. Those functions
2330 determine *which * parameters receive Jacobians; grouping then determines *how * those Jacobians
24- are partitioned for aggregation. Calling :func: `~torchjd.autojac.jac_to_grad ` once on all shared
25- parameters corresponds to the Whole Model strategy. Splitting those parameters into
26- sub-networks and calling :func: `~torchjd.autojac.jac_to_grad ` separately on each — with a
27- dedicated aggregator per sub-network — gives an arbitrary custom grouping, such as the
28- Encoder-Decoder strategy described in the GradVac paper for encoder-decoder architectures.
31+ are partitioned for aggregation.
2932
3033.. note ::
3134 The examples below use :class: `~torchjd.aggregation.GradVac `, but the same pattern applies to
32- any aggregator .
35+ any :class: ` ~torchjd.aggregation.Aggregator ` .
3336
34- 1. Whole Model
35- --------------
37+ 1. Together
38+ -----------
3639
37- A single :class: `~torchjd.aggregation.GradVac ` instance aggregates all shared parameters
40+ A single :class: `~torchjd.aggregation.Aggregator ` instance aggregates all shared parameters
3841together. Cosine similarities are computed between the full task gradient vectors.
3942
4043.. testcode ::
41- :emphasize-lines: 14, 19
44+ :emphasize-lines: 14, 21
4245
4346 import torch
4447 from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -53,25 +56,27 @@ together. Cosine similarities are computed between the full task gradient vector
5356 loss_fn = MSELoss()
5457 inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
5558
56- gradvac = GradVac()
59+ aggregator = GradVac()
5760
5861 for x, y1, y2 in zip(inputs, t1, t2):
5962 features = encoder(x)
60- mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
61- jac_to_grad(encoder.parameters(), gradvac)
63+ loss1 = loss_fn(task1_head(features), y1)
64+ loss2 = loss_fn(task2_head(features), y2)
65+ mtl_backward([loss1, loss2], features=features)
66+ jac_to_grad(encoder.parameters(), aggregator)
6267 optimizer.step()
6368 optimizer.zero_grad()
6469
65- 2. Encoder-Decoder
66- ------------------
70+ 2. Per network
71+ --------------
6772
68- One :class: `~torchjd.aggregation.GradVac ` instance per top-level sub-network. Here the model
73+ One :class: `~torchjd.aggregation.Aggregator ` instance per top-level sub-network. Here the model
6974is split into an encoder and a decoder; cosine similarities are computed separately within each.
7075Passing ``features=dec_out `` to :func: `~torchjd.autojac.mtl_backward ` causes both sub-networks
7176to receive Jacobians, which are then aggregated independently.
7277
7378.. testcode ::
74- :emphasize-lines: 8-9, 15-16, 22-23
79+ :emphasize-lines: 8-9, 15-16, 24-25
7580
7681 import torch
7782 from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -87,26 +92,28 @@ to receive Jacobians, which are then aggregated independently.
8792 loss_fn = MSELoss()
8893 inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
8994
90- encoder_gradvac = GradVac()
91- decoder_gradvac = GradVac()
95+ encoder_aggregator = GradVac()
96+ decoder_aggregator = GradVac()
9297
9398 for x, y1, y2 in zip(inputs, t1, t2):
9499 enc_out = encoder(x)
95100 dec_out = decoder(enc_out)
96- mtl_backward([loss_fn(task1_head(dec_out), y1), loss_fn(task2_head(dec_out), y2)], features=dec_out)
97- jac_to_grad(encoder.parameters(), encoder_gradvac)
98- jac_to_grad(decoder.parameters(), decoder_gradvac)
101+ loss1 = loss_fn(task1_head(dec_out), y1)
102+ loss2 = loss_fn(task2_head(dec_out), y2)
103+ mtl_backward([loss1, loss2], features=dec_out)
104+ jac_to_grad(encoder.parameters(), encoder_aggregator)
105+ jac_to_grad(decoder.parameters(), decoder_aggregator)
99106 optimizer.step()
100107 optimizer.zero_grad()
101108
102- 3. All Layers
103- -------------
109+ 3. Per layer
110+ ------------
104111
105- One :class: `~torchjd.aggregation.GradVac ` instance per leaf module. Cosine similarities are
106- computed between the per-layer blocks of the task gradients.
112+ One :class: `~torchjd.aggregation.Aggregator ` instance per leaf module. Cosine similarities are
113+ computed per-layer between the task gradients.
107114
108115.. testcode ::
109- :emphasize-lines: 14-15, 20-21
116+ :emphasize-lines: 14-15, 22-23
110117
111118 import torch
112119 from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -121,26 +128,28 @@ computed between the per-layer blocks of the task gradients.
121128 loss_fn = MSELoss()
122129 inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
123130
124- leaf_layers = [m for m in encoder.modules() if not list(m.children ()) and list(m.parameters ())]
125- gradvacs = [GradVac() for _ in leaf_layers]
131+ leaf_layers = [m for m in encoder.modules() if list(m.parameters ()) and not list(m.children ())]
132+ aggregators = [GradVac() for _ in leaf_layers]
126133
127134 for x, y1, y2 in zip(inputs, t1, t2):
128135 features = encoder(x)
129- mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
130- for layer, gradvac in zip(leaf_layers, gradvacs):
131- jac_to_grad(layer.parameters(), gradvac)
136+ loss1 = loss_fn(task1_head(features), y1)
137+ loss2 = loss_fn(task2_head(features), y2)
138+ mtl_backward([loss1, loss2], features=features)
139+ for layer, aggregator in zip(leaf_layers, aggregators):
140+ jac_to_grad(layer.parameters(), aggregator)
132141 optimizer.step()
133142 optimizer.zero_grad()
134143
135- 4. All Matrices
136- ---------------
144+ 4. Per parameter
145+ ----------------
137146
138- One :class: `~torchjd.aggregation.GradVac ` instance per individual parameter tensor. Cosine
139- similarities are computed between the per-tensor blocks of the task gradients (e.g. weights and
140- biases of each layer are treated as separate groups).
147+ One :class: `~torchjd.aggregation.Aggregator ` instance per individual parameter tensor. Cosine
148+ similarities are computed per-tensor between the task gradients (e.g. weights and biases of each
149+ layer are treated as separate groups).
141150
142151.. testcode ::
143- :emphasize-lines: 14-15, 20-21
152+ :emphasize-lines: 14-15, 22-23
144153
145154 import torch
146155 from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -156,12 +165,14 @@ biases of each layer are treated as separate groups).
156165 inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
157166
158167 shared_params = list(encoder.parameters())
159- gradvacs = [GradVac() for _ in shared_params]
168+ aggregators = [GradVac() for _ in shared_params]
160169
161170 for x, y1, y2 in zip(inputs, t1, t2):
162171 features = encoder(x)
163- mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
164- for param, gradvac in zip(shared_params, gradvacs):
165- jac_to_grad([param], gradvac)
172+ loss1 = loss_fn(task1_head(features), y1)
173+ loss2 = loss_fn(task2_head(features), y2)
174+ mtl_backward([loss1, loss2], features=features)
175+ for param, aggregator in zip(shared_params, aggregators):
176+ jac_to_grad([param], aggregator)
166177 optimizer.step()
167178 optimizer.zero_grad()
0 commit comments