Skip to content

Commit 674f6ad

Browse files
committed
Improve error message and usage example of jac_to_grad
1 parent c13a75b commit 674f6ad

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def jac_to_grad(
3535
3636
>>> import torch
3737
>>>
38-
>>> from torchjd.autojac import backward, jac_to_grad
3938
>>> from torchjd.aggregation import UPGrad
39+
>>> from torchjd.autojac import backward, jac_to_grad
4040
>>>
4141
>>> param = torch.tensor([1., 2.], requires_grad=True)
4242
>>> # Compute arbitrary quantities that are function of param
@@ -48,16 +48,16 @@ def jac_to_grad(
4848
>>> param.grad
4949
tensor([-1., 1.])
5050
51-
The ``.grad`` field of ``param`` now contains the aggregation of the Jacobian of
51+
The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of
5252
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
5353
"""
5454

5555
tensors_ = list[TensorWithJac]()
5656
for t in tensors:
5757
if not hasattr(t, "jac"):
5858
raise ValueError(
59-
"Some `jac` fields were not populated. Did you use `autojac.backward` before"
60-
"calling `jac_to_grad`?"
59+
"Some `jac` fields were not populated. Did you use `autojac.backward` or "
60+
"`autojac.mtl_backward` before calling `jac_to_grad`?"
6161
)
6262
t_ = cast(TensorWithJac, t)
6363
tensors_.append(t_)

0 commit comments

Comments
 (0)