Skip to content

Commit d826990

Browse files
authored
Merge branch 'main' into check-todos
2 parents 760fe9e + 179bdfd commit d826990

37 files changed

+812
-665
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ repos:
1818
]
1919

2020
- repo: https://github.com/pycqa/isort
21-
rev: 6.1.0
21+
rev: 7.0.0
2222
hooks:
2323
- id: isort # Sort imports.
2424
args: [
@@ -31,7 +31,7 @@ repos:
3131
]
3232

3333
- repo: https://github.com/psf/black-pre-commit-mirror
34-
rev: 25.9.0
34+
rev: 25.12.0
3535
hooks:
3636
- id: black # Format code.
3737
args: [--line-length=100]

CHANGELOG.md

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,47 @@
33
All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
6-
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). This changelog does not include internal
7-
changes that do not affect the user.
6+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). This
7+
changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Changed
12+
13+
- **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the
14+
Jacobian. Now, these functions compute and populate the `.jac` fields of the parameters, and a new
15+
function `torchjd.autojac.jac_to_grad` should then be called to aggregate those `.jac` fields into
16+
`.grad` fields.
17+
This means that users now have more control on what they do with the Jacobians (they can easily
18+
aggregate them group by group or even param by param if they want), but it now requires an extra
19+
line of code to do the Jacobian descent step. To update, please change:
20+
```python
21+
backward(losses, aggregator)
22+
```
23+
to
24+
```python
25+
backward(losses)
26+
jac_to_grad(model.parameters(), aggregator)
27+
```
28+
and
29+
```python
30+
mtl_backward(losses, features, aggregator)
31+
```
32+
to
33+
```python
34+
mtl_backward(losses, features)
35+
jac_to_grad(shared_module.parameters(), aggregator)
36+
```
37+
38+
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
39+
efficiency of `autojac`.
40+
41+
## [0.8.1] - 2026-01-07
42+
1143
### Added
1244

13-
- Added `__all__` in the `__init__.py` of packages. This should prevent PyLance from triggering warnings when importing from `torchjd`.
45+
- Added `__all__` in the `__init__.py` of packages. This should prevent PyLance from triggering
46+
warnings when importing from `torchjd`.
1447

1548
## [0.8.0] - 2025-11-13
1649

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
111111
loss1 = loss_fn(output1, target1)
112112
loss2 = loss_fn(output2, target2)
113113

114-
optimizer.zero_grad()
115114
- loss = loss1 + loss2
116115
- loss.backward()
117116
+ mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
118117
optimizer.step()
118+
optimizer.zero_grad()
119119
```
120120

121121
> [!NOTE]
@@ -150,12 +150,12 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr
150150
- loss = loss_fn(output, target) # shape [1]
151151
+ losses = loss_fn(output, target) # shape [16]
152152

153-
optimizer.zero_grad()
154153
- loss.backward()
155154
+ gramian = engine.compute_gramian(losses) # shape: [16, 16]
156155
+ weights = weighting(gramian) # shape: [16]
157156
+ losses.backward(weights)
158157
optimizer.step()
158+
optimizer.zero_grad()
159159
```
160160

161161
Lastly, you can even combine the two approaches by considering multiple tasks and each element of
@@ -201,10 +201,10 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
201201
# Obtain the weights that lead to no conflict between reweighted gradients
202202
weights = weighting(gramian) # shape: [16, 2]
203203

204-
optimizer.zero_grad()
205204
# Do the standard backward pass, but weighted using the obtained weights
206205
losses.backward(weights)
207206
optimizer.step()
207+
optimizer.zero_grad()
208208
```
209209

210210
> [!NOTE]

docs/source/docs/autojac/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ autojac
1010

1111
backward.rst
1212
mtl_backward.rst
13+
jac_to_grad.rst
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
:hide-toc:
2+
3+
jac_to_grad
4+
===========
5+
6+
.. autofunction:: torchjd.autojac.jac_to_grad

docs/source/examples/amp.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ case, the losses) should preferably be scaled with a `GradScaler
1212
following example shows the resulting code for a multi-task learning use-case.
1313

1414
.. code-block:: python
15-
:emphasize-lines: 2, 17, 27, 34, 36-38
15+
:emphasize-lines: 2, 17, 27, 34-35, 37-38
1616
1717
import torch
1818
from torch.amp import GradScaler
1919
from torch.nn import Linear, MSELoss, ReLU, Sequential
2020
from torch.optim import SGD
2121
2222
from torchjd.aggregation import UPGrad
23-
from torchjd.autojac import mtl_backward
23+
from torchjd.autojac import mtl_backward, jac_to_grad
2424
2525
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
2626
task1_module = Linear(3, 1)
@@ -48,10 +48,11 @@ following example shows the resulting code for a multi-task learning use-case.
4848
loss2 = loss_fn(output2, target2)
4949
5050
scaled_losses = scaler.scale([loss1, loss2])
51-
optimizer.zero_grad()
52-
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
51+
mtl_backward(losses=scaled_losses, features=features)
52+
jac_to_grad(shared_module.parameters(), aggregator)
5353
scaler.step(optimizer)
5454
scaler.update()
55+
optimizer.zero_grad()
5556
5657
.. hint::
5758
Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For

docs/source/examples/basic_usage.rst

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Import several classes from ``torch`` and ``torchjd``:
2020
2121
from torchjd import autojac
2222
from torchjd.aggregation import UPGrad
23+
from torchjd.autojac import jac_to_grad
2324
2425
Define the model and the optimizer, as usual:
2526

@@ -59,20 +60,16 @@ We can now compute the losses associated to each element of the batch.
5960
6061
The last steps are similar to gradient descent-based optimization, but using the two losses.
6162

