Skip to content

Commit 99b4260

Browse files
committed
feat: Use .jac fields
1 parent 8de14e0 commit 99b4260

File tree

19 files changed

+258
-281
lines changed

19 files changed

+258
-281
lines changed

docs/source/examples/amp.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ case, the losses) should preferably be scaled with a `GradScaler
1212
following example shows the resulting code for a multi-task learning use-case.
1313

1414
.. code-block:: python
15-
:emphasize-lines: 2, 17, 27, 34, 36-38
15+
:emphasize-lines: 2, 18, 28, 35-36, 38-39
1616
1717
import torch
1818
from torch.amp import GradScaler
@@ -21,6 +21,7 @@ following example shows the resulting code for a multi-task learning use-case.
2121
2222
from torchjd.aggregation import UPGrad
2323
from torchjd.autojac import mtl_backward
24+
from torchjd.utils import jac_to_grad
2425
2526
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
2627
task1_module = Linear(3, 1)
@@ -48,10 +49,11 @@ following example shows the resulting code for a multi-task learning use-case.
4849
loss2 = loss_fn(output2, target2)
4950
5051
scaled_losses = scaler.scale([loss1, loss2])
51-
optimizer.zero_grad()
52-
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
52+
mtl_backward(losses=scaled_losses, features=features)
53+
jac_to_grad(shared_module.parameters(), aggregator)
5354
scaler.step(optimizer)
5455
scaler.update()
56+
optimizer.zero_grad()
5557
5658
.. hint::
5759
Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For

docs/source/examples/basic_usage.rst

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Import several classes from ``torch`` and ``torchjd``:
2020
2121
from torchjd import autojac
2222
from torchjd.aggregation import UPGrad
23+
from torchjd.utils import jac_to_grad
2324
2425
Define the model and the optimizer, as usual:
2526

@@ -59,20 +60,16 @@ We can now compute the losses associated to each element of the batch.
5960
6061
The last steps are similar to gradient descent-based optimization, but using the two losses.
6162

62-
Reset the ``.grad`` field of each model parameter:
63-
64-
.. code-block:: python
65-
66-
optimizer.zero_grad()
67-
6863
Perform the Jacobian descent backward pass:
6964

7065
.. code-block:: python
7166
72-
autojac.backward([loss1, loss2], aggregator)
67+
autojac.backward([loss1, loss2])
68+
jac_to_grad(model.parameters(), aggregator)
7369
74-
This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
75-
Jacobian matrix.
70+
The first function will populate the ``.jac`` field of each model parameter with the corresponding
71+
Jacobian, and the second one will aggregate these Jacobians and store the result in the ``.grad``
72+
field of the parameters. It also resets the ``.jac`` fields to ``None`` to save some memory.
7673

7774
Update each parameter based on its ``.grad`` field, using the ``optimizer``:
7875

@@ -81,3 +78,9 @@ Update each parameter based on its ``.grad`` field, using the ``optimizer``:
8178
optimizer.step()
8279
8380
The model's parameters have been updated!
81+
82+
As usual, you should now reset the ``.grad`` field of each model parameter:
83+
84+
.. code-block:: python
85+
86+
optimizer.zero_grad()

docs/source/examples/iwmtl.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ this Gramian to reweight the gradients and resolve conflict entirely.
1010
The following example shows how to do that.
1111

1212
.. code-block:: python
13-
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42
13+
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 40-41
1414
1515
import torch
1616
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -51,10 +51,10 @@ The following example shows how to do that.
5151
# Obtain the weights that lead to no conflict between reweighted gradients
5252
weights = weighting(gramian) # shape: [16, 2]
5353
54-
optimizer.zero_grad()
5554
# Do the standard backward pass, but weighted using the obtained weights
5655
losses.backward(weights)
5756
optimizer.step()
57+
optimizer.zero_grad()
5858
5959
.. note::
6060
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a

docs/source/examples/iwrm.rst

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
5050
5151
5252
53+
5354
X = torch.randn(8, 16, 10)
5455
Y = torch.randn(8, 16)
5556
@@ -64,26 +65,27 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
6465
for x, y in zip(X, Y):
6566
y_hat = model(x).squeeze(dim=1) # shape: [16]
6667
loss = loss_fn(y_hat, y) # shape: [] (scalar)
67-
optimizer.zero_grad()
6868
loss.backward()
6969
7070
7171
optimizer.step()
72+
optimizer.zero_grad()
7273
7374
In this baseline example, the update may negatively affect the loss of some elements of the
7475
batch.
7576

7677
.. tab-item:: autojac
7778

7879
.. code-block:: python
79-
:emphasize-lines: 5-6, 12, 16, 21, 23
80+
:emphasize-lines: 5-7, 13, 17, 22-24
8081
8182
import torch
8283
from torch.nn import Linear, MSELoss, ReLU, Sequential
8384
from torch.optim import SGD
8485
8586
from torchjd.aggregation import UPGrad
8687
from torchjd.autojac import backward
88+
from torchjd.utils import jac_to_grad
8789
8890
X = torch.randn(8, 16, 10)
8991
Y = torch.randn(8, 16)
@@ -99,19 +101,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
99101
for x, y in zip(X, Y):
100102
y_hat = model(x).squeeze(dim=1) # shape: [16]
101103
losses = loss_fn(y_hat, y) # shape: [16]
102-
optimizer.zero_grad()
103-
backward(losses, aggregator)
104-
104+
backward(losses)
105+
jac_to_grad(model.parameters(), aggregator)
105106
106107
optimizer.step()
108+
optimizer.zero_grad()
107109
108110
Here, we compute the Jacobian of the per-sample losses with respect to the model parameters
109111
and use it to update the model such that no loss from the batch is (locally) increased.
110112

111113
.. tab-item:: autogram (recommended)
112114

113115
.. code-block:: python
114-
:emphasize-lines: 5-6, 12, 16-17, 21, 23-25
116+
:emphasize-lines: 5-6, 13, 17-18, 22-25
115117
116118
import torch
117119
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -120,6 +122,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
120122
from torchjd.aggregation import UPGradWeighting
121123
from torchjd.autogram import Engine
122124
125+
123126
X = torch.randn(8, 16, 10)
124127
Y = torch.randn(8, 16)
125128
@@ -134,11 +137,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
134137
for x, y in zip(X, Y):
135138
y_hat = model(x).squeeze(dim=1) # shape: [16]
136139
losses = loss_fn(y_hat, y) # shape: [16]
137-
optimizer.zero_grad()
138140
gramian = engine.compute_gramian(losses) # shape: [16, 16]
139141
weights = weighting(gramian) # shape: [16]
140142
losses.backward(weights)
141143
optimizer.step()
144+
optimizer.zero_grad()
142145
143146
Here, the per-sample gradients are never fully stored in memory, leading to large
144147
improvements in memory usage and speed compared to autojac, in most practical cases. The

docs/source/examples/lightning_integration.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using
1111
<../docs/autojac/mtl_backward>` at each training iteration.
1212

