-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_stack.py
More file actions
103 lines (73 loc) · 3.08 KB
/
test_stack.py
File metadata and controls
103 lines (73 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Iterable
import torch
from torch import Tensor
from torchjd.autojac._transform import EmptyTensorDict, Gradients, Stack, Transform
from ._dict_assertions import assert_tensor_dicts_are_close
class FakeGradientsTransform(Transform[EmptyTensorDict, Gradients]):
"""
Transform that produces gradients filled with ones, for testing purposes. Note that it does the
same thing as Init, but it does not depend on Init.
"""
def __init__(self, keys: Iterable[Tensor]):
self.keys = set(keys)
def _compute(self, input: EmptyTensorDict) -> Gradients:
return Gradients({key: torch.ones_like(key) for key in self.keys})
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return set(), self.keys
def test_single_key():
"""
Tests that the Stack transform correctly stacks gradients into a jacobian, in a very simple
example with 2 transforms sharing the same key.
"""
key = torch.zeros([3, 4])
input = EmptyTensorDict()
transform = FakeGradientsTransform([key])
stack = Stack([transform, transform])
output = stack(input)
expected_output = {key: torch.ones([2, 3, 4])}
assert_tensor_dicts_are_close(output, expected_output)
def test_disjoint_key_sets():
"""
Tests that the Stack transform correctly stacks gradients into a jacobian, in an example where
the output key sets of all of its transforms are disjoint. The missing values should be replaced
by zeros.
"""
key1 = torch.zeros([1, 2])
key2 = torch.zeros([3])
input = EmptyTensorDict()
transform1 = FakeGradientsTransform([key1])
transform2 = FakeGradientsTransform([key2])
stack = Stack([transform1, transform2])
output = stack(input)
expected_output = {
key1: torch.tensor([[[1.0, 1.0]], [[0.0, 0.0]]]),
key2: torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]),
}
assert_tensor_dicts_are_close(output, expected_output)
def test_overlapping_key_sets():
"""
Tests that the Stack transform correctly stacks gradients into a jacobian, in an example where
the output key sets all of its transforms are overlapping (non-empty intersection, but not
equal). The missing values should be replaced by zeros.
"""
key1 = torch.zeros([1, 2])
key2 = torch.zeros([3])
key3 = torch.zeros([4])
input = EmptyTensorDict()
transform12 = FakeGradientsTransform([key1, key2])
transform23 = FakeGradientsTransform([key2, key3])
stack = Stack([transform12, transform23])
output = stack(input)
expected_output = {
key1: torch.tensor([[[1.0, 1.0]], [[0.0, 0.0]]]),
key2: torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
key3: torch.tensor([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]]),
}
assert_tensor_dicts_are_close(output, expected_output)
def test_empty():
"""Tests that the Stack transform correctly handles an empty list of transforms."""
stack = Stack([])
input = EmptyTensorDict({})
output = stack(input)
expected_output = EmptyTensorDict({})
assert_tensor_dicts_are_close(output, expected_output)