Skip to content

Commit 1f0afca

Browse files
refactor(autojac): Isolate key checks in transforms (#279)
* Add method check_and_get_keys to all transforms * Move checks related to keys from __init__ to check_and_get_keys and make them recursive * Remove required_keys and output_keys properties * Remove check_keys_are call on the input keys in __call__ and remove check_keys_are from TensorDict * Replace _compute with __call__ * Change some tests of __init__ into equivalent tests of check_and_get_keys * Add new tests for check_and_get_keys * Extract the creation of the transform in backward and mlt_backward into their _create_transform functions * Rename _make_task_transform to _create_task_transform for uniformity * Add test_check_create_transform for these functions * Add changelog entry --------- Co-authored-by: Valérian Rey <valerian.rey@gmail.com>
1 parent 8914b16 commit 1f0afca

24 files changed

+352
-239
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ changes that do not affect the user.
1616

1717
- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to
1818
project onto the dual cone. This may minimally affect the output of these aggregators.
19+
- Refactored internal verifications in the autojac engine so that they do not run at runtime
20+
anymore. This should minimally improve the performance and reduce the memory usage of `backward`
21+
and `mtl_backward`.
1922

2023
### Fixed
2124
- Fixed the behavior of `backward` and `mtl_backward` when some tensors are repeated (i.e. when they

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
self.retain_graph = retain_graph
2121
self.create_graph = create_graph
2222

23-
def _compute(self, tensors: _A) -> _A:
23+
def __call__(self, tensors: _A) -> _A:
2424
tensor_outputs = [tensors[output] for output in self.outputs]
2525

2626
differentiated_tuple = self._differentiate(tensor_outputs)
@@ -38,12 +38,6 @@ def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]
3838
tensor_outputs should be.
3939
"""
4040

41-
@property
42-
def required_keys(self) -> set[Tensor]:
41+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
4342
# outputs in the forward direction become inputs in the backward direction, and vice-versa
44-
return set(self.outputs)
45-
46-
@property
47-
def output_keys(self) -> set[Tensor]:
48-
# outputs in the forward direction become inputs in the backward direction, and vice-versa
49-
return set(self.inputs)
43+
return set(self.outputs), set(self.inputs)

src/torchjd/autojac/_transform/accumulate.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class Accumulate(Transform[Gradients, EmptyTensorDict]):
1010
def __init__(self, required_keys: Iterable[Tensor]):
1111
self._required_keys = set(required_keys)
1212

13-
def _compute(self, gradients: Gradients) -> EmptyTensorDict:
13+
def __call__(self, gradients: Gradients) -> EmptyTensorDict:
1414
"""
1515
Accumulates gradients with respect to keys in their ``.grad`` field.
1616
"""
@@ -28,13 +28,8 @@ def _compute(self, gradients: Gradients) -> EmptyTensorDict:
2828

2929
return EmptyTensorDict()
3030

31-
@property
32-
def required_keys(self) -> set[Tensor]:
33-
return self._required_keys
34-
35-
@property
36-
def output_keys(self) -> set[Tensor]:
37-
return set()
31+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
32+
return self._required_keys, set()
3833

3934

4035
def _check_expects_grad(tensor: Tensor) -> None:

src/torchjd/autojac/_transform/aggregate.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,19 @@ def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
2323
self._aggregator_str = str(aggregator)
2424
self.transform = reshape << aggregate_matrices << matrixify
2525

26-
def _compute(self, input: Jacobians) -> Gradients:
26+
def __call__(self, input: Jacobians) -> Gradients:
2727
return self.transform(input)
2828

29-
@property
30-
def required_keys(self) -> set[Tensor]:
31-
return self.transform.required_keys
32-
33-
@property
34-
def output_keys(self) -> set[Tensor]:
35-
return self.transform.output_keys
29+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
30+
return self.transform.check_and_get_keys()
3631

3732

3833
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
3934
def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
4035
self.key_order = ordered_set(key_order)
4136
self.aggregator = aggregator
4237

43-
def _compute(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
38+
def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
4439
"""
4540
Concatenates the provided ``jacobian_matrices`` into a single matrix and aggregates it using
4641
the ``aggregator``. Returns the dictionary mapping each key from ``jacobian_matrices`` to
@@ -53,13 +48,9 @@ def _compute(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
5348
ordered_matrices = self._select_ordered_subdict(jacobian_matrices, self.key_order)
5449
return self._aggregate_group(ordered_matrices, self.aggregator)
5550

56-
@property
57-
def required_keys(self) -> set[Tensor]:
58-
return set(self.key_order)
59-
60-
@property
61-
def output_keys(self) -> set[Tensor]:
62-
return set(self.key_order)
51+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
52+
keys = set(self.key_order)
53+
return keys, keys
6354

6455
@staticmethod
6556
def _select_ordered_subdict(
@@ -120,36 +111,26 @@ class _Matrixify(Transform[Jacobians, JacobianMatrices]):
120111
def __init__(self, required_keys: Iterable[Tensor]):
121112
self._required_keys = set(required_keys)
122113

123-
def _compute(self, jacobians: Jacobians) -> JacobianMatrices:
114+
def __call__(self, jacobians: Jacobians) -> JacobianMatrices:
124115
jacobian_matrices = {
125116
key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items()
126117
}
127118
return JacobianMatrices(jacobian_matrices)
128119

129-
@property
130-
def required_keys(self) -> set[Tensor]:
131-
return self._required_keys
132-
133-
@property
134-
def output_keys(self) -> set[Tensor]:
135-
return self._required_keys
120+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
121+
return self._required_keys, self._required_keys
136122

137123

138124
class _Reshape(Transform[GradientVectors, Gradients]):
139125
def __init__(self, required_keys: Iterable[Tensor]):
140126
self._required_keys = set(required_keys)
141127

142-
def _compute(self, gradient_vectors: GradientVectors) -> Gradients:
128+
def __call__(self, gradient_vectors: GradientVectors) -> Gradients:
143129
gradients = {
144130
key: gradient_vector.view(key.shape)
145131
for key, gradient_vector in gradient_vectors.items()
146132
}
147133
return Gradients(gradients)
148134

149-
@property
150-
def required_keys(self) -> set[Tensor]:
151-
return self._required_keys
152-
153-
@property
154-
def output_keys(self) -> set[Tensor]:
155-
return self._required_keys
135+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
136+
return self._required_keys, self._required_keys

src/torchjd/autojac/_transform/base.py

Lines changed: 36 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -40,72 +40,54 @@ def __str__(self) -> str:
4040
return type(self).__name__
4141

4242
@abstractmethod
43-
def _compute(self, input: _B) -> _C:
44-
"""Applies the transform to the input."""
45-
4643
def __call__(self, input: _B) -> _C:
47-
input.check_keys_are(self.required_keys)
48-
return self._compute(input)
44+
"""Applies the transform to the input."""
4945

50-
@property
5146
@abstractmethod
52-
def required_keys(self) -> set[Tensor]:
53-
"""
54-
Returns the set of keys that the transform requires to be present in its input TensorDicts.
47+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
5548
"""
49+
Returns a pair containing (in order) the required keys and the output keys of the Transform
50+
and recursively checks that the transform is valid.
5651
57-
@property
58-
@abstractmethod
59-
def output_keys(self) -> set[Tensor]:
60-
"""Returns the set of keys that will be present in the output of the transform."""
52+
The required keys are the set of keys that the transform requires to be present in its input
53+
TensorDicts. The output keys are the set of keys that will be present in the output
54+
TensorDicts of the transform.
55+
56+
Since the computation of the required and output keys and the verification that the
57+
transform is valid are sometimes intertwined operations, we do them in a single method.
58+
"""
6159

6260
__lshift__ = compose
6361
__or__ = conjunct
6462

6563

6664
class Composition(Transform[_A, _C]):
6765
def __init__(self, outer: Transform[_B, _C], inner: Transform[_A, _B]):
68-
if outer.required_keys != inner.output_keys:
69-
raise ValueError(
70-
"The `output_keys` of `inner` must match with the `required_keys` of "
71-
f"outer. Found {outer.required_keys} and {inner.output_keys}"
72-
)
7366
self.outer = outer
7467
self.inner = inner
7568

7669
def __str__(self) -> str:
7770
return str(self.outer) + " ∘ " + str(self.inner)
7871

79-
def _compute(self, input: _A) -> _C:
72+
def __call__(self, input: _A) -> _C:
8073
intermediate = self.inner(input)
8174
return self.outer(intermediate)
8275

83-
@property
84-
def required_keys(self) -> set[Tensor]:
85-
return self.inner.required_keys
86-
87-
@property
88-
def output_keys(self) -> set[Tensor]:
89-
return self.outer.output_keys
76+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
77+
outer_required_keys, outer_output_keys = self.outer.check_and_get_keys()
78+
inner_required_keys, inner_output_keys = self.inner.check_and_get_keys()
79+
if outer_required_keys != inner_output_keys:
80+
raise ValueError(
81+
"The `output_keys` of `inner` must match with the `required_keys` of "
82+
f"outer. Found {outer_required_keys} and {inner_output_keys}"
83+
)
84+
return inner_required_keys, outer_output_keys
9085

9186

9287
class Conjunction(Transform[_A, _B]):
9388
def __init__(self, transforms: Sequence[Transform[_A, _B]]):
9489
self.transforms = transforms
9590

96-
self._required_keys = set(
97-
key for transform in transforms for key in transform.required_keys
98-
)
99-
for transform in transforms:
100-
if transform.required_keys != self.required_keys:
101-
raise ValueError("All transforms should require the same set of keys.")
102-
103-
output_keys_with_duplicates = [key for t in transforms for key in t.output_keys]
104-
self._output_keys = set(output_keys_with_duplicates)
105-
106-
if len(self._output_keys) != len(output_keys_with_duplicates):
107-
raise ValueError("The sets of output keys of transforms should be disjoint.")
108-
10991
def __str__(self) -> str:
11092
strings = []
11193
for t in self.transforms:
@@ -116,14 +98,22 @@ def __str__(self) -> str:
11698
strings.append(s)
11799
return "(" + " | ".join(strings) + ")"
118100

119-
def _compute(self, tensor_dict: _A) -> _B:
101+
def __call__(self, tensor_dict: _A) -> _B:
120102
output = _union([transform(tensor_dict) for transform in self.transforms])
121103
return output
122104

123-
@property
124-
def required_keys(self) -> set[Tensor]:
125-
return self._required_keys
105+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
106+
keys_pairs = [transform.check_and_get_keys() for transform in self.transforms]
107+
108+
required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys)
109+
for transform_required_keys, _ in keys_pairs:
110+
if transform_required_keys != required_keys:
111+
raise ValueError("All transforms should require the same set of keys.")
112+
113+
output_keys_with_duplicates = [key for _, output_keys in keys_pairs for key in output_keys]
114+
output_keys = set(output_keys_with_duplicates)
115+
116+
if len(output_keys) != len(output_keys_with_duplicates):
117+
raise ValueError("The sets of output keys of transforms should be disjoint.")
126118

127-
@property
128-
def output_keys(self) -> set[Tensor]:
129-
return self._output_keys
119+
return required_keys, output_keys

src/torchjd/autojac/_transform/diagonalize.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, considered: Iterable[Tensor]):
1818
self.indices.append((begin, end))
1919
begin = end
2020

21-
def _compute(self, tensors: Gradients) -> Jacobians:
21+
def __call__(self, tensors: Gradients) -> Jacobians:
2222
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered]
2323
diagonal_matrix = torch.cat(flattened_considered_values).diag()
2424
diagonalized_tensors = {
@@ -27,10 +27,6 @@ def _compute(self, tensors: Gradients) -> Jacobians:
2727
}
2828
return Jacobians(diagonalized_tensors)
2929

30-
@property
31-
def required_keys(self) -> set[Tensor]:
32-
return set(self.considered)
33-
34-
@property
35-
def output_keys(self) -> set[Tensor]:
36-
return set(self.considered)
30+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
31+
keys = set(self.considered)
32+
return keys, keys

src/torchjd/autojac/_transform/init.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class Init(Transform[EmptyTensorDict, Gradients]):
1111
def __init__(self, values: Iterable[Tensor]):
1212
self.values = set(values)
1313

14-
def _compute(self, input: EmptyTensorDict) -> Gradients:
14+
def __call__(self, input: EmptyTensorDict) -> Gradients:
1515
r"""
1616
Computes the gradients of the ``value`` with respect to itself. Returns the result as a
1717
dictionary. The only key of the dictionary is ``value``. The corresponding gradient is a
@@ -21,10 +21,5 @@ def _compute(self, input: EmptyTensorDict) -> Gradients:
2121

2222
return Gradients({value: torch.ones_like(value) for value in self.values})
2323

24-
@property
25-
def required_keys(self) -> set[Tensor]:
26-
return set()
27-
28-
@property
29-
def output_keys(self) -> set[Tensor]:
30-
return self.values
24+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
25+
return set(), self.values

src/torchjd/autojac/_transform/select.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,13 @@ def __init__(self, keys: Iterable[Tensor], required_keys: Iterable[Tensor]):
1111
self.keys = set(keys)
1212
self._required_keys = set(required_keys)
1313

14-
if not self.keys.issubset(self._required_keys):
15-
raise ValueError("Parameter `keys` should be a subset of parameter `required_keys`")
16-
17-
def _compute(self, tensor_dict: _A) -> _A:
14+
def __call__(self, tensor_dict: _A) -> _A:
1815
output = {key: tensor_dict[key] for key in self.keys}
1916
return type(tensor_dict)(output)
2017

21-
@property
22-
def required_keys(self) -> set[Tensor]:
23-
return self._required_keys
18+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
19+
required_keys = self._required_keys
20+
if not self.keys.issubset(required_keys):
21+
raise ValueError("Parameter `keys` should be a subset of parameter `required_keys`")
2422

25-
@property
26-
def output_keys(self) -> set[Tensor]:
27-
return self.keys
23+
return required_keys, self.keys

src/torchjd/autojac/_transform/stack.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,22 @@ class Stack(Transform[_A, Jacobians]):
1212
def __init__(self, transforms: Sequence[Transform[_A, Gradients]]):
1313
self.transforms = transforms
1414

15-
self._required_keys = {key for transform in transforms for key in transform.required_keys}
16-
self._output_keys = {key for transform in transforms for key in transform.output_keys}
17-
18-
for transform in transforms:
19-
if transform.required_keys != self.required_keys:
20-
raise ValueError("All transforms should require the same set of keys.")
21-
22-
def _compute(self, input: _A) -> Jacobians:
15+
def __call__(self, input: _A) -> Jacobians:
2316
results = [transform(input) for transform in self.transforms]
2417
result = _stack(results)
2518
return result
2619

27-
@property
28-
def required_keys(self) -> set[Tensor]:
29-
return self._required_keys
20+
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
21+
keys_pairs = [transform.check_and_get_keys() for transform in self.transforms]
22+
23+
required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys)
24+
output_keys = set(key for _, output_keys in keys_pairs for key in output_keys)
25+
26+
for transform_required_keys, _ in keys_pairs:
27+
if transform_required_keys != required_keys:
28+
raise ValueError("All transforms should require the same set of keys.")
3029

31-
@property
32-
def output_keys(self) -> set[Tensor]:
33-
return self._output_keys
30+
return required_keys, output_keys
3431

3532

3633
def _stack(gradient_dicts: list[Gradients]) -> Jacobians:

src/torchjd/autojac/_transform/tensor_dict.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,6 @@ def __init__(self, tensor_dict: dict[Tensor, Tensor]):
1414
self._check_all_pairs(tensor_dict)
1515
super().__init__(tensor_dict)
1616

17-
def check_keys_are(self, keys: set[Tensor]) -> None:
18-
"""
19-
Checks that the keys in the mapping are the same as the provided ``keys``.
20-
21-
:param keys: Keys that the mapping should (exclusively) contain.
22-
"""
23-
24-
if set(keys) != set(self.keys()):
25-
raise ValueError(
26-
f"The keys of the {self.__class__.__name__} should be {keys}. Found self.keys = "
27-
f"{self.keys()}."
28-
)
29-
3017
@staticmethod
3118
def _check_dict(tensor_dict: dict[Tensor, Tensor]) -> None:
3219
pass

0 commit comments

Comments
 (0)