@@ -16,7 +16,7 @@ class TensorWithJac(Tensor):
1616
1717def accumulate_jacs (params : Iterable [Tensor ], jacobians : Iterable [Tensor ]) -> None :
1818 for param , jac in zip (params , jacobians , strict = True ):
19- _check_expects_grad (param )
19+ _check_expects_grad (param , field_name = ".jac" )
2020 # We that the shape is correct to be consistent with torch, that checks that the grad
2121 # shape is correct before assigning it.
2222 if jac .shape [1 :] != param .shape :
@@ -43,17 +43,17 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No
4343
4444def accumulate_grads (params : Iterable [Tensor ], gradients : Iterable [Tensor ]) -> None :
4545 for param , grad in zip (params , gradients , strict = True ):
46- _check_expects_grad (param )
46+ _check_expects_grad (param , field_name = ".grad" )
4747 if hasattr (param , "grad" ) and param .grad is not None :
4848 param .grad += grad
4949 else :
5050 param .grad = grad
5151
5252
53- def _check_expects_grad (tensor : Tensor ) -> None :
53+ def _check_expects_grad (tensor : Tensor , field_name : str ) -> None :
5454 if not _expects_grad (tensor ):
5555 raise ValueError (
56- "Cannot populate the .grad field of a Tensor that does not satisfy:"
56+ f "Cannot populate the { field_name } field of a Tensor that does not satisfy:\n "
5757 "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`."
5858 )
5959
0 commit comments