-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_select.py
More file actions
73 lines (53 loc) · 2.07 KB
/
test_select.py
File metadata and controls
73 lines (53 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from pytest import raises
from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import tensor_
from torchjd.autojac._transform import RequirementError, Select
def test_partition():
"""
Tests that the Select transform works correctly by applying 2 different Selects to a TensorDict,
whose keys form a partition of the keys of the TensorDict.
"""
key1 = tensor_([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
key2 = tensor_([1.0, 3.0, 5.0])
key3 = tensor_(2.0)
value1 = torch.ones_like(key1)
value2 = torch.ones_like(key2)
value3 = torch.ones_like(key3)
input = {key1: value1, key2: value2, key3: value3}
select1 = Select({key1, key2})
select2 = Select({key3})
output1 = select1(input)
expected_output1 = {key1: value1, key2: value2}
assert_tensor_dicts_are_close(output1, expected_output1)
output2 = select2(input)
expected_output2 = {key3: value3}
assert_tensor_dicts_are_close(output2, expected_output2)
def test_conjunction_of_selects_is_select():
"""
Tests that the conjunction of 2 Select transforms is equivalent to directly using a Select with
the union of the keys of the 2 Selects.
"""
x1 = tensor_(5.0)
x2 = tensor_(6.0)
x3 = tensor_(7.0)
input = {x1: torch.ones_like(x1), x2: torch.ones_like(x2), x3: torch.ones_like(x3)}
select1 = Select({x1})
select2 = Select({x2})
conjunction_of_selects = select1 | select2
select = Select({x1, x2})
output = conjunction_of_selects(input)
expected_output = select(input)
assert_tensor_dicts_are_close(output, expected_output)
def test_check_keys():
"""
Tests that the `check_keys` method works correctly: the set of keys to select should be a subset
of the set of required_keys.
"""
key1 = tensor_([1.0])
key2 = tensor_([2.0])
key3 = tensor_([3.0])
output_keys = Select({key1, key2}).check_keys({key1, key2, key3})
assert output_keys == {key1, key2}
with raises(RequirementError):
Select({key1, key2}).check_keys({key1})