Skip to content

Commit 179bdfd

Browse files
authored
test: Fix using wrong backward interface (#513)
1 parent 7e3637e commit 179bdfd

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/utils/forward_backwards.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torchjd.aggregation import Aggregator, Weighting
1212
from torchjd.autogram import Engine
1313
from torchjd.autojac import backward
14+
from torchjd.autojac._jac_to_grad import jac_to_grad
1415

1516

1617
def autograd_forward_backward(
@@ -29,7 +30,8 @@ def autojac_forward_backward(
2930
aggregator: Aggregator,
3031
) -> None:
3132
losses = forward_pass(model, inputs, loss_fn, reduce_to_vector)
32-
backward(losses, aggregator=aggregator)
33+
backward(losses)
34+
jac_to_grad(model.parameters(), aggregator)
3335

3436

3537
def autograd_gramian_forward_backward(

0 commit comments

Comments
 (0)