Skip to content

Commit 99e8bea

Browse files
authored
refactor(autojac): Remove useless checks in Jac (#276)
* Remove check that is impossible to fail in _differentiate * Remove check that is impossible to fail in _extract_sub_matrices * Remove check that is impossible to fail in _reshape_matrices
1 parent a2e77c7 commit 99e8bea

File tree

1 file changed

+1
-20
lines changed
  • src/torchjd/autojac/_transform

1 file changed

+1
-20
lines changed

src/torchjd/autojac/_transform/jac.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,7 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
4343
if len(inputs) == 0:
4444
return tuple()
4545

46-
n_outputs = len(outputs)
47-
if len(jac_outputs) != n_outputs:
48-
raise ValueError(
49-
"Parameters `outputs` and `jac_outputs` should be sequences of the same length."
50-
f"Found `len(outputs) = {n_outputs}` and `len(jac_outputs) = {len(jac_outputs)}`."
51-
)
52-
53-
if n_outputs == 0:
46+
if len(outputs) == 0:
5447
return tuple(
5548
[
5649
torch.empty((0,) + input.shape, device=input.device, dtype=input.dtype)
@@ -123,22 +116,10 @@ def _get_jac_matrix_chunk(
123116

124117
def _extract_sub_matrices(matrix: Tensor, lengths: Sequence[int]) -> list[Tensor]:
125118
cumulative_lengths = [*accumulate(lengths)]
126-
127-
if cumulative_lengths[-1] != matrix.shape[1]:
128-
raise ValueError(
129-
"The sum of the provided lengths should be equal to the number of columns in the "
130-
"provided matrix."
131-
)
132-
133119
start_indices = [0] + cumulative_lengths[:-1]
134120
end_indices = cumulative_lengths
135121
return [matrix[:, start:end] for start, end in zip(start_indices, end_indices)]
136122

137123

138124
def _reshape_matrices(matrices: Sequence[Tensor], shapes: Sequence[Size]) -> Sequence[Tensor]:
139-
if len(matrices) != len(shapes):
140-
raise ValueError(
141-
"Parameters `matrices` and `shapes` should contain the same number of elements."
142-
)
143-
144125
return [matrix.view((matrix.shape[0],) + shape) for matrix, shape in zip(matrices, shapes)]

0 commit comments

Comments
 (0)