Skip to content

Commit 840d035

Browse files
committed
Fix concat for cases where it has to densify
1 parent 112931f commit 840d035

2 files changed

Lines changed: 21 additions & 4 deletions

File tree

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ def cat_default(tensors: list[Tensor], dim: int = 0) -> Tensor:
127127
f"basis. Found the following tensors:\n{[repr(t) for t in tensors_]} and the following "
128128
f"dim: {dim}."
129129
)
130+
if any(t.physical.shape != ref_tensor.physical.shape for t in tensors_[1:]):
131+
# This can happen in the following example:
132+
# t1 = SLT([1 2 3], [[2]])
133+
# t2 = SLT([4 5 6 7], [[2]])
134+
# The expected result would be 1 0 2 0 3 4 0 5 0 6 0 7, but this is not representable
135+
# efficiently as an SLT (because there is no 0 between 3 and 4, and both physicals have a
136+
# different shape so we can't just stack them).
137+
138+
# TODO: Maybe a partial densify is possible rather than a full densify.
139+
print_fallback(aten.cat.default, (tensors, dim), {})
140+
return aten.cat.default([unwrap_to_dense(t) for t in tensors])
130141

131142
# We need to try to find the (pretty sure it either does not exist or is unique) physical
132143
# dimension that makes us only move on virtual dimension dim. It also needs to be such that

tests/unit/sparse/test_sparse_latticed_tensor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,21 +332,27 @@ def test_get_column_indices(source: list[int], destination: list[int], ndim: int
332332

333333

334334
@mark.parametrize(
335-
["slt_args", "dim"],
335+
["slt_args", "dim", "expected_densify"],
336336
[
337-
([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1),
338-
([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1),
337+
([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1, False),
338+
([([3], tensor([[2]])), ([4], tensor([[2]]))], 0, True),
339+
([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1, False),
339340
],
340341
)
341342
def test_concatenate(
342343
slt_args: list[tuple[list[int], Tensor]],
343344
dim: int,
345+
expected_densify: bool,
344346
):
345347
tensors = [SparseLatticedTensor(randn_(pshape), basis) for pshape, basis in slt_args]
346348
res = aten.cat.default(tensors, dim)
347349
expected = aten.cat.default([t.to_dense() for t in tensors], dim)
348350

349-
assert isinstance(res, SparseLatticedTensor)
351+
if expected_densify:
352+
assert not isinstance(res, SparseLatticedTensor)
353+
else:
354+
assert isinstance(res, SparseLatticedTensor)
355+
350356
assert torch.all(torch.eq(res.to_dense(), expected))
351357

352358

0 commit comments

Comments
 (0)