|
| 1 | +from collections.abc import Sequence |
| 2 | +from typing import Iterable |
| 3 | + |
| 4 | +from torch import Tensor |
| 5 | + |
| 6 | +from torchjd.autojac._transform._base import Transform |
| 7 | +from torchjd.autojac._transform._diagonalize import Diagonalize |
| 8 | +from torchjd.autojac._transform._init import Init |
| 9 | +from torchjd.autojac._transform._jac import Jac |
| 10 | +from torchjd.autojac._transform._ordered_set import OrderedSet |
| 11 | +from torchjd.autojac._utils import ( |
| 12 | + as_checked_ordered_set, |
| 13 | + check_optional_positive_chunk_size, |
| 14 | + get_leaf_tensors, |
| 15 | +) |
| 16 | + |
| 17 | + |
| 18 | +def jac( |
| 19 | + outputs: Sequence[Tensor] | Tensor, |
| 20 | + inputs: Iterable[Tensor] | None = None, |
| 21 | + retain_graph: bool = False, |
| 22 | + parallel_chunk_size: int | None = None, |
| 23 | +) -> tuple[Tensor, ...]: |
| 24 | + r""" |
| 25 | + Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the |
| 26 | + result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to |
| 27 | + input ``t`` has shape ``[m] + t.shape``. |
| 28 | +
|
| 29 | + :param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will |
| 30 | + have one row for each value of each of these tensors. |
| 31 | + :param inputs: The tensors with respect to which the Jacobian must be computed. These must have |
| 32 | + their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors |
| 33 | + that were used to compute the ``outputs`` parameter. |
| 34 | + :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to |
| 35 | + ``False``. |
| 36 | + :param parallel_chunk_size: The number of scalars to differentiate simultaneously in the |
| 37 | + backward pass. If set to ``None``, all coordinates of ``outputs`` will be differentiated in |
| 38 | + parallel at once. If set to ``1``, all coordinates will be differentiated sequentially. A |
| 39 | + larger value results in faster differentiation, but also higher memory usage. Defaults to |
| 40 | + ``None``. |
| 41 | +
|
| 42 | + .. note:: |
| 43 | + The only difference between this function and :func:`torchjd.autojac.backward`, is that it |
| 44 | + returns the Jacobians as a tuple, while :func:`torchjd.autojac.backward` stores them in the |
| 45 | + ``.jac`` fields of the inputs. |
| 46 | +
|
| 47 | + .. admonition:: |
| 48 | + Example |
| 49 | +
|
| 50 | + The following example shows how to use ``jac``. |
| 51 | +
|
| 52 | + >>> import torch |
| 53 | + >>> |
| 54 | + >>> from torchjd.autojac import jac |
| 55 | + >>> |
| 56 | + >>> param = torch.tensor([1., 2.], requires_grad=True) |
| 57 | + >>> # Compute arbitrary quantities that are function of param |
| 58 | + >>> y1 = torch.tensor([-1., 1.]) @ param |
| 59 | + >>> y2 = (param ** 2).sum() |
| 60 | + >>> |
| 61 | + >>> jacobians = jac([y1, y2], [param]) |
| 62 | + >>> |
| 63 | + >>> jacobians |
| 64 | + (tensor([-1., 1.], |
| 65 | + [ 2., 4.]]),) |
| 66 | +
|
| 67 | + .. admonition:: |
| 68 | + Example |
| 69 | +
|
| 70 | + The following example shows how to compute jacobians, combine them into a single Jacobian |
| 71 | + matrix, and compute its Gramian. |
| 72 | +
|
| 73 | + >>> import torch |
| 74 | + >>> |
| 75 | + >>> from torchjd.autojac import jac |
| 76 | + >>> |
| 77 | + >>> weight = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True) # shape: [2, 2] |
| 78 | + >>> bias = torch.tensor([0.5, -0.5], requires_grad=True) # shape: [2] |
| 79 | + >>> # Compute arbitrary quantities that are function of weight and bias |
| 80 | + >>> input_vec = torch.tensor([1., -1.]) |
| 81 | + >>> y1 = weight @ input_vec + bias # shape: [2] |
| 82 | + >>> y2 = (weight ** 2).sum() + (bias ** 2).sum() # shape: [] (scalar) |
| 83 | + >>> |
| 84 | + >>> jacobians = jac([y1, y2], [weight, bias]) # shapes: [3, 2, 2], [3, 2] |
| 85 | + >>> jacobian_matrices = tuple(J.flatten(1) for J in jacobians) # shapes: [3, 4], [3, 2] |
| 86 | + >>> combined_jacobian_matrix = torch.concat(jacobian_matrices, dim=1) # shape: [3, 6] |
| 87 | + >>> gramian = combined_jacobian_matrix @ combined_jacobian_matrix.T # shape: [3, 3] |
| 88 | + >>> gramian |
| 89 | + tensor([[ 3., 0., -1.], |
| 90 | + [ 0., 3., -3.], |
| 91 | + [ -1., -3., 122.]]) |
| 92 | +
|
| 93 | + The obtained gramian is a symmetric matrix containing the dot products between all pairs of |
| 94 | + gradients. It's a strong indicator of gradient norm (the diagonal elements are the squared |
| 95 | + norms of the gradients) and conflict (a negative off-diagonal value means that the gradients |
| 96 | + conflict). In fact, most aggregators base their decision entirely on the gramian. |
| 97 | +
|
| 98 | + In this case, we can see that the first two gradients (those of y1) both have a squared norm |
| 99 | + of 3, while the third gradient (that of y2) has a squared norm of 122. The first two |
| 100 | + gradients are exactly orthogonal (they have an inner product of 0), but they conflict with |
| 101 | + the third gradient (inner product of -1 and -3). |
| 102 | +
|
| 103 | + .. warning:: |
| 104 | + To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some |
| 105 | + limitations: `it does not work on the output of compiled functions |
| 106 | + <https://github.com/pytorch/pytorch/issues/138422>`_, `when some tensors have |
| 107 | + <https://github.com/TorchJD/torchjd/issues/184>`_ ``retains_grad=True`` or `when using an |
| 108 | + RNN on CUDA <https://github.com/TorchJD/torchjd/issues/220>`_, for instance. If you |
| 109 | + experience issues with ``jac`` try to use ``parallel_chunk_size=1`` to avoid relying on |
| 110 | + ``torch.vmap``. |
| 111 | + """ |
| 112 | + |
| 113 | + check_optional_positive_chunk_size(parallel_chunk_size) |
| 114 | + |
| 115 | + outputs_ = as_checked_ordered_set(outputs, "outputs") |
| 116 | + if len(outputs_) == 0: |
| 117 | + raise ValueError("`outputs` cannot be empty") |
| 118 | + |
| 119 | + if inputs is None: |
| 120 | + inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set()) |
| 121 | + inputs_with_repetition = list(inputs_) |
| 122 | + else: |
| 123 | + inputs_with_repetition = list(inputs) # Create a list to avoid emptying generator |
| 124 | + inputs_ = OrderedSet(inputs_with_repetition) |
| 125 | + |
| 126 | + jac_transform = _create_transform( |
| 127 | + outputs=outputs_, |
| 128 | + inputs=inputs_, |
| 129 | + retain_graph=retain_graph, |
| 130 | + parallel_chunk_size=parallel_chunk_size, |
| 131 | + ) |
| 132 | + |
| 133 | + result = jac_transform({}) |
| 134 | + return tuple(result[input] for input in inputs_with_repetition) |
| 135 | + |
| 136 | + |
| 137 | +def _create_transform( |
| 138 | + outputs: OrderedSet[Tensor], |
| 139 | + inputs: OrderedSet[Tensor], |
| 140 | + retain_graph: bool, |
| 141 | + parallel_chunk_size: int | None, |
| 142 | +) -> Transform: |
| 143 | + # Transform that creates gradient outputs containing only ones. |
| 144 | + init = Init(outputs) |
| 145 | + |
| 146 | + # Transform that turns the gradients into Jacobians. |
| 147 | + diag = Diagonalize(outputs) |
| 148 | + |
| 149 | + # Transform that computes the required Jacobians. |
| 150 | + jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph) |
| 151 | + |
| 152 | + return jac << diag << init |
0 commit comments