Skip to content

Commit e3673c1

Browse files
committed
1 parent 3ada480 commit e3673c1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+3141
-633
lines changed

stable/.buildinfo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Sphinx build info version 1
22
# This file records the configuration used when building these files. When it is not found, a full rebuild will be done.
3-
config: e5000e9e67322558c40cde17b9944e7c
3+
config: a7571c8a4da17ac2925af991bf4fc8fb
44
tags: d77d1c0d9ca2f4c8421862c7c5a0d620

stable/_sources/docs/autojac/index.rst.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ autojac
1010

1111
backward.rst
1212
mtl_backward.rst
13+
jac.rst
14+
jac_to_grad.rst
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
:hide-toc:
2+
3+
jac
4+
===
5+
6+
.. autofunction:: torchjd.autojac.jac
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
:hide-toc:
2+
3+
jac_to_grad
4+
===========
5+
6+
.. autofunction:: torchjd.autojac.jac_to_grad

stable/_sources/examples/amp.rst.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ case, the losses) should preferably be scaled with a `GradScaler
1212
following example shows the resulting code for a multi-task learning use-case.
1313

1414
.. code-block:: python
15-
:emphasize-lines: 2, 17, 27, 34, 36-38
15+
:emphasize-lines: 2, 17, 27, 34-35, 37-38
1616
1717
import torch
1818
from torch.amp import GradScaler
1919
from torch.nn import Linear, MSELoss, ReLU, Sequential
2020
from torch.optim import SGD
2121
2222
from torchjd.aggregation import UPGrad
23-
from torchjd.autojac import mtl_backward
23+
from torchjd.autojac import mtl_backward, jac_to_grad
2424
2525
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
2626
task1_module = Linear(3, 1)
@@ -48,10 +48,11 @@ following example shows the resulting code for a multi-task learning use-case.
4848
loss2 = loss_fn(output2, target2)
4949
5050
scaled_losses = scaler.scale([loss1, loss2])
51-
optimizer.zero_grad()
52-
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
51+
mtl_backward(scaled_losses, features=features)
52+
jac_to_grad(shared_module.parameters(), aggregator)
5353
scaler.step(optimizer)
5454
scaler.update()
55+
optimizer.zero_grad()
5556
5657
.. hint::
5758
Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For

stable/_sources/examples/basic_usage.rst.txt

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Import several classes from ``torch`` and ``torchjd``:
2020
2121
from torchjd import autojac
2222
from torchjd.aggregation import UPGrad
23+
from torchjd.autojac import jac_to_grad
2324
2425
Define the model and the optimizer, as usual:
2526

@@ -59,20 +60,16 @@ We can now compute the losses associated to each element of the batch.
5960
6061
The last steps are similar to gradient descent-based optimization, but using the two losses.
6162

62-
Reset the ``.grad`` field of each model parameter:
63-
64-
.. code-block:: python
65-
66-
optimizer.zero_grad()
67-
6863
Perform the Jacobian descent backward pass:
6964

7065
.. code-block:: python
7166
72-
autojac.backward([loss1, loss2], aggregator)
67+
autojac.backward([loss1, loss2])
68+
jac_to_grad(model.parameters(), aggregator)
7369
74-
This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
75-
Jacobian matrix.
70+
The first function will populate the ``.jac`` field of each model parameter with the corresponding
71+
Jacobian, and the second one will aggregate these Jacobians and store the result in the ``.grad``
72+
field of the parameters. It also deletes the ``.jac`` fields save some memory.
7673

7774
Update each parameter based on its ``.grad`` field, using the ``optimizer``:
7875

@@ -81,3 +78,9 @@ Update each parameter based on its ``.grad`` field, using the ``optimizer``:
8178
optimizer.step()
8279
8380
The model's parameters have been updated!
81+
82+
As usual, you should now reset the ``.grad`` field of each model parameter:
83+
84+
.. code-block:: python
85+
86+
optimizer.zero_grad()

stable/_sources/examples/iwmtl.rst.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ this Gramian to reweight the gradients and resolve conflict entirely.
1010
The following example shows how to do that.
1111

1212
.. code-block:: python
13-
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42
13+
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 40-41
1414
1515
import torch
1616
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -51,10 +51,10 @@ The following example shows how to do that.
5151
# Obtain the weights that lead to no conflict between reweighted gradients
5252
weights = weighting(gramian) # shape: [16, 2]
5353
54-
optimizer.zero_grad()
5554
# Do the standard backward pass, but weighted using the obtained weights
5655
losses.backward(weights)
5756
optimizer.step()
57+
optimizer.zero_grad()
5858
5959
.. note::
6060
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a

stable/_sources/examples/iwrm.rst.txt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,26 +64,26 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
6464
for x, y in zip(X, Y):
6565
y_hat = model(x).squeeze(dim=1) # shape: [16]
6666
loss = loss_fn(y_hat, y) # shape: [] (scalar)
67-
optimizer.zero_grad()
6867
loss.backward()
6968
7069
7170
optimizer.step()
71+
optimizer.zero_grad()
7272
7373
In this baseline example, the update may negatively affect the loss of some elements of the
7474
batch.
7575

7676
.. tab-item:: autojac
7777

7878
.. code-block:: python
79-
:emphasize-lines: 5-6, 12, 16, 21, 23
79+
:emphasize-lines: 5-6, 12, 16, 21-23
8080
8181
import torch
8282
from torch.nn import Linear, MSELoss, ReLU, Sequential
8383
from torch.optim import SGD
8484
8585
from torchjd.aggregation import UPGrad
86-
from torchjd.autojac import backward
86+
from torchjd.autojac import backward, jac_to_grad
8787
8888
X = torch.randn(8, 16, 10)
8989
Y = torch.randn(8, 16)
@@ -99,19 +99,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
9999
for x, y in zip(X, Y):
100100
y_hat = model(x).squeeze(dim=1) # shape: [16]
101101
losses = loss_fn(y_hat, y) # shape: [16]
102-
optimizer.zero_grad()
103-
backward(losses, aggregator)
104-
102+
backward(losses)
103+
jac_to_grad(model.parameters(), aggregator)
105104
106105
optimizer.step()
106+
optimizer.zero_grad()
107107
108108
Here, we compute the Jacobian of the per-sample losses with respect to the model parameters
109109
and use it to update the model such that no loss from the batch is (locally) increased.
110110

111111
.. tab-item:: autogram (recommended)
112112

113113
.. code-block:: python
114-
:emphasize-lines: 5-6, 12, 16-17, 21, 23-25
114+
:emphasize-lines: 5-6, 12, 16-17, 21-24
115115
116116
import torch
117117
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -134,11 +134,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
134134
for x, y in zip(X, Y):
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]
136136
losses = loss_fn(y_hat, y) # shape: [16]
137-
optimizer.zero_grad()
138137
gramian = engine.compute_gramian(losses) # shape: [16, 16]
139138
weights = weighting(gramian) # shape: [16]
140139
losses.backward(weights)
141140
optimizer.step()
141+
optimizer.zero_grad()
142142
143143
Here, the per-sample gradients are never fully stored in memory, leading to large
144144
improvements in memory usage and speed compared to autojac, in most practical cases. The

stable/_sources/examples/lightning_integration.rst.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using
1111
<../docs/autojac/mtl_backward>` at each training iteration.
1212

