99from ._accumulation import TensorWithJac , accumulate_grads
1010
1111
12- def jac_to_grad (params : Iterable [Tensor ], aggregator : Aggregator , retain_jac : bool = False ) -> None :
12+ def jac_to_grad (
13+ tensors : Iterable [Tensor ], aggregator : Aggregator , retain_jac : bool = False
14+ ) -> None :
1315 """
14- Aggregates the Jacobians stored in the ``.jac`` fields of ``params `` and accumulates the result
16+ Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors `` and accumulates the result
1517 into their ``.grad`` fields.
1618
17- :param params : The parameters whose ``.jac`` fields should be aggregated. All Jacobians must
19+ :param tensors : The tensors whose ``.jac`` fields should be aggregated. All Jacobians must
1820 have the same first dimension (number of outputs).
1921 :param aggregator: The aggregator used to reduce the Jacobians into gradients.
20- :param retain_jac: Whether to preserve the ``.jac`` fields of the parameters .
22+ :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors .
2123 """
2224
23- params_ = list [TensorWithJac ]()
24- for p in params :
25- if not hasattr (p , "jac" ):
25+ tensors_ = list [TensorWithJac ]()
26+ for t in tensors :
27+ if not hasattr (t , "jac" ):
2628 raise ValueError (
2729 "Some `jac` fields were not populated. Did you use `autojac.backward` before"
2830 "calling `jac_to_grad`?"
2931 )
30- p_ = cast (TensorWithJac , p )
31- params_ .append (p_ )
32+ t_ = cast (TensorWithJac , t )
33+ tensors_ .append (t_ )
3234
33- if len (params_ ) == 0 :
35+ if len (tensors_ ) == 0 :
3436 return
3537
36- jacobians = [p .jac for p in params_ ]
38+ jacobians = [t .jac for t in tensors_ ]
3739
3840 if not all ([jacobian .shape [0 ] == jacobians [0 ].shape [0 ] for jacobian in jacobians [1 :]]):
3941 raise ValueError ("All Jacobians should have the same number of rows." )
4042
4143 jacobian_matrix = _unite_jacobians (jacobians )
4244 gradient_vector = aggregator (jacobian_matrix )
43- gradients = _disunite_gradient (gradient_vector , jacobians , params_ )
44- accumulate_grads (params_ , gradients )
45+ gradients = _disunite_gradient (gradient_vector , jacobians , tensors_ )
46+ accumulate_grads (tensors_ , gradients )
4547
4648 if not retain_jac :
47- _free_jacs (params_ )
49+ _free_jacs (tensors_ )
4850
4951
5052def _unite_jacobians (jacobians : list [Tensor ]) -> Tensor :
@@ -54,7 +56,7 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
5456
5557
5658def _disunite_gradient (
57- gradient_vector : Tensor , jacobians : list [Tensor ], params : list [TensorWithJac ]
59+ gradient_vector : Tensor , jacobians : list [Tensor ], tensors : list [TensorWithJac ]
5860) -> list [Tensor ]:
5961 gradient_vectors = []
6062 start = 0
@@ -63,16 +65,16 @@ def _disunite_gradient(
6365 current_gradient_vector = gradient_vector [start :end ]
6466 gradient_vectors .append (current_gradient_vector )
6567 start = end
66- gradients = [g .view (param .shape ) for param , g in zip (params , gradient_vectors , strict = True )]
68+ gradients = [g .view (t .shape ) for t , g in zip (tensors , gradient_vectors , strict = True )]
6769 return gradients
6870
6971
70- def _free_jacs (params : Iterable [TensorWithJac ]) -> None :
72+ def _free_jacs (tensors : Iterable [TensorWithJac ]) -> None :
7173 """
72- Deletes the ``.jac`` field of the provided parameters .
74+ Deletes the ``.jac`` field of the provided tensors .
7375
74- :param params : The parameters whose ``.jac`` fields should be cleared.
76+ :param tensors : The tensors whose ``.jac`` fields should be cleared.
7577 """
7678
77- for p in params :
78- del p .jac
79+ for t in tensors :
80+ del t .jac
0 commit comments