Skip to content

Commit d7bbfae

Browse files
committed
feat: Add autojac.jac
1 parent 4de55ab commit d7bbfae

File tree

7 files changed

+190
-1
lines changed

7 files changed

+190
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ changes that do not affect the user.
1010

1111
### Added
1212

13+
- Added the function `torchjd.autojac.jac` to compute the Jacobian of some outputs with respect to
14+
some inputs, without doing any aggregation. Its interface is very similar to
15+
`torch.autograd.grad`.
1316
- Added `__all__` in the `__init__.py` of packages. This should prevent PyLance from triggering warnings when importing from `torchjd`.
1417

1518
## [0.8.0] - 2025-11-13

docs/source/docs/autojac/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ autojac
1010

1111
backward.rst
1212
mtl_backward.rst
13+
jac.rst

docs/source/docs/autojac/jac.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
:hide-toc:
2+
3+
jac
4+
===
5+
6+
.. autofunction:: torchjd.autojac.jac

src/torchjd/autojac/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
from ._backward import backward
9+
from ._jac import jac
910
from ._mtl_backward import mtl_backward
1011

11-
__all__ = ["backward", "mtl_backward"]
12+
__all__ = ["backward", "jac", "mtl_backward"]

src/torchjd/autojac/_jac.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from collections.abc import Sequence
2+
from typing import Iterable
3+
4+
from torch import Tensor
5+
6+
from torchjd.autojac._transform._base import Transform
7+
from torchjd.autojac._transform._diagonalize import Diagonalize
8+
from torchjd.autojac._transform._init import Init
9+
from torchjd.autojac._transform._jac import Jac
10+
from torchjd.autojac._transform._ordered_set import OrderedSet
11+
from torchjd.autojac._utils import (
12+
as_checked_ordered_set,
13+
check_optional_positive_chunk_size,
14+
get_leaf_tensors,
15+
)
16+
17+
18+
def jac(
19+
outputs: Sequence[Tensor] | Tensor,
20+
inputs: Iterable[Tensor] | None = None,
21+
retain_graph: bool = False,
22+
parallel_chunk_size: int | None = None,
23+
) -> tuple[Tensor, ...]:
24+
r"""
25+
Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the
26+
result as a tuple, with one element per input tensor.
27+
28+
:param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobian
29+
matrices will have one row for each value of each of these tensors.
30+
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
31+
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
32+
that were used to compute the ``outputs`` parameter.
33+
:param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to
34+
``False``.
35+
:param parallel_chunk_size: The number of scalars to differentiate simultaneously in the
36+
backward pass. If set to ``None``, all coordinates of ``outputs`` will be differentiated in
37+
parallel at once. If set to ``1``, all coordinates will be differentiated sequentially. A
38+
larger value results in faster differentiation, but also higher memory usage. Defaults to
39+
``None``.
40+
41+
.. admonition::
42+
Example
43+
44+
The following example shows how to use ``jac``.
45+
46+
>>> import torch
47+
>>>
48+
>>> from torchjd.autojac import jac
49+
>>>
50+
>>> param = torch.tensor([1., 2.], requires_grad=True)
51+
>>> # Compute arbitrary quantities that are function of param
52+
>>> y1 = torch.tensor([-1., 1.]) @ param
53+
>>> y2 = (param ** 2).sum()
54+
>>>
55+
>>> jacobians = jac([y1, y2], [param])
56+
>>>
57+
>>> jacobians
58+
(tensor([-1., 1.],
59+
[ 2., 4.]]),)
60+
61+
The returned tuple contains a single tensor (because there is a single param), that is the
62+
Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
63+
64+
.. warning::
65+
To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some
66+
limitations: `it does not work on the output of compiled functions
67+
<https://github.com/pytorch/pytorch/issues/138422>`_, `when some tensors have
68+
<https://github.com/TorchJD/torchjd/issues/184>`_ ``retains_grad=True`` or `when using an
69+
RNN on CUDA <https://github.com/TorchJD/torchjd/issues/220>`_, for instance. If you
70+
experience issues with ``backward`` try to use ``parallel_chunk_size=1`` to avoid relying on
71+
``torch.vmap``.
72+
"""
73+
74+
check_optional_positive_chunk_size(parallel_chunk_size)
75+
outputs_ = as_checked_ordered_set(outputs, "outputs")
76+
77+
if inputs is None:
78+
inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set())
79+
else:
80+
inputs_ = OrderedSet(inputs)
81+
82+
if len(outputs_) == 0:
83+
raise ValueError("`outputs` cannot be empty")
84+
85+
if len(inputs_) == 0:
86+
raise ValueError("`inputs` cannot be empty")
87+
88+
jac_transform = _create_transform(
89+
outputs=outputs_,
90+
inputs=inputs_,
91+
retain_graph=retain_graph,
92+
parallel_chunk_size=parallel_chunk_size,
93+
)
94+
95+
result = jac_transform({})
96+
return tuple(val for val in result.values())
97+
98+
99+
def _create_transform(
100+
outputs: OrderedSet[Tensor],
101+
inputs: OrderedSet[Tensor],
102+
retain_graph: bool,
103+
parallel_chunk_size: int | None,
104+
) -> Transform:
105+
# Transform that creates gradient outputs containing only ones.
106+
init = Init(outputs)
107+
108+
# Transform that turns the gradients into Jacobians.
109+
diag = Diagonalize(outputs)
110+
111+
# Transform that computes the required Jacobians.
112+
jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph)
113+
114+
return jac << diag << init

