|
| 1 | +from contextlib import nullcontext as does_not_raise |
| 2 | + |
1 | 3 | import torch |
| 4 | +from pytest import mark, raises |
| 5 | +from unit._utils import ExceptionContext |
2 | 6 |
|
3 | 7 | from torchjd.autojac._transform import Select, TensorDict |
4 | 8 |
|
@@ -53,3 +57,28 @@ def test_conjunction_of_selects_is_select(): |
53 | 57 | expected_output = select(input) |
54 | 58 |
|
55 | 59 | assert_tensor_dicts_are_close(output, expected_output) |
| 60 | + |
| 61 | + |
| 62 | +@mark.parametrize( |
| 63 | + ["key_indices", "required_key_indices", "expectation"], |
| 64 | + [ |
| 65 | + ({0}, {0, 1}, does_not_raise()), |
| 66 | + ({0}, {1}, raises(ValueError)), |
| 67 | + ({0, 1}, {0}, raises(ValueError)), |
| 68 | + ({}, {0}, does_not_raise()), |
| 69 | + ], |
| 70 | +) |
| 71 | +def test_keys_check( |
| 72 | + key_indices: set[int], required_key_indices: set[int], expectation: ExceptionContext |
| 73 | +): |
| 74 | + """ |
| 75 | + Tests that the Select transform correctly checks that the keys are a subset of the required |
| 76 | + keys. |
| 77 | + """ |
| 78 | + |
| 79 | + all_keys = [torch.tensor(i) for i in range(10)] |
| 80 | + keys = [all_keys[i] for i in key_indices] |
| 81 | + required_keys = [all_keys[i] for i in required_key_indices] |
| 82 | + |
| 83 | + with expectation: |
| 84 | + _ = Select(keys, required_keys) |
0 commit comments