Skip to content

Commit 654dd88

Browse files
authored
Improve sequential differentiation (#222)
* Remove some tested parameters from test_value_is_correct to make it faster * Add parametrization of chunk_size in test_value_is_correct and in most test_jac tests * Add test_tensor_used_multiple_times in test_backward.py * Change tests with retains_grad to make them have m>1 * Change chunk_size from 3 to 2 in test_value_is_correct * Change the implementation of Jac to always make the chunks ourselves and use vmap only if necessary * Change the role of retain_graph to apply to the last differentiation * Add changelog entries --------- Co-authored-by: Pierre Quinton <pierre.quinton@epfl.ch>
1 parent 645629c commit 654dd88

File tree

8 files changed

+151
-142
lines changed

8 files changed

+151
-142
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@ changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Changed
12+
13+
- Changed how the Jacobians are computed when calling `backward` or `mtl_backward` with
14+
`parallel_chunk_size=1` to not rely on `torch.autograd.vmap` in this case. Whenever `vmap` does
15+
not support something (compiled functions, RNN on cuda, etc.), users should now be able to avoid
16+
using `vmap` by calling `backward` or `mtl_backward` with `parallel_chunk_size=1`.
17+
18+
- Changed the effect of the parameter `retain_graph` of `backward` and `mtl_backward`. When set to
19+
`False`, it now frees the graph only after all gradients have been computed. In most cases, users
20+
should now leave the default value `retain_graph=False`, no matter what the value of
21+
`parallel_chunk_size` is. This will reduce the memory overhead.
22+
1123
## [0.3.1] - 2024-12-21
1224

1325
### Changed

src/torchjd/autojac/_transform/jac.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import math
2+
from functools import partial
13
from itertools import accumulate
2-
from typing import Iterable, Sequence
4+
from typing import Callable, Iterable, Sequence
35

46
import torch
57
from torch import Size, Tensor
@@ -56,30 +58,67 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
5658
]
5759
)
5860

59-
def get_vjp(grad_outputs: Sequence[Tensor]) -> Tensor:
61+
def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor:
6062
optional_grads = torch.autograd.grad(
6163
outputs,
6264
inputs,
6365
grad_outputs=grad_outputs,
64-
retain_graph=self.retain_graph,
66+
retain_graph=retain_graph,
6567
create_graph=self.create_graph,
6668
allow_unused=True,
6769
)
6870
grads = _materialize(optional_grads, inputs=inputs)
6971
return torch.concatenate([grad.reshape([-1]) for grad in grads])
7072

71-
# Because of a limitation of vmap, this breaks when some tensors have `retains_grad=True`.
72-
# See https://pytorch.org/functorch/stable/ux_limitations.html for more information.
73-
# This also breaks when some tensors have been produced by compiled functions.
74-
grouped_jacobian_matrix = torch.vmap(get_vjp, chunk_size=self.chunk_size)(jac_outputs)
75-
73+
# By the Jacobians constraint, this value should be the same for all jac_outputs.
74+
m = jac_outputs[0].shape[0]
75+
max_chunk_size = self.chunk_size if self.chunk_size is not None else m
76+
n_chunks = math.ceil(m / max_chunk_size)
77+
78+
# List of tensors of shape [k_i, n] where the k_i's sum to m
79+
jac_matrix_chunks = []
80+
81+
# First differentiations: always retain graph
82+
get_vjp_retain = partial(_get_vjp, retain_graph=True)
83+
for i in range(n_chunks - 1):
84+
start = i * max_chunk_size
85+
end = (i + 1) * max_chunk_size
86+
jac_outputs_chunk = [jac_output[start:end] for jac_output in jac_outputs]
87+
jac_matrix_chunks.append(_get_jac_matrix_chunk(jac_outputs_chunk, get_vjp_retain))
88+
89+
# Last differentiation: retain the graph only if self.retain_graph==True
90+
get_vjp_last = partial(_get_vjp, retain_graph=self.retain_graph)
91+
start = (n_chunks - 1) * max_chunk_size
92+
jac_outputs_chunk = [jac_output[start:] for jac_output in jac_outputs]
93+
jac_matrix_chunks.append(_get_jac_matrix_chunk(jac_outputs_chunk, get_vjp_last))
94+
95+
jac_matrix = torch.vstack(jac_matrix_chunks)
7696
lengths = [input.numel() for input in inputs]
77-
jacobian_matrices = _extract_sub_matrices(grouped_jacobian_matrix, lengths)
97+
jac_matrices = _extract_sub_matrices(jac_matrix, lengths)
7898

