Skip to content

Commit 2e1d302

Browse files
committed
1 parent 9997bc4 commit 2e1d302

20 files changed

Lines changed: 47 additions & 47 deletions

File tree

latest/_sources/examples/amp.rst.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ case, the losses) should preferably be scaled with a `GradScaler
1212
following example shows the resulting code for a multi-task learning use-case.
1313

1414
.. code-block:: python
15-
:emphasize-lines: 2, 17, 27, 34, 36-38
15+
:emphasize-lines: 2, 17, 27, 34-37
1616
1717
import torch
1818
from torch.amp import GradScaler
@@ -48,10 +48,10 @@ following example shows the resulting code for a multi-task learning use-case.
4848
loss2 = loss_fn(output2, target2)
4949
5050
scaled_losses = scaler.scale([loss1, loss2])
51-
optimizer.zero_grad()
5251
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
5352
scaler.step(optimizer)
5453
scaler.update()
54+
optimizer.zero_grad()
5555
5656
.. hint::
5757
Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For

latest/_sources/examples/basic_usage.rst.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,6 @@ We can now compute the losses associated to each element of the batch.
5959
6060
The last steps are similar to gradient descent-based optimization, but using the two losses.
6161

62-
Reset the ``.grad`` field of each model parameter:
63-
64-
.. code-block:: python
65-
66-
optimizer.zero_grad()
67-
6862
Perform the Jacobian descent backward pass:
6963

7064
.. code-block:: python
@@ -81,3 +75,9 @@ Update each parameter based on its ``.grad`` field, using the ``optimizer``:
8175
optimizer.step()
8276
8377
The model's parameters have been updated!
78+
79+
As usual, you should now reset the ``.grad`` field of each model parameter:
80+
81+
.. code-block:: python
82+
83+
optimizer.zero_grad()

latest/_sources/examples/iwmtl.rst.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ this Gramian to reweight the gradients and resolve conflict entirely.
1010
The following example shows how to do that.
1111

1212
.. code-block:: python
13-
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42
13+
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 40-41
1414
1515
import torch
1616
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -51,10 +51,10 @@ The following example shows how to do that.
5151
# Obtain the weights that lead to no conflict between reweighted gradients
5252
weights = weighting(gramian) # shape: [16, 2]
5353
54-
optimizer.zero_grad()
5554
# Do the standard backward pass, but weighted using the obtained weights
5655
losses.backward(weights)
5756
optimizer.step()
57+
optimizer.zero_grad()
5858
5959
.. note::
6060
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a

latest/_sources/examples/iwrm.rst.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
6464
for x, y in zip(X, Y):
6565
y_hat = model(x).squeeze(dim=1) # shape: [16]
6666
loss = loss_fn(y_hat, y) # shape: [] (scalar)
67-
optimizer.zero_grad()
6867
loss.backward()
6968
7069
7170
optimizer.step()
71+
optimizer.zero_grad()
7272
7373
In this baseline example, the update may negatively affect the loss of some elements of the
7474
batch.
7575

7676
.. tab-item:: autojac
7777

7878
.. code-block:: python
79-
:emphasize-lines: 5-6, 12, 16, 21, 23
79+
:emphasize-lines: 5-6, 12, 16, 21-22
8080
8181
import torch
8282
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -99,19 +99,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
9999
for x, y in zip(X, Y):
100100
y_hat = model(x).squeeze(dim=1) # shape: [16]
101101
losses = loss_fn(y_hat, y) # shape: [16]
102-
optimizer.zero_grad()
103102
backward(losses, aggregator)
104103
105104
106105
optimizer.step()
106+
optimizer.zero_grad()
107107
108108
Here, we compute the Jacobian of the per-sample losses with respect to the model parameters
109109
and use it to update the model such that no loss from the batch is (locally) increased.
110110

111111
.. tab-item:: autogram (recommended)
112112

113113
.. code-block:: python
114-
:emphasize-lines: 5-6, 12, 16-17, 21, 23-25
114+
:emphasize-lines: 5-6, 12, 16-17, 21-24
115115
116116
import torch
117117
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -134,11 +134,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
134134
for x, y in zip(X, Y):
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]
136136
losses = loss_fn(y_hat, y) # shape: [16]
137-
optimizer.zero_grad()
138137
gramian = engine.compute_gramian(losses) # shape: [16, 16]
139138
weights = weighting(gramian) # shape: [16]
140139
losses.backward(weights)
141140
optimizer.step()
141+
optimizer.zero_grad()
142142
143143
Here, the per-sample gradients are never fully stored in memory, leading to large
144144
improvements in memory usage and speed compared to autojac, in most practical cases. The

latest/_sources/examples/lightning_integration.rst.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using
1111
<../docs/autojac/mtl_backward>` at each training iteration.
1212