1313
.. code-block:: python
14-
:emphasize-lines: 9-10, 18, 32
14+
:emphasize-lines: 9-11, 19, 32-33
1515
1616
import torch
1717
from lightning import LightningModule, Trainer
@@ -23,6 +23,7 @@ The following code example demonstrates a basic multi-task learning setup using
2323
2424
from torchjd.aggregation import UPGrad
2525
from torchjd.autojac import mtl_backward
26+
from torchjd.utils import jac_to_grad
2627
2728
class Model(LightningModule):
2829
def __init__(self):
@@ -43,9 +44,10 @@ The following code example demonstrates a basic multi-task learning setup using
4344
loss2 = mse_loss(output2, target2)
4445
4546
opt = self.optimizers()
46-
opt.zero_grad()
47-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
47+
mtl_backward(losses=[loss1, loss2], features=features)
48+
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
4849
opt.step()
50+
opt.zero_grad()
4951
5052
def configure_optimizers(self) -> OptimizerLRScheduler:
5153
optimizer = Adam(self.parameters(), lr=1e-3)

docs/source/examples/monitoring.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Jacobian descent is doing something different than gradient descent. With
1515
they have a negative inner product).
1616

1717
.. code-block:: python
18-
:emphasize-lines: 9-11, 13-18, 33-34
18+
:emphasize-lines: 10-12, 14-19, 34-35
1919
2020
import torch
2121
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -24,6 +24,7 @@ they have a negative inner product).
2424
2525
from torchjd.aggregation import UPGrad
2626
from torchjd.autojac import mtl_backward
27+
from torchjd.utils import jac_to_grad
2728
2829
def print_weights(_, __, weights: torch.Tensor) -> None:
2930
"""Prints the extracted weights."""
@@ -63,6 +64,7 @@ they have a negative inner product).
6364
loss1 = loss_fn(output1, target1)
6465
loss2 = loss_fn(output2, target2)
6566
66-
optimizer.zero_grad()
67-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
67+
mtl_backward(losses=[loss1, loss2], features=features)
68+
jac_to_grad(shared_module.parameters(), aggregator)
6869
optimizer.step()
70+
optimizer.zero_grad()

