Skip to content

Commit 8614939

Browse files
committed
1 parent 75a9571 commit 8614939

47 files changed

Lines changed: 1038 additions & 73 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
:hide-toc:
2+
3+
Flattening
4+
==========
5+
6+
.. autoclass:: torchjd.aggregation.Flattening
7+
:members:
8+
:undoc-members:
9+
:exclude-members: forward

latest/_sources/docs/aggregation/index.rst.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ Abstract base classes
1717
:undoc-members:
1818
:exclude-members: forward
1919

20+
.. autoclass:: torchjd.aggregation.GeneralizedWeighting
21+
:members:
22+
:undoc-members:
23+
:exclude-members: forward
24+
2025

2126
.. toctree::
2227
:hidden:
@@ -28,6 +33,7 @@ Abstract base classes
2833
config.rst
2934
constant.rst
3035
dualproj.rst
36+
flattening.rst
3137
graddrop.rst
3238
imtl_g.rst
3339
krum.rst

latest/_sources/examples/index.rst.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ This section contains some usage examples for TorchJD.
1818
- :doc:`Multi-Task Learning (MTL) <mtl>` provides an example of multi-task learning where Jacobian
1919
descent is used to optimize the vector of per-task losses of a multi-task model, using the
2020
dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`.
21+
- :doc:`Instance-Wise Multi-Task Learning (IWMTL) <iwmtl>` shows how to combine multi-task learning
22+
with instance-wise risk minimization: one loss per task and per element of the batch, using the
23+
:doc:`autogram.Engine <../docs/autogram/engine>` and a :doc:`GeneralizedWeighting
24+
<../docs/aggregation/index>`.
2125
- :doc:`Recurrent Neural Network (RNN) <rnn>` shows how to apply Jacobian descent to RNN training,
2226
with one loss per output sequence element.
2327
- :doc:`Monitoring Aggregations <monitoring>` shows how to monitor the aggregation performed by the
@@ -34,6 +38,7 @@ This section contains some usage examples for TorchJD.
3438
iwrm.rst
3539
partial_jd.rst
3640
mtl.rst
41+
iwmtl.rst
3742
rnn.rst
3843
monitoring.rst
3944
lightning_integration.rst
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
Instance-Wise Multi-Task Learning (IWMTL)
2+
=========================================
3+
4+
When training a model with multiple tasks, the gradients of the individual tasks are likely to
5+
conflict. This is particularly true when looking at the individual (per-sample) gradients.
6+
The :doc:`autogram engine <../docs/autogram/engine>` can be used to efficiently compute the Gramian
7+
of the Jacobian of the matrix of per-sample and per-task losses. Weights can then be extracted from
8+
this Gramian to reweight the gradients and resolve conflict entirely.
9+
10+
The following example shows how to do that.
11+
12+
.. code-block:: python
13+
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42
14+
15+
import torch
16+
from torch.nn import Linear, MSELoss, ReLU, Sequential
17+
from torch.optim import SGD
18+
19+
from torchjd.aggregation import Flattening, UPGradWeighting
20+
from torchjd.autogram import Engine
21+
22+
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
23+
task1_module = Linear(3, 1)
24+
task2_module = Linear(3, 1)
25+
params = [
26+
*shared_module.parameters(),
27+
*task1_module.parameters(),
28+
*task2_module.parameters(),
29+
]
30+
31+
optimizer = SGD(params, lr=0.1)
32+
mse = MSELoss(reduction="none")
33+
weighting = Flattening(UPGradWeighting())
34+
engine = Engine(shared_module.modules(), batch_dim=0)
35+
36+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
37+
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
38+
task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task
39+
40+
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
41+
features = shared_module(input) # shape: [16, 3]
42+
out1 = task1_module(features).squeeze(1) # shape: [16]
43+
out2 = task2_module(features).squeeze(1) # shape: [16]
44+
45+
# Compute the matrix of losses: one loss per element of the batch and per task
46+
losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2]
47+
48+
# Compute the gramian (inner products between pairs of gradients of the losses)
49+
gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]
50+
51+
# Obtain the weights that lead to no conflict between reweighted gradients
52+
weights = weighting(gramian) # shape: [16, 2]
53+
54+
optimizer.zero_grad()
55+
# Do the standard backward pass, but weighted using the obtained weights
56+
losses.backward(weights)
57+
optimizer.step()
58+
59+
.. note::
60+
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a
61+
4D tensor rather than a matrix, and a
62+
:class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting`, such as
63+
:class:`~torchjd.aggregation._flattening.Flattening`, has to be used to extract a matrix of
64+
weights from it. More information about ``GeneralizedWeighting`` can be found in the
65+
:doc:`../../docs/aggregation/index` page.

