@@ -50,7 +50,6 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
5050
5151
5252
53-
5453 X = torch.randn(8 , 16 , 10 )
5554 Y = torch.randn(8 , 16 )
5655
@@ -78,15 +77,14 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
7877 .. tab-item :: autojac
7978
8079 .. code-block :: python
81- :emphasize- lines: 5 - 7 , 13 , 17 , 22 - 24
80+ :emphasize- lines: 5 - 6 , 12 , 16 , 21 - 23
8281
8382 import torch
8483 from torch.nn import Linear, MSELoss, ReLU, Sequential
8584 from torch.optim import SGD
8685
8786 from torchjd.aggregation import UPGrad
88- from torchjd.autojac import backward
89- from torchjd.utils import jac_to_grad
87+ from torchjd.autojac import backward, jac_to_grad
9088
9189 X = torch.randn(8 , 16 , 10 )
9290 Y = torch.randn(8 , 16 )
@@ -115,7 +113,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
115113 .. tab-item :: autogram (recommended)
116114
117115 .. code-block :: python
118- :emphasize- lines: 5 - 6 , 13 , 17 - 18 , 22 - 25
116+ :emphasize- lines: 5 - 6 , 12 , 16 - 17 , 21 - 24
119117
120118 import torch
121119 from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -124,7 +122,6 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
124122 from torchjd.aggregation import UPGradWeighting
125123 from torchjd.autogram import Engine
126124
127-
128125 X = torch.randn(8 , 16 , 10 )
129126 Y = torch.randn(8 , 16 )
130127
0 commit comments