Skip to content

Commit 1bb5c35

Browse files
authored
test(autojac): Add test_stack_different_required_keys (#274)
1 parent b514092 commit 1bb5c35

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/unit/autojac/_transform/test_interactions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from pytest import raises
23
from torch.testing import assert_close
34

45
from torchjd.autojac._transform import (
@@ -248,3 +249,17 @@ def test_equivalence_jac_grads():
248249
assert_close(jac_A, torch.stack([grad_1_A, grad_2_A]))
249250
assert_close(jac_b, torch.stack([grad_1_b, grad_2_b]))
250251
assert_close(jac_c, torch.stack([grad_1_c, grad_2_c]))
252+
253+
254+
def test_stack_different_required_keys():
255+
"""Tests that the Stack transform fails on transforms with different required keys."""
256+
257+
a = torch.tensor(1.0, requires_grad=True)
258+
y1 = a * 2.0
259+
y2 = a * 3.0
260+
261+
grad1 = Grad([y1], [a])
262+
grad2 = Grad([y2], [a])
263+
264+
with raises(ValueError):
265+
_ = Stack([grad1, grad2])

0 commit comments

Comments
 (0)