Skip to content

Commit b96a9a1

Browse files
authored
Merge branch 'main' into stationarity_property
2 parents 1ec464a + a2e77c7 commit b96a9a1

2 files changed

Lines changed: 44 additions & 0 deletions

File tree

tests/unit/autojac/_transform/test_interactions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from pytest import raises
23
from torch.testing import assert_close
34

45
from torchjd.autojac._transform import (
@@ -248,3 +249,17 @@ def test_equivalence_jac_grads():
248249
assert_close(jac_A, torch.stack([grad_1_A, grad_2_A]))
249250
assert_close(jac_b, torch.stack([grad_1_b, grad_2_b]))
250251
assert_close(jac_c, torch.stack([grad_1_c, grad_2_c]))
252+
253+
254+
def test_stack_different_required_keys():
255+
"""Tests that the Stack transform fails on transforms with different required keys."""
256+
257+
a = torch.tensor(1.0, requires_grad=True)
258+
y1 = a * 2.0
259+
y2 = a * 3.0
260+
261+
grad1 = Grad([y1], [a])
262+
grad2 = Grad([y2], [a])
263+
264+
with raises(ValueError):
265+
_ = Stack([grad1, grad2])

tests/unit/autojac/_transform/test_select.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
from contextlib import nullcontext as does_not_raise
2+
13
import torch
4+
from pytest import mark, raises
5+
from unit._utils import ExceptionContext
26

37
from torchjd.autojac._transform import Select, TensorDict
48

@@ -53,3 +57,28 @@ def test_conjunction_of_selects_is_select():
5357
expected_output = select(input)
5458

5559
assert_tensor_dicts_are_close(output, expected_output)
60+
61+
62+
@mark.parametrize(
63+
["key_indices", "required_key_indices", "expectation"],
64+
[
65+
([0], [0, 1], does_not_raise()),
66+
([0], [1], raises(ValueError)),
67+
([0, 1], [0], raises(ValueError)),
68+
([], [0], does_not_raise()),
69+
],
70+
)
71+
def test_keys_check(
72+
key_indices: list[int], required_key_indices: list[int], expectation: ExceptionContext
73+
):
74+
"""
75+
Tests that the Select transform correctly checks that the keys are a subset of the required
76+
keys.
77+
"""
78+
79+
all_keys = [torch.tensor(i) for i in range(2)]
80+
keys = [all_keys[i] for i in key_indices]
81+
required_keys = [all_keys[i] for i in required_key_indices]
82+
83+
with expectation:
84+
_ = Select(keys, required_keys)

0 commit comments

Comments
 (0)