tests/doc/test_jac.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""
2+
This file contains the test of the jac usage example, with a verification of the value of the obtained jacobians tuple.
3+
"""
4+
5+
from torch.testing import assert_close
6+
7+
8+
def test_jac():
9+
import torch
10+
11+
from torchjd.autojac import jac
12+
13+
param = torch.tensor([1.0, 2.0], requires_grad=True)
14+
# Compute arbitrary quantities that are function of param
15+
y1 = torch.tensor([-1.0, 1.0]) @ param
16+
y2 = (param**2).sum()
17+
jacobians = jac([y1, y2], [param])
18+
19+
assert len(jacobians) == 1
20+
assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04)

tests/unit/autojac/test_jac.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from utils.tensors import tensor_
2+
3+
from torchjd.autojac import jac
4+
from torchjd.autojac._jac import _create_transform
5+
from torchjd.autojac._transform import OrderedSet
6+
7+
8+
def test_check_create_transform():
9+
"""Tests that _create_transform creates a valid Transform."""
10+
11+
a1 = tensor_([1.0, 2.0], requires_grad=True)
12+
a2 = tensor_([3.0, 4.0], requires_grad=True)
13+
14+
y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum()
15+
y2 = (a1**2).sum() + a2.norm()
16+
17+
transform = _create_transform(
18+
outputs=OrderedSet([y1, y2]),
19+
inputs=OrderedSet([a1, a2]),
20+
retain_graph=False,
21+
parallel_chunk_size=None,
22+
)
23+
24+
output_keys = transform.check_keys(set())
25+
assert output_keys == {a1, a2}
26+
27+
28+
def test_jac():
29+
"""Tests that jac works."""
30+
31+
a1 = tensor_([1.0, 2.0], requires_grad=True)
32+
a2 = tensor_([3.0, 4.0], requires_grad=True)
33+
inputs = [a1, a2]
34+
35+
y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum()
36+
y2 = (a1**2).sum() + a2.norm()
37+
outputs = [y1, y2]
38+
39+
jacobians = jac(outputs, inputs)
40+
41+
assert len(jacobians) == len([a1, a2])
42+
for jacobian, a in zip(jacobians, [a1, a2]):
43+
assert jacobian.shape[0] == len([y1, y2])
44+
assert jacobian.shape[1:] == a.shape

0 commit comments

Comments
 (0)