7999
shapes = [input.shape for input in inputs]
80-
jacobians = _reshape_matrices(jacobian_matrices, shapes)
81-
82-
return tuple(jacobians)
100+
jacs = _reshape_matrices(jac_matrices, shapes)
101+
102+
return tuple(jacs)
103+
104+
105+
def _get_jac_matrix_chunk(
106+
jac_outputs_chunk: list[Tensor], get_vjp: Callable[[Sequence[Tensor]], Tensor]
107+
) -> Tensor:
108+
"""
109+
Computes the jacobian matrix chunk corresponding to the provided get_vjp function, either by
110+
calling get_vjp directly or by wrapping it into a call to ``torch.vmap``, depending on the shape
111+
of the provided ``jac_outputs_chunk``. Because of the numerous issues of vmap, we use it only if
112+
necessary (i.e. when the ``jac_outputs_chunk`` have more than 1 row).
113+
"""
114+
115+
chunk_size = jac_outputs_chunk[0].shape[0]
116+
if chunk_size == 1:
117+
grad_outputs = [tensor.squeeze() for tensor in jac_outputs_chunk]
118+
gradient_vector = get_vjp(grad_outputs)
119+
return gradient_vector.unsqueeze(0)
120+
else:
121+
return torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs_chunk)
83122

84123

85124
def _extract_sub_matrices(matrix: Tensor, lengths: Sequence[int]) -> list[Tensor]:

src/torchjd/autojac/_utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,6 @@ def _as_tensor_list(tensors: Sequence[Tensor] | Tensor) -> list[Tensor]:
2121
return output
2222

2323

24-
def _check_retain_graph_compatible_with_chunk_size(
25-
tensors: list[Tensor],
26-
retain_graph: bool,
27-
parallel_chunk_size: int | None,
28-
) -> None:
29-
tensors_numel = sum([tensor.numel() for tensor in tensors])
30-
if parallel_chunk_size is not None and parallel_chunk_size < tensors_numel and not retain_graph:
31-
raise ValueError(
32-
"When using `retain_graph=False`, parameter `parallel_chunk_size` must be `None` or "
33-
"large enough to compute all gradients in parallel."
34-
)
35-
36-
3724
def _get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> set[Tensor]:
3825
"""
3926
Gets the leaves of the autograd graph of all specified ``tensors``.

src/torchjd/autojac/backward.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,7 @@
55
from torchjd.aggregation import Aggregator
66

77
from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac
8-
from ._utils import (
9-
_as_tensor_list,
10-
_check_optional_positive_chunk_size,
11-
_check_retain_graph_compatible_with_chunk_size,
12-
_get_leaf_tensors,
13-
)
8+
from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors
149

1510

1611
def backward(
@@ -37,8 +32,7 @@ def backward(
3732
backward pass. If set to ``None``, all coordinates of ``tensors`` will be differentiated in
3833
parallel at once. If set to ``1``, all coordinates will be differentiated sequentially. A
3934
larger value results in faster differentiation, but also higher memory usage. Defaults to
40-
``None``. If ``parallel_chunk_size`` is not large enough to differentiate all tensors
41-
simultaneously, ``retain_graph`` has to be set to ``True``.
35+
``None``.
4236
4337
.. admonition::
4438
Example
@@ -64,13 +58,13 @@ def backward(
6458
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
6559
6660
.. warning::
67-
``backward`` relies on a usage of ``torch.vmap`` that is not compatible with compiled
68-
functions. The arguments of ``backward`` should thus not come from a compiled model. Check
69-
https://github.com/pytorch/pytorch/issues/138422 for the status of this issue.
70-
71-
.. warning::
72-
Because of a limitation of ``torch.vmap``, tensors in the computation graph of the
73-
``tensors`` parameter should not have their ``retains_grad`` parameter set to ``True``.
61+
To differentiate in parallel, ``backward`` relies on ``torch.vmap``, which has some
62+
limitations: `it does not work on the output of compiled functions
63+
<https://github.com/pytorch/pytorch/issues/138422>`_, `when some tensors have
64+
<https://github.com/TorchJD/torchjd/issues/184>`_ ``retains_grad=True`` or `when using an
65+
RNN on CUDA <https://github.com/TorchJD/torchjd/issues/220>`_, for instance. If you
66+
experience issues with ``backward`` try to use ``parallel_chunk_size=1`` to avoid relying on
67+
``torch.vmap``.
7468
"""
7569
_check_optional_positive_chunk_size(parallel_chunk_size)
7670

