Skip to content

Commit e0dc1a7

Browse files
committed
Add test_concatenate
1 parent 59cf10b commit e0dc1a7

1 file changed

Lines changed: 19 additions & 0 deletions

File tree

tests/unit/sparse/test_diagonal_sparse_tensor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,22 @@ def test_unsquash_pdim(
248248

249249
assert list(new_physical.shape) == expected_physical_shape
250250
assert new_encoding == expected_new_encoding
251+
252+
253+
@mark.parametrize(
254+
["dst_args", "dim"],
255+
[
256+
([([3, 4], [[0], [0, 1]]), ([3, 3, 4], [[0, 1], [1, 2]])], 0),
257+
([([3, 12], [[0, 1], [0]]), ([9, 4], [[0, 1], [0]])], 1),
258+
],
259+
)
260+
def test_concatenate(
261+
dst_args: list[tuple[list[int], list[list[int]]]],
262+
dim: int,
263+
):
264+
tensors = [DiagonalSparseTensor(randn_(pshape), v_to_ps) for pshape, v_to_ps in dst_args]
265+
res = aten.cat.default(tensors, dim)
266+
expected = aten.cat.default([t.to_dense() for t in tensors], dim)
267+
268+
assert isinstance(res, DiagonalSparseTensor)
269+
assert torch.all(torch.eq(res.to_dense(), expected))

0 commit comments

Comments
 (0)