File tree Expand file tree Collapse file tree
tests/unit/autojac/_transform Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11import torch
22from pytest import raises
33
4- from torchjd .autojac ._transform import Grad , Gradients
4+ from torchjd .autojac ._transform import Grad , Gradients , RequirementError
55
66from ._dict_assertions import assert_tensor_dicts_are_close
77
@@ -284,7 +284,10 @@ def test_create_graph():
284284
285285
286286def test_check_keys ():
287- """Tests that the `check_keys` method works correctly."""
287+ """
288+ Tests that the `check_keys` method works correctly: the input_keys should match the stored
289+ outputs.
290+ """
288291
289292 x = torch .tensor (5.0 )
290293 a1 = torch .tensor (2.0 , requires_grad = True )
@@ -296,3 +299,9 @@ def test_check_keys():
296299 output_keys = grad .check_keys ({y })
297300
298301 assert output_keys == {a1 , a2 }
302+
303+ with raises (RequirementError ):
304+ grad .check_keys ({y , x })
305+
306+ with raises (RequirementError ):
307+ grad .check_keys (set ())
You can’t perform that action at this time.
0 commit comments