Skip to content

Commit d5ed6f3

Browse files
authored
docs(autojac): Document Transforms (#326)
1 parent aa87f95 commit d5ed6f3

File tree

10 files changed

+170
-49
lines changed

10 files changed

+170
-49
lines changed

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,28 @@
99

1010

1111
class Differentiate(Transform[_A, _A], ABC):
12+
"""
13+
Abstract base class for transforms responsible for differentiating some outputs with respect to
14+
some inputs.
15+
16+
:param outputs: Tensors to differentiate.
17+
:param inputs: Tensors with respect to which we differentiate.
18+
:param retain_graph: If False, the graph used to compute the grads will be freed.
19+
:param create_graph: If True, graph of the derivative will be constructed, allowing to compute
20+
higher order derivative products.
21+
22+
.. note:: The order of outputs and inputs only matters because we have no guarantee that
23+
torch.autograd.grad is *exactly* equivariant to input permutations and invariant to output
24+
(with their corresponding grad_output) permutations.
25+
"""
26+
1227
def __init__(
1328
self,
1429
outputs: OrderedSet[Tensor],
1530
inputs: OrderedSet[Tensor],
1631
retain_graph: bool,
1732
create_graph: bool,
1833
):
19-
# The order of outputs and inputs only matters because we have no guarantee that
20-
# torch.autograd.grad is *exactly* equivariant to input permutations and invariant to
21-
# output (with their corresponding grad_output) permutations.
22-
2334
self.outputs = list(outputs)
2435
self.inputs = list(inputs)
2536
self.retain_graph = retain_graph

src/torchjd/autojac/_transform/accumulate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55

66

77
class Accumulate(Transform[Gradients, EmptyTensorDict]):
8-
def __call__(self, gradients: Gradients) -> EmptyTensorDict:
9-
"""
10-
Accumulates gradients with respect to keys in their ``.grad`` field.
11-
"""
8+
"""Transform that accumulates gradients with respect to keys into their ``grad`` field."""
129

10+
def __call__(self, gradients: Gradients) -> EmptyTensorDict:
1311
for key in gradients.keys():
1412
_check_expects_grad(key)
1513
if hasattr(key, "grad") and key.grad is not None:

src/torchjd/autojac/_transform/aggregate.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515

1616

1717
class Aggregate(Transform[Jacobians, Gradients]):
18+
"""
19+
Transform aggregating Jacobians into Gradients.
20+
21+
It does so by reshaping these Jacobians into matrices, concatenating them into a single matrix,
22+
applying an aggregator to it, separating the result back into one gradient vector per key, and
23+
finally reshaping those into gradients of the same shape as their corresponding keys.
24+
25+
:param aggregator: The aggregator used to aggregate the concatenated jacobian matrix.
26+
:param key_order: Order in which the different jacobian matrices must be concatenated.
27+
"""
28+
1829
def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
1930
matrixify = _Matrixify()
2031
aggregate_matrices = _AggregateMatrices(aggregator, key_order)
@@ -31,6 +42,16 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
3142

3243

3344
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
45+
"""
46+
Transform aggregating JacobiansMatrices into GradientsVectors.
47+
48+
It does so by concatenating the matrices into a single matrix, applying an aggregator to it and
49+
separating the result back into one gradient vector per key.
50+
51+
:param aggregator: The aggregator used to aggregate the concatenated jacobian matrix.
52+
:param key_order: Order in which the different jacobian matrices must be concatenated.
53+
"""
54+
3455
def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
3556
self.key_order = key_order
3657
self.aggregator = aggregator
@@ -112,6 +133,8 @@ def _disunite(
112133

113134

114135
class _Matrixify(Transform[Jacobians, JacobianMatrices]):
136+
"""Transform reshaping Jacobians into JacobianMatrices."""
137+
115138
def __call__(self, jacobians: Jacobians) -> JacobianMatrices:
116139
jacobian_matrices = {
117140
key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items()
@@ -123,6 +146,8 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
123146

124147

125148
class _Reshape(Transform[GradientVectors, Gradients]):
149+
"""Transform reshaping GradientVectors into Gradients."""
150+
126151
def __call__(self, gradient_vectors: GradientVectors) -> Gradients:
127152
gradients = {
128153
key: gradient_vector.view(key.shape)

src/torchjd/autojac/_transform/base.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,9 @@ class RequirementError(ValueError):
1515

1616

1717
class Transform(Generic[_B, _C], ABC):
18-
r"""
18+
"""
1919
Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
20-
descent backward phase. A transform maps a :class:`~torchjd.transform.tensor_dict.TensorDict` to
21-
another. The input :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `required_keys`
22-
and the output :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `output_keys`.
23-
24-
Formally a transform is a function:
25-
26-
.. math::
27-
f:\mathbb R^{n_1+\dots+n_p}\to \mathbb R^{m_1+\dots+m_q}
28-
29-
where we have ``p`` `required_keys`, ``q`` `output_keys`, ``n_i`` is the number of elements in
30-
the value associated to the ``i`` th `required_key` of the input
31-
:class:`~torchjd.transform.tensor_dict.TensorDict` and ``m_j`` is the number of elements in the
32-
value associated to the ``j`` th `output_key` of the output
33-
:class:`~torchjd.transform.tensor_dict.TensorDict`.
34-
35-
As they are mathematical functions, transforms can be composed together as long as their
36-
domains and range meaningfully match.
20+
descent backward phase. A transform maps a TensorDict to another.
3721
"""
3822

3923
def compose(self, other: Transform[_A, _B]) -> Transform[_A, _C]:
@@ -67,6 +51,13 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
6751

6852

6953
class Composition(Transform[_A, _C]):
54+
"""
55+
Transform corresponding to the composition of two transforms inner and outer.
56+
57+
:param inner: The transform to apply first, to the input.
58+
:param outer: The transform to apply second, to the result of ``inner``.
59+
"""
60+
7061
def __init__(self, outer: Transform[_B, _C], inner: Transform[_A, _B]):
7162
self.outer = outer
7263
self.inner = inner
@@ -85,6 +76,13 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
8576

8677

8778
class Conjunction(Transform[_A, _B]):
79+
"""
80+
Transform applying several transforms to the same input, and combining the results (by union)
81+
into a single TensorDict.
82+
83+
:param transforms: The transforms to apply. Their outputs should have disjoint sets of keys.
84+
"""
85+
8886
def __init__(self, transforms: Sequence[Transform[_A, _B]]):
8987
self.transforms = transforms
9088

src/torchjd/autojac/_transform/diagonalize.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,51 @@
77

88

99
class Diagonalize(Transform[Gradients, Jacobians]):
10+
"""
11+
Transform diagonalizing Gradients into Jacobians.
12+
13+
The first dimension of the returned Jacobians will be equal to the total number of elements in
14+
the tensors of the input tensor dict. The exact behavior of the diagonalization is best
15+
explained by some examples.
16+
17+
Example 1:
18+
The input is one tensor of shape [3] and of value [1 2 3].
19+
The output Jacobian will be:
20+
[[1 0 0]
21+
[0 2 0]
22+
[0 0 3]]
23+
24+
Example 2:
25+
The input is one tensor of shape [2, 2] and of value [[4 5] [6 7]].
26+
The output Jacobian will be:
27+
[[[4 0] [0 0]]
28+
[[0 5] [0 0]]
29+
[[0 0] [6 0]]
30+
[[0 0] [0 7]]]
31+
32+
Example 3:
33+
The input is two tensors, of shapes [3] and [2, 2] and of values [1 2 3] and [[4 5] [6 7]].
34+
If the key_order has the tensor of shape [3] appear first and the one of shape [2, 2] appear
35+
second, the output Jacobians will be:
36+
[[1 0 0]
37+
[0 2 0]
38+
[0 0 3]
39+
[0 0 0]
40+
[0 0 0]
41+
[0 0 0]
42+
[0 0 0]] and
43+
[[[0 0] [0 0]]
44+
[[0 0] [0 0]]
45+
[[0 0] [0 0]]
46+
[[4 0] [0 0]]
47+
[[0 5] [0 0]]
48+
[[0 0] [6 0]]
49+
[[0 0] [0 7]]]
50+
51+
:param key_order: The order in which the keys are represented in the rows of the output
52+
Jacobians.
53+
"""
54+
1055
def __init__(self, key_order: OrderedSet[Tensor]):
1156
self.key_order = key_order
1257
self.indices: list[tuple[int, int]] = []

src/torchjd/autojac/_transform/grad.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@
1010

1111

1212
class Grad(Differentiate[Gradients]):
13+
"""
14+
Transform computing the gradient of each output element with respect to each input tensor, and
15+
applying the linear transformations represented by provided the grad_outputs to the results.
16+
17+
:param outputs: Tensors to differentiate.
18+
:param inputs: Tensors with respect to which we differentiate.
19+
:param retain_graph: If False, the graph used to compute the grads will be freed. Defaults to
20+
False.
21+
:param create_graph: If True, graph of the derivative will be constructed, allowing to compute
22+
higher order derivative products. Defaults to False.
23+
24+
.. note:: The order of outputs and inputs only matters because we have no guarantee that
25+
torch.autograd.grad is *exactly* equivariant to input permutations and invariant to output
26+
(with their corresponding grad_output) permutations.
27+
"""
28+
1329
def __init__(
1430
self,
1531
outputs: OrderedSet[Tensor],
@@ -21,14 +37,15 @@ def __init__(
2137

2238
def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
2339
"""
24-
Computes the gradient of each output with respect to each input, and applies the linear
25-
transformations represented by the grad_outputs to the results.
40+
Computes the gradient of each output element with respect to each input tensor, and applies
41+
the linear transformations represented by the grad_outputs to the results.
2642
27-
Returns one gradient per input.
43+
Returns one gradient per input, corresponding to the sum of the scaled gradients with
44+
respect to this input.
2845
29-
:param grad_outputs: The sequence of scalar tensors to scale the obtained gradients with.
30-
Its length should be equal to the length of ``outputs``. Each grad_output should have
31-
the same shape as the corresponding output.
46+
:param grad_outputs: The sequence of tensors to scale the obtained gradients with. Its
47+
length should be equal to the length of ``outputs``. Each grad_output should have the
48+
same shape as the corresponding output.
3249
"""
3350

3451
if len(self.inputs) == 0:

src/torchjd/autojac/_transform/init.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,16 @@
88

99

1010
class Init(Transform[EmptyTensorDict, Gradients]):
11+
"""
12+
Transform returning Gradients filled with ones for each of the provided values.
13+
14+
:param values: Tensors for which Gradients must be returned.
15+
"""
16+
1117
def __init__(self, values: Set[Tensor]):
1218
self.values = values
1319

1420
def __call__(self, input: EmptyTensorDict) -> Gradients:
15-
r"""
16-
Computes the gradients of the ``value`` with respect to itself. Returns the result as a
17-
dictionary. The only key of the dictionary is ``value``. The corresponding gradient is a
18-
tensor of 1s of identical shape, because :math:`\frac{\partial v}{\partial v} = 1` for any
19-
:math:`v`.
20-
"""
21-
2221
return Gradients({value: torch.ones_like(value) for value in self.values})
2322

2423
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:

src/torchjd/autojac/_transform/jac.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,26 @@
1313

1414

1515
class Jac(Differentiate[Jacobians]):
16+
"""
17+
Transform computing the jacobian of each output with respect to each input, and applying the
18+
linear transformations represented by the argument jac_outputs to the results.
19+
20+
:param outputs: Tensors to differentiate.
21+
:param inputs: Tensors with respect to which we differentiate.
22+
:param chunk_size: The number of scalars to differentiate simultaneously. If set to ``None``,
23+
all outputs will be differentiated in parallel at once. If set to ``1``, all will be
24+
differentiated sequentially. A larger value results in faster differentiation, but also
25+
higher memory usage. Defaults to ``None``.
26+
:param retain_graph: If False, the graph used to compute the grads will be freed. Defaults to
27+
False.
28+
:param create_graph: If True, graph of the derivative will be constructed, allowing to compute
29+
higher order derivative products. Defaults to False.
30+
31+
.. note:: The order of outputs and inputs only matters because we have no guarantee that
32+
torch.autograd.grad is *exactly* equivariant to input permutations and invariant to output
33+
(with their corresponding grad_output) permutations.
34+
"""
35+
1636
def __init__(
1737
self,
1838
outputs: OrderedSet[Tensor],

src/torchjd/autojac/_transform/select.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88

99
class Select(Transform[_A, _A]):
10+
"""
11+
Transform returning a subset of the provided TensorDict.
12+
13+
:param keys: The keys that should be included in the returned subset.
14+
"""
15+
1016
def __init__(self, keys: Set[Tensor]):
1117
self.keys = keys
1218

src/torchjd/autojac/_transform/stack.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99

1010

1111
class Stack(Transform[_A, Jacobians]):
12+
"""
13+
Transform applying several transforms to the same input, and combining the results (by stacking)
14+
into a single TensorDict.
15+
16+
The set of keys of the resulting dict is the union of the sets of keys of the input dicts.
17+
18+
:param transforms: The transforms to apply. Their outputs may have different sets of keys. If a
19+
key is absent in some output dicts, the corresponding stacked tensor is filled with zeroes
20+
at the positions corresponding to those dicts.
21+
"""
22+
1223
def __init__(self, transforms: Sequence[Transform[_A, Gradients]]):
1324
self.transforms = transforms
1425

@@ -22,13 +33,6 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
2233

2334

2435
def _stack(gradient_dicts: list[Gradients]) -> Jacobians:
25-
"""
26-
Transforms a list of tensor dicts into a single dict of (stacked) tensors. The set of keys of
27-
the resulting dict is the union of the sets of keys of the input dicts.
28-
If a key is absent in some input dicts, the corresponding stacked tensor is filled with zeroes
29-
at the positions corresponding to those dicts.
30-
"""
31-
3236
# It is important to first remove duplicate keys before computing their associated
3337
# stacked tensor. Otherwise, some computations would be duplicated. Therefore, we first compute
3438
# unique_keys, and only then, we compute the stacked tensors.
@@ -41,9 +45,7 @@ def _stack(gradient_dicts: list[Gradients]) -> Jacobians:
4145

4246

4347
def _stack_one_key(gradient_dicts: list[Gradients], input: Tensor) -> Tensor:
44-
"""
45-
Makes the stacked tensor corresponding to a given key, from a list of tensor dicts.
46-
"""
48+
"""Makes the stacked tensor corresponding to a given key, from a list of tensor dicts."""
4749

4850
optional_gradients = [gradients.get(input, None) for gradients in gradient_dicts]
4951
gradients = materialize(optional_gradients, [input] * len(optional_gradients))

0 commit comments

Comments
 (0)