Skip to content

Commit 9fbd3d3

Browse files
committed
Add test_stack_different_required_keys
1 parent b514092 commit 9fbd3d3

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/unit/autojac/_transform/test_interactions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from pytest import raises
23
from torch.testing import assert_close
34

45
from torchjd.autojac._transform import (
@@ -248,3 +249,18 @@ def test_equivalence_jac_grads():
248249
assert_close(jac_A, torch.stack([grad_1_A, grad_2_A]))
249250
assert_close(jac_b, torch.stack([grad_1_b, grad_2_b]))
250251
assert_close(jac_c, torch.stack([grad_1_c, grad_2_c]))
252+
253+
254+
def test_stack_different_required_keys():
255+
"""Tests that the Stack transform fails on transforms with different required keys."""
256+
257+
a1 = torch.tensor(1.0, requires_grad=True)
258+
a2 = torch.tensor(2.0, requires_grad=True)
259+
y1 = a1 * 2.0
260+
y2 = a2 * 3.0
261+
262+
transform1 = Grad([y1], [torch.tensor(1.0, requires_grad=True)])
263+
transform2 = Grad([y2], [torch.tensor(2.0, requires_grad=True)])
264+
265+
with raises(ValueError):
266+
_ = Stack([transform1, transform2])

0 commit comments

Comments
 (0)