Skip to content

Commit 82fcc38

Browse files
committed
Add test_keys_check for Select
1 parent b514092 commit 82fcc38

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/unit/autojac/_transform/test_select.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
from contextlib import nullcontext as does_not_raise
2+
13
import torch
4+
from pytest import mark, raises
5+
from unit._utils import ExceptionContext
26

37
from torchjd.autojac._transform import Select, TensorDict
48

@@ -53,3 +57,28 @@ def test_conjunction_of_selects_is_select():
5357
expected_output = select(input)
5458

5559
assert_tensor_dicts_are_close(output, expected_output)
60+
61+
62+
@mark.parametrize(
63+
["key_indices", "required_key_indices", "expectation"],
64+
[
65+
([0], [0, 1], does_not_raise()),
66+
([0], [1], raises(ValueError)),
67+
([0, 1], [0], raises(ValueError)),
68+
([], [0], does_not_raise()),
69+
],
70+
)
71+
def test_keys_check(
72+
key_indices: list[int], required_key_indices: list[int], expectation: ExceptionContext
73+
):
74+
"""
75+
Tests that the Select transform correctly checks that the keys are a subset of the required
76+
keys.
77+
"""
78+
79+
all_keys = [torch.tensor(i) for i in range(10)]
80+
keys = [all_keys[i] for i in key_indices]
81+
required_keys = [all_keys[i] for i in required_key_indices]
82+
83+
with expectation:
84+
_ = Select(keys, required_keys)

0 commit comments

Comments
 (0)