-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathaggregate.py
More file actions
136 lines (106 loc) · 5.44 KB
/
aggregate.py
File metadata and controls
136 lines (106 loc) · 5.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from collections import OrderedDict
from typing import Hashable, Iterable, TypeVar
import torch
from torch import Tensor
from torchjd.aggregation import Aggregator
from ._utils import _OrderedSet, ordered_set
from .base import Transform
from .tensor_dict import EmptyTensorDict, Gradients, GradientVectors, JacobianMatrices, Jacobians
_KeyType = TypeVar("_KeyType", bound=Hashable)
_ValueType = TypeVar("_ValueType")
class Aggregate(Transform[Jacobians, Gradients]):
def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
matrixify = _Matrixify(key_order)
aggregate_matrices = _AggregateMatrices(aggregator, key_order)
reshape = _Reshape(key_order)
self._aggregator_str = str(aggregator)
self.transform = reshape << aggregate_matrices << matrixify
def _compute(self, input: Jacobians) -> Gradients:
return self.transform(input)
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return self.transform.check_keys()
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
self.key_order = ordered_set(key_order)
self.aggregator = aggregator
def _compute(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
"""
Concatenates the provided ``jacobian_matrices`` into a single matrix and aggregates it using
the ``aggregator``. Returns the dictionary mapping each key from ``jacobian_matrices`` to
the part of the obtained gradient vector, that corresponds to the jacobian matrix given for
that key.
:param jacobian_matrices: The dictionary of jacobian matrices to aggregate. The first
dimension of each jacobian matrix should be the same.
"""
ordered_matrices = self._select_ordered_subdict(jacobian_matrices, self.key_order)
return self._aggregate_group(ordered_matrices, self.aggregator)
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
keys = set(self.key_order)
return keys, keys
@staticmethod
def _select_ordered_subdict(
dictionary: dict[_KeyType, _ValueType], ordered_keys: _OrderedSet[_KeyType]
) -> OrderedDict[_KeyType, _ValueType]:
"""
Selects a subset of a dictionary corresponding to the keys given by ``ordered_keys``.
Returns an OrderedDict in the same order as the provided ``ordered_keys``.
"""
return OrderedDict([(key, dictionary[key]) for key in ordered_keys])
@staticmethod
def _aggregate_group(
jacobian_matrices: OrderedDict[Tensor, Tensor], aggregator: Aggregator
) -> GradientVectors:
"""
Unites the jacobian matrices and aggregates them using an
:class:`~torchjd.aggregation.bases.Aggregator`. Returns the obtained gradient vectors.
"""
if len(jacobian_matrices) == 0:
return EmptyTensorDict()
united_jacobian_matrix = _AggregateMatrices._unite(jacobian_matrices)
united_gradient_vector = aggregator(united_jacobian_matrix)
gradient_vectors = _AggregateMatrices._disunite(united_gradient_vector, jacobian_matrices)
return gradient_vectors
@staticmethod
def _unite(jacobian_matrices: OrderedDict[Tensor, Tensor]) -> Tensor:
return torch.cat(list(jacobian_matrices.values()), dim=1)
@staticmethod
def _disunite(
united_gradient_vector: Tensor, jacobian_matrices: OrderedDict[Tensor, Tensor]
) -> GradientVectors:
expected_length = sum([matrix.shape[1] for matrix in jacobian_matrices.values()])
if len(united_gradient_vector) != expected_length:
raise ValueError(
"Parameter `united_gradient_vector` should be a vector with length equal to the sum"
"of the numbers of columns in the jacobian matrices. Found"
f"`len(united_gradient_vector) = {len(united_gradient_vector)}` and the sum of the "
f"numbers of columns in the jacobian matrices is {expected_length}."
)
gradient_vectors = {}
start = 0
for key, jacobian_matrix in jacobian_matrices.items():
end = start + jacobian_matrix.shape[1]
current_gradient_vector = united_gradient_vector[start:end]
gradient_vectors[key] = current_gradient_vector
start = end
return GradientVectors(gradient_vectors)
class _Matrixify(Transform[Jacobians, JacobianMatrices]):
def __init__(self, required_keys: Iterable[Tensor]):
self._required_keys = set(required_keys)
def _compute(self, jacobians: Jacobians) -> JacobianMatrices:
jacobian_matrices = {
key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items()
}
return JacobianMatrices(jacobian_matrices)
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return self._required_keys, self._required_keys
class _Reshape(Transform[GradientVectors, Gradients]):
def __init__(self, required_keys: Iterable[Tensor]):
self._required_keys = set(required_keys)
def _compute(self, gradient_vectors: GradientVectors) -> Gradients:
gradients = {
key: gradient_vector.view(key.shape)
for key, gradient_vector in gradient_vectors.items()
}
return Gradients(gradients)
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return self._required_keys, self._required_keys