62-
Reset the ``.grad`` field of each model parameter:
63-
64-
.. code-block:: python
65-
66-
optimizer.zero_grad()
67-
6863
Perform the Jacobian descent backward pass:
6964

7065
.. code-block:: python
7166
72-
autojac.backward([loss1, loss2], aggregator)
67+
autojac.backward([loss1, loss2])
68+
jac_to_grad(model.parameters(), aggregator)
7369
74-
This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
75-
Jacobian matrix.
70+
The first function will populate the ``.jac`` field of each model parameter with the corresponding
71+
Jacobian, and the second one will aggregate these Jacobians and store the result in the ``.grad``
72+
field of the parameters. It also deletes the ``.jac`` fields save some memory.
7673

7774
Update each parameter based on its ``.grad`` field, using the ``optimizer``:
7875

@@ -81,3 +78,9 @@ Update each parameter based on its ``.grad`` field, using the ``optimizer``:
8178
optimizer.step()
8279
8380
The model's parameters have been updated!
81+
82+
As usual, you should now reset the ``.grad`` field of each model parameter:
83+
84+
.. code-block:: python
85+
86+
optimizer.zero_grad()

docs/source/examples/iwmtl.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ this Gramian to reweight the gradients and resolve conflict entirely.
1010
The following example shows how to do that.
1111

1212
.. code-block:: python
13-
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42
13+
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 40-41
1414
1515
import torch
1616
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -51,10 +51,10 @@ The following example shows how to do that.
5151
# Obtain the weights that lead to no conflict between reweighted gradients
5252
weights = weighting(gramian) # shape: [16, 2]
5353
54-
optimizer.zero_grad()
5554
# Do the standard backward pass, but weighted using the obtained weights
5655
losses.backward(weights)
5756
optimizer.step()
57+
optimizer.zero_grad()
5858
5959
.. note::
6060
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a

docs/source/examples/iwrm.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,26 +64,26 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
6464
for x, y in zip(X, Y):
6565
y_hat = model(x).squeeze(dim=1) # shape: [16]
6666
loss = loss_fn(y_hat, y) # shape: [] (scalar)
67-
optimizer.zero_grad()
6867
loss.backward()
6968
7069
7170
optimizer.step()
71+
optimizer.zero_grad()
7272
7373
In this baseline example, the update may negatively affect the loss of some elements of the
7474
batch.
7575

7676
.. tab-item:: autojac
7777

7878
.. code-block:: python
79-
:emphasize-lines: 5-6, 12, 16, 21, 23
79+
:emphasize-lines: 5-6, 12, 16, 21-23
8080
8181
import torch
8282
from torch.nn import Linear, MSELoss, ReLU, Sequential
8383
from torch.optim import SGD
8484
8585
from torchjd.aggregation import UPGrad
86-
from torchjd.autojac import backward
86+
from torchjd.autojac import backward, jac_to_grad
8787
8888
X = torch.randn(8, 16, 10)
8989
Y = torch.randn(8, 16)
@@ -99,19 +99,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
9999
for x, y in zip(X, Y):
100100
y_hat = model(x).squeeze(dim=1) # shape: [16]
101101
losses = loss_fn(y_hat, y) # shape: [16]
102-
optimizer.zero_grad()
103-
backward(losses, aggregator)
104-
102+
backward(losses)
103+
jac_to_grad(model.parameters(), aggregator)
105104
106105
optimizer.step()
106+
optimizer.zero_grad()
107107
108108
Here, we compute the Jacobian of the per-sample losses with respect to the model parameters
109109
and use it to update the model such that no loss from the batch is (locally) increased.
110110

111111
.. tab-item:: autogram (recommended)
112112

113113
.. code-block:: python
114-
:emphasize-lines: 5-6, 12, 16-17, 21, 23-25
114+
:emphasize-lines: 5-6, 12, 16-17, 21-24
115115
116116
import torch
117117
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -134,11 +134,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
134134
for x, y in zip(X, Y):
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]
136136
losses = loss_fn(y_hat, y) # shape: [16]
137-
optimizer.zero_grad()
138137
gramian = engine.compute_gramian(losses) # shape: [16, 16]
139138
weights = weighting(gramian) # shape: [16]
140139
losses.backward(weights)
141140
optimizer.step()
141+
optimizer.zero_grad()
142142
143143
Here, the per-sample gradients are never fully stored in memory, leading to large
144144
improvements in memory usage and speed compared to autojac, in most practical cases. The

docs/source/examples/lightning_integration.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using
1111
<../docs/autojac/mtl_backward>` at each training iteration.
1212

1313
.. code-block:: python
14-
:emphasize-lines: 9-10, 18, 32
14+
:emphasize-lines: 9-10, 18, 31-32
1515
1616
import torch
1717
from lightning import LightningModule, Trainer
@@ -22,7 +22,7 @@ The following code example demonstrates a basic multi-task learning setup using
2222
from torch.utils.data import DataLoader, TensorDataset
2323
2424
from torchjd.aggregation import UPGrad
25-
from torchjd.autojac import mtl_backward
25+
from torchjd.autojac import mtl_backward, jac_to_grad
2626
2727
class Model(LightningModule):
2828
def __init__(self):
@@ -43,9 +43,10 @@ The following code example demonstrates a basic multi-task learning setup using
4343
loss2 = mse_loss(output2, target2)
4444
4545
opt = self.optimizers()
46-
opt.zero_grad()
47-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
46+
mtl_backward(losses=[loss1, loss2], features=features)
47+
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
4848
opt.step()
49+
opt.zero_grad()
4950
5051
def configure_optimizers(self) -> OptimizerLRScheduler:
5152
optimizer = Adam(self.parameters(), lr=1e-3)

0 commit comments

Comments
 (0)