Skip to content

Commit e890e65

Browse files
authored
docs: Fix order of zero_grad (#512)
1 parent 502973f commit e890e65

File tree

14 files changed

+41
-41
lines changed

14 files changed

+41
-41
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
111111
loss1 = loss_fn(output1, target1)
112112
loss2 = loss_fn(output2, target2)
113113

114-
optimizer.zero_grad()
115114
- loss = loss1 + loss2
116115
- loss.backward()
117116
+ mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
118117
optimizer.step()
118+
optimizer.zero_grad()
119119
```
120120

121121
> [!NOTE]
@@ -150,12 +150,12 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr
150150
- loss = loss_fn(output, target) # shape [1]
151151
+ losses = loss_fn(output, target) # shape [16]
152152

153-
optimizer.zero_grad()
154153
- loss.backward()
155154
+ gramian = engine.compute_gramian(losses) # shape: [16, 16]
156155
+ weights = weighting(gramian) # shape: [16]
157156
+ losses.backward(weights)
158157
optimizer.step()
158+
optimizer.zero_grad()
159159
```
160160

161161
Lastly, you can even combine the two approaches by considering multiple tasks and each element of
@@ -201,10 +201,10 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
201201
# Obtain the weights that lead to no conflict between reweighted gradients
202202
weights = weighting(gramian) # shape: [16, 2]
203203

204-
optimizer.zero_grad()
205204
# Do the standard backward pass, but weighted using the obtained weights
206205
losses.backward(weights)
207206
optimizer.step()
207+
optimizer.zero_grad()
208208
```
209209

210210
> [!NOTE]

docs/source/examples/amp.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ case, the losses) should preferably be scaled with a `GradScaler
1212
following example shows the resulting code for a multi-task learning use-case.
1313

1414
.. code-block:: python
15-
:emphasize-lines: 2, 17, 27, 34, 36-38
15+
:emphasize-lines: 2, 17, 27, 34-37
1616
1717
import torch
1818
from torch.amp import GradScaler
@@ -48,10 +48,10 @@ following example shows the resulting code for a multi-task learning use-case.
4848
loss2 = loss_fn(output2, target2)
4949
5050
scaled_losses = scaler.scale([loss1, loss2])
51-
optimizer.zero_grad()
5251
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
5352
scaler.step(optimizer)
5453
scaler.update()
54+
optimizer.zero_grad()
5555
5656
.. hint::
5757
Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For

docs/source/examples/basic_usage.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,6 @@ We can now compute the losses associated to each element of the batch.
5959
6060
The last steps are similar to gradient descent-based optimization, but using the two losses.
6161

62-
Reset the ``.grad`` field of each model parameter:
63-
64-
.. code-block:: python
65-
66-
optimizer.zero_grad()
67-
6862
Perform the Jacobian descent backward pass:
6963

7064
.. code-block:: python
@@ -81,3 +75,9 @@ Update each parameter based on its ``.grad`` field, using the ``optimizer``:
8175
optimizer.step()
8276
8377
The model's parameters have been updated!
78+
79+
As usual, you should now reset the ``.grad`` field of each model parameter:
80+
81+
.. code-block:: python
82+
83+
optimizer.zero_grad()

docs/source/examples/iwmtl.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ this Gramian to reweight the gradients and resolve conflict entirely.
1010
The following example shows how to do that.
1111

1212
.. code-block:: python
13-
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42
13+
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 40-41
1414
1515
import torch
1616
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -51,10 +51,10 @@ The following example shows how to do that.
5151
# Obtain the weights that lead to no conflict between reweighted gradients
5252
weights = weighting(gramian) # shape: [16, 2]
5353
54-
optimizer.zero_grad()
5554
# Do the standard backward pass, but weighted using the obtained weights
5655
losses.backward(weights)
5756
optimizer.step()
57+
optimizer.zero_grad()
5858
5959
.. note::
6060
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a

docs/source/examples/iwrm.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
6464
for x, y in zip(X, Y):
6565
y_hat = model(x).squeeze(dim=1) # shape: [16]
6666
loss = loss_fn(y_hat, y) # shape: [] (scalar)
67-
optimizer.zero_grad()
6867
loss.backward()
6968
7069
7170
optimizer.step()
71+
optimizer.zero_grad()
7272
7373
In this baseline example, the update may negatively affect the loss of some elements of the
7474
batch.
7575

7676
.. tab-item:: autojac
7777

7878
.. code-block:: python
79-
:emphasize-lines: 5-6, 12, 16, 21, 23
79+
:emphasize-lines: 5-6, 12, 16, 21-22
8080
8181
import torch
8282
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -99,19 +99,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
9999
for x, y in zip(X, Y):
100100
y_hat = model(x).squeeze(dim=1) # shape: [16]
101101
losses = loss_fn(y_hat, y) # shape: [16]
102-
optimizer.zero_grad()
103102
backward(losses, aggregator)
104103
105104
106105
optimizer.step()
106+
optimizer.zero_grad()
107107
108108
Here, we compute the Jacobian of the per-sample losses with respect to the model parameters
109109
and use it to update the model such that no loss from the batch is (locally) increased.
110110

111111
.. tab-item:: autogram (recommended)
112112

113113
.. code-block:: python
114-
:emphasize-lines: 5-6, 12, 16-17, 21, 23-25
114+
:emphasize-lines: 5-6, 12, 16-17, 21-24
115115
116116
import torch
117117
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -134,11 +134,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
134134
for x, y in zip(X, Y):
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]
136136
losses = loss_fn(y_hat, y) # shape: [16]
137-
optimizer.zero_grad()
138137
gramian = engine.compute_gramian(losses) # shape: [16, 16]
139138
weights = weighting(gramian) # shape: [16]
140139
losses.backward(weights)
141140
optimizer.step()
141+
optimizer.zero_grad()
142142
143143
Here, the per-sample gradients are never fully stored in memory, leading to large
144144
improvements in memory usage and speed compared to autojac, in most practical cases. The

