We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 45c068a commit 6c3dd75Copy full SHA for 6c3dd75
src/torchjd/autojac/_transform/jac.py
@@ -116,13 +116,6 @@ def _get_jac_matrix_chunk(
116
117
def _extract_sub_matrices(matrix: Tensor, lengths: Sequence[int]) -> list[Tensor]:
118
cumulative_lengths = [*accumulate(lengths)]
119
-
120
- if cumulative_lengths[-1] != matrix.shape[1]:
121
- raise ValueError(
122
- "The sum of the provided lengths should be equal to the number of columns in the "
123
- "provided matrix."
124
- )
125
126
start_indices = [0] + cumulative_lengths[:-1]
127
end_indices = cumulative_lengths
128
return [matrix[:, start:end] for start, end in zip(start_indices, end_indices)]
0 commit comments