-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy path_backward.py
More file actions
135 lines (115 loc) · 5.68 KB
/
_backward.py
File metadata and controls
135 lines (115 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from collections.abc import Iterable, Sequence
from torch import Tensor
from ._transform import AccumulateJac, Jac, OrderedSet, Transform
from ._utils import (
as_checked_ordered_set,
check_optional_positive_chunk_size,
create_jac_dict,
get_leaf_tensors,
)
def backward(
tensors: Sequence[Tensor] | Tensor,
/,
*,
jac_tensors: Sequence[Tensor] | Tensor | None = None,
inputs: Iterable[Tensor] | None = None,
retain_graph: bool = False,
parallel_chunk_size: int | None = None,
) -> None:
r"""
Computes the Jacobians of ``tensors`` with respect to ``inputs``, left-multiplied by
``jac_tensors`` (or identity if ``jac_tensors`` is ``None``), and accumulates the results in the
``.jac`` fields of the ``inputs``.
:param tensors: The tensor or tensors to differentiate. Should be non-empty.
:param jac_tensors: The initial Jacobians to backpropagate, analog to the ``grad_tensors``
parameter of :func:`torch.autograd.backward`. If provided, it must have the same structure
as ``tensors`` and each tensor in ``jac_tensors`` must match the shape of the corresponding
tensor in ``tensors``, with an extra leading dimension representing the number of rows of
the resulting Jacobian (e.g. the number of losses). All tensors in ``jac_tensors`` must
have the same first dimension. If ``None``, defaults to the identity matrix. In this case,
the standard Jacobian of ``tensors`` is computed, with one row for each value in the
``tensors``.
:param inputs: The tensors with respect to which the Jacobians must be computed. These must have
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
that were used to compute the ``tensors`` parameter.
:param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to
``False``.
:param parallel_chunk_size: The number of scalars to differentiate simultaneously in the
backward pass. If set to ``None``, all coordinates of ``tensors`` will be differentiated in
parallel at once. If set to ``1``, all coordinates will be differentiated sequentially. A
larger value results in faster differentiation, but also higher memory usage. Defaults to
``None``.
.. admonition::
Example
This example shows a simple usage of ``backward``.
>>> import torch
>>>
>>> from torchjd.autojac import backward
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> backward([y1, y2])
>>>
>>> param.jac
tensor([[-1., 1.],
[ 2., 4.]])
The ``.jac`` field of ``param`` now contains the Jacobian of
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
.. admonition::
Example
This is the same example as before, except that we explicitly specify ``jac_tensors`` as
the rows of the identity matrix (which is equivalent to using the default ``None``).
>>> import torch
>>>
>>> from torchjd.autojac import backward
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> J1 = torch.tensor([1.0, 0.0])
>>> J2 = torch.tensor([0.0, 1.0])
>>>
>>> backward([y1, y2], jac_tensors=[J1, J2])
>>>
>>> param.jac
tensor([[-1., 1.],
[ 2., 4.]])
Instead of using the identity ``jac_tensors``, you can backpropagate some Jacobians obtained
by a call to :func:`torchjd.autojac.jac` on a later part of the computation graph.
.. warning::
To differentiate in parallel, ``backward`` relies on ``torch.vmap``, which has some
limitations: `it does not work on the output of compiled functions
<https://github.com/pytorch/pytorch/issues/138422>`_, `when some tensors have
<https://github.com/SimplexLab/TorchJD/issues/184>`_ ``retains_grad=True`` or `when using an
RNN on CUDA <https://github.com/SimplexLab/TorchJD/issues/220>`_, for instance. If you
experience issues with ``backward`` try to use ``parallel_chunk_size=1`` to avoid relying on
``torch.vmap``.
"""
check_optional_positive_chunk_size(parallel_chunk_size)
tensors_ = as_checked_ordered_set(tensors, "tensors")
if len(tensors_) == 0:
raise ValueError("`tensors` cannot be empty.")
if inputs is None:
inputs_ = get_leaf_tensors(tensors=tensors_, excluded=set())
else:
inputs_ = OrderedSet(inputs)
jac_tensors_dict = create_jac_dict(tensors_, jac_tensors, "tensors", "jac_tensors")
transform = _create_transform(tensors_, inputs_, parallel_chunk_size, retain_graph)
transform(jac_tensors_dict)
def _create_transform(
tensors: OrderedSet[Tensor],
inputs: OrderedSet[Tensor],
parallel_chunk_size: int | None,
retain_graph: bool,
) -> Transform:
"""Creates the backward transform that computes and accumulates Jacobians."""
# Transform that computes the required Jacobians.
jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph)
# Transform that accumulates the result in the .jac field of the inputs.
accumulate = AccumulateJac()
return accumulate << jac