File tree Expand file tree Collapse file tree
tests/unit/autojac/_transform Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -72,6 +72,28 @@ def test_empty_inputs_2(chunk_size: int | None):
7272 assert_tensor_dicts_are_close (jacobians , expected_jacobians )
7373
7474
75+ @mark .parametrize ("chunk_size" , [1 , 3 , None ])
76+ def test_empty_outputs (chunk_size : int | None ):
77+ """
78+ Tests that the Jac transform works correctly when the `outputs` parameter is an empty
79+ `Iterable`.
80+ """
81+
82+ a1 = torch .tensor (1.0 , requires_grad = True )
83+ a2 = torch .tensor (1.0 , requires_grad = True )
84+ input = Jacobians ({})
85+
86+ jac = Jac (outputs = [], inputs = [a1 , a2 ], chunk_size = chunk_size )
87+
88+ jacobians = jac (input )
89+ expected_jacobians = {
90+ a1 : torch .empty_like (a1 ).unsqueeze (0 )[:0 ], # Jacobian with no row
91+ a2 : torch .empty_like (a2 ).unsqueeze (0 )[:0 ], # Jacobian with no row
92+ }
93+
94+ assert_tensor_dicts_are_close (jacobians , expected_jacobians )
95+
96+
7597def test_retain_graph ():
7698 """Tests that the `Jac` transform behaves as expected with the `retain_graph` flag."""
7799
You can’t perform that action at this time.
0 commit comments