Skip to content

Commit 3762c6d

Browse files
committed
Add other usage example of autojac.jac
1 parent 33bce01 commit 3762c6d

2 files changed

Lines changed: 68 additions & 5 deletions

File tree

src/torchjd/autojac/_jac.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ def jac(
2323
) -> tuple[Tensor, ...]:
2424
r"""
2525
Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the
26-
result as a tuple, with one element per input tensor.
26+
result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to
27+
input `t` have shape `[m] + t.shape`.
2728
28-
:param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobian
29-
matrices will have one row for each value of each of these tensors.
29+
:param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will
30+
have one row for each value of each of these tensors.
3031
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
3132
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
3233
that were used to compute the ``outputs`` parameter.
@@ -38,6 +39,11 @@ def jac(
3839
larger value results in faster differentiation, but also higher memory usage. Defaults to
3940
``None``.
4041
42+
.. note::
43+
The only difference between this function and :func:`torchjd.autojac.backward`, is that it
44+
returns the Jacobians as a tuple, while :func:`torchjd.autojac.backward` stores them in the
45+
``.jac`` fields of the inputs.
46+
4147
.. admonition::
4248
Example
4349
@@ -58,8 +64,41 @@ def jac(
5864
(tensor([-1., 1.],
5965
[ 2., 4.]]),)
6066
61-
The returned tuple contains a single tensor (because there is a single param), that is the
62-
Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
67+
.. admonition::
68+
Example
69+
70+
The following example shows how to compute jacobians, combine them into a single Jacobian
71+
matrix, and compute its Gramian.
72+
73+
>>> import torch
74+
>>>
75+
>>> from torchjd.autojac import jac
76+
>>>
77+
>>> weight = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True) # shape: [2, 2]
78+
>>> bias = torch.tensor([0.5, -0.5], requires_grad=True) # shape: [2]
79+
>>> # Compute arbitrary quantities that are function of weight and bias
80+
>>> input_vec = torch.tensor([1., -1.])
81+
>>> y1 = weight @ input_vec + bias # shape: [2]
82+
>>> y2 = (weight ** 2).sum() + (bias ** 2).sum() # shape: [] (scalar)
83+
>>>
84+
>>> jacobians = jac([y1, y2], [weight, bias]) # shapes: [3, 2, 2], [3, 2]
85+
>>> jacobian_matrices = tuple(J.flatten(1) for J in jacobians) # shapes: [3, 4], [3, 2]
86+
>>> combined_jacobian_matrix = torch.concat(jacobian_matrices, dim=1) # shape: [3, 6]
87+
>>> gramian = combined_jacobian_matrix @ combined_jacobian_matrix.T # shape: [3, 3]
88+
>>> gramian
89+
tensor([[ 3., 0., -1.],
90+
[ 0., 3., -3.],
91+
[ -1., -3., 122.]])
92+
93+
The obtained gramian is a symmetric matrix containing the dot products between all pairs of
94+
gradients. It's a strong indicator of gradient norm (the diagonal elements are the squared
95+
norms of the gradients) and conflict (a negative off-diagonal value means that the gradients
96+
conflict). In fact, most aggregators base their decision entirely on the gramian.
97+
98+
In this case, we can see that the first two gradients (those of y1) both have a squared norm
99+
of 3, while the third gradient (that of y2) has a squared norm of 122. The first two
100+
gradients are exactly orthogonal (they have an inner product of 0), but they conflict with
101+
the third gradient (inner product of -1 and -3).
63102
64103
.. warning::
65104
To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some

tests/doc/test_jac.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,27 @@ def test_jac():
1818

1919
assert len(jacobians) == 1
2020
assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04)
21+
22+
23+
def test_jac_2():
24+
import torch
25+
26+
from torchjd.autojac import jac
27+
28+
weight = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) # shape: [2, 2]
29+
bias = torch.tensor([0.5, -0.5], requires_grad=True) # shape: [2]
30+
# Compute arbitrary quantities that are function of weight and bias
31+
input_vec = torch.tensor([1.0, -1.0])
32+
y1 = weight @ input_vec + bias # shape: [2]
33+
y2 = (weight**2).sum() + (bias**2).sum() # shape: [] (scalar)
34+
jacobians = jac([y1, y2], [weight, bias]) # shapes: [3, 2, 2], [3, 2]
35+
jacobian_matrices = tuple(J.flatten(1) for J in jacobians) # shapes: [3, 4], [3, 2]
36+
combined_jacobian_matrix = torch.concat(jacobian_matrices, dim=1) # shape: [3, 6]
37+
gramian = combined_jacobian_matrix @ combined_jacobian_matrix.T # shape: [3, 3]
38+
39+
assert_close(
40+
gramian,
41+
torch.tensor([[3.0, 0.0, -1.0], [0.0, 3.0, -3.0], [-1.0, -3.0, 122.0]]),
42+
rtol=0.0,
43+
atol=1e-04,
44+
)

0 commit comments

Comments
 (0)