Skip to content

Commit 7e3637e

Browse files
authored
feat(autojac): Use jac_to_grad to aggregate .jac fields (#510)
* Move accumulation code to _accumulation.py * Rename Accumulate to AccumulateGrad * Add AccumulateJac * Add jac_to_grad * Remove aggregation responsibility from backward and mtl_backward * Remove Aggregate transform (move its code to jac_to_grad.py) * Simplify _disunite_gradient (use split instead of manually looping) * Add utils/asserts.py file to check things about .jac or .grad fields * Update tests and usage examples according to all those changes * Add changelog entry
1 parent e890e65 commit 7e3637e

28 files changed

+758
-627
lines changed

CHANGELOG.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,31 @@ changelog does not include internal changes that do not affect the user.
1010

1111
### Changed
1212

13+
- **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the
14+
Jacobian. Now, these functions compute and populate the `.jac` fields of the parameters, and a new
15+
function `torchjd.autojac.jac_to_grad` should then be called to aggregate those `.jac` fields into
16+
`.grad` fields.
17+
This means that users now have more control on what they do with the Jacobians (they can easily
18+
aggregate them group by group or even param by param if they want), but it now requires an extra
19+
line of code to do the Jacobian descent step. To update, please change:
20+
```python
21+
backward(losses, aggregator)
22+
```
23+
to
24+
```python
25+
backward(losses)
26+
jac_to_grad(model.parameters(), aggregator)
27+
```
28+
and
29+
```python
30+
mtl_backward(losses, features, aggregator)
31+
```
32+
to
33+
```python
34+
mtl_backward(losses, features)
35+
jac_to_grad(shared_module.parameters(), aggregator)
36+
```
37+
1338
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
1439
efficiency of `autojac`.
1540

docs/source/docs/autojac/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ autojac
1010

1111
backward.rst
1212
mtl_backward.rst
13+
jac_to_grad.rst
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
:hide-toc:
2+
3+
jac_to_grad
4+
===========
5+
6+
.. autofunction:: torchjd.autojac.jac_to_grad

docs/source/examples/amp.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ case, the losses) should preferably be scaled with a `GradScaler
1212
following example shows the resulting code for a multi-task learning use-case.
1313

1414
.. code-block:: python
15-
:emphasize-lines: 2, 17, 27, 34-37
15+
:emphasize-lines: 2, 17, 27, 34-35, 37-38
1616
1717
import torch
1818
from torch.amp import GradScaler
1919
from torch.nn import Linear, MSELoss, ReLU, Sequential
2020
from torch.optim import SGD
2121
2222
from torchjd.aggregation import UPGrad
23-
from torchjd.autojac import mtl_backward
23+
from torchjd.autojac import mtl_backward, jac_to_grad
2424
2525
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
2626
task1_module = Linear(3, 1)
@@ -48,7 +48,8 @@ following example shows the resulting code for a multi-task learning use-case.
4848
loss2 = loss_fn(output2, target2)
4949
5050
scaled_losses = scaler.scale([loss1, loss2])
51-
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
51+
mtl_backward(losses=scaled_losses, features=features)
52+
jac_to_grad(shared_module.parameters(), aggregator)
5253
scaler.step(optimizer)
5354
scaler.update()
5455
optimizer.zero_grad()

docs/source/examples/basic_usage.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Import several classes from ``torch`` and ``torchjd``:
2020
2121
from torchjd import autojac
2222
from torchjd.aggregation import UPGrad
23+
from torchjd.autojac import jac_to_grad
2324
2425
Define the model and the optimizer, as usual:
2526

@@ -63,10 +64,12 @@ Perform the Jacobian descent backward pass:
6364

6465
.. code-block:: python
6566
66-
autojac.backward([loss1, loss2], aggregator)
67+
autojac.backward([loss1, loss2])
68+
jac_to_grad(model.parameters(), aggregator)
6769
68-
This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
69-
Jacobian matrix.
70+
The first function will populate the ``.jac`` field of each model parameter with the corresponding
71+
Jacobian, and the second one will aggregate these Jacobians and store the result in the ``.grad``
72+
field of the parameters. It also deletes the ``.jac`` fields save some memory.
7073

7174
Update each parameter based on its ``.grad`` field, using the ``optimizer``:
7275

docs/source/examples/iwrm.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,14 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
7676
.. tab-item:: autojac
7777

7878
.. code-block:: python
79-
:emphasize-lines: 5-6, 12, 16, 21-22
79+
:emphasize-lines: 5-6, 12, 16, 21-23
8080
8181
import torch
8282
from torch.nn import Linear, MSELoss, ReLU, Sequential
8383
from torch.optim import SGD
8484
8585
from torchjd.aggregation import UPGrad
86-
from torchjd.autojac import backward
86+
from torchjd.autojac import backward, jac_to_grad
8787
8888
X = torch.randn(8, 16, 10)
8989
Y = torch.randn(8, 16)
@@ -99,8 +99,8 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
9999
for x, y in zip(X, Y):
100100
y_hat = model(x).squeeze(dim=1) # shape: [16]
101101
losses = loss_fn(y_hat, y) # shape: [16]
102-
backward(losses, aggregator)
103-
102+
backward(losses)
103+
jac_to_grad(model.parameters(), aggregator)
104104
105105
optimizer.step()
106106
optimizer.zero_grad()

docs/source/examples/lightning_integration.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using
1111
<../docs/autojac/mtl_backward>` at each training iteration.
1212