latest/_sources/examples/iwrm.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
129129
params = model.parameters()
130130
optimizer = SGD(params, lr=0.1)
131131
weighting = UPGradWeighting()
132-
engine = Engine(model.modules())
132+
engine = Engine(model.modules(), batch_dim=0)
133133
134134
for x, y in zip(X, Y):
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]

latest/_sources/examples/partial_jd.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
3333
3434
# Create the autogram engine that will compute the Gramian of the
3535
# Jacobian with respect to the two last Linear layers' parameters.
36-
engine = Engine(model[2:].modules())
36+
engine = Engine(model[2:].modules(), batch_dim=0)
3737
3838
params = model.parameters()
3939
optimizer = SGD(params, lr=0.1)

latest/_sources/index.rst.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ per-task losses has to be minimized. To start using TorchJD for multi-task learn
3838
Another more interesting application is to consider separately the loss of each element in the
3939
batch. This is what we define as :doc:`Instance-Wise Risk Minimization <examples/iwrm>` (IWRM).
4040

41-
For IWRM, in many cases, there exists an algorithm that is both equivalent to Jacobian descent, and
42-
much more efficient. This algorithm, called Gramian-based Jacobian descent, consists in computing
41+
The Gramian-based Jacobian descent algorithm provides a very efficient alternative way of
42+
performing Jacobian descent. It consists in computing
4343
the Gramian of the Jacobian iteratively during the backward pass (without ever storing the full
4444
Jacobian in memory), weighting the losses using the information of the Gramian, and then computing
4545
the gradient of the obtained weighted loss. The iterative computation of the Gramian corresponds to
@@ -48,6 +48,11 @@ Algorithm 3 of
4848
documentation and usage example of this algorithm is provided in
4949
:doc:`autogram.Engine <docs/autogram/engine>`.
5050

51+
The original usage of the autogram engine is to compute the Gramian of the Jacobian very efficiently
52+
for :doc:`IWRM <examples/iwrm>`. Another direct application is when considering one loss per element
53+
of the batch and per task, in the context of multi-task learning. We call this
54+
:doc:`Instance-Wise Risk Multi-Task Learning <examples/iwmtl>` (IWMTL).
55+
5156
TorchJD is open-source, under MIT License. The source code is available on
5257
`GitHub <https://github.com/TorchJD/torchjd>`_.
5358

latest/docs/aggregation/aligned_mtl/index.html

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
<li class="toctree-l2"><a class="reference internal" href="../../../examples/iwrm/">Instance-Wise Risk Minimization (IWRM)</a></li>
175175
<li class="toctree-l2"><a class="reference internal" href="../../../examples/partial_jd/">Partial Jacobian Descent for IWRM</a></li>
176176
<li class="toctree-l2"><a class="reference internal" href="../../../examples/mtl/">Multi-Task Learning (MTL)</a></li>
177+
<li class="toctree-l2"><a class="reference internal" href="../../../examples/iwmtl/">Instance-Wise Multi-Task Learning (IWMTL)</a></li>
177178
<li class="toctree-l2"><a class="reference internal" href="../../../examples/rnn/">Recurrent Neural Network (RNN)</a></li>
178179
<li class="toctree-l2"><a class="reference internal" href="../../../examples/monitoring/">Monitoring aggregations</a></li>
179180
<li class="toctree-l2"><a class="reference internal" href="../../../examples/lightning_integration/">PyTorch Lightning Integration</a></li>
@@ -199,6 +200,7 @@
199200
<li class="toctree-l2"><a class="reference internal" href="../config/">ConFIG</a></li>
200201
<li class="toctree-l2"><a class="reference internal" href="../constant/">Constant</a></li>
201202
<li class="toctree-l2"><a class="reference internal" href="../dualproj/">DualProj</a></li>
203+
<li class="toctree-l2"><a class="reference internal" href="../flattening/">Flattening</a></li>
202204
<li class="toctree-l2"><a class="reference internal" href="../graddrop/">GradDrop</a></li>
203205
<li class="toctree-l2"><a class="reference internal" href="../imtl_g/">IMTL-G</a></li>
204206
<li class="toctree-l2"><a class="reference internal" href="../krum/">Krum</a></li>

