File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed
tests/unit/autojac/_transform Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change 11import torch
2+ from pytest import raises
23from torch .testing import assert_close
34
45from torchjd .autojac ._transform import (
@@ -248,3 +249,17 @@ def test_equivalence_jac_grads():
248249 assert_close (jac_A , torch .stack ([grad_1_A , grad_2_A ]))
249250 assert_close (jac_b , torch .stack ([grad_1_b , grad_2_b ]))
250251 assert_close (jac_c , torch .stack ([grad_1_c , grad_2_c ]))
252+
253+
254+ def test_stack_different_required_keys ():
255+ """Tests that the Stack transform fails on transforms with different required keys."""
256+
257+ a = torch .tensor (1.0 , requires_grad = True )
258+ y1 = a * 2.0
259+ y2 = a * 3.0
260+
261+ grad1 = Grad ([y1 ], [a ])
262+ grad2 = Grad ([y2 ], [a ])
263+
264+ with raises (ValueError ):
265+ _ = Stack ([grad1 , grad2 ])
You can’t perform that action at this time.
0 commit comments