docs/source/examples/mtl.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
1919

2020

2121
.. code-block:: python
22-
:emphasize-lines: 5-6, 19, 33
22+
:emphasize-lines: 5-7, 20, 33-34
2323
2424
import torch
2525
from torch.nn import Linear, MSELoss, ReLU, Sequential
2626
from torch.optim import SGD
2727
2828
from torchjd.aggregation import UPGrad
2929
from torchjd.autojac import mtl_backward
30+
from torchjd.utils import jac_to_grad
3031
3132
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
3233
task1_module = Linear(3, 1)
@@ -52,9 +53,10 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
5253
loss1 = loss_fn(output1, target1)
5354
loss2 = loss_fn(output2, target2)
5455
55-
optimizer.zero_grad()
56-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
56+
mtl_backward(losses=[loss1, loss2], features=features)
57+
jac_to_grad(shared_module.parameters(), aggregator)
5758
optimizer.step()
59+
optimizer.zero_grad()
5860
5961
.. note::
6062
In this example, the Jacobian is only with respect to the shared parameters. The task-specific

docs/source/examples/partial_jd.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
4141
for x, y in zip(X, Y):
4242
y_hat = model(x).squeeze(dim=1) # shape: [16]
4343
losses = loss_fn(y_hat, y) # shape: [16]
44-
optimizer.zero_grad()
4544
gramian = engine.compute_gramian(losses)
4645
weights = weighting(gramian)
4746
losses.backward(weights)
4847
optimizer.step()
48+
optimizer.zero_grad()

docs/source/examples/rnn.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ element of the output sequences. If the gradients of these losses are likely to
66
descent can be leveraged to enhance optimization.
77

88
.. code-block:: python
9-
:emphasize-lines: 5-6, 10, 17, 20
9+
:emphasize-lines: 5-7, 11, 18, 20-21
1010
1111
import torch
1212
from torch.nn import RNN
1313
from torch.optim import SGD
1414
1515
from torchjd.aggregation import UPGrad
1616
from torchjd.autojac import backward
17+
from torchjd.utils import jac_to_grad
1718
1819
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
1920
optimizer = SGD(rnn.parameters(), lr=0.1)
@@ -26,9 +27,10 @@ descent can be leveraged to enhance optimization.
2627
output, _ = rnn(input) # output is of shape [5, 3, 20].
2728
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
2829
29-
optimizer.zero_grad()
30-
backward(losses, aggregator, parallel_chunk_size=1)
30+
backward(losses, parallel_chunk_size=1)
31+
jac_to_grad(rnn.parameters(), aggregator)
3132
optimizer.step()
33+
optimizer.zero_grad()
3234
3335
.. note::
3436
At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and

0 commit comments

Comments
 (0)