@@ -50,6 +50,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
5050
5151
5252
53+
5354 X = torch.randn(8 , 16 , 10 )
5455 Y = torch.randn(8 , 16 )
5556
@@ -64,26 +65,27 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
6465 for x, y in zip (X, Y):
6566 y_hat = model(x).squeeze(dim = 1 ) # shape: [16]
6667 loss = loss_fn(y_hat, y) # shape: [] (scalar)
67- optimizer.zero_grad()
6868 loss.backward()
6969
7070
7171 optimizer.step()
72+ optimizer.zero_grad()
7273
7374 In this baseline example, the update may negatively affect the loss of some elements of the
7475 batch.
7576
7677 .. tab-item :: autojac
7778
7879 .. code-block :: python
79- :emphasize- lines: 5 - 6 , 12 , 16 , 21 , 23
80+ :emphasize- lines: 5 - 7 , 13 , 17 , 22 - 24
8081
8182 import torch
8283 from torch.nn import Linear, MSELoss, ReLU, Sequential
8384 from torch.optim import SGD
8485
8586 from torchjd.aggregation import UPGrad
8687 from torchjd.autojac import backward
88+ from torchjd.utils import jac_to_grad
8789
8890 X = torch.randn(8 , 16 , 10 )
8991 Y = torch.randn(8 , 16 )
@@ -99,19 +101,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
99101 for x, y in zip (X, Y):
100102 y_hat = model(x).squeeze(dim = 1 ) # shape: [16]
101103 losses = loss_fn(y_hat, y) # shape: [16]
102- optimizer.zero_grad()
103- backward(losses, aggregator)
104-
104+ backward(losses)
105+ jac_to_grad(model.parameters(), aggregator)
105106
106107 optimizer.step()
108+ optimizer.zero_grad()
107109
108110 Here, we compute the Jacobian of the per-sample losses with respect to the model parameters
109111 and use it to update the model such that no loss from the batch is (locally) increased.
110112
111113 .. tab-item :: autogram (recommended)
112114
113115 .. code-block :: python
114- :emphasize- lines: 5 - 6 , 12 , 16 - 17 , 21 , 23 - 25
116+ :emphasize- lines: 5 - 6 , 13 , 17 - 18 , 22 - 25
115117
116118 import torch
117119 from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -120,6 +122,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
120122 from torchjd.aggregation import UPGradWeighting
121123 from torchjd.autogram import Engine
122124
125+
123126 X = torch.randn(8 , 16 , 10 )
124127 Y = torch.randn(8 , 16 )
125128
@@ -134,11 +137,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
134137 for x, y in zip (X, Y):
135138 y_hat = model(x).squeeze(dim = 1 ) # shape: [16]
136139 losses = loss_fn(y_hat, y) # shape: [16]
137- optimizer.zero_grad()
138140 gramian = engine.compute_gramian(losses) # shape: [16, 16]
139141 weights = weighting(gramian) # shape: [16]
140142 losses.backward(weights)
141143 optimizer.step()
144+ optimizer.zero_grad()
142145
143146 Here, the per-sample gradients are never fully stored in memory, leading to large
144147 improvements in memory usage and speed compared to autojac, in most practical cases. The
0 commit comments