latest/docs/aggregation/cagrad/index.html

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
<li class="toctree-l2"><a class="reference internal" href="../../../examples/iwrm/">Instance-Wise Risk Minimization (IWRM)</a></li>
175175
<li class="toctree-l2"><a class="reference internal" href="../../../examples/partial_jd/">Partial Jacobian Descent for IWRM</a></li>
176176
<li class="toctree-l2"><a class="reference internal" href="../../../examples/mtl/">Multi-Task Learning (MTL)</a></li>
177+
<li class="toctree-l2"><a class="reference internal" href="../../../examples/iwmtl/">Instance-Wise Multi-Task Learning (IWMTL)</a></li>
177178
<li class="toctree-l2"><a class="reference internal" href="../../../examples/rnn/">Recurrent Neural Network (RNN)</a></li>
178179
<li class="toctree-l2"><a class="reference internal" href="../../../examples/monitoring/">Monitoring aggregations</a></li>
179180
<li class="toctree-l2"><a class="reference internal" href="../../../examples/lightning_integration/">PyTorch Lightning Integration</a></li>
@@ -199,6 +200,7 @@
199200
<li class="toctree-l2"><a class="reference internal" href="../config/">ConFIG</a></li>
200201
<li class="toctree-l2"><a class="reference internal" href="../constant/">Constant</a></li>
201202
<li class="toctree-l2"><a class="reference internal" href="../dualproj/">DualProj</a></li>
203+
<li class="toctree-l2"><a class="reference internal" href="../flattening/">Flattening</a></li>
202204
<li class="toctree-l2"><a class="reference internal" href="../graddrop/">GradDrop</a></li>
203205
<li class="toctree-l2"><a class="reference internal" href="../imtl_g/">IMTL-G</a></li>
204206
<li class="toctree-l2"><a class="reference internal" href="../krum/">Krum</a></li>

latest/docs/aggregation/config/index.html

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
<li class="toctree-l2"><a class="reference internal" href="../../../examples/iwrm/">Instance-Wise Risk Minimization (IWRM)</a></li>
175175
<li class="toctree-l2"><a class="reference internal" href="../../../examples/partial_jd/">Partial Jacobian Descent for IWRM</a></li>
176176
<li class="toctree-l2"><a class="reference internal" href="../../../examples/mtl/">Multi-Task Learning (MTL)</a></li>
177+
<li class="toctree-l2"><a class="reference internal" href="../../../examples/iwmtl/">Instance-Wise Multi-Task Learning (IWMTL)</a></li>
177178
<li class="toctree-l2"><a class="reference internal" href="../../../examples/rnn/">Recurrent Neural Network (RNN)</a></li>
178179
<li class="toctree-l2"><a class="reference internal" href="../../../examples/monitoring/">Monitoring aggregations</a></li>
179180
<li class="toctree-l2"><a class="reference internal" href="../../../examples/lightning_integration/">PyTorch Lightning Integration</a></li>
@@ -199,6 +200,7 @@
199200
<li class="toctree-l2 current current-page"><a class="current reference internal" href="#">ConFIG</a></li>
200201
<li class="toctree-l2"><a class="reference internal" href="../constant/">Constant</a></li>
201202
<li class="toctree-l2"><a class="reference internal" href="../dualproj/">DualProj</a></li>
203+
<li class="toctree-l2"><a class="reference internal" href="../flattening/">Flattening</a></li>
202204
<li class="toctree-l2"><a class="reference internal" href="../graddrop/">GradDrop</a></li>
203205
<li class="toctree-l2"><a class="reference internal" href="../imtl_g/">IMTL-G</a></li>
204206
<li class="toctree-l2"><a class="reference internal" href="../krum/">Krum</a></li>

0 commit comments

Comments
 (0)