Skip to content

Commit 66c2210

Browse files
committed
One-line view_default.
The trick is that remaining_cT_S did not really depend on stride_row, because the floor_division would give the same result if stride_row was not removed. So remaining_cT_S can be precomputed by pre-dividing. Similarly, the modulo can be done in parallel. So there's no need for any for-loop anymore, all can be computed at once.
1 parent fac9c72 commit 66c2210

1 file changed

Lines changed: 2 additions & 11 deletions

File tree

  • src/torchjd/sparse/_aten_function_overrides

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,8 @@ def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor:
4949
S = t.strides
5050
vshape = list(t.shape)
5151
c = _reverse_cumulative_product(vshape)
52-
remaining_cT_S = c @ S
53-
54-
stride_rows = list[Tensor]()
55-
for modulo in shape[::-1]:
56-
stride_row = remaining_cT_S % modulo
57-
stride_rows.append(stride_row)
58-
remaining_cT_S = (remaining_cT_S - stride_row) // modulo
59-
# I think we could skip the - stride_row because the floor div will handle it for us, but it
60-
# will make code harder to understand.
61-
62-
new_strides = torch.stack(stride_rows[::-1], dim=0)
52+
c_prime = _reverse_cumulative_product(shape)
53+
new_strides = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1)
6354
return to_most_efficient_tensor(t.physical, new_strides)
6455

6556

0 commit comments

Comments
 (0)