-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathaggregate.py
More file actions
134 lines (106 loc) · 5.38 KB
/
aggregate.py
File metadata and controls
134 lines (106 loc) · 5.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from collections import OrderedDict
from typing import Hashable, TypeVar
import torch
from torch import Tensor
from torchjd.aggregation import Aggregator
from .base import RequirementError, Transform
from .ordered_set import OrderedSet
from .tensor_dict import EmptyTensorDict, Gradients, GradientVectors, JacobianMatrices, Jacobians
_KeyType = TypeVar("_KeyType", bound=Hashable)
_ValueType = TypeVar("_ValueType")
class Aggregate(Transform[Jacobians, Gradients]):
def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
matrixify = _Matrixify()
aggregate_matrices = _AggregateMatrices(aggregator, key_order)
reshape = _Reshape()
self._aggregator_str = str(aggregator)
self.transform = reshape << aggregate_matrices << matrixify
def __call__(self, input: Jacobians) -> Gradients:
return self.transform(input)
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
return self.transform.check_keys(input_keys)
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
self.key_order = key_order
self.aggregator = aggregator
def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
"""
Concatenates the provided ``jacobian_matrices`` into a single matrix and aggregates it using
the ``aggregator``. Returns the dictionary mapping each key from ``jacobian_matrices`` to
the part of the obtained gradient vector, that corresponds to the jacobian matrix given for
that key.
:param jacobian_matrices: The dictionary of jacobian matrices to aggregate. The first
dimension of each jacobian matrix should be the same.
"""
ordered_matrices = self._select_ordered_subdict(jacobian_matrices, self.key_order)
return self._aggregate_group(ordered_matrices, self.aggregator)
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
if not set(self.key_order) == input_keys:
raise RequirementError(
f"The input_keys must match the key_order. Found input_keys {input_keys} and"
f"key_order {self.key_order}."
)
return input_keys
@staticmethod
def _select_ordered_subdict(
dictionary: dict[_KeyType, _ValueType], ordered_keys: OrderedSet[_KeyType]
) -> OrderedDict[_KeyType, _ValueType]:
"""
Selects a subset of a dictionary corresponding to the keys given by ``ordered_keys``.
Returns an OrderedDict in the same order as the provided ``ordered_keys``.
"""
return OrderedDict([(key, dictionary[key]) for key in ordered_keys])
@staticmethod
def _aggregate_group(
jacobian_matrices: OrderedDict[Tensor, Tensor], aggregator: Aggregator
) -> GradientVectors:
"""
Unites the jacobian matrices and aggregates them using an
:class:`~torchjd.aggregation.bases.Aggregator`. Returns the obtained gradient vectors.
"""
if len(jacobian_matrices) == 0:
return EmptyTensorDict()
united_jacobian_matrix = _AggregateMatrices._unite(jacobian_matrices)
united_gradient_vector = aggregator(united_jacobian_matrix)
gradient_vectors = _AggregateMatrices._disunite(united_gradient_vector, jacobian_matrices)
return gradient_vectors
@staticmethod
def _unite(jacobian_matrices: OrderedDict[Tensor, Tensor]) -> Tensor:
return torch.cat(list(jacobian_matrices.values()), dim=1)
@staticmethod
def _disunite(
united_gradient_vector: Tensor, jacobian_matrices: OrderedDict[Tensor, Tensor]
) -> GradientVectors:
expected_length = sum([matrix.shape[1] for matrix in jacobian_matrices.values()])
if len(united_gradient_vector) != expected_length:
raise ValueError(
"Parameter `united_gradient_vector` should be a vector with length equal to the sum"
"of the numbers of columns in the jacobian matrices. Found"
f"`len(united_gradient_vector) = {len(united_gradient_vector)}` and the sum of the "
f"numbers of columns in the jacobian matrices is {expected_length}."
)
gradient_vectors = {}
start = 0
for key, jacobian_matrix in jacobian_matrices.items():
end = start + jacobian_matrix.shape[1]
current_gradient_vector = united_gradient_vector[start:end]
gradient_vectors[key] = current_gradient_vector
start = end
return GradientVectors(gradient_vectors)
class _Matrixify(Transform[Jacobians, JacobianMatrices]):
def __call__(self, jacobians: Jacobians) -> JacobianMatrices:
jacobian_matrices = {
key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items()
}
return JacobianMatrices(jacobian_matrices)
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
return input_keys
class _Reshape(Transform[GradientVectors, Gradients]):
def __call__(self, gradient_vectors: GradientVectors) -> Gradients:
gradients = {
key: gradient_vector.view(key.shape)
for key, gradient_vector in gradient_vectors.items()
}
return Gradients(gradients)
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
return input_keys