@@ -13,12 +13,16 @@ def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> N
1313 param_ = cast (TensorWithJac , param )
1414 param_ .jac += jac
1515 else :
16- # TODO: this could be a serious memory issue
17- # We clone the value because we do not want subsequent accumulations to also affect
18- # this value (in case it is still used outside). We do not detach from the
19- # computation graph because the value can have grad_fn that we want to keep track of
20- # (in case it was obtained via create_graph=True and a differentiable aggregator).
21- param .__setattr__ ("jac" , jac .clone ())
16+ # We do not clone the value to save memory and time, so subsequent modifications of
17+ # the value of key.grad (subsequent accumulations) will also affect the value of
18+ # gradients[key] and outside changes to the value of gradients[key] will also affect
19+ # the value of key.grad. So to be safe, the values of gradients should not be used
20+ # anymore after being passed to this function.
21+ #
22+ # We do not detach from the computation graph because the value can have grad_fn
23+ # that we want to keep track of (in case it was obtained via create_graph=True and a
24+ # differentiable aggregator).
25+ param .__setattr__ ("jac" , jac )
2226
2327
2428def _accumulate_grads (params : Iterable [Tensor ], gradients : Iterable [Tensor ]) -> None :
@@ -27,7 +31,7 @@ def _accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) ->
2731 if hasattr (param , "grad" ) and param .grad is not None :
2832 param .grad += grad
2933 else :
30- param .grad = grad . clone ()
34+ param .grad = grad
3135
3236
3337def _check_expects_grad (tensor : Tensor ) -> None :
0 commit comments