We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e951575 commit 7a29ee7Copy full SHA for 7a29ee7
1 file changed
tests/unit/autojac/_transform/test_grad.py
@@ -62,6 +62,23 @@ def test_empty_inputs_2():
62
assert_tensor_dicts_are_close(gradients, expected_gradients)
63
64
65
+def test_empty_outputs():
66
+ """
67
+ Tests that the Grad transform works correctly when the `outputs` parameter is an empty
68
+ `Iterable`.
69
70
+
71
+ a = torch.tensor(1.0, requires_grad=True)
72
+ input = Gradients({})
73
74
+ grad = Grad(outputs=[], inputs=[a])
75
76
+ gradients = grad(input)
77
+ expected_gradients = {a: torch.zeros_like(a)}
78
79
+ assert_tensor_dicts_are_close(gradients, expected_gradients)
80
81
82
def test_retain_graph():
83
"""Tests that the `Grad` transform behaves as expected with the `retain_graph` flag."""
84
0 commit comments