Skip to content

Commit 6ea5983

Browse files
committed
Add check of jac shape before assigning to .jac
1 parent 57e5c6d commit 6ea5983

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

src/torchjd/utils/_accumulation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> N
2222
# We do not detach from the computation graph because the value can have grad_fn
2323
# that we want to keep track of (in case it was obtained via create_graph=True and a
2424
# differentiable aggregator).
25+
#
26+
# We also check that the shape is correct to be consistent with torch, that checks that
27+
# the grad shape is correct before assigning it.
28+
29+
if jac.shape[1:] != param.shape:
30+
raise RuntimeError(
31+
f"attempting to assign a jacobian of size '{list(jac.shape)}' to a tensor of "
32+
f"size '{list(param.shape)}'. Please ensure that the tensor and each row of the"
33+
" jacobian are the same size"
34+
)
35+
2536
param.__setattr__("jac", jac)
2637

2738

0 commit comments

Comments
 (0)