11from collections import OrderedDict
2- from typing import Hashable , Iterable , TypeVar
2+ from typing import Hashable , TypeVar
33
44import torch
55from torch import Tensor
66
77from torchjd .aggregation import Aggregator
88
9- from ._utils import _OrderedSet , ordered_set
10- from .base import Transform
9+ from .base import RequirementError , Transform
10+ from .ordered_set import OrderedSet
1111from .tensor_dict import EmptyTensorDict , Gradients , GradientVectors , JacobianMatrices , Jacobians
1212
1313_KeyType = TypeVar ("_KeyType" , bound = Hashable )
1414_ValueType = TypeVar ("_ValueType" )
1515
1616
1717class Aggregate (Transform [Jacobians , Gradients ]):
18- def __init__ (self , aggregator : Aggregator , key_order : Iterable [Tensor ]):
19- matrixify = _Matrixify (key_order )
18+ def __init__ (self , aggregator : Aggregator , key_order : OrderedSet [Tensor ]):
19+ matrixify = _Matrixify ()
2020 aggregate_matrices = _AggregateMatrices (aggregator , key_order )
21- reshape = _Reshape (key_order )
21+ reshape = _Reshape ()
2222
2323 self ._aggregator_str = str (aggregator )
2424 self .transform = reshape << aggregate_matrices << matrixify
2525
2626 def __call__ (self , input : Jacobians ) -> Gradients :
2727 return self .transform (input )
2828
29- def check_and_get_keys (self ) -> tuple [ set [Tensor ], set [Tensor ] ]:
30- return self .transform .check_and_get_keys ( )
29+ def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
30+ return self .transform .check_keys ( input_keys )
3131
3232
3333class _AggregateMatrices (Transform [JacobianMatrices , GradientVectors ]):
34- def __init__ (self , aggregator : Aggregator , key_order : Iterable [Tensor ]):
35- self .key_order = ordered_set (key_order )
34+ def __init__ (self , aggregator : Aggregator , key_order : OrderedSet [Tensor ]):
35+ self .key_order = OrderedSet (key_order )
3636 self .aggregator = aggregator
3737
3838 def __call__ (self , jacobian_matrices : JacobianMatrices ) -> GradientVectors :
@@ -48,13 +48,17 @@ def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
4848 ordered_matrices = self ._select_ordered_subdict (jacobian_matrices , self .key_order )
4949 return self ._aggregate_group (ordered_matrices , self .aggregator )
5050
51- def check_and_get_keys (self ) -> tuple [set [Tensor ], set [Tensor ]]:
52- keys = set (self .key_order )
53- return keys , keys
51+ def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
52+ if not set (self .key_order ) == input_keys :
53+ raise RequirementError (
54+ f"The input_keys must match the key_order. Found input_keys { input_keys } and"
55+ f"key_order { self .key_order } ."
56+ )
57+ return input_keys
5458
5559 @staticmethod
5660 def _select_ordered_subdict (
57- dictionary : dict [_KeyType , _ValueType ], ordered_keys : _OrderedSet [_KeyType ]
61+ dictionary : dict [_KeyType , _ValueType ], ordered_keys : OrderedSet [_KeyType ]
5862 ) -> OrderedDict [_KeyType , _ValueType ]:
5963 """
6064 Selects a subset of a dictionary corresponding to the keys given by ``ordered_keys``.
@@ -108,29 +112,23 @@ def _disunite(
108112
109113
110114class _Matrixify (Transform [Jacobians , JacobianMatrices ]):
111- def __init__ (self , required_keys : Iterable [Tensor ]):
112- self ._required_keys = set (required_keys )
113-
114115 def __call__ (self , jacobians : Jacobians ) -> JacobianMatrices :
115116 jacobian_matrices = {
116117 key : jacobian .view (jacobian .shape [0 ], - 1 ) for key , jacobian in jacobians .items ()
117118 }
118119 return JacobianMatrices (jacobian_matrices )
119120
120- def check_and_get_keys (self ) -> tuple [ set [Tensor ], set [Tensor ] ]:
121- return self . _required_keys , self . _required_keys
121+ def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
122+ return input_keys
122123
123124
124125class _Reshape (Transform [GradientVectors , Gradients ]):
125- def __init__ (self , required_keys : Iterable [Tensor ]):
126- self ._required_keys = set (required_keys )
127-
128126 def __call__ (self , gradient_vectors : GradientVectors ) -> Gradients :
129127 gradients = {
130128 key : gradient_vector .view (key .shape )
131129 for key , gradient_vector in gradient_vectors .items ()
132130 }
133131 return Gradients (gradients )
134132
135- def check_and_get_keys (self ) -> tuple [ set [Tensor ], set [Tensor ] ]:
136- return self . _required_keys , self . _required_keys
133+ def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
134+ return input_keys
0 commit comments