Skip to content

Commit 5062b06

Browse files
committed
Add concat implementation for when the physical dimension on which to concatenate has to be added
1 parent 5a25233 commit 5062b06

1 file changed

Lines changed: 21 additions & 6 deletions

File tree

src/torchjd/sparse/_diagonal_sparse_tensor.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -549,13 +549,28 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
549549
assert len(indices) <= 1
550550

551551
if len(indices) == 0:
552-
# TODO: create new physical dimension on which we'll concatenate
553-
raise NotImplementedError()
554-
555-
pdim = indices[0][0]
552+
# Add a physical dimension pdim on which we can concatenate the physicals such that this
553+
# translates into a concatenation of the virtuals on virtual dimension dim.
554+
555+
# Stride-based representation:
556+
# new_stride_column = torch.zeros(ref_tensor.ndim, dtype=torch.int)
557+
# new_stride_column[dim] = ref_virtual_dim_size
558+
559+
pdim = ref_tensor.physical.ndim
560+
new_v_to_ps = [[d for d in pdims] for pdims in ref_tensor.v_to_ps]
561+
new_v_to_ps[dim] = [pdim] + new_v_to_ps[dim]
562+
new_v_to_ps, destination = encode_v_to_ps(new_v_to_ps)
563+
source = list(range(len(destination)))
564+
physicals = [t.physical.unsqueeze(-1).movedim(source, destination) for t in tensors_]
565+
else:
566+
# Such a physical dimension already exists. Note that an alternative implementation would be
567+
# to simply always add the physical dimension, and squash it if it ends up being not needed.
568+
physicals = [t.physical for t in tensors_]
569+
pdim = indices[0][0]
570+
new_v_to_ps = ref_tensor.v_to_ps
556571

557-
new_physical = aten.cat.default([t.physical for t in tensors_], dim=pdim)
558-
return DiagonalSparseTensor(new_physical, ref_tensor.v_to_ps)
572+
new_physical = aten.cat.default(physicals, dim=pdim)
573+
return DiagonalSparseTensor(new_physical, new_v_to_ps)
559574

560575

561576
def unsquash_pdim_from_strides(

0 commit comments

Comments
 (0)