File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 99from ._accumulation import TensorWithJac , accumulate_grads
1010
1111
12- def jac_to_grad (
13- params : Iterable [Tensor ], aggregator : Aggregator , retain_jacs : bool = False
14- ) -> None :
12+ def jac_to_grad (params : Iterable [Tensor ], aggregator : Aggregator , retain_jac : bool = False ) -> None :
1513 """
1614 Aggregates the Jacobians stored in the ``.jac`` fields of ``params`` and accumulates the result
1715 into their ``.grad`` fields.
1816
1917 :param params: The parameters whose ``.jac`` fields should be aggregated. All Jacobians must
2018 have the same first dimension (number of outputs).
2119 :param aggregator: The aggregator used to reduce the Jacobians into gradients.
22- :param retain_jacs : Whether to preserve the ``.jac`` fields of the parameters.
20+ :param retain_jac : Whether to preserve the ``.jac`` fields of the parameters.
2321 """
2422
2523 params_ = list [TensorWithJac ]()
@@ -45,7 +43,7 @@ def jac_to_grad(
4543 gradients = _disunite_gradient (gradient_vector , jacobians , params_ )
4644 accumulate_grads (params_ , gradients )
4745
48- if not retain_jacs :
46+ if not retain_jac :
4947 _free_jacs (params_ )
5048
5149
You can’t perform that action at this time.
0 commit comments