File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed
tests/unit/autojac/_transform Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change 11import torch
2+ from pytest import raises
23from torch .testing import assert_close
34
45from torchjd .autojac ._transform import (
@@ -248,3 +249,18 @@ def test_equivalence_jac_grads():
248249 assert_close (jac_A , torch .stack ([grad_1_A , grad_2_A ]))
249250 assert_close (jac_b , torch .stack ([grad_1_b , grad_2_b ]))
250251 assert_close (jac_c , torch .stack ([grad_1_c , grad_2_c ]))
252+
253+
254+ def test_stack_different_required_keys ():
255+ """Tests that the Stack transform fails on transforms with different required keys."""
256+
257+ a1 = torch .tensor (1.0 , requires_grad = True )
258+ a2 = torch .tensor (2.0 , requires_grad = True )
259+ y1 = a1 * 2.0
260+ y2 = a2 * 3.0
261+
262+ transform1 = Grad ([y1 ], [torch .tensor (1.0 , requires_grad = True )])
263+ transform2 = Grad ([y2 ], [torch .tensor (2.0 , requires_grad = True )])
264+
265+ with raises (ValueError ):
266+ _ = Stack ([transform1 , transform2 ])
You can’t perform that action at this time.
0 commit comments