-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_base.py
More file actions
137 lines (95 loc) · 3.69 KB
/
test_base.py
File metadata and controls
137 lines (95 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import typing
import torch
from pytest import raises
from torch import Tensor
from torchjd.autojac._transform._utils import _B, _C
from torchjd.autojac._transform.base import Conjunction, Transform
from torchjd.autojac._transform.tensor_dict import TensorDict
class FakeTransform(Transform[_B, _C]):
"""
Fake ``Transform`` to test `required_keys` and `output_keys` when composing and conjuncting.
"""
def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]):
self._required_keys = required_keys
self._output_keys = output_keys
def __str__(self):
return "T"
def _compute(self, input: _B) -> _C:
# Ignore the input, create a dictionary with the right keys as an output.
# Cast the type for the purpose of type-checking.
output_dict = {key: torch.empty(0) for key in self._output_keys}
return typing.cast(_C, output_dict)
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return self._required_keys, self._output_keys
def test_call_checks_keys():
"""
Tests that a ``Transform`` checks that the provided dictionary to the `__call__` function
contains keys that correspond exactly to `required_keys`.
"""
a1 = torch.randn([2])
a2 = torch.randn([3])
t = FakeTransform(required_keys={a1}, output_keys={a1, a2})
t(TensorDict({a1: a2}))
with raises(ValueError):
t(TensorDict({a2: a1}))
with raises(ValueError):
t(TensorDict({}))
with raises(ValueError):
t(TensorDict({a1: a2, a2: a1}))
def test_compose_checks_keys():
"""
Tests that the composition of ``Transform``s checks that the inner transform's `output_keys`
match with the outer transform's `required_keys`.
"""
a1 = torch.randn([2])
a2 = torch.randn([3])
t1 = FakeTransform(required_keys={a1}, output_keys={a1, a2})
t2 = FakeTransform(required_keys={a2}, output_keys={a1})
(t1 << t2).check_keys()
with raises(ValueError):
(t2 << t1).check_keys()
def test_conjunct_checks_required_keys():
"""
Tests that the conjunction of ``Transform``s checks that the provided transforms all have the
same `required_keys`.
"""
a1 = torch.randn([2])
a2 = torch.randn([3])
t1 = FakeTransform(required_keys={a1}, output_keys=set())
t2 = FakeTransform(required_keys={a1}, output_keys=set())
t3 = FakeTransform(required_keys={a2}, output_keys=set())
(t1 | t2).check_keys()
with raises(ValueError):
(t2 | t3).check_keys()
with raises(ValueError):
(t1 | t2 | t3).check_keys()
def test_conjunct_checks_output_keys():
"""
Tests that the conjunction of ``Transform``s checks that the transforms `output_keys` are
disjoint.
"""
a1 = torch.randn([2])
a2 = torch.randn([3])
t1 = FakeTransform(required_keys=set(), output_keys={a1, a2})
t2 = FakeTransform(required_keys=set(), output_keys={a1})
t3 = FakeTransform(required_keys=set(), output_keys={a2})
(t2 | t3).check_keys()
with raises(ValueError):
(t1 | t3).check_keys()
with raises(ValueError):
(t1 | t2 | t3).check_keys()
def test_empty_conjunction():
"""
Tests that it is possible to take the conjunction of no transform. This should return an empty
dictionary.
"""
conjunction = Conjunction([])
assert len(conjunction(TensorDict({}))) == 0
def test_str():
"""
Tests that the __str__ method works correctly even for transform involving compositions and
conjunctions.
"""
t = FakeTransform(required_keys=set(), output_keys=set())
transform = (t | t << t << t | t) << t << (t | t)
assert str(transform) == "(T | T ∘ T ∘ T | T) ∘ T ∘ (T | T)"