1313
.. code-block:: python
14-
:emphasize-lines: 9-10, 18, 31
14+
:emphasize-lines: 9-10, 18, 31-32
1515
1616
import torch
1717
from lightning import LightningModule, Trainer
@@ -22,7 +22,7 @@ The following code example demonstrates a basic multi-task learning setup using
2222
from torch.utils.data import DataLoader, TensorDataset
2323
2424
from torchjd.aggregation import UPGrad
25-
from torchjd.autojac import mtl_backward
25+
from torchjd.autojac import mtl_backward, jac_to_grad
2626
2727
class Model(LightningModule):
2828
def __init__(self):
@@ -43,7 +43,8 @@ The following code example demonstrates a basic multi-task learning setup using
4343
loss2 = mse_loss(output2, target2)
4444
4545
opt = self.optimizers()
46-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
46+
mtl_backward(losses=[loss1, loss2], features=features)
47+
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
4748
opt.step()
4849
opt.zero_grad()
4950

docs/source/examples/monitoring.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ they have a negative inner product).
2323
from torch.optim import SGD
2424
2525
from torchjd.aggregation import UPGrad
26-
from torchjd.autojac import mtl_backward
26+
from torchjd.autojac import mtl_backward, jac_to_grad
2727
2828
def print_weights(_, __, weights: torch.Tensor) -> None:
2929
"""Prints the extracted weights."""
@@ -63,6 +63,7 @@ they have a negative inner product).
6363
loss1 = loss_fn(output1, target1)
6464
loss2 = loss_fn(output2, target2)
6565
66-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
66+
mtl_backward(losses=[loss1, loss2], features=features)
67+
jac_to_grad(shared_module.parameters(), aggregator)
6768
optimizer.step()
6869
optimizer.zero_grad()

docs/source/examples/mtl.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
1919

2020

2121
.. code-block:: python
22-
:emphasize-lines: 5-6, 19, 32
22+
:emphasize-lines: 5-6, 19, 32-33
2323
2424
import torch
2525
from torch.nn import Linear, MSELoss, ReLU, Sequential
2626
from torch.optim import SGD
2727
2828
from torchjd.aggregation import UPGrad
29-
from torchjd.autojac import mtl_backward
29+
from torchjd.autojac import mtl_backward, jac_to_grad
3030
3131
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
3232
task1_module = Linear(3, 1)
@@ -52,7 +52,8 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
5252
loss1 = loss_fn(output1, target1)
5353
loss2 = loss_fn(output2, target2)
5454
55-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
55+
mtl_backward(losses=[loss1, loss2], features=features)
56+
jac_to_grad(shared_module.parameters(), aggregator)
5657
optimizer.step()
5758
optimizer.zero_grad()
5859

docs/source/examples/rnn.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ element of the output sequences. If the gradients of these losses are likely to
66
descent can be leveraged to enhance optimization.
77

88
.. code-block:: python
9-
:emphasize-lines: 5-6, 10, 17, 19
9+
:emphasize-lines: 5-6, 10, 17, 19-20
1010
1111
import torch
1212
from torch.nn import RNN
1313
from torch.optim import SGD
1414
1515
from torchjd.aggregation import UPGrad
16-
from torchjd.autojac import backward
16+
from torchjd.autojac import backward, jac_to_grad
1717
1818
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
1919
optimizer = SGD(rnn.parameters(), lr=0.1)
@@ -26,7 +26,8 @@ descent can be leveraged to enhance optimization.
2626
output, _ = rnn(input) # output is of shape [5, 3, 20].
2727
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
2828
29-
backward(losses, aggregator, parallel_chunk_size=1)
29+
backward(losses, parallel_chunk_size=1)
30+
jac_to_grad(rnn.parameters(), aggregator)
3031
optimizer.step()
3132
optimizer.zero_grad()
3233

0 commit comments

Comments
 (0)