diff --git a/tests/unit/autojac/_transform/test_select.py b/tests/unit/autojac/_transform/test_select.py index 76fb72cfc..7cf37b165 100644 --- a/tests/unit/autojac/_transform/test_select.py +++ b/tests/unit/autojac/_transform/test_select.py @@ -1,4 +1,8 @@ +from contextlib import nullcontext as does_not_raise + import torch +from pytest import mark, raises +from unit._utils import ExceptionContext from torchjd.autojac._transform import Select, TensorDict @@ -53,3 +57,28 @@ def test_conjunction_of_selects_is_select(): expected_output = select(input) assert_tensor_dicts_are_close(output, expected_output) + + +@mark.parametrize( + ["key_indices", "required_key_indices", "expectation"], + [ + ([0], [0, 1], does_not_raise()), + ([0], [1], raises(ValueError)), + ([0, 1], [0], raises(ValueError)), + ([], [0], does_not_raise()), + ], +) +def test_keys_check( + key_indices: list[int], required_key_indices: list[int], expectation: ExceptionContext +): + """ + Tests that the Select transform correctly checks that the keys are a subset of the required + keys. + """ + + all_keys = [torch.tensor(i) for i in range(2)] + keys = [all_keys[i] for i in key_indices] + required_keys = [all_keys[i] for i in required_key_indices] + + with expectation: + _ = Select(keys, required_keys)