@@ -59,41 +59,43 @@ The main way to use TorchJD is to replace the usual call to `loss.backward()` by
5959The following example shows how to use TorchJD to train a multi-task model with Jacobian descent,
6060using [ UPGrad] ( https://torchjd.org/docs/aggregation/upgrad/ ) .
6161
62- ``` python
63- import torch
64- from torch.nn import Linear, MSELoss, ReLU, Sequential
65- from torch.optim import SGD
66-
67- from torchjd import mtl_backward
68- from torchjd.aggregation import UPGrad
69-
70- shared_module = Sequential(Linear(10 , 5 ), ReLU(), Linear(5 , 3 ), ReLU())
71- task1_module = Linear(3 , 1 )
72- task2_module = Linear(3 , 1 )
73- params = [
74- * shared_module.parameters(),
75- * task1_module.parameters(),
76- * task2_module.parameters(),
77- ]
78-
79- loss_fn = MSELoss()
80- optimizer = SGD(params, lr = 0.1 )
81- aggregator = UPGrad()
82-
83- inputs = torch.randn(8 , 16 , 10 ) # 8 batches of 16 random input vectors of length 10
84- task1_targets = torch.randn(8 , 16 , 1 ) # 8 batches of 16 targets for the first task
85- task2_targets = torch.randn(8 , 16 , 1 ) # 8 batches of 16 targets for the second task
86-
87- for input , target1, target2 in zip (inputs, task1_targets, task2_targets):
88- features = shared_module(input )
89- output1 = task1_module(features)
90- output2 = task2_module(features)
91- loss1 = loss_fn(output1, target1)
92- loss2 = loss_fn(output2, target2)
93-
94- optimizer.zero_grad()
95- mtl_backward(losses = [loss1, loss2], features = features, aggregator = aggregator)
96- optimizer.step()
62+ ``` diff
63+ import torch
64+ from torch.nn import Linear, MSELoss, ReLU, Sequential
65+ from torch.optim import SGD
66+
67+ + from torchjd import mtl_backward
68+ + from torchjd.aggregation import UPGrad
69+
70+ shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
71+ task1_module = Linear(3, 1)
72+ task2_module = Linear(3, 1)
73+ params = [
74+ *shared_module.parameters(),
75+ *task1_module.parameters(),
76+ *task2_module.parameters(),
77+ ]
78+
79+ loss_fn = MSELoss()
80+ optimizer = SGD(params, lr=0.1)
81+ + aggregator = UPGrad()
82+
83+ inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
84+ task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
85+ task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
86+
87+ for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
88+ features = shared_module(input)
89+ output1 = task1_module(features)
90+ output2 = task2_module(features)
91+ loss1 = loss_fn(output1, target1)
92+ loss2 = loss_fn(output2, target2)
93+
94+ optimizer.zero_grad()
95+ - loss = loss1 + loss2
96+ - loss.backward()
97+ + mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
98+ optimizer.step()
9799```
98100
99101> [ !NOTE]
0 commit comments