@@ -79,8 +73,6 @@ def backward(
7973
if len(tensors) == 0:
8074
raise ValueError("`tensors` cannot be empty")
8175

82-
_check_retain_graph_compatible_with_chunk_size(tensors, retain_graph, parallel_chunk_size)
83-
8476
if inputs is None:
8577
inputs = _get_leaf_tensors(tensors=tensors, excluded=set())
8678
else:

src/torchjd/autojac/mtl_backward.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@
1616
Stack,
1717
Transform,
1818
)
19-
from ._utils import (
20-
_as_tensor_list,
21-
_check_optional_positive_chunk_size,
22-
_check_retain_graph_compatible_with_chunk_size,
23-
_get_leaf_tensors,
24-
)
19+
from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors
2520

2621

2722
def mtl_backward(
@@ -60,8 +55,7 @@ def mtl_backward(
6055
backward pass. If set to ``None``, all coordinates of ``tensors`` will be differentiated in
6156
parallel at once. If set to ``1``, all coordinates will be differentiated sequentially. A
6257
larger value results in faster differentiation, but also higher memory usage. Defaults to
63-
``None``. If ``parallel_chunk_size`` is not large enough to differentiate all tensors
64-
simultaneously, ``retain_graph`` has to be set to ``True``.
58+
``None``.
6559
6660
.. admonition::
6761
Example
@@ -75,13 +69,13 @@ def mtl_backward(
7569
respect to those parameters will be accumulated into their ``.grad`` fields.
7670
7771
.. warning::
78-
``mtl_backward`` relies on a usage of ``torch.vmap`` that is not compatible with compiled
79-
functions. The arguments of ``mtl_backward`` should thus not come from a compiled model.
80-
Check https://github.com/pytorch/pytorch/issues/138422 for the status of this issue.
81-
82-
.. warning::
83-
Because of a limitation of ``torch.vmap``, tensors in the computation graph of the
84-
``features`` parameter should not have their ``retains_grad`` parameter set to ``True``.
72+
To differentiate in parallel, ``mtl_backward`` relies on ``torch.vmap``, which has some
73+
limitations: `it does not work on the output of compiled functions
74+
<https://github.com/pytorch/pytorch/issues/138422>`_, `when some tensors have
75+
<https://github.com/TorchJD/torchjd/issues/184>`_ ``retains_grad=True`` or `when using an
76+
RNN on CUDA <https://github.com/TorchJD/torchjd/issues/220>`_, for instance. If you
77+
experience issues with ``backward`` try to use ``parallel_chunk_size=1`` to avoid relying on
78+
``torch.vmap``.
8579
"""
8680

8781
_check_optional_positive_chunk_size(parallel_chunk_size)
@@ -96,7 +90,6 @@ def mtl_backward(
9690
if len(features) == 0:
9791
raise ValueError("`features` cannot be empty.")
9892

99-
_check_retain_graph_compatible_with_chunk_size(features, retain_graph, parallel_chunk_size)
10093
_check_no_overlap(shared_params, tasks_params)
10194
_check_losses_are_scalar(losses)
10295

tests/unit/autojac/_transform/test_jac.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import torch
2-
from pytest import raises
2+
from pytest import mark, raises
33
from unit.conftest import DEVICE
44

55
from torchjd.autojac._transform import Jac, Jacobians
66

77
from ._dict_assertions import assert_tensor_dicts_are_close
88

99

10-
def test_single_input():
10+
@mark.parametrize("chunk_size", [1, 3, None])
11+
def test_single_input(chunk_size: int | None):
1112
"""
1213
Tests that the Jac transform works correctly for an example of multiple differentiation. Here,
1314
the function considered is: `y = [a1 * x, a2 * x]`. We want to compute the jacobians of `y` with
@@ -20,7 +21,7 @@ def test_single_input():
2021
y = torch.stack([a1 * x, a2 * x])
2122
input = Jacobians({y: torch.eye(2, device=DEVICE)})
2223

23-
jac = Jac(outputs=[y], inputs=[a1, a2], chunk_size=None)
24+
jac = Jac(outputs=[y], inputs=[a1, a2], chunk_size=chunk_size)
2425

2526
jacobians = jac(input)
2627
expected_jacobians = {
@@ -31,7 +32,8 @@ def test_single_input():
3132
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
3233

3334

34-
def test_empty_inputs_1():
35+
@mark.parametrize("chunk_size", [1, 3, None])
36+
def test_empty_inputs_1(chunk_size: int | None):
3537
"""
3638
Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`.
3739
"""
@@ -41,15 +43,16 @@ def test_empty_inputs_1():
4143
y = torch.stack([y1, y2])
4244
input = Jacobians({y: torch.eye(2, device=DEVICE)})
4345

44-
jac = Jac(outputs=[y], inputs=[], chunk_size=None)
46+
jac = Jac(outputs=[y], inputs=[], chunk_size=chunk_size)
4547

4648
jacobians = jac(input)
4749
expected_jacobians = {}
4850

4951
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
5052

5153

52-
def test_empty_inputs_2():
54+
@mark.parametrize("chunk_size", [1, 3, None])
55+
def test_empty_inputs_2(chunk_size: int | None):
5356
"""
5457
Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`.
5558
"""
@@ -62,7 +65,7 @@ def test_empty_inputs_2():
6265
y = torch.stack([y1, y2])
6366
input = Jacobians({y: torch.eye(2, device=DEVICE)})
6467

65-
jac = Jac(outputs=[y], inputs=[], chunk_size=None)
68+
jac = Jac(outputs=[y], inputs=[], chunk_size=chunk_size)
6669

6770
jacobians = jac(input)
6871
expected_jacobians = {}
@@ -122,7 +125,8 @@ def test_two_levels():
122125
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
123126

124127

125-
def test_multiple_outputs_1():
128+
@mark.parametrize("chunk_size", [1, 3, None])
129+
def test_multiple_outputs_1(chunk_size: int | None):
126130
"""
127131
Tests that the Jac transform works correctly when the `outputs` contains 3 vectors.
128132
The input (jac_outputs) is not the same for all outputs, so that this test also checks that the
@@ -143,7 +147,7 @@ def test_multiple_outputs_1():
143147
jac_output3 = torch.cat([zeros_2x2, zeros_2x2, identity_2x2])
144148
input = Jacobians({y1: jac_output1, y2: jac_output2, y3: jac_output3})
145149

146-
jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=None)
150+
jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=chunk_size)
147151

148152
jacobians = jac(input)
149153
zero_scalar = torch.tensor(0.0, device=DEVICE)
@@ -155,7 +159,8 @@ def test_multiple_outputs_1():
155159
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
156160

157161

158-
def test_multiple_outputs_2():
162+
@mark.parametrize("chunk_size", [1, 3, None])
163+
def test_multiple_outputs_2(chunk_size: int | None):
159164
"""
160165
Same as test_multiple_outputs_1 but with different jac_outputs, so the returned jacobians are of
161166
different shapes.
@@ -175,7 +180,7 @@ def test_multiple_outputs_2():
175180
jac_output3 = torch.stack([zeros_2, zeros_2, ones_2])
176181
input = Jacobians({y1: jac_output1, y2: jac_output2, y3: jac_output3})
177182

178-
jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=None)
183+
jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=chunk_size)
179184

180185
jacobians = jac(input)
181186
expected_jacobians = {

0 commit comments

Comments
 (0)