-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathdiagonalize.py
More file actions
32 lines (26 loc) · 1.11 KB
/
diagonalize.py
File metadata and controls
32 lines (26 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import Iterable
import torch
from torch import Tensor
from ._utils import ordered_set
from .base import Transform
from .tensor_dict import Gradients, Jacobians
class Diagonalize(Transform[Gradients, Jacobians]):
def __init__(self, considered: Iterable[Tensor]):
self.considered = ordered_set(considered)
self.indices: list[tuple[int, int]] = []
begin = 0
for tensor in self.considered:
end = begin + tensor.numel()
self.indices.append((begin, end))
begin = end
def _compute(self, tensors: Gradients) -> Jacobians:
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered]
diagonal_matrix = torch.cat(flattened_considered_values).diag()
diagonalized_tensors = {
key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape)
for (begin, end), key in zip(self.indices, self.considered)
}
return Jacobians(diagonalized_tensors)
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
keys = set(self.considered)
return keys, keys