@@ -23,10 +23,11 @@ def jac(
2323) -> tuple [Tensor , ...]:
2424 r"""
2525 Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the
26- result as a tuple, with one element per input tensor.
26+ result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to
27+ input `t` have shape `[m] + t.shape`.
2728
28- :param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobian
29- matrices will have one row for each value of each of these tensors.
29+ :param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will
30+ have one row for each value of each of these tensors.
3031 :param inputs: The tensors with respect to which the Jacobian must be computed. These must have
3132 their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
3233 that were used to compute the ``outputs`` parameter.
@@ -38,6 +39,11 @@ def jac(
3839 larger value results in faster differentiation, but also higher memory usage. Defaults to
3940 ``None``.
4041
42+ .. note::
43+ The only difference between this function and :func:`torchjd.autojac.backward`, is that it
44+ returns the Jacobians as a tuple, while :func:`torchjd.autojac.backward` stores them in the
45+ ``.jac`` fields of the inputs.
46+
4147 .. admonition::
4248 Example
4349
@@ -58,8 +64,41 @@ def jac(
5864 (tensor([-1., 1.],
5965 [ 2., 4.]]),)
6066
61- The returned tuple contains a single tensor (because there is a single param), that is the
62- Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
67+ .. admonition::
68+ Example
69+
70+ The following example shows how to compute jacobians, combine them into a single Jacobian
71+ matrix, and compute its Gramian.
72+
73+ >>> import torch
74+ >>>
75+ >>> from torchjd.autojac import jac
76+ >>>
77+ >>> weight = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True) # shape: [2, 2]
78+ >>> bias = torch.tensor([0.5, -0.5], requires_grad=True) # shape: [2]
79+ >>> # Compute arbitrary quantities that are function of weight and bias
80+ >>> input_vec = torch.tensor([1., -1.])
81+ >>> y1 = weight @ input_vec + bias # shape: [2]
82+ >>> y2 = (weight ** 2).sum() + (bias ** 2).sum() # shape: [] (scalar)
83+ >>>
84+ >>> jacobians = jac([y1, y2], [weight, bias]) # shapes: [3, 2, 2], [3, 2]
85+ >>> jacobian_matrices = tuple(J.flatten(1) for J in jacobians) # shapes: [3, 4], [3, 2]
86+ >>> combined_jacobian_matrix = torch.concat(jacobian_matrices, dim=1) # shape: [3, 6]
87+ >>> gramian = combined_jacobian_matrix @ combined_jacobian_matrix.T # shape: [3, 3]
88+ >>> gramian
89+ tensor([[ 3., 0., -1.],
90+ [ 0., 3., -3.],
91+ [ -1., -3., 122.]])
92+
93+ The obtained gramian is a symmetric matrix containing the dot products between all pairs of
94+ gradients. It's a strong indicator of gradient norm (the diagonal elements are the squared
95+ norms of the gradients) and conflict (a negative off-diagonal value means that the gradients
96+ conflict). In fact, most aggregators base their decision entirely on the gramian.
97+
98+ In this case, we can see that the first two gradients (those of y1) both have a squared norm
99+ of 3, while the third gradient (that of y2) has a squared norm of 122. The first two
100+ gradients are exactly orthogonal (they have an inner product of 0), but they conflict with
101+ the third gradient (inner product of -1 and -3).
63102
64103 .. warning::
65104 To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some
0 commit comments