Skip to content

Commit ce9231b

Browse files
committed
Rename retain_jacs to retain_jac
1 parent b1aaee9 commit ce9231b

1 file changed

Lines changed: 3 additions & 5 deletions

File tree

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,15 @@
99
from ._accumulation import TensorWithJac, accumulate_grads
1010

1111

12-
def jac_to_grad(
13-
params: Iterable[Tensor], aggregator: Aggregator, retain_jacs: bool = False
14-
) -> None:
12+
def jac_to_grad(params: Iterable[Tensor], aggregator: Aggregator, retain_jac: bool = False) -> None:
1513
"""
1614
Aggregates the Jacobians stored in the ``.jac`` fields of ``params`` and accumulates the result
1715
into their ``.grad`` fields.
1816
1917
:param params: The parameters whose ``.jac`` fields should be aggregated. All Jacobians must
2018
have the same first dimension (number of outputs).
2119
:param aggregator: The aggregator used to reduce the Jacobians into gradients.
22-
:param retain_jacs: Whether to preserve the ``.jac`` fields of the parameters.
20+
:param retain_jac: Whether to preserve the ``.jac`` fields of the parameters.
2321
"""
2422

2523
params_ = list[TensorWithJac]()
@@ -45,7 +43,7 @@ def jac_to_grad(
4543
gradients = _disunite_gradient(gradient_vector, jacobians, params_)
4644
accumulate_grads(params_, gradients)
4745

48-
if not retain_jacs:
46+
if not retain_jac:
4947
_free_jacs(params_)
5048

5149

0 commit comments

Comments
 (0)