Skip to content

Commit 7616559

Browse files
authored
docs: Update README.md for autogram (#410)
* Update `README.md` for autogram * Fix usage example * Add IWMTL example in README.md * Improve formatting * Add disclaimer about future changes for autogram
1 parent f36c7f7 commit 7616559

File tree

2 files changed

+130
-41
lines changed

2 files changed

+130
-41
lines changed

README.md

Lines changed: 127 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,24 @@ Some aggregators may have additional dependencies. Please refer to the
5555
[installation documentation](https://torchjd.org/stable/installation) for them.
5656

5757
## Usage
58-
The main way to use TorchJD is to replace the usual call to `loss.backward()` by a call to
59-
`torchjd.backward` or `torchjd.mtl_backward`, depending on the use-case.
58+
There are two main ways to use TorchJD. The first one is to replace the usual call to
59+
`loss.backward()` by a call to
60+
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) or
61+
[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/), depending
62+
on the use-case. This will compute the Jacobian of the vector of losses with respect to the model
63+
parameters, and aggregate it with the specified
64+
[`Aggregator`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Aggregator).
65+
Whenever you want to optimize the vector of per-sample losses, you should rather use the
66+
[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine.html). Instead of
67+
computing the full Jacobian at once, it computes the Gramian of this Jacobian, layer by layer, in a
68+
memory-efficient way. A vector of weights (one per element of the batch) can then be extracted from
69+
this Gramian, using a
70+
[`Weighting`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Weighting),
71+
and used to combine the losses of the batch. Assuming each element of the batch is
72+
processed independently from the others, this approach is equivalent to
73+
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) while being
74+
generally much faster due to the lower memory usage. Note that we're still working on making
75+
`autogram` faster and more memory-efficient, and it's interface may change in future releases.
6076

6177
The following example shows how to use TorchJD to train a multi-task model with Jacobian descent,
6278
using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
@@ -66,7 +82,7 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
6682
from torch.nn import Linear, MSELoss, ReLU, Sequential
6783
from torch.optim import SGD
6884

69-
+ from torchjd import mtl_backward
85+
+ from torchjd.autojac import mtl_backward
7086
+ from torchjd.aggregation import UPGrad
7187

7288
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
@@ -104,49 +120,120 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
104120
> In this example, the Jacobian is only with respect to the shared parameters. The task-specific
105121
> parameters are simply updated via the gradient of their task’s loss with respect to them.
106122
107-
More usage examples can be found [here](https://torchjd.org/stable/examples/).
123+
The following example shows how to use TorchJD to minimize the vector of per-instance losses with
124+
Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
108125

109-
## Supported Aggregators
110-
TorchJD provides many existing aggregators from the literature, listed in the following table.
126+
```diff
127+
import torch
128+
from torch.nn import Linear, MSELoss, ReLU, Sequential
129+
from torch.optim import SGD
111130

112-
<!-- recommended aggregators first, then alphabetical order -->
113-
| Aggregator | Publication |
114-
|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
115-
| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/) (recommended) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
116-
| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl/) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) |
117-
| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad/) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) |
118-
| [ConFIG](https://torchjd.org/stable/docs/aggregation/config/) | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) |
119-
| [Constant](https://torchjd.org/stable/docs/aggregation/constant/) | - |
120-
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj/) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
121-
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop/) | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
122-
| [IMTL-G](https://torchjd.org/stable/docs/aggregation/imtl_g/) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
123-
| [Krum](https://torchjd.org/stable/docs/aggregation/krum/) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
124-
| [Mean](https://torchjd.org/stable/docs/aggregation/mean/) | - |
125-
| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda/) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://www.sciencedirect.com/science/article/pii/S1631073X12000738) |
126-
| [Nash-MTL](https://torchjd.org/stable/docs/aggregation/nash_mtl/) | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) |
127-
| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad/) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) |
128-
| [Random](https://torchjd.org/stable/docs/aggregation/random/) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) |
129-
| [Sum](https://torchjd.org/stable/docs/aggregation/sum/) | - |
130-
| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean/) | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) |
131-
132-
The following example shows how to instantiate
133-
[UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/) and aggregate a simple matrix `J` with
134-
it.
135-
```python
136-
from torch import tensor
137-
from torchjd.aggregation import UPGrad
131+
+ from torchjd.autogram import Engine
132+
+ from torchjd.aggregation import UPGradWeighting
133+
134+
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU(), Linear(3, 1), ReLU())
138135

139-
A = UPGrad()
140-
J = tensor([[-4., 1., 1.], [6., 1., 1.]])
136+
- loss_fn = MSELoss()
137+
+ loss_fn = MSELoss(reduction="none")
138+
optimizer = SGD(model.parameters(), lr=0.1)
141139