docs/source/examples/lightning_integration.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using
1111
<../docs/autojac/mtl_backward>` at each training iteration.
1212

1313
.. code-block:: python
14-
:emphasize-lines: 9-10, 18, 32
14+
:emphasize-lines: 9-10, 18, 31
1515
1616
import torch
1717
from lightning import LightningModule, Trainer
@@ -43,9 +43,9 @@ The following code example demonstrates a basic multi-task learning setup using
4343
loss2 = mse_loss(output2, target2)
4444
4545
opt = self.optimizers()
46-
opt.zero_grad()
4746
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
4847
opt.step()
48+
opt.zero_grad()
4949
5050
def configure_optimizers(self) -> OptimizerLRScheduler:
5151
optimizer = Adam(self.parameters(), lr=1e-3)

docs/source/examples/monitoring.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,6 @@ they have a negative inner product).
6363
loss1 = loss_fn(output1, target1)
6464
loss2 = loss_fn(output2, target2)
6565
66-
optimizer.zero_grad()
6766
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
6867
optimizer.step()
68+
optimizer.zero_grad()

docs/source/examples/mtl.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
1919

2020

2121
.. code-block:: python
22-
:emphasize-lines: 5-6, 19, 33
22+
:emphasize-lines: 5-6, 19, 32
2323
2424
import torch
2525
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -52,9 +52,9 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
5252
loss1 = loss_fn(output1, target1)
5353
loss2 = loss_fn(output2, target2)
5454
55-
optimizer.zero_grad()
5655
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
5756
optimizer.step()
57+
optimizer.zero_grad()
5858
5959
.. note::
6060
In this example, the Jacobian is only with respect to the shared parameters. The task-specific

docs/source/examples/partial_jd.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
4141
for x, y in zip(X, Y):
4242
y_hat = model(x).squeeze(dim=1) # shape: [16]
4343
losses = loss_fn(y_hat, y) # shape: [16]
44-
optimizer.zero_grad()
4544
gramian = engine.compute_gramian(losses)
4645
weights = weighting(gramian)
4746
losses.backward(weights)
4847
optimizer.step()
48+
optimizer.zero_grad()

docs/source/examples/rnn.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ element of the output sequences. If the gradients of these losses are likely to
66
descent can be leveraged to enhance optimization.
77

88
.. code-block:: python
9-
:emphasize-lines: 5-6, 10, 17, 20
9+
:emphasize-lines: 5-6, 10, 17, 19
1010
1111
import torch
1212
from torch.nn import RNN
@@ -26,9 +26,9 @@ descent can be leveraged to enhance optimization.
2626
output, _ = rnn(input) # output is of shape [5, 3, 20].
2727
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
2828
29-
optimizer.zero_grad()
3029
backward(losses, aggregator, parallel_chunk_size=1)
3130
optimizer.step()
31+
optimizer.zero_grad()
3232
3333
.. note::
3434
At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and

0 commit comments

Comments
 (0)