Skip to content

Commit 5a25233

Browse files
committed
Add basic implementation of cat_default for when all strides match and the pdim on which we concatenate already exists.
1 parent 9ede7ec commit 5a25233

1 file changed

Lines changed: 26 additions & 3 deletions

File tree

src/torchjd/sparse/_diagonal_sparse_tensor.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import wraps
33
from itertools import accumulate
44
from math import prod
5+
from typing import cast
56

67
import torch
78
from torch import Tensor, arange, meshgrid, stack, tensor, tensordot, zeros
@@ -530,9 +531,31 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
530531
print_fallback(aten.cat.default, (tensors, dim), {})
531532
return aten.cat.default([unwrap_to_dense(t) for t in tensors])
532533

533-
else:
534-
# TODO: efficient implementation when all tensors are sparse
535-
return aten.cat.default([unwrap_to_dense(t) for t in tensors])
534+
tensors_ = [cast(DiagonalSparseTensor, t) for t in tensors]
535+
ref_tensor = tensors_[0]
536+
ref_strides = ref_tensor.strides
537+
if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]):
538+
raise NotImplementedError()
539+
540+
# We need to try to find the (pretty sure it either does not exist or is unique) physical
541+
# dimension that makes us only move on virtual dimension dim. It also needs to be such that
542+
# traversing it entirely brings us exactly to the end of virtual dimension dim.
543+
544+
ref_virtual_dim_size = ref_tensor.shape[dim]
545+
indices = torch.argwhere(
546+
torch.eq(ref_strides[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size)
547+
& torch.eq(ref_strides.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size)
548+
)
549+
assert len(indices) <= 1
550+
551+
if len(indices) == 0:
552+
# TODO: create new physical dimension on which we'll concatenate
553+
raise NotImplementedError()
554+
555+
pdim = indices[0][0]
556+
557+
new_physical = aten.cat.default([t.physical for t in tensors_], dim=pdim)
558+
return DiagonalSparseTensor(new_physical, ref_tensor.v_to_ps)
536559

537560

538561
def unsquash_pdim_from_strides(

0 commit comments

Comments
 (0)