1313
.. code-block:: python
14-
:emphasize-lines: 9-10, 18, 32
14+
:emphasize-lines: 9-10, 18, 31-32
1515
1616
import torch
1717
from lightning import LightningModule, Trainer
@@ -22,7 +22,7 @@ The following code example demonstrates a basic multi-task learning setup using
2222
from torch.utils.data import DataLoader, TensorDataset
2323
2424
from torchjd.aggregation import UPGrad
25-
from torchjd.autojac import mtl_backward
25+
from torchjd.autojac import mtl_backward, jac_to_grad
2626
2727
class Model(LightningModule):
2828
def __init__(self):
@@ -43,9 +43,10 @@ The following code example demonstrates a basic multi-task learning setup using
4343
loss2 = mse_loss(output2, target2)
4444
4545
opt = self.optimizers()
46-
opt.zero_grad()
47-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
46+
mtl_backward([loss1, loss2], features=features)
47+
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
4848
opt.step()
49+
opt.zero_grad()
4950
5051
def configure_optimizers(self) -> OptimizerLRScheduler:
5152
optimizer = Adam(self.parameters(), lr=1e-3)

stable/_sources/examples/monitoring.rst.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ they have a negative inner product).
2323
from torch.optim import SGD
2424
2525
from torchjd.aggregation import UPGrad
26-
from torchjd.autojac import mtl_backward
26+
from torchjd.autojac import mtl_backward, jac_to_grad
2727
2828
def print_weights(_, __, weights: torch.Tensor) -> None:
2929
"""Prints the extracted weights."""
@@ -49,7 +49,7 @@ they have a negative inner product).
4949
optimizer = SGD(params, lr=0.1)
5050
aggregator = UPGrad()
5151
52-
aggregator.weighting.weighting.register_forward_hook(print_weights)
52+
aggregator.gramian_weighting.register_forward_hook(print_weights)
5353
aggregator.register_forward_hook(print_gd_similarity)
5454
5555
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
@@ -63,6 +63,7 @@ they have a negative inner product).
6363
loss1 = loss_fn(output1, target1)
6464
loss2 = loss_fn(output2, target2)
6565
66-
optimizer.zero_grad()
67-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
66+
mtl_backward([loss1, loss2], features=features)
67+
jac_to_grad(shared_module.parameters(), aggregator)
6868
optimizer.step()
69+
optimizer.zero_grad()

0 commit comments

Comments
 (0)