@@ -15,25 +15,9 @@ class RequirementError(ValueError):
1515
1616
1717class Transform (Generic [_B , _C ], ABC ):
18- r """
18+ """
1919 Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
20- descent backward phase. A transform maps a :class:`~torchjd.transform.tensor_dict.TensorDict` to
21- another. The input :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `required_keys`
22- and the output :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `output_keys`.
23-
24- Formally a transform is a function:
25-
26- .. math::
27- f:\mathbb R^{n_1+\dots+n_p}\to \mathbb R^{m_1+\dots+m_q}
28-
29- where we have ``p`` `required_keys`, ``q`` `output_keys`, ``n_i`` is the number of elements in
30- the value associated to the ``i`` th `required_key` of the input
31- :class:`~torchjd.transform.tensor_dict.TensorDict` and ``m_j`` is the number of elements in the
32- value associated to the ``j`` th `output_key` of the output
33- :class:`~torchjd.transform.tensor_dict.TensorDict`.
34-
35- As they are mathematical functions, transforms can be composed together as long as their
36- domains and range meaningfully match.
20+ descent backward phase. A transform maps a TensorDict to another.
3721 """
3822
3923 def compose (self , other : Transform [_A , _B ]) -> Transform [_A , _C ]:
@@ -67,6 +51,13 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
6751
6852
6953class Composition (Transform [_A , _C ]):
54+ """
55+ Transform corresponding to the composition of two transforms inner and outer.
56+
57+ :param inner: The transform to apply first, to the input.
58+ :param outer: The transform to apply second, to the result of ``inner``.
59+ """
60+
7061 def __init__ (self , outer : Transform [_B , _C ], inner : Transform [_A , _B ]):
7162 self .outer = outer
7263 self .inner = inner
@@ -85,6 +76,13 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
8576
8677
8778class Conjunction (Transform [_A , _B ]):
79+ """
80+ Transform applying several transforms to the same input, and combining the results (by union)
81+ into a single TensorDict.
82+
83+ :param transforms: The transforms to apply. Their outputs should have disjoint sets of keys.
84+ """
85+
8886 def __init__ (self , transforms : Sequence [Transform [_A , _B ]]):
8987 self .transforms = transforms
9088
0 commit comments