File tree Expand file tree Collapse file tree 2 files changed +40
-0
lines changed
tests/unit/autojac/_transform Expand file tree Collapse file tree 2 files changed +40
-0
lines changed Original file line number Diff line number Diff line change @@ -62,6 +62,24 @@ def test_empty_inputs_2():
6262 assert_tensor_dicts_are_close (gradients , expected_gradients )
6363
6464
65+ def test_empty_outputs ():
66+ """
67+ Tests that the Grad transform works correctly when the `outputs` parameter is an empty
68+ `Iterable`.
69+ """
70+
71+ a1 = torch .tensor (1.0 , requires_grad = True )
72+ a2 = torch .tensor ([1.0 , 2.0 ], requires_grad = True )
73+ input = Gradients ({})
74+
75+ grad = Grad (outputs = [], inputs = [a1 , a2 ])
76+
77+ gradients = grad (input )
78+ expected_gradients = {a1 : torch .zeros_like (a1 ), a2 : torch .zeros_like (a2 )}
79+
80+ assert_tensor_dicts_are_close (gradients , expected_gradients )
81+
82+
6583def test_retain_graph ():
6684 """Tests that the `Grad` transform behaves as expected with the `retain_graph` flag."""
6785
Original file line number Diff line number Diff line change @@ -72,6 +72,28 @@ def test_empty_inputs_2(chunk_size: int | None):
7272 assert_tensor_dicts_are_close (jacobians , expected_jacobians )
7373
7474
75+ @mark .parametrize ("chunk_size" , [1 , 3 , None ])
76+ def test_empty_outputs (chunk_size : int | None ):
77+ """
78+ Tests that the Jac transform works correctly when the `outputs` parameter is an empty
79+ `Iterable`.
80+ """
81+
82+ a1 = torch .tensor (1.0 , requires_grad = True )
83+ a2 = torch .tensor ([1.0 , 2.0 ], requires_grad = True )
84+ input = Jacobians ({})
85+
86+ jac = Jac (outputs = [], inputs = [a1 , a2 ], chunk_size = chunk_size )
87+
88+ jacobians = jac (input )
89+ expected_jacobians = {
90+ a1 : torch .empty_like (a1 ).unsqueeze (0 )[:0 ], # Jacobian with no row
91+ a2 : torch .empty_like (a2 ).unsqueeze (0 )[:0 ], # Jacobian with no row
92+ }
93+
94+ assert_tensor_dicts_are_close (jacobians , expected_jacobians )
95+
96+
7597def test_retain_graph ():
7698 """Tests that the `Jac` transform behaves as expected with the `retain_graph` flag."""
7799
You can’t perform that action at this time.
0 commit comments