Skip to content

Commit 3e843d0

Browse files
feat(autojac): Add autojac.jac (#505)
Co-authored-by: Pierre Quinton <pierre.quinton@epfl.ch>
1 parent 81648e5 commit 3e843d0

File tree

7 files changed

+490
-1
lines changed

7 files changed

+490
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ changelog does not include internal changes that do not affect the user.
1010

1111
### Added
1212

13+
- Added the function `torchjd.autojac.jac`. It's the same as `torchjd.autojac.backward` except that
14+
it returns the Jacobians as a tuple instead of storing them in the `.jac` fields of the inputs.
15+
Its interface is analog to that of `torch.autograd.grad`.
1316
- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose
1417
between `"min"`, `"median"`, and `"rmse"` scaling.
1518
- Added an attribute `gramian_weighting` to all aggregators that use a gramian-based `Weighting`.

docs/source/docs/autojac/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ autojac
1010

1111
backward.rst
1212
mtl_backward.rst
13+
jac.rst
1314
jac_to_grad.rst

docs/source/docs/autojac/jac.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
:hide-toc:
2+
3+
jac
4+
===
5+
6+
.. autofunction:: torchjd.autojac.jac

src/torchjd/autojac/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
"""
77

88
from ._backward import backward
9+
from ._jac import jac
910
from ._jac_to_grad import jac_to_grad
1011
from ._mtl_backward import mtl_backward
1112

12-
__all__ = ["backward", "jac_to_grad", "mtl_backward"]
13+
__all__ = ["backward", "jac", "jac_to_grad", "mtl_backward"]

src/torchjd/autojac/_jac.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from collections.abc import Sequence
2+
from typing import Iterable
3+
4+
from torch import Tensor
5+
6+
from torchjd.autojac._transform._base import Transform
7+
from torchjd.autojac._transform._diagonalize import Diagonalize
8+
from torchjd.autojac._transform._init import Init
9+
from torchjd.autojac._transform._jac import Jac
10+
from torchjd.autojac._transform._ordered_set import OrderedSet
11+
from torchjd.autojac._utils import (
12+
as_checked_ordered_set,
13+
check_optional_positive_chunk_size,
14+
get_leaf_tensors,
15+
)
16+
17+
18+
def jac(
19+
outputs: Sequence[Tensor] | Tensor,
20+
inputs: Iterable[Tensor] | None = None,
21+
retain_graph: bool = False,
22+
parallel_chunk_size: int | None = None,
23+
) -> tuple[Tensor, ...]:
24+
r"""
25+
Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the
26+
result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to
27+
input ``t`` has shape ``[m] + t.shape``.
28+
29+
:param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will
30+
have one row for each value of each of these tensors.
31+
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
32+
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
33+
that were used to compute the ``outputs`` parameter.
34+
:param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to
35+
``False``.
36+
:param parallel_chunk_size: The number of scalars to differentiate simultaneously in the
37+
backward pass. If set to ``None``, all coordinates of ``outputs`` will be differentiated in
38+
parallel at once. If set to ``1``, all coordinates will be differentiated sequentially. A
39+
larger value results in faster differentiation, but also higher memory usage. Defaults to
40+
``None``.
41+
42+
.. note::
43+
The only difference between this function and :func:`torchjd.autojac.backward`, is that it
44+
returns the Jacobians as a tuple, while :func:`torchjd.autojac.backward` stores them in the
45+
``.jac`` fields of the inputs.
46+
47+
.. admonition::
48+
Example
49+
50+
The following example shows how to use ``jac``.
51+
52+
>>> import torch
53+
>>>
54+
>>> from torchjd.autojac import jac
55+
>>>
56+
>>> param = torch.tensor([1., 2.], requires_grad=True)
57+
>>> # Compute arbitrary quantities that are function of param
58+
>>> y1 = torch.tensor([-1., 1.]) @ param
59+
>>> y2 = (param ** 2).sum()
60+
>>>
61+
>>> jacobians = jac([y1, y2], [param])
62+
>>>
63+
>>> jacobians
64+
(tensor([-1., 1.],
65+
[ 2., 4.]]),)
66+
67+
.. admonition::
68+
Example
69+
70+
The following example shows how to compute jacobians, combine them into a single Jacobian
71+
matrix, and compute its Gramian.
72+
73+
>>> import torch
74+
>>>
75+
>>> from torchjd.autojac import jac
76+
>>>
77+
>>> weight = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True) # shape: [2, 2]
78+
>>> bias = torch.tensor([0.5, -0.5], requires_grad=True) # shape: [2]
79+
>>> # Compute arbitrary quantities that are function of weight and bias
80+
>>> input_vec = torch.tensor([1., -1.])
81+
>>> y1 = weight @ input_vec + bias # shape: [2]
82+
>>> y2 = (weight ** 2).sum() + (bias ** 2).sum() # shape: [] (scalar)
83+
>>>
84+
>>> jacobians = jac([y1, y2], [weight, bias]) # shapes: [3, 2, 2], [3, 2]
85+
>>> jacobian_matrices = tuple(J.flatten(1) for J in jacobians) # shapes: [3, 4], [3, 2]
86+
>>> combined_jacobian_matrix = torch.concat(jacobian_matrices, dim=1) # shape: [3, 6]
87+
>>> gramian = combined_jacobian_matrix @ combined_jacobian_matrix.T # shape: [3, 3]
88+
>>> gramian
89+
tensor([[ 3., 0., -1.],
90+
[ 0., 3., -3.],
91+
[ -1., -3., 122.]])
92+
93+
The obtained gramian is a symmetric matrix containing the dot products between all pairs of
94+
gradients. It's a strong indicator of gradient norm (the diagonal elements are the squared
95+
norms of the gradients) and conflict (a negative off-diagonal value means that the gradients
96+
conflict). In fact, most aggregators base their decision entirely on the gramian.
97+
98+
In this case, we can see that the first two gradients (those of y1) both have a squared norm
99+
of 3, while the third gradient (that of y2) has a squared norm of 122. The first two
100+
gradients are exactly orthogonal (they have an inner product of 0), but they conflict with
101+
the third gradient (inner product of -1 and -3).
102+
103+
.. warning::
104+
To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some
105+
limitations: `it does not work on the output of compiled functions
106+
<https://github.com/pytorch/pytorch/issues/138422>`_, `when some tensors have
107+
<https://github.com/TorchJD/torchjd/issues/184>`_ ``retains_grad=True`` or `when using an
108+
RNN on CUDA <https://github.com/TorchJD/torchjd/issues/220>`_, for instance. If you
109+
experience issues with ``jac`` try to use ``parallel_chunk_size=1`` to avoid relying on
110+
``torch.vmap``.
111+
"""
112+
113+
check_optional_positive_chunk_size(parallel_chunk_size)
114+
115+
outputs_ = as_checked_ordered_set(outputs, "outputs")
116+
if len(outputs_) == 0:
117+
raise ValueError("`outputs` cannot be empty")
118+
119+
if inputs is None:
120+
inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set())
121+
inputs_with_repetition = list(inputs_)
122+
else:
123+
inputs_with_repetition = list(inputs) # Create a list to avoid emptying generator
124+
inputs_ = OrderedSet(inputs_with_repetition)
125+
126+
jac_transform = _create_transform(
127+
outputs=outputs_,
128+
inputs=inputs_,
129+
retain_graph=retain_graph,
130+
parallel_chunk_size=parallel_chunk_size,
131+
)
132+
133+
result = jac_transform({})
134+
return tuple(result[input] for input in inputs_with_repetition)
135+
136+
137+
def _create_transform(
138+
outputs: OrderedSet[Tensor],
139+
inputs: OrderedSet[Tensor],
140+
retain_graph: bool,
141+
parallel_chunk_size: int | None,
142+
) -> Transform:
143+
# Transform that creates gradient outputs containing only ones.
144+
init = Init(outputs)
145+
146+
# Transform that turns the gradients into Jacobians.
147+
diag = Diagonalize(outputs)
148+
149+
# Transform that computes the required Jacobians.
150+
jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph)
151+
152+
return jac << diag << init

tests/doc/test_jac.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
This file contains the test of the jac usage example, with a verification of the value of the obtained jacobians tuple.
3+
"""
4+
5+
from torch.testing import assert_close
6+
7+
8+
def test_jac():
9+
import torch
10+
11+
from torchjd.autojac import jac
12+
13+
param = torch.tensor([1.0, 2.0], requires_grad=True)
14+
# Compute arbitrary quantities that are function of param
15+
y1 = torch.tensor([-1.0, 1.0]) @ param
16+
y2 = (param**2).sum()
17+
jacobians = jac([y1, y2], [param])
18+
19+
assert len(jacobians) == 1
20+
assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04)
21+
22+
23+
def test_jac_2():
24+
import torch
25+
26+
from torchjd.autojac import jac
27+
28+
weight = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) # shape: [2, 2]
29+
bias = torch.tensor([0.5, -0.5], requires_grad=True) # shape: [2]
30+
# Compute arbitrary quantities that are function of weight and bias
31+
input_vec = torch.tensor([1.0, -1.0])
32+
y1 = weight @ input_vec + bias # shape: [2]
33+
y2 = (weight**2).sum() + (bias**2).sum() # shape: [] (scalar)
34+
jacobians = jac([y1, y2], [weight, bias]) # shapes: [3, 2, 2], [3, 2]
35+
jacobian_matrices = tuple(J.flatten(1) for J in jacobians) # shapes: [3, 4], [3, 2]
36+
combined_jacobian_matrix = torch.concat(jacobian_matrices, dim=1) # shape: [3, 6]
37+
gramian = combined_jacobian_matrix @ combined_jacobian_matrix.T # shape: [3, 3]
38+
39+
assert_close(
40+
gramian,
41+
torch.tensor([[3.0, 0.0, -1.0], [0.0, 3.0, -3.0], [-1.0, -3.0, 122.0]]),
42+
rtol=0.0,
43+
atol=1e-04,
44+
)

0 commit comments

Comments
 (0)