Skip to content

Commit 662dd0f

Browse files
authored
Add RNN example (#225)
* Add examples/rnn.rst * Add link to the RNN example in examples/index.rst * Add doc test in test_rst.py * Add changelog entry
1 parent 654dd88 commit 662dd0f

File tree

4 files changed

+69
-0
lines changed

4 files changed

+69
-0
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ changes that do not affect the user.
2020
should now leave the default value `retain_graph=False`, no matter what the value of
2121
`parallel_chunk_size` is. This will reduce the memory overhead.
2222

23+
### Added
24+
25+
- RNN training usage example in the documentation.
26+
2327
## [0.3.1] - 2024-12-21
2428

2529
### Changed

docs/source/examples/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ This section contains some usage examples for TorchJD.
1313
- :doc:`Multi-Task Learning (MTL) <mtl>` provides an example of multi-task learning where Jacobian
1414
descent is used to optimize the vector of per-task losses of a multi-task model, using the
1515
dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`.
16+
- :doc:`Recurrent Neural Network (RNN) <rnn>` shows how to apply Jacobian descent to RNN training,
17+
with one loss per output sequence element.
1618
- :doc:`PyTorch Lightning Integration <lightning_integration>` showcases how to combine
1719
TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
1820
``LightningModule`` optimized by Jacobian descent.
@@ -23,4 +25,5 @@ This section contains some usage examples for TorchJD.
2325
basic_usage.rst
2426
iwrm.rst
2527
mtl.rst
28+
rnn.rst
2629
lightning_integration.rst

docs/source/examples/rnn.rst

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
Recurrent Neural Network (RNN)
2+
==============================
3+
4+
When training recurrent neural networks for sequence modelling, we can easily obtain one loss per
5+
element of the output sequences. If the gradients of these losses are likely to conflict, Jacobian
6+
descent can be leveraged to enhance optimization.
7+
8+
.. code-block:: python
9+
:emphasize-lines: 5-6, 10, 17, 20
10+
11+
import torch
12+
from torch.nn import RNN
13+
from torch.optim import SGD
14+
15+
from torchjd import backward
16+
from torchjd.aggregation import UPGrad
17+
18+
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
19+
optimizer = SGD(rnn.parameters(), lr=0.1)
20+
aggregator = UPGrad()
21+
22+
inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
23+
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.
24+
25+
for input, target in zip(inputs, targets):
26+
output, _ = rnn(input) # output is of shape [5, 3, 20].
27+
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
28+
29+
optimizer.zero_grad()
30+
backward(losses, aggregator, parallel_chunk_size=1)
31+
optimizer.step()
32+
33+
.. note::
34+
At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and
35+
``torch.nn.RNN`` when running on CUDA (see `this issue
36+
<https://github.com/TorchJD/torchjd/issues/220>`_ for more info), so we advise to set the
37+
``parallel_chunk_size`` to ``1`` to avoid using ``torch.vmap``. To improve performance, you can
38+
check whether ``parallel_chunk_size=None`` (maximal parallelization) works on your side.

tests/doc/test_rst.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,27 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
183183
)
184184

185185
trainer.fit(model=model, train_dataloaders=train_loader)
186+
187+
188+
def test_rnn():
189+
import torch
190+
from torch.nn import RNN
191+
from torch.optim import SGD
192+
193+
from torchjd import backward
194+
from torchjd.aggregation import UPGrad
195+
196+
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
197+
optimizer = SGD(rnn.parameters(), lr=0.1)
198+
aggregator = UPGrad()
199+
200+
inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
201+
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.
202+
203+
for input, target in zip(inputs, targets):
204+
output, _ = rnn(input) # output is of shape [5, 3, 20].
205+
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
206+
207+
optimizer.zero_grad()
208+
backward(losses, aggregator, parallel_chunk_size=1)
209+
optimizer.step()

0 commit comments

Comments
 (0)