|
2 | 2 | from functools import wraps |
3 | 3 | from itertools import accumulate |
4 | 4 | from math import prod |
| 5 | +from typing import cast |
5 | 6 |
|
6 | 7 | import torch |
7 | 8 | from torch import Tensor, arange, meshgrid, stack, tensor, tensordot, zeros |
@@ -530,9 +531,31 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: |
530 | 531 | print_fallback(aten.cat.default, (tensors, dim), {}) |
531 | 532 | return aten.cat.default([unwrap_to_dense(t) for t in tensors]) |
532 | 533 |
|
533 | | - else: |
534 | | - # TODO: efficient implementation when all tensors are sparse |
535 | | - return aten.cat.default([unwrap_to_dense(t) for t in tensors]) |
| 534 | + tensors_ = [cast(DiagonalSparseTensor, t) for t in tensors] |
| 535 | + ref_tensor = tensors_[0] |
| 536 | + ref_strides = ref_tensor.strides |
| 537 | + if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): |
| 538 | + raise NotImplementedError() |
| 539 | + |
| 540 | + # We need to try to find the (pretty sure it either does not exist or is unique) physical |
| 541 | + # dimension that makes us only move on virtual dimension dim. It also needs to be such that |
| 542 | + # traversing it entirely brings us exactly to the end of virtual dimension dim. |
| 543 | + |
| 544 | + ref_virtual_dim_size = ref_tensor.shape[dim] |
| 545 | + indices = torch.argwhere( |
| 546 | + torch.eq(ref_strides[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) |
| 547 | + & torch.eq(ref_strides.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) |
| 548 | + ) |
| 549 | + assert len(indices) <= 1 |
| 550 | + |
| 551 | + if len(indices) == 0: |
| 552 | + # TODO: create new physical dimension on which we'll concatenate |
| 553 | + raise NotImplementedError() |
| 554 | + |
| 555 | + pdim = indices[0][0] |
| 556 | + |
| 557 | + new_physical = aten.cat.default([t.physical for t in tensors_], dim=pdim) |
| 558 | + return DiagonalSparseTensor(new_physical, ref_tensor.v_to_ps) |
536 | 559 |
|
537 | 560 |
|
538 | 561 | def unsquash_pdim_from_strides( |
|
0 commit comments