@@ -64,26 +64,26 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
6464 for x, y in zip (X, Y):
6565 y_hat = model(x).squeeze(dim = 1 ) # shape: [16]
6666 loss = loss_fn(y_hat, y) # shape: [] (scalar)
67- optimizer.zero_grad()
6867 loss.backward()
6968
7069
7170 optimizer.step()
71+ optimizer.zero_grad()
7272
7373 In this baseline example, the update may negatively affect the loss of some elements of the
7474 batch.
7575
7676 .. tab-item :: autojac
7777
7878 .. code-block :: python
79- :emphasize- lines: 5 - 6 , 12 , 16 , 21 , 23
79+ :emphasize- lines: 5 - 6 , 12 , 16 , 21 - 23
8080
8181 import torch
8282 from torch.nn import Linear, MSELoss, ReLU, Sequential
8383 from torch.optim import SGD
8484
8585 from torchjd.aggregation import UPGrad
86- from torchjd.autojac import backward
86+ from torchjd.autojac import backward, jac_to_grad
8787
8888 X = torch.randn(8 , 16 , 10 )
8989 Y = torch.randn(8 , 16 )
@@ -99,19 +99,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
9999 for x, y in zip (X, Y):
100100 y_hat = model(x).squeeze(dim = 1 ) # shape: [16]
101101 losses = loss_fn(y_hat, y) # shape: [16]
102- optimizer.zero_grad()
103- backward(losses, aggregator)
104-
102+ backward(losses)
103+ jac_to_grad(model.parameters(), aggregator)
105104
106105 optimizer.step()
106+ optimizer.zero_grad()
107107
108108 Here, we compute the Jacobian of the per-sample losses with respect to the model parameters
109109 and use it to update the model such that no loss from the batch is (locally) increased.
110110
111111 .. tab-item :: autogram (recommended)
112112
113113 .. code-block :: python
114- :emphasize- lines: 5 - 6 , 12 , 16 - 17 , 21 , 23 - 25
114+ :emphasize- lines: 5 - 6 , 12 , 16 - 17 , 21 - 24
115115
116116 import torch
117117 from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -134,11 +134,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
134134 for x, y in zip (X, Y):
135135 y_hat = model(x).squeeze(dim = 1 ) # shape: [16]
136136 losses = loss_fn(y_hat, y) # shape: [16]
137- optimizer.zero_grad()
138137 gramian = engine.compute_gramian(losses) # shape: [16, 16]
139138 weights = weighting(gramian) # shape: [16]
140139 losses.backward(weights)
141140 optimizer.step()
141+ optimizer.zero_grad()
142142
143143 Here, the per-sample gradients are never fully stored in memory, leading to large
144144 improvements in memory usage and speed compared to autojac, in most practical cases. The
0 commit comments