Skip to content

Commit 57e5c6d

Browse files
committed
Add check that the number of rows of the jacobians is consistant
1 parent 0bcf1c0 commit 57e5c6d

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/torchjd/utils/_jac_to_grad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def jac_to_grad(
3737

3838
jacobians = [p.jac for p in params_]
3939

40-
# TODO: check that the Jacobian shapes match
40+
if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]):
41+
raise ValueError("All Jacobians should have the same number of rows.")
4142

4243
jacobian_matrix = _unite_jacobians(jacobians)
4344
gradient_vector = aggregator(jacobian_matrix)

0 commit comments

Comments
 (0)