Skip to content

Commit b217983

Browse files
committed
addressing feedback and comments
1 parent 16ff7e6 commit b217983

1 file changed

Lines changed: 31 additions & 15 deletions

File tree

README.md

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ two complementary approaches:
2323
literature (geometric mean, softmax weighting, [etc.](#supported-scalarizers)). This is often a good baseline.
2424
- **[Jacobian descent](https://arxiv.org/pdf/2406.16232)**: compute the Jacobian matrix of losses
2525
with respect to parameters and aggregate it into an update direction using state-of-the-art
26-
aggregators (UPGrad, MGDA, CAGrad, [and many more](#supported-aggregators-and-weightings)).
26+
aggregators (UPGrad, MGDA, CAGrad, [and many more](#supported-aggregators-and-weightings)).
2727
This in particular allows taking conflict-free
2828
optimization directions, which can resolve problems that may be impossible to solve with standard
2929
scalarizers.
@@ -57,22 +57,30 @@ a standard training loop to use scalarization:
5757

5858
+ from torchjd.scalarization import GeometricMean
5959

60-
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1))
61-
optimizer = SGD(model.parameters(), lr=0.1)
62-
criterion = MSELoss()
60+
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
61+
task1_module = Linear(3, 1)
62+
task2_module = Linear(3, 1)
63+
params = [*shared_module.parameters(), *task1_module.parameters(), *task2_module.parameters()]
64+
65+
loss_fn = MSELoss()
66+
optimizer = SGD(params, lr=0.1)
6367
+ scalarizer = GeometricMean()
6468

65-
inputs = torch.randn(16, 10)
66-
task1_targets, task2_targets = torch.randn(16, 1), torch.randn(16, 1)
67-
68-
output = model(inputs)
69-
- loss = criterion(output, task1_targets) + criterion(output, task2_targets)
70-
- loss.backward()
71-
+ losses = torch.stack([criterion(output, task1_targets), criterion(output, task2_targets)])
72-
+ loss = scalarizer(losses)
73-
+ loss.backward()
74-
optimizer.step()
75-
optimizer.zero_grad()
69+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
70+
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
71+
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
72+
73+
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
74+
features = shared_module(input)
75+
loss1 = loss_fn(task1_module(features), target1)
76+
loss2 = loss_fn(task2_module(features), target2)
77+
78+
- loss = loss1 + loss2
79+
- loss.backward()
80+
+ loss = scalarizer(torch.stack([loss1, loss2]))
81+
+ loss.backward()
82+
optimizer.step()
83+
optimizer.zero_grad()
7684
```
7785

7886
### Jacobian descent
@@ -116,6 +124,14 @@ Here is how to change a standard multi-task training loop to use Jacobian descen
116124
optimizer.zero_grad()
117125
```
118126

127+
### The `autojac` engine
128+
129+
The [`autojac` engine](https://torchjd.org/stable/docs/autojac/) provides fine-grained control
130+
over Jacobian computation and aggregation. It lets you compute Jacobians with respect to specific
131+
layers or activations (partial Jacobian descent), store them in `.jac` fields for inspection, and
132+
apply any aggregator independently. See the [autojac examples](https://torchjd.org/stable/examples/)
133+
for more details.
134+
119135
### The `autogram` engine
120136

121137
TorchJD also provides the [`autogram` engine](https://torchjd.org/stable/docs/autogram/engine/),

0 commit comments

Comments
 (0)