Skip to content

Commit 1394395

Browse files
committed
Fix error message in _check_expects_grad
1 parent 87b66f8 commit 1394395

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/torchjd/autojac/_accumulation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class TensorWithJac(Tensor):
1616

1717
def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None:
1818
for param, jac in zip(params, jacobians, strict=True):
19-
_check_expects_grad(param)
19+
_check_expects_grad(param, field_name=".jac")
2020
# We that the shape is correct to be consistent with torch, that checks that the grad
2121
# shape is correct before assigning it.
2222
if jac.shape[1:] != param.shape:
@@ -43,17 +43,17 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No
4343

4444
def accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None:
4545
for param, grad in zip(params, gradients, strict=True):
46-
_check_expects_grad(param)
46+
_check_expects_grad(param, field_name=".grad")
4747
if hasattr(param, "grad") and param.grad is not None:
4848
param.grad += grad
4949
else:
5050
param.grad = grad
5151

5252

53-
def _check_expects_grad(tensor: Tensor) -> None:
53+
def _check_expects_grad(tensor: Tensor, field_name: str) -> None:
5454
if not _expects_grad(tensor):
5555
raise ValueError(
56-
"Cannot populate the .grad field of a Tensor that does not satisfy:"
56+
f"Cannot populate the {field_name} field of a Tensor that does not satisfy:\n"
5757
"`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`."
5858
)
5959

0 commit comments

Comments
 (0)