66
77from torchjd .aggregation import Aggregator
88
9- from .base import Transform
9+ from .base import RequirementError , Transform
1010from .ordered_set import OrderedSet
1111from .tensor_dict import EmptyTensorDict , Gradients , GradientVectors , JacobianMatrices , Jacobians
1212
1616
1717class Aggregate (Transform [Jacobians , Gradients ]):
1818 def __init__ (self , aggregator : Aggregator , key_order : Iterable [Tensor ]):
19- matrixify = _Matrixify (key_order )
19+ matrixify = _Matrixify ()
2020 aggregate_matrices = _AggregateMatrices (aggregator , key_order )
21- reshape = _Reshape (key_order )
21+ reshape = _Reshape ()
2222
2323 self ._aggregator_str = str (aggregator )
2424 self .transform = reshape << aggregate_matrices << matrixify
2525
2626 def __call__ (self , input : Jacobians ) -> Gradients :
2727 return self .transform (input )
2828
29- def check_and_get_keys (self ) -> tuple [ set [Tensor ], set [Tensor ] ]:
30- return self .transform .check_and_get_keys ( )
29+ def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
30+ return self .transform .check_keys ( input_keys )
3131
3232
3333class _AggregateMatrices (Transform [JacobianMatrices , GradientVectors ]):
@@ -48,9 +48,12 @@ def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
4848 ordered_matrices = self ._select_ordered_subdict (jacobian_matrices , self .key_order )
4949 return self ._aggregate_group (ordered_matrices , self .aggregator )
5050
51- def check_and_get_keys (self ) -> tuple [set [Tensor ], set [Tensor ]]:
52- keys = set (self .key_order )
53- return keys , keys
51+ def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
52+ if not set (self .key_order ) == input_keys :
53+ raise RequirementError (
54+ f"The input_keys must match the key_order. Found { input_keys } and { self .key_order } "
55+ )
56+ return input_keys
5457
5558 @staticmethod
5659 def _select_ordered_subdict (
@@ -108,29 +111,23 @@ def _disunite(
108111
109112
110113class _Matrixify (Transform [Jacobians , JacobianMatrices ]):
111- def __init__ (self , required_keys : Iterable [Tensor ]):
112- self ._required_keys = set (required_keys )
113-
114114 def __call__ (self , jacobians : Jacobians ) -> JacobianMatrices :
115115 jacobian_matrices = {
116116 key : jacobian .view (jacobian .shape [0 ], - 1 ) for key , jacobian in jacobians .items ()
117117 }
118118 return JacobianMatrices (jacobian_matrices )
119119
120- def check_and_get_keys (self ) -> tuple [ set [Tensor ], set [Tensor ] ]:
121- return self . _required_keys , self . _required_keys
120+ def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
121+ return input_keys
122122
123123
124124class _Reshape (Transform [GradientVectors , Gradients ]):
125- def __init__ (self , required_keys : Iterable [Tensor ]):
126- self ._required_keys = set (required_keys )
127-
128125 def __call__ (self , gradient_vectors : GradientVectors ) -> Gradients :
129126 gradients = {
130127 key : gradient_vector .view (key .shape )
131128 for key , gradient_vector in gradient_vectors .items ()
132129 }
133130 return Gradients (gradients )
134131
135- def check_and_get_keys (self ) -> tuple [ set [Tensor ], set [Tensor ] ]:
136- return self . _required_keys , self . _required_keys
132+ def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
133+ return input_keys
0 commit comments