Skip to content

Commit 9e855e2

Browse files
committed
Merge branch 'main' into revamp-interface
2 parents 9a2a0ec + e890e65 commit 9e855e2

7 files changed

Lines changed: 16 additions & 7 deletions

File tree

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
111111
loss1 = loss_fn(output1, target1)
112112
loss2 = loss_fn(output2, target2)
113113

114-
optimizer.zero_grad()
115114
- loss = loss1 + loss2
116115
- loss.backward()
117116
+ mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
118117
optimizer.step()
118+
optimizer.zero_grad()
119119
```
120120

121121
> [!NOTE]
@@ -150,12 +150,12 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr
150150
- loss = loss_fn(output, target) # shape [1]
151151
+ losses = loss_fn(output, target) # shape [16]
152152

153-
optimizer.zero_grad()
154153
- loss.backward()
155154
+ gramian = engine.compute_gramian(losses) # shape: [16, 16]
156155
+ weights = weighting(gramian) # shape: [16]
157156
+ losses.backward(weights)
158157
optimizer.step()
158+
optimizer.zero_grad()
159159
```
160160

161161
Lastly, you can even combine the two approaches by considering multiple tasks and each element of
@@ -201,10 +201,10 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
201201
# Obtain the weights that lead to no conflict between reweighted gradients
202202
weights = weighting(gramian) # shape: [16, 2]
203203

204-
optimizer.zero_grad()
205204
# Do the standard backward pass, but weighted using the obtained weights
206205
losses.backward(weights)
207206
optimizer.step()
207+
optimizer.zero_grad()
208208
```
209209

210210
> [!NOTE]

docs/source/examples/amp.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ following example shows the resulting code for a multi-task learning use-case.
5454
scaler.step(optimizer)
5555
scaler.update()
5656
optimizer.zero_grad()
57+
optimizer.zero_grad()
5758
5859
.. hint::
5960
Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For

docs/source/examples/iwrm.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
7070
7171
optimizer.step()
7272
optimizer.zero_grad()
73+
optimizer.zero_grad()
7374
7475
In this baseline example, the update may negatively affect the loss of some elements of the
7576
batch.
@@ -106,6 +107,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
106107
107108
optimizer.step()
108109
optimizer.zero_grad()
110+
optimizer.zero_grad()
109111
110112
Here, we compute the Jacobian of the per-sample losses with respect to the model parameters
111113
and use it to update the model such that no loss from the batch is (locally) increased.
@@ -142,6 +144,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
142144
losses.backward(weights)
143145
optimizer.step()
144146
optimizer.zero_grad()
147+
optimizer.zero_grad()
145148
146149
Here, the per-sample gradients are never fully stored in memory, leading to large
147150
improvements in memory usage and speed compared to autojac, in most practical cases. The

src/torchjd/autogram/_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class Engine:
7777
Train a model using Gramian-based Jacobian descent.
7878
7979
.. code-block:: python
80-
:emphasize-lines: 5-6, 15-16, 18-19, 26-28
80+
:emphasize-lines: 5-6, 15-16, 18-19, 26-29
8181
8282
import torch
8383
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -103,11 +103,11 @@ class Engine:
103103
output = model(input).squeeze(dim=1) # shape: [16]
104104
losses = criterion(output, target) # shape: [16]
105105
106-
optimizer.zero_grad()
107106
gramian = engine.compute_gramian(losses) # shape: [16, 16]
108107
weights = weighting(gramian) # shape: [16]
109108
losses.backward(weights)
110109
optimizer.step()
110+
optimizer.zero_grad()
111111
112112
This is equivalent to just calling ``torchjd.autojac.backward(losses, UPGrad())``. However,
113113
since the Jacobian never has to be entirely in memory, it is often much more

tests/doc/test_autogram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def test_engine():
2626
output = model(input).squeeze(dim=1) # shape: [16]
2727
losses = criterion(output, target) # shape: [16]
2828

29-
optimizer.zero_grad()
3029
gramian = engine.compute_gramian(losses) # shape: [16, 16]
3130
weights = weighting(gramian) # shape: [16]
3231
losses.backward(weights)
3332
optimizer.step()
33+
optimizer.zero_grad()

tests/doc/test_rst.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_amp():
4848
scaler.step(optimizer)
4949
scaler.update()
5050
optimizer.zero_grad()
51+
optimizer.zero_grad()
5152

5253

5354
def test_basic_usage():
@@ -122,6 +123,7 @@ def test_iwmtl():
122123
losses.backward(weights)
123124
optimizer.step()
124125
optimizer.zero_grad()
126+
optimizer.zero_grad()
125127

126128

127129
def test_iwrm():
@@ -145,6 +147,7 @@ def test_autograd():
145147
loss.backward()
146148
optimizer.step()
147149
optimizer.zero_grad()
150+
optimizer.zero_grad()
148151

149152
def test_autojac():
150153
import torch
@@ -200,6 +203,7 @@ def test_autogram():
200203
losses.backward(weights)
201204
optimizer.step()
202205
optimizer.zero_grad()
206+
optimizer.zero_grad()
203207

204208
test_autograd()
205209
test_autojac()
@@ -400,6 +404,7 @@ def test_partial_jd():
400404
losses.backward(weights)
401405
optimizer.step()
402406
optimizer.zero_grad()
407+
optimizer.zero_grad()
403408

404409

405410
def test_rnn():

tests/unit/autogram/test_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch
345345
loss_fn = make_mse_loss_fn(targets)
346346
autogram_forward_backward(model, inputs, loss_fn, engine, weighting)
347347
optimizer.step()
348-
model.zero_grad()
348+
optimizer.zero_grad()
349349

350350

351351
@mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS)

0 commit comments

Comments
 (0)