Skip to content

Commit f37207a

Browse files
committed
Change shape of inputs
1 parent 27a4245 commit f37207a

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

tests/unit/autojac/_transform/test_grad.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,14 @@ def test_empty_outputs():
6868
`Iterable`.
6969
"""
7070

71-
a = torch.tensor(1.0, requires_grad=True)
71+
a1 = torch.tensor(1.0, requires_grad=True)
72+
a2 = torch.tensor([1.0, 2.0], requires_grad=True)
7273
input = Gradients({})
7374

74-
grad = Grad(outputs=[], inputs=[a])
75+
grad = Grad(outputs=[], inputs=[a1, a2])
7576

7677
gradients = grad(input)
77-
expected_gradients = {a: torch.zeros_like(a)}
78+
expected_gradients = {a1: torch.zeros_like(a1), a2: torch.zeros_like(a2)}
7879

7980
assert_tensor_dicts_are_close(gradients, expected_gradients)
8081

tests/unit/autojac/_transform/test_jac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_empty_outputs(chunk_size: int | None):
8080
"""
8181

8282
a1 = torch.tensor(1.0, requires_grad=True)
83-
a2 = torch.tensor(1.0, requires_grad=True)
83+
a2 = torch.tensor([1.0, 2.0], requires_grad=True)
8484
input = Jacobians({})
8585

8686
jac = Jac(outputs=[], inputs=[a1, a2], chunk_size=chunk_size)

0 commit comments

Comments
 (0)