Skip to content

Commit 4e96aa4

Browse files
committed
Improve test_check_keys of grad
1 parent 799cb39 commit 4e96aa4

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

tests/unit/autojac/_transform/test_grad.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from pytest import raises
33

4-
from torchjd.autojac._transform import Grad, Gradients
4+
from torchjd.autojac._transform import Grad, Gradients, RequirementError
55

66
from ._dict_assertions import assert_tensor_dicts_are_close
77

@@ -284,7 +284,10 @@ def test_create_graph():
284284

285285

286286
def test_check_keys():
287-
"""Tests that the `check_keys` method works correctly."""
287+
"""
288+
Tests that the `check_keys` method works correctly: the input_keys should match the stored
289+
outputs.
290+
"""
288291

289292
x = torch.tensor(5.0)
290293
a1 = torch.tensor(2.0, requires_grad=True)
@@ -296,3 +299,9 @@ def test_check_keys():
296299
output_keys = grad.check_keys({y})
297300

298301
assert output_keys == {a1, a2}
302+
303+
with raises(RequirementError):
304+
grad.check_keys({y, x})
305+
306+
with raises(RequirementError):
307+
grad.check_keys(set())

0 commit comments

Comments
 (0)