1313
.. code-block:: python
14-
:emphasize-lines: 9-10, 18, 32
14+
:emphasize-lines: 9-10, 18, 31
1515
1616
import torch
1717
from lightning import LightningModule, Trainer
@@ -43,9 +43,9 @@ The following code example demonstrates a basic multi-task learning setup using
4343
loss2 = mse_loss(output2, target2)
4444
4545
opt = self.optimizers()
46-
opt.zero_grad()
4746
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
4847
opt.step()
48+
opt.zero_grad()
4949
5050
def configure_optimizers(self) -> OptimizerLRScheduler:
5151
optimizer = Adam(self.parameters(), lr=1e-3)

latest/_sources/examples/monitoring.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,6 @@ they have a negative inner product).
6363
loss1 = loss_fn(output1, target1)
6464
loss2 = loss_fn(output2, target2)
6565
66-
optimizer.zero_grad()
6766
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
6867
optimizer.step()
68+
optimizer.zero_grad()

latest/_sources/examples/mtl.rst.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
1919

2020

2121
.. code-block:: python
22-
:emphasize-lines: 5-6, 19, 33
22+
:emphasize-lines: 5-6, 19, 32
2323
2424
import torch
2525
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -52,9 +52,9 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
5252
loss1 = loss_fn(output1, target1)
5353
loss2 = loss_fn(output2, target2)
5454
55-
optimizer.zero_grad()
5655
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
5756
optimizer.step()
57+
optimizer.zero_grad()
5858
5959
.. note::
6060
In this example, the Jacobian is only with respect to the shared parameters. The task-specific

latest/_sources/examples/partial_jd.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
4141
for x, y in zip(X, Y):
4242
y_hat = model(x).squeeze(dim=1) # shape: [16]
4343
losses = loss_fn(y_hat, y) # shape: [16]
44-
optimizer.zero_grad()
4544
gramian = engine.compute_gramian(losses)
4645
weights = weighting(gramian)
4746
losses.backward(weights)
4847
optimizer.step()
48+
optimizer.zero_grad()

latest/_sources/examples/rnn.rst.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ element of the output sequences. If the gradients of these losses are likely to
66
descent can be leveraged to enhance optimization.
77

88
.. code-block:: python
9-
:emphasize-lines: 5-6, 10, 17, 20
9+
:emphasize-lines: 5-6, 10, 17, 19
1010
1111
import torch
1212
from torch.nn import RNN
@@ -26,9 +26,9 @@ descent can be leveraged to enhance optimization.
2626
output, _ = rnn(input) # output is of shape [5, 3, 20].
2727
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
2828
29-
optimizer.zero_grad()
3029
backward(losses, aggregator, parallel_chunk_size=1)
3130
optimizer.step()
31+
optimizer.zero_grad()
3232
3333
.. note::
3434
At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and

latest/docs/autogram/engine/index.html

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,12 @@ <h1>Engine<a class="headerlink" href="#engine" title="Link to this heading">¶</
349349
<span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># shape: [16]</span>
350350
<span class="n">losses</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="c1"># shape: [16]</span>
351351

352-
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
353-
<span class="hll"> <span class="n">gramian</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">compute_gramian</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span> <span class="c1"># shape: [16, 16]</span>
354-
</span><span class="hll"> <span class="n">weights</span> <span class="o">=</span> <span class="n">weighting</span><span class="p">(</span><span class="n">gramian</span><span class="p">)</span> <span class="c1"># shape: [16]</span>
352+
<span class="n">gramian</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">compute_gramian</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span> <span class="c1"># shape: [16, 16]</span>
353+
<span class="hll"> <span class="n">weights</span> <span class="o">=</span> <span class="n">weighting</span><span class="p">(</span><span class="n">gramian</span><span class="p">)</span> <span class="c1"># shape: [16]</span>
355354
</span><span class="hll"> <span class="n">losses</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">weights</span><span class="p">)</span>
356-
</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
357-
</pre></div>
355+
</span><span class="hll"> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
356+
</span><span class="hll"> <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
357+
</span></pre></div>
358358
</div>
359359
<p>This is equivalent to just calling <code class="docutils literal notranslate"><span class="pre">torchjd.autojac.backward(losses,</span> <span class="pre">UPGrad())</span></code>. However,
360360
since the Jacobian never has to be entirely in memory, it is often much more

0 commit comments

Comments
 (0)