Skip to content
114 changes: 73 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,23 @@ Some aggregators may have additional dependencies. Please refer to the
[installation documentation](https://torchjd.org/stable/installation) for them.

## Usage
The main way to use TorchJD is to replace the usual call to `loss.backward()` by a call to
`torchjd.backward` or `torchjd.mtl_backward`, depending on the use-case.
There are two main ways to use TorchJD. The first one is to replace the usual call to
`loss.backward()` by a call to
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) or
[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/), depending
on the use-case. This will compute the Jacobian of the vector of losses with respect to the model
parameters, and aggregate it with the specified
[`Aggregator`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Aggregator).
Whenever you want to optimize the vector of per-sample losses, you should rather use the
[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine.html). Instead of
computing the full Jacobian at once, it computes the Gramian of this Jacobian, layer by layer, in a
memory-efficient way. A vector of weights (one per element of the batch) can then be extracted from
this Gramian, using a
[`Weighting`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Weighting),
and used to combine the losses of the batch. Assuming each element of the batch is
processed independently from the others, this approach is equivalent to
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) while being
generally much faster due to the lower memory usage.

The following example shows how to use TorchJD to train a multi-task model with Jacobian descent,
using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
Expand All @@ -66,7 +81,7 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

+ from torchjd import mtl_backward
+ from torchjd.autojac import mtl_backward
+ from torchjd.aggregation import UPGrad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
Expand Down Expand Up @@ -104,49 +119,66 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
> In this example, the Jacobian is only with respect to the shared parameters. The task-specific
> parameters are simply updated via the gradient of their task’s loss with respect to them.

The following example shows how to use TorchJD to minimize the vector of per-instance losses with
Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).

```diff
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

+ from torchjd.autogram import Engine
+ from torchjd.aggregation import UPGradWeighting

model = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU(), Linear(3, 1), ReLU())

- loss_fn = MSELoss()
+ loss_fn = MSELoss(reduction="none")
optimizer = SGD(model.parameters(), lr=0.1)

+ weighting = UPGradWeighting()
+ engine = Engine(model.modules())
Comment thread
ValerianRey marked this conversation as resolved.
Outdated

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task

for input, target in zip(inputs, targets):
output = model(input).squeeze(dim=1) # shape [16]
- loss = loss_fn(output, target) # shape [1]
+ losses = loss_fn(output, target) # shape [16]

optimizer.zero_grad()
- loss.backward()
+ gramian = engine.compute_gramian(losses) # shape: [16, 16]
+ weights = weighting(gramian) # shape: [16]
+ losses.backward(weights)
optimizer.step()
```
Comment thread
ValerianRey marked this conversation as resolved.

More usage examples can be found [here](https://torchjd.org/stable/examples/).

## Supported Aggregators
## Supported Aggregators and Weightings
TorchJD provides many existing aggregators from the literature, listed in the following table.

<!-- recommended aggregators first, then alphabetical order -->
| Aggregator | Publication |
|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/) (recommended) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl/) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) |
| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad/) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) |
| [ConFIG](https://torchjd.org/stable/docs/aggregation/config/) | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) |
| [Constant](https://torchjd.org/stable/docs/aggregation/constant/) | - |
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj/) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop/) | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
| [IMTL-G](https://torchjd.org/stable/docs/aggregation/imtl_g/) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
| [Krum](https://torchjd.org/stable/docs/aggregation/krum/) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
| [Mean](https://torchjd.org/stable/docs/aggregation/mean/) | - |
| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda/) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://www.sciencedirect.com/science/article/pii/S1631073X12000738) |
| [Nash-MTL](https://torchjd.org/stable/docs/aggregation/nash_mtl/) | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) |
| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad/) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) |
| [Random](https://torchjd.org/stable/docs/aggregation/random/) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) |
| [Sum](https://torchjd.org/stable/docs/aggregation/sum/) | - |
| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean/) | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) |

The following example shows how to instantiate
[UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/) and aggregate a simple matrix `J` with
it.
```python
from torch import tensor
from torchjd.aggregation import UPGrad

A = UPGrad()
J = tensor([[-4., 1., 1.], [6., 1., 1.]])

A(J)
# Output: tensor([0.2929, 1.9004, 1.9004])
```

> [!TIP]
> When using TorchJD, you generally don't have to use aggregators directly. You simply instantiate
> one and pass it to the backward function (`torchjd.backward` or `torchjd.mtl_backward`), which
> will in turn apply it to the Jacobian matrix that it will compute.
| Aggregator | Weighting | Publication |
|------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad.html#torchjd.aggregation.UPGrad) (recommended) | [UPGradWeighting](https://torchjd.org/stable/docs/aggregation/upgrad#torchjd.aggregation.UPGradWeighting) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTL) | [AlignedMTLWeighting](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTLWeighting) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) |
| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGrad) | [CAGradWeighting](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGradWeighting) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) |
| [ConFIG](https://torchjd.org/stable/docs/aggregation/config#torchjd.aggregation.ConFIG) | - | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) |
| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - |
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - |
| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDA) | [MGDAWeighting](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDAWeighting) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://www.sciencedirect.com/science/article/pii/S1631073X12000738) |
| [NashMTL](https://torchjd.org/stable/docs/aggregation/nash_mtl#torchjd.aggregation.NashMTL) | - | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) |
| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGrad) | [PCGradWeighting](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGradWeighting) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) |
| [Random](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.Random) | [RandomWeighting](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.RandomWeighting) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) |
| [Sum](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.Sum) | [SumWeighting](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.SumWeighting) | - |
| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean#torchjd.aggregation.TrimmedMean) | - | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) |

## Contribution
Please read the [Contribution page](CONTRIBUTING.md).
Expand Down
Loading