@@ -549,13 +549,28 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
549549 assert len (indices ) <= 1
550550
551551 if len (indices ) == 0 :
552- # TODO: create new physical dimension on which we'll concatenate
553- raise NotImplementedError ()
554-
555- pdim = indices [0 ][0 ]
552+ # Add a physical dimension pdim on which we can concatenate the physicals such that this
553+ # translates into a concatenation of the virtuals on virtual dimension dim.
554+
555+ # Stride-based representation:
556+ # new_stride_column = torch.zeros(ref_tensor.ndim, dtype=torch.int)
557+ # new_stride_column[dim] = ref_virtual_dim_size
558+
559+ pdim = ref_tensor .physical .ndim
560+ new_v_to_ps = [[d for d in pdims ] for pdims in ref_tensor .v_to_ps ]
561+ new_v_to_ps [dim ] = [pdim ] + new_v_to_ps [dim ]
562+ new_v_to_ps , destination = encode_v_to_ps (new_v_to_ps )
563+ source = list (range (len (destination )))
564+ physicals = [t .physical .unsqueeze (- 1 ).movedim (source , destination ) for t in tensors_ ]
565+ else :
566+ # Such a physical dimension already exists. Note that an alternative implementation would be
567+ # to simply always add the physical dimension, and squash it if it ends up being not needed.
568+ physicals = [t .physical for t in tensors_ ]
569+ pdim = indices [0 ][0 ]
570+ new_v_to_ps = ref_tensor .v_to_ps
556571
557- new_physical = aten .cat .default ([ t . physical for t in tensors_ ] , dim = pdim )
558- return DiagonalSparseTensor (new_physical , ref_tensor . v_to_ps )
572+ new_physical = aten .cat .default (physicals , dim = pdim )
573+ return DiagonalSparseTensor (new_physical , new_v_to_ps )
559574
560575
561576def unsquash_pdim_from_strides (
0 commit comments