Skip to content

Commit b60a29d

Browse files
committed
unsquash_pdim_from_strides
1 parent 48de187 commit b60a29d

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

src/torchjd/sparse/_diagonal_sparse_tensor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,17 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
535535
return aten.cat.default([unwrap_to_dense(t) for t in tensors])
536536

537537

538+
def unsquash_pdim_from_strides(
539+
physical: Tensor, pdim: int, new_pdim_shape: list[int]
540+
) -> tuple[Tensor, Tensor]:
541+
new_shape = list(physical.shape)
542+
new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :]
543+
new_physical = physical.reshape(new_shape)
544+
545+
stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))])
546+
return new_physical, stride_multipliers
547+
548+
538549
def unsquash_pdim(
539550
physical: Tensor, pdim: int, new_pdim_shape: list[int]
540551
) -> tuple[Tensor, list[list[int]]]:

0 commit comments

Comments
 (0)