We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 0bcf1c0 commit 57e5c6dCopy full SHA for 57e5c6d
1 file changed
src/torchjd/utils/_jac_to_grad.py
@@ -37,7 +37,8 @@ def jac_to_grad(
37
38
jacobians = [p.jac for p in params_]
39
40
- # TODO: check that the Jacobian shapes match
+ if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]):
41
+ raise ValueError("All Jacobians should have the same number of rows.")
42
43
jacobian_matrix = _unite_jacobians(jacobians)
44
gradient_vector = aggregator(jacobian_matrix)
0 commit comments