142-
A(J)
143-
# Output: tensor([0.2929, 1.9004, 1.9004])
140+
+ weighting = UPGradWeighting()
141+
+ engine = Engine(model, batch_dim=0)
142+
143+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
144+
targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
145+
146+
for input, target in zip(inputs, targets):
147+
output = model(input).squeeze(dim=1) # shape [16]
148+
- loss = loss_fn(output, target) # shape [1]
149+
+ losses = loss_fn(output, target) # shape [16]
150+
151+
optimizer.zero_grad()
152+
- loss.backward()
153+
+ gramian = engine.compute_gramian(losses) # shape: [16, 16]
154+
+ weights = weighting(gramian) # shape: [16]
155+
+ losses.backward(weights)
156+
optimizer.step()
144157
```
145158

146-
> [!TIP]
147-
> When using TorchJD, you generally don't have to use aggregators directly. You simply instantiate
148-
> one and pass it to the backward function (`torchjd.backward` or `torchjd.mtl_backward`), which
149-
> will in turn apply it to the Jacobian matrix that it will compute.
159+
Lastly, you can even combine the two approaches by considering multiple tasks and each element of
160+
the batch independently. We call that Instance-Wise Multitask Learning (IWMTL).
161+
162+
```python
163+
import torch
164+
from torch.nn import Linear, MSELoss, ReLU, Sequential
165+
from torch.optim import SGD
166+
167+
from torchjd.aggregation import Flattening, UPGradWeighting
168+
from torchjd.autogram import Engine
169+
170+
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
171+
task1_module = Linear(3, 1)
172+
task2_module = Linear(3, 1)
173+
params = [
174+
*shared_module.parameters(),
175+
*task1_module.parameters(),
176+
*task2_module.parameters(),
177+
]
178+
179+
optimizer = SGD(params, lr=0.1)
180+
mse = MSELoss(reduction="none")
181+
weighting = Flattening(UPGradWeighting())
182+
engine = Engine(shared_module, batch_dim=0)
183+
184+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
185+
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
186+
task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task
187+
188+
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
189+
features = shared_module(input) # shape: [16, 3]
190+
out1 = task1_module(features).squeeze(1) # shape: [16]
191+
out2 = task2_module(features).squeeze(1) # shape: [16]
192+
193+
# Compute the matrix of losses: one loss per element of the batch and per task
194+
losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2]
195+
196+
# Compute the gramian (inner products between pairs of gradients of the losses)
197+
gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]
198+
199+
# Obtain the weights that lead to no conflict between reweighted gradients
200+
weights = weighting(gramian) # shape: [16, 2]
201+
202+
optimizer.zero_grad()
203+
# Do the standard backward pass, but weighted using the obtained weights
204+
losses.backward(weights)
205+
optimizer.step()
206+
```
207+
208+
> [!NOTE]
209+
> Here, because the losses are a matrix instead of a simple vector, we compute a *generalized
210+
> Gramian* and we extract weights from it using a
211+
> [GeneralizedWeighting](https://torchjd.org/docs/aggregation/index.html#torchjd.aggregation.GeneralizedWeighting).
212+
213+
More usage examples can be found [here](https://torchjd.org/stable/examples/).
214+
215+
## Supported Aggregators and Weightings
216+
TorchJD provides many existing aggregators from the literature, listed in the following table.
217+
218+
<!-- recommended aggregators first, then alphabetical order -->
219+
| Aggregator | Weighting | Publication |
220+
|------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
221+
| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad.html#torchjd.aggregation.UPGrad) (recommended) | [UPGradWeighting](https://torchjd.org/stable/docs/aggregation/upgrad#torchjd.aggregation.UPGradWeighting) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
222+
| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTL) | [AlignedMTLWeighting](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTLWeighting) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) |
223+
| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGrad) | [CAGradWeighting](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGradWeighting) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) |
224+
| [ConFIG](https://torchjd.org/stable/docs/aggregation/config#torchjd.aggregation.ConFIG) | - | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) |
225+
| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - |
226+
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
227+
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
228+
| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
229+
| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
230+
| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - |
231+
| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDA) | [MGDAWeighting](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDAWeighting) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://www.sciencedirect.com/science/article/pii/S1631073X12000738) |
232+
| [NashMTL](https://torchjd.org/stable/docs/aggregation/nash_mtl#torchjd.aggregation.NashMTL) | - | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) |
233+
| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGrad) | [PCGradWeighting](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGradWeighting) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) |
234+
| [Random](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.Random) | [RandomWeighting](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.RandomWeighting) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) |
235+
| [Sum](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.Sum) | [SumWeighting](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.SumWeighting) | - |
236+
| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean#torchjd.aggregation.TrimmedMean) | - | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) |
150237

151238
## Contribution
152239
Please read the [Contribution page](CONTRIBUTING.md).

src/torchjd/autogram/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
88
Due to computing the Gramian iteratively over the layers, without ever having to store the full
99
Jacobian in memory, this method is much more memory-efficient than
10-
:doc:`autojac <../autojac/index>`, which makes it often much faster.
10+
:doc:`autojac <../autojac/index>`, which makes it often much faster. Note that we're still working
11+
on making autogram faster and more memory-efficient, and it's interface may change in future
12+
releases.
1113
1214
The list of Weightings compatible with ``autogram`` is:
1315

0 commit comments

Comments
 (0)