Skip to content

Commit e4502f7

Browse files
committed
Replace check_and_get_keys by check_keys
1 parent 922e746 commit e4502f7

9 files changed

Lines changed: 69 additions & 83 deletions

File tree

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from torch import Tensor
55

6-
from .base import _A, Transform
6+
from .base import _A, RequirementError, Transform
77
from .ordered_set import OrderedSet
88

99

@@ -38,6 +38,11 @@ def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]
3838
tensor_outputs should be.
3939
"""
4040

41-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
42-
# outputs in the forward direction become inputs in the backward direction, and vice-versa
43-
return set(self.outputs), set(self.inputs)
41+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
42+
outputs = set(self.outputs)
43+
if not outputs.issubset(input_keys):
44+
raise RequirementError(
45+
f"The input_keys needs to be a super set of the outputs. Found {input_keys} and "
46+
f"{outputs}"
47+
)
48+
return set(self.inputs)

src/torchjd/autojac/_transform/accumulate.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,13 @@ def __call__(self, gradients: Gradients) -> EmptyTensorDict:
2828

2929
return EmptyTensorDict()
3030

31-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
32-
return self._required_keys, set()
31+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
32+
if not self._required_keys.issubset(input_keys):
33+
raise RequirementError(
34+
f"The input_keys needs to be a super set of the required_keys. Found {input_keys} "
35+
f"and {self._required_keys}"
36+
)
37+
return set()
3338

3439

3540
def _check_expects_grad(tensor: Tensor) -> None:

src/torchjd/autojac/_transform/aggregate.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from torchjd.aggregation import Aggregator
88

9-
from .base import Transform
9+
from .base import RequirementError, Transform
1010
from .ordered_set import OrderedSet
1111
from .tensor_dict import EmptyTensorDict, Gradients, GradientVectors, JacobianMatrices, Jacobians
1212

@@ -16,18 +16,18 @@
1616

1717
class Aggregate(Transform[Jacobians, Gradients]):
1818
def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
19-
matrixify = _Matrixify(key_order)
19+
matrixify = _Matrixify()
2020
aggregate_matrices = _AggregateMatrices(aggregator, key_order)
21-
reshape = _Reshape(key_order)
21+
reshape = _Reshape()
2222

2323
self._aggregator_str = str(aggregator)
2424
self.transform = reshape << aggregate_matrices << matrixify
2525

2626
def __call__(self, input: Jacobians) -> Gradients:
2727
return self.transform(input)
2828

29-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
30-
return self.transform.check_and_get_keys()
29+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
30+
return self.transform.check_keys(input_keys)
3131

3232

3333
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
@@ -48,9 +48,12 @@ def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
4848
ordered_matrices = self._select_ordered_subdict(jacobian_matrices, self.key_order)
4949
return self._aggregate_group(ordered_matrices, self.aggregator)
5050

51-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
52-
keys = set(self.key_order)
53-
return keys, keys
51+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
52+
if not set(self.key_order) == input_keys:
53+
raise RequirementError(
54+
f"The input_keys must match the key_order. Found {input_keys} and {self.key_order}"
55+
)
56+
return input_keys
5457

5558
@staticmethod
5659
def _select_ordered_subdict(
@@ -108,29 +111,23 @@ def _disunite(
108111

109112

110113
class _Matrixify(Transform[Jacobians, JacobianMatrices]):
111-
def __init__(self, required_keys: Iterable[Tensor]):
112-
self._required_keys = set(required_keys)
113-
114114
def __call__(self, jacobians: Jacobians) -> JacobianMatrices:
115115
jacobian_matrices = {
116116
key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items()
117117
}
118118
return JacobianMatrices(jacobian_matrices)
119119

120-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
121-
return self._required_keys, self._required_keys
120+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
121+
return input_keys
122122

123123

124124
class _Reshape(Transform[GradientVectors, Gradients]):
125-
def __init__(self, required_keys: Iterable[Tensor]):
126-
self._required_keys = set(required_keys)
127-
128125
def __call__(self, gradient_vectors: GradientVectors) -> Gradients:
129126
gradients = {
130127
key: gradient_vector.view(key.shape)
131128
for key, gradient_vector in gradient_vectors.items()
132129
}
133130
return Gradients(gradients)
134131

135-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
136-
return self._required_keys, self._required_keys
132+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
133+
return input_keys

src/torchjd/autojac/_transform/base.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,13 @@ def __call__(self, input: _B) -> _C:
4848
"""Applies the transform to the input."""
4949

5050
@abstractmethod
51-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
51+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
5252
"""
53-
Returns a pair containing (in order) the required keys and the output keys of the Transform
54-
and recursively checks that the transform is valid.
53+
Checks the keys of the Transform for the provided input_keys and returns the corresponding
54+
output keys for recursion.
5555
56-
The required keys are the set of keys that the transform requires to be present in its input
57-
TensorDicts. The output keys are the set of keys that will be present in the output
58-
TensorDicts of the transform.
59-
60-
Since the computation of the required and output keys and the verification that the
61-
transform is valid are sometimes intertwined operations, we do them in a single method.
56+
The output keys are the set of keys that will be present in the output TensorDict of the
57+
transform given that the provided TensorDict has the provided input_keys.
6258
"""
6359

6460
__lshift__ = compose
@@ -77,15 +73,9 @@ def __call__(self, input: _A) -> _C:
7773
intermediate = self.inner(input)
7874
return self.outer(intermediate)
7975

80-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
81-
outer_required_keys, outer_output_keys = self.outer.check_and_get_keys()
82-
inner_required_keys, inner_output_keys = self.inner.check_and_get_keys()
83-
if outer_required_keys != inner_output_keys:
84-
raise RequirementError(
85-
"The `output_keys` of `inner` must match with the `required_keys` of "
86-
f"outer. Found {outer_required_keys} and {inner_output_keys}"
87-
)
88-
return inner_required_keys, outer_output_keys
76+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
77+
intermediate_keys = self.inner.check_keys(input_keys)
78+
return self.outer.check_keys(intermediate_keys)
8979

9080

9181
class Conjunction(Transform[_A, _B]):
@@ -106,18 +96,11 @@ def __call__(self, tensor_dict: _A) -> _B:
10696
output = _union([transform(tensor_dict) for transform in self.transforms])
10797
return output
10898

109-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
110-
keys_pairs = [transform.check_and_get_keys() for transform in self.transforms]
111-
112-
required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys)
113-
for transform_required_keys, _ in keys_pairs:
114-
if transform_required_keys != required_keys:
115-
raise RequirementError("All transforms should require the same set of keys.")
116-
117-
output_keys_with_duplicates = [key for _, output_keys in keys_pairs for key in output_keys]
118-
output_keys = set(output_keys_with_duplicates)
99+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
100+
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]
101+
output_keys = set(output_keys_list)
119102

120-
if len(output_keys) != len(output_keys_with_duplicates):
103+
if len(output_keys) != len(output_keys_list):
121104
raise RequirementError("The sets of output keys of transforms should be disjoint.")
122105

123-
return required_keys, output_keys
106+
return output_keys

src/torchjd/autojac/_transform/diagonalize.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55

6-
from .base import Transform
6+
from .base import RequirementError, Transform
77
from .ordered_set import OrderedSet
88
from .tensor_dict import Gradients, Jacobians
99

@@ -27,6 +27,11 @@ def __call__(self, tensors: Gradients) -> Jacobians:
2727
}
2828
return Jacobians(diagonalized_tensors)
2929

30-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
31-
keys = set(self.considered)
32-
return keys, keys
30+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
31+
considered = set(self.considered)
32+
if not considered.issubset(input_keys):
33+
raise RequirementError(
34+
f"The input_keys needs to be a super set of the considered keys. Found {input_keys} "
35+
f"and {considered}"
36+
)
37+
return considered

src/torchjd/autojac/_transform/init.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55

6-
from .base import Transform
6+
from .base import RequirementError, Transform
77
from .tensor_dict import EmptyTensorDict, Gradients
88

99

@@ -21,5 +21,7 @@ def __call__(self, input: EmptyTensorDict) -> Gradients:
2121

2222
return Gradients({value: torch.ones_like(value) for value in self.values})
2323

24-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
25-
return set(), self.values
24+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
25+
if input_keys == set():
26+
raise RequirementError(f"Init expects an empty set of input_keys. Found {input_keys}")
27+
return self.values

src/torchjd/autojac/_transform/select.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,17 @@
77

88

99
class Select(Transform[_A, _A]):
10-
def __init__(self, keys: Iterable[Tensor], required_keys: Iterable[Tensor]):
10+
def __init__(self, keys: Iterable[Tensor]):
1111
self.keys = set(keys)
12-
self._required_keys = set(required_keys)
1312

1413
def __call__(self, tensor_dict: _A) -> _A:
1514
output = {key: tensor_dict[key] for key in self.keys}
1615
return type(tensor_dict)(output)
1716

18-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
19-
required_keys = self._required_keys
20-
if not self.keys.issubset(required_keys):
17+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
18+
if not self.keys.issubset(input_keys):
2119
raise RequirementError(
22-
"Parameter `keys` should be a subset of parameter `required_keys`"
20+
f"The input_keys needs to be a super set of the keys to select. Found {input_keys} "
21+
f"and {self.keys}"
2322
)
24-
25-
return required_keys, self.keys
23+
return self.keys

src/torchjd/autojac/_transform/stack.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55

66
from ._utils import _A, _materialize, dicts_union
7-
from .base import RequirementError, Transform
7+
from .base import Transform
88
from .tensor_dict import Gradients, Jacobians
99

1010

@@ -17,17 +17,8 @@ def __call__(self, input: _A) -> Jacobians:
1717
result = _stack(results)
1818
return result
1919

20-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
21-
keys_pairs = [transform.check_and_get_keys() for transform in self.transforms]
22-
23-
required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys)
24-
output_keys = set(key for _, output_keys in keys_pairs for key in output_keys)
25-
26-
for transform_required_keys, _ in keys_pairs:
27-
if transform_required_keys != required_keys:
28-
raise RequirementError("All transforms should require the same set of keys.")
29-
30-
return required_keys, output_keys
20+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
21+
return {key for transform in self.transforms for key in transform.check_keys(input_keys)}
3122

3223

3324
def _stack(gradient_dicts: list[Gradients]) -> Jacobians:

src/torchjd/autojac/mtl_backward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ def _create_task_transform(
179179

180180
# Transform that accumulates the gradients w.r.t. the task-specific parameters into their
181181
# .grad fields.
182-
accumulate = Accumulate(task_params) << Select(task_params, to_differentiate)
182+
accumulate = Accumulate(task_params) << Select(task_params)
183183

184184
# Transform that backpropagates the gradients of the losses w.r.t. the features.
185-
backpropagate = Select(features, to_differentiate)
185+
backpropagate = Select(features)
186186

187187
# Transform that accumulates the gradient of the losses w.r.t. the task-specific parameters into
188188
# their .grad fields and backpropagates the gradient of the losses w.r.t. to the features.

0 commit comments

Comments
 (0)