Skip to content

Commit c6d2866

Browse files
authored
docs(readme): use diff code block for usage example (#238)
1 parent 3a3c459 commit c6d2866

File tree

1 file changed

+37
-35
lines changed

1 file changed

+37
-35
lines changed

README.md

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -59,41 +59,43 @@ The main way to use TorchJD is to replace the usual call to `loss.backward()` by
5959
The following example shows how to use TorchJD to train a multi-task model with Jacobian descent,
6060
using [UPGrad](https://torchjd.org/docs/aggregation/upgrad/).
6161

62-
```python
63-
import torch
64-
from torch.nn import Linear, MSELoss, ReLU, Sequential
65-
from torch.optim import SGD
66-
67-
from torchjd import mtl_backward
68-
from torchjd.aggregation import UPGrad
69-
70-
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
71-
task1_module = Linear(3, 1)
72-
task2_module = Linear(3, 1)
73-
params = [
74-
*shared_module.parameters(),
75-
*task1_module.parameters(),
76-
*task2_module.parameters(),
77-
]
78-
79-
loss_fn = MSELoss()
80-
optimizer = SGD(params, lr=0.1)
81-
aggregator = UPGrad()
82-
83-
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
84-
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
85-
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
86-
87-
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
88-
features = shared_module(input)
89-
output1 = task1_module(features)
90-
output2 = task2_module(features)
91-
loss1 = loss_fn(output1, target1)
92-
loss2 = loss_fn(output2, target2)
93-
94-
optimizer.zero_grad()
95-
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
96-
optimizer.step()
62+
```diff
63+
import torch
64+
from torch.nn import Linear, MSELoss, ReLU, Sequential
65+
from torch.optim import SGD
66+
67+
+ from torchjd import mtl_backward
68+
+ from torchjd.aggregation import UPGrad
69+
70+
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
71+
task1_module = Linear(3, 1)
72+
task2_module = Linear(3, 1)
73+
params = [
74+
*shared_module.parameters(),
75+
*task1_module.parameters(),
76+
*task2_module.parameters(),
77+
]
78+
79+
loss_fn = MSELoss()
80+
optimizer = SGD(params, lr=0.1)
81+
+ aggregator = UPGrad()
82+
83+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
84+
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
85+
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
86+
87+
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
88+
features = shared_module(input)
89+
output1 = task1_module(features)
90+
output2 = task2_module(features)
91+
loss1 = loss_fn(output1, target1)
92+
loss2 = loss_fn(output2, target2)
93+
94+
optimizer.zero_grad()
95+
- loss = loss1 + loss2
96+
- loss.backward()
97+
+ mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
98+
optimizer.step()
9799
```
98100

99101
> [!NOTE]

0 commit comments

Comments
 (0)