Skip to content

Commit a5c19da

Browse files
authored
docs: Add monitoring example (#337)
* Add monitoring.rst * Add its test in test_rst.py * Add link in index.rst
1 parent aeb34ea commit a5c19da

File tree

3 files changed

+123
-0
lines changed

3 files changed

+123
-0
lines changed

docs/source/examples/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ This section contains some usage examples for TorchJD.
1515
dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`.
1616
- :doc:`Recurrent Neural Network (RNN) <rnn>` shows how to apply Jacobian descent to RNN training,
1717
with one loss per output sequence element.
18+
- :doc:`Monitoring Aggregations <monitoring>` shows how to monitor the aggregation performed by the
19+
aggregator, to check if Jacobian descent is prescribed for your use-case.
1820
- :doc:`PyTorch Lightning Integration <lightning_integration>` showcases how to combine
1921
TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
2022
``LightningModule`` optimized by Jacobian descent.
@@ -27,5 +29,6 @@ This section contains some usage examples for TorchJD.
2729
iwrm.rst
2830
mtl.rst
2931
rnn.rst
32+
monitoring.rst
3033
lightning_integration.rst
3134
amp.rst
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
Monitoring aggregations
2+
=======================
3+
4+
The :doc:`Aggregator <../docs/aggregation/bases>` class is a subclass of :class:`torch.nn.Module`.
5+
This allows registering hooks, which can be used to monitor some information about aggregations.
6+
The following code example demonstrates registering a hook to compute and print the cosine
7+
similarity between the aggregation performed by :doc:`UPGrad <../docs/aggregation/upgrad>` and the
8+
average of the gradients, and another hook to compute and print the weights of the weighting of
9+
:doc:`UPGrad <../docs/aggregation/upgrad>`.
10+
11+
Updating the parameters of the model with the average gradient is equivalent to using gradient
12+
descent on the average of the losses. Observing a cosine similarity smaller than 1 means that
13+
Jacobian descent is doing something different than gradient descent. With
14+
:doc:`UPGrad <../docs/aggregation/upgrad>`, this happens when the original gradients conflict (i.e.
15+
they have a negative inner product).
16+
17+
.. code-block:: python
18+
:emphasize-lines: 9-11, 13-18, 33-34
19+
20+
import torch
21+
from torch.nn import Linear, MSELoss, ReLU, Sequential
22+
from torch.optim import SGD
23+
from torch.nn.functional import cosine_similarity
24+
25+
from torchjd import mtl_backward
26+
from torchjd.aggregation import UPGrad
27+
28+
def print_weights(_, __, weights: torch.Tensor) -> None:
29+
"""Prints the extracted weights."""
30+
print(f"Weights: {weights}")
31+
32+
def print_similarity_with_gd(_, inputs: torch.Tensor, aggregation: torch.Tensor) -> None:
33+
"""Prints the cosine similarity between the aggregation and the average gradient."""
34+
matrix = inputs[0]
35+
gd_output = matrix.mean(dim=0)
36+
similarity = cosine_similarity(aggregation, gd_output, dim=0)
37+
print(f"Cosine similarity: {similarity.item():.4f}")
38+
39+
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
40+
task1_module = Linear(3, 1)
41+
task2_module = Linear(3, 1)
42+
params = [
43+
*shared_module.parameters(),
44+
*task1_module.parameters(),
45+
*task2_module.parameters(),
46+
]
47+
48+
loss_fn = MSELoss()
49+
optimizer = SGD(params, lr=0.1)
50+
aggregator = UPGrad()
51+
52+
aggregator.weighting.register_forward_hook(print_weights)
53+
aggregator.register_forward_hook(print_similarity_with_gd)
54+
55+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
56+
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
57+
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
58+
59+
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
60+
features = shared_module(input)
61+
output1 = task1_module(features)
62+
output2 = task2_module(features)
63+
loss1 = loss_fn(output1, target1)
64+
loss2 = loss_fn(output2, target2)
65+
66+
optimizer.zero_grad()
67+
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
68+
optimizer.step()

tests/doc/test_rst.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,58 @@ def test_rnn():
209209
optimizer.step()
210210

211211

212+
def test_monitoring():
213+
import torch
214+
from torch.nn import Linear, MSELoss, ReLU, Sequential
215+
from torch.nn.functional import cosine_similarity
216+
from torch.optim import SGD
217+
218+
from torchjd import mtl_backward
219+
from torchjd.aggregation import UPGrad
220+
221+
def print_weights(_, __, weights: torch.Tensor) -> None:
222+
"""Prints the extracted weights."""
223+
print(f"Weights: {weights}")
224+
225+
def print_similarity_with_gd(_, inputs: torch.Tensor, aggregation: torch.Tensor) -> None:
226+
"""Prints the cosine similarity between the aggregation and the average gradient."""
227+
matrix = inputs[0]
228+
gd_output = matrix.mean(dim=0)
229+
similarity = cosine_similarity(aggregation, gd_output, dim=0)
230+
print(f"Cosine similarity: {similarity.item():.4f}")
231+
232+
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
233+
task1_module = Linear(3, 1)
234+
task2_module = Linear(3, 1)
235+
params = [
236+
*shared_module.parameters(),
237+
*task1_module.parameters(),
238+
*task2_module.parameters(),
239+
]
240+
241+
loss_fn = MSELoss()
242+
optimizer = SGD(params, lr=0.1)
243+
aggregator = UPGrad()
244+
245+
aggregator.weighting.register_forward_hook(print_weights)
246+
aggregator.register_forward_hook(print_similarity_with_gd)
247+
248+
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
249+
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
250+
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
251+
252+
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
253+
features = shared_module(input)
254+
output1 = task1_module(features)
255+
output2 = task2_module(features)
256+
loss1 = loss_fn(output1, target1)
257+
loss2 = loss_fn(output2, target2)
258+
259+
optimizer.zero_grad()
260+
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
261+
optimizer.step()
262+
263+
212264
def test_amp():
213265
import torch
214266
from torch.amp import GradScaler

0 commit comments

Comments
 (0)