-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_diagonalize.py
More file actions
119 lines (93 loc) · 3.28 KB
/
test_diagonalize.py
File metadata and controls
119 lines (93 loc) · 3.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from pytest import raises
from torchjd.autojac._transform import Diagonalize, Gradients, RequirementError
from torchjd.autojac._transform.ordered_set import OrderedSet
from ._dict_assertions import assert_tensor_dicts_are_close
def test_single_input():
"""Tests that the Diagonalize transform works when given a single input."""
key = torch.tensor([1.0, 2.0, 3.0])
value = torch.ones_like(key)
input = Gradients({key: value})
diag = Diagonalize(OrderedSet([key]))
output = diag(input)
expected_output = {key: torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])}
assert_tensor_dicts_are_close(output, expected_output)
def test_multiple_inputs():
"""Tests that the Diagonalize transform works when given multiple inputs."""
key1 = torch.tensor([[1.0, 2.0], [4.0, 5.0]])
key2 = torch.tensor([1.0, 3.0, 5.0])
key3 = torch.tensor(1.0)
value1 = torch.ones_like(key1)
value2 = torch.ones_like(key2)
value3 = torch.ones_like(key3)
input = Gradients({key1: value1, key2: value2, key3: value3})
diag = Diagonalize(OrderedSet([key1, key2, key3]))
output = diag(input)
expected_output = {
key1: torch.tensor(
[
[[1.0, 0.0], [0.0, 0.0]],
[[0.0, 1.0], [0.0, 0.0]],
[[0.0, 0.0], [1.0, 0.0]],
[[0.0, 0.0], [0.0, 1.0]],
[[0.0, 0.0], [0.0, 0.0]],
[[0.0, 0.0], [0.0, 0.0]],
[[0.0, 0.0], [0.0, 0.0]],
[[0.0, 0.0], [0.0, 0.0]],
],
),
key2: torch.tensor(
[
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0],
],
),
key3: torch.tensor(
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
],
),
}
assert_tensor_dicts_are_close(output, expected_output)
def test_permute_order():
"""
Tests that the Diagonalize transform outputs a permuted mapping when its keys are permuted.
"""
key1 = torch.tensor(2.0)
key2 = torch.tensor(1.0)
value1 = torch.ones_like(key1)
value2 = torch.ones_like(key2)
input = Gradients({key1: value1, key2: value2})
permuted_diag = Diagonalize(OrderedSet([key2, key1]))
diag = Diagonalize(OrderedSet([key1, key2]))
permuted_output = permuted_diag(input)
output = {key1: permuted_output[key2], key2: permuted_output[key1]} # un-permute
expected_output = diag(input)
assert_tensor_dicts_are_close(output, expected_output)
def test_check_keys():
"""
Tests that the `check_keys` method works correctly. The input_keys must match the stored
considered keys.
"""
key1 = torch.tensor([1.0])
key2 = torch.tensor([1.0])
diag = Diagonalize(OrderedSet([key1]))
output_keys = diag.check_keys({key1})
assert output_keys == {key1}
with raises(RequirementError):
diag.check_keys(set())
with raises(RequirementError):
diag.check_keys({key1, key2})