Skip to content

Commit 51a579a

Browse files
committed
Fix tests
1 parent 7f2c974 commit 51a579a

1 file changed

Lines changed: 145 additions & 67 deletions

File tree

tests/unit/sparse/test_structured_sparse_tensor.py

Lines changed: 145 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim
1414
from torchjd.sparse._structured_sparse_tensor import (
1515
StructuredSparseTensor,
16-
encode_by_order,
1716
fix_ungrouped_dims,
1817
fix_zero_stride_columns,
1918
get_groupings,
@@ -24,7 +23,7 @@ def test_to_dense():
2423
n = 2
2524
m = 3
2625
a = randn_([n, m])
27-
b = StructuredSparseTensor(a, [[0], [1], [1], [0]])
26+
b = StructuredSparseTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]]))
2827
c = b.to_dense()
2928

3029
for i in range(n):
@@ -34,32 +33,56 @@ def test_to_dense():
3433

3534
def test_to_dense2():
3635
a = tensor_([1.0, 2.0, 3.0])
37-
b = StructuredSparseTensor(a, [[0, 0]])
36+
b = StructuredSparseTensor(a, tensor([[4]]))
3837
c = b.to_dense()
3938
expected = tensor_([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0])
4039
assert torch.all(torch.eq(c, expected))
4140

4241

4342
@mark.parametrize(
44-
["a_pshape", "a_v_to_ps", "b_pshape", "b_v_to_ps", "a_indices", "b_indices", "output_indices"],
43+
["a_pshape", "a_strides", "b_pshape", "b_strides", "a_indices", "b_indices", "output_indices"],
4544
[
46-
([4, 5], [[0], [0], [1]], [4, 5], [[0], [1], [1]], [0, 1, 2], [0, 2, 3], [0, 1, 3]),
47-
([2, 3, 5], [[0, 1], [2, 0]], [10, 3], [[0], [1]], [0, 1], [1, 2], [0, 2]),
48-
([2, 3], [[0, 1]], [6], [[0]], [0], [0], []),
49-
([6, 2, 3], [[0], [1], [2]], [2, 3], [[0, 1], [0], [1]], [0, 1, 2], [0, 1, 2], [0, 1, 2]),
45+
(
46+
[4, 5],
47+
tensor([[1, 0], [1, 0], [0, 1]]),
48+
[4, 5],
49+
tensor([[1, 0], [0, 1], [0, 1]]),
50+
[0, 1, 2],
51+
[0, 2, 3],
52+
[0, 1, 3],
53+
),
54+
(
55+
[2, 3, 5],
56+
tensor([[3, 1, 0], [1, 0, 2]]),
57+
[10, 3],
58+
tensor([[1, 0], [0, 1]]),
59+
[0, 1],
60+
[1, 2],
61+
[0, 2],
62+
),
63+
([2, 3], tensor([[3, 1]]), [6], tensor([[1]]), [0], [0], []),
64+
(
65+
[6, 2, 3],
66+
tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
67+
[2, 3],
68+
tensor([[3, 1], [1, 0], [0, 1]]),
69+
[0, 1, 2],
70+
[0, 1, 2],
71+
[0, 1, 2],
72+
),
5073
],
5174
)
5275
def test_einsum(
5376
a_pshape: list[int],
54-
a_v_to_ps: list[list[int]],
77+
a_strides: Tensor,
5578
b_pshape: list[int],
56-
b_v_to_ps: list[list[int]],
79+
b_strides: Tensor,
5780
a_indices: list[int],
5881
b_indices: list[int],
5982
output_indices: list[int],
6083
):
61-
a = StructuredSparseTensor(randn_(a_pshape), a_v_to_ps)
62-
b = StructuredSparseTensor(randn_(b_pshape), b_v_to_ps)
84+
a = StructuredSparseTensor(randn_(a_pshape), a_strides)
85+
b = StructuredSparseTensor(randn_(b_pshape), b_strides)
6386

6487
res = einsum((a, a_indices), (b, b_indices), output=output_indices)
6588

@@ -80,15 +103,15 @@ def test_einsum(
80103
)
81104
def test_structured_sparse_tensor_scalar(shape: list[int]):
82105
a = randn_(shape)
83-
b = StructuredSparseTensor(a, [[dim] for dim in range(len(shape))])
106+
b = StructuredSparseTensor(a, torch.eye(len(shape), dtype=torch.int64))
84107

85108
assert_close(a, b.to_dense())
86109

87110

88111
@mark.parametrize("dim", [2, 3, 4, 5, 10])
89112
def test_diag_equivalence(dim: int):
90113
a = randn_([dim])
91-
b = StructuredSparseTensor(a, [[0], [0]])
114+
b = StructuredSparseTensor(a, tensor([[1], [1]]))
92115

93116
diag_a = torch.diag(a)
94117

@@ -98,7 +121,7 @@ def test_diag_equivalence(dim: int):
98121
def test_three_virtual_single_physical():
99122
dim = 10
100123
a = randn_([dim])
101-
b = StructuredSparseTensor(a, [[0], [0], [0]])
124+
b = StructuredSparseTensor(a, tensor([[1], [1], [1]]))
102125

103126
expected = zeros_([dim, dim, dim])
104127
for i in range(dim):
@@ -111,7 +134,7 @@ def test_three_virtual_single_physical():
111134
def test_pointwise(func):
112135
dim = 10
113136
a = randn_([dim])
114-
b = StructuredSparseTensor(a, [[0], [0]])
137+
b = StructuredSparseTensor(a, tensor([[1], [1]]))
115138
c = b.to_dense()
116139
res = func(b)
117140
assert isinstance(res, StructuredSparseTensor)
@@ -123,7 +146,7 @@ def test_pointwise(func):
123146
def test_inplace_pointwise(func):
124147
dim = 10
125148
a = randn_([dim])
126-
b = StructuredSparseTensor(a, [[0], [0]])
149+
b = StructuredSparseTensor(a, tensor([[1], [1]]))
127150
c = b.to_dense()
128151
func(b)
129152
assert isinstance(b, StructuredSparseTensor)
@@ -135,69 +158,114 @@ def test_inplace_pointwise(func):
135158
def test_unary(func):
136159
dim = 10
137160
a = randn_([dim])
138-
b = StructuredSparseTensor(a, [[0], [0]])
161+
b = StructuredSparseTensor(a, tensor([[1], [1]]))
139162
c = b.to_dense()
140163

141164
res = func(b)
142165
assert_close(res.to_dense(), func(c))
143166

144167

145168
@mark.parametrize(
146-
["physical_shape", "v_to_ps", "target_shape", "expected_physical_shape", "expected_v_to_ps"],
169+
["physical_shape", "strides", "target_shape", "expected_physical_shape", "expected_strides"],
147170
[
148-
([2, 3], [[0], [0], [1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # no change of shape
149-
([2, 3], [[0], [0, 1]], [2, 6], [2, 3], [[0], [0, 1]]), # no change of shape
150-
([2, 3], [[0], [0], [1]], [2, 6], [2, 3], [[0], [0, 1]]), # squashing 2 dims
151-
([2, 3], [[0], [0, 1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # unsquashing into 2 dims
152-
([2, 3], [[0, 0, 1]], [2, 6], [2, 3], [[0], [0, 1]]), # unsquashing into 2 dims
153-
([2, 3], [[0], [0], [1]], [12], [2, 3], [[0, 0, 1]]), # squashing 3 dims
154-
([2, 3], [[0, 0, 1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # unsquashing into 3 dims
155-
([4], [[0], [0]], [2, 2, 4], [2, 2], [[0], [1], [0, 1]]), # unsquashing physical dim
156-
([4], [[0], [0]], [4, 2, 2], [2, 2], [[0, 1], [0], [1]]), # unsquashing physical dim
157-
([2, 3, 4], [[0], [0], [1], [2]], [4, 12], [2, 12], [[0, 0], [1]]), # world boss
158-
([2, 12], [[0, 0], [1]], [2, 2, 3, 4], [2, 3, 4], [[0], [0], [1], [2]]), # world boss
171+
(
172+
[2, 3],
173+
tensor([[1, 0], [1, 0], [0, 1]]),
174+
[2, 2, 3],
175+
[2, 3],
176+
tensor([[1, 0], [1, 0], [0, 1]]),
177+
), # no change of shape
178+
(
179+
[2, 3],
180+
tensor([[1, 0], [3, 1]]),
181+
[2, 6],
182+
[2, 3],
183+
tensor([[1, 0], [3, 1]]),
184+
), # no change of shape
185+
(
186+
[2, 3],
187+
tensor([[1, 0], [1, 0], [0, 1]]),
188+
[2, 6],
189+
[2, 3],
190+
tensor([[1, 0], [3, 1]]),
191+
), # squashing 2 dims
192+
(
193+
[2, 3],
194+
tensor([[1, 0], [3, 1]]),
195+
[2, 2, 3],
196+
[2, 3],
197+
tensor([[1, 0], [1, 0], [0, 1]]),
198+
), # unsquashing into 2 dims
199+
(
200+
[2, 3],
201+
tensor([[9, 1]]),
202+
[2, 6],
203+
[2, 3],
204+
tensor([[1, 0], [3, 1]]),
205+
), # unsquashing into 2 dims
206+
(
207+
[2, 3],
208+
tensor([[1, 0], [1, 0], [0, 1]]),
209+
[12],
210+
[2, 3],
211+
tensor([[9, 1]]),
212+
), # squashing 3 dims
213+
(
214+
[2, 3],
215+
tensor([[9, 1]]),
216+
[2, 2, 3],
217+
[2, 3],
218+
tensor([[1, 0], [1, 0], [0, 1]]),
219+
), # unsquashing into 3 dims
220+
(
221+
[4],
222+
tensor([[1], [1]]),
223+
[2, 2, 4],
224+
[2, 2],
225+
tensor([[1, 0], [0, 1], [2, 1]]),
226+
), # unsquashing physical dim
227+
(
228+
[4],
229+
tensor([[1], [1]]),
230+
[4, 2, 2],
231+
[2, 2],
232+
tensor([[2, 1], [1, 0], [0, 1]]),
233+
), # unsquashing physical dim
234+
(
235+
[2, 3, 4],
236+
tensor([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]),
237+
[4, 12],
238+
[2, 12],
239+
tensor([[3, 0], [0, 1]]),
240+
), # world boss
241+
(
242+
[2, 12],
243+
tensor([[3, 0], [0, 1]]),
244+
[2, 2, 3, 4],
245+
[2, 3, 4],
246+
tensor([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]),
247+
), # world boss
159248
],
160249
)
161250
def test_view(
162251
physical_shape: list[int],
163-
v_to_ps: list[list[int]],
252+
strides: Tensor,
164253
target_shape: list[int],
165254
expected_physical_shape: list[int],
166-
expected_v_to_ps: list[list[int]],
255+
expected_strides: Tensor,
167256
):
168257
a = randn_(tuple(physical_shape))
169-
t = StructuredSparseTensor(a, v_to_ps)
258+
t = StructuredSparseTensor(a, strides)
170259

171260
result = aten.view.default(t, target_shape)
172261
expected = t.to_dense().reshape(target_shape)
173262

174263
assert isinstance(result, StructuredSparseTensor)
175264
assert list(result.physical.shape) == expected_physical_shape
176-
assert result.v_to_ps == expected_v_to_ps
265+
assert torch.equal(result.strides, expected_strides)
177266
assert torch.all(torch.eq(result.to_dense(), expected))
178267

179268

180-
@mark.parametrize(
181-
["input", "expected_output", "expected_destination"],
182-
[
183-
([0, 1, 0, 2, 1, 3], [0, 1, 0, 2, 1, 3], [0, 1, 2, 3]), # trivial
184-
([1, 0, 3, 2, 1], [0, 1, 2, 3, 0], [1, 0, 3, 2]),
185-
([1, 0, 3, 2], [0, 1, 2, 3], [1, 0, 3, 2]),
186-
([0, 2, 0, 1], [0, 1, 0, 2], [0, 2, 1]),
187-
([1, 0, 0, 1], [0, 1, 1, 0], [1, 0]),
188-
],
189-
)
190-
def test_encode_by_order(
191-
input: list[int],
192-
expected_output: list[int],
193-
expected_destination: list[int],
194-
):
195-
output, destination = encode_by_order(input)
196-
197-
assert output == expected_output
198-
assert destination == expected_destination
199-
200-
201269
@mark.parametrize(
202270
["pshape", "strides", "expected"],
203271
[
@@ -214,24 +282,34 @@ def test_get_groupings(pshape: list[int], strides: torch.Tensor, expected: list[
214282

215283

216284
@mark.parametrize(
217-
["physical_shape", "v_to_ps", "expected_physical_shape", "expected_v_to_ps"],
285+
["physical_shape", "strides", "expected_physical_shape", "expected_strides"],
218286
[
219-
([3, 4, 5], [[0, 1, 2], [2, 0, 1], [2]], [12, 5], [[0, 1], [1, 0], [1]]),
220-
([32, 20, 8], [[0], [1, 0], [2]], [32, 20, 8], [[0], [1, 0], [2]]),
221-
([3, 3, 4], [[0, 1], [1, 2]], [3, 3, 4], [[0, 1], [1, 2]]),
287+
(
288+
[3, 4, 5],
289+
tensor([[20, 5, 1], [4, 1, 12], [0, 0, 1]]),
290+
[12, 5],
291+
tensor([[5, 1], [1, 12], [0, 1]]),
292+
),
293+
(
294+
[32, 20, 8],
295+
tensor([[1, 0, 0], [1, 32, 0], [0, 0, 1]]),
296+
[32, 20, 8],
297+
tensor([[1, 0, 0], [1, 32, 0], [0, 0, 1]]),
298+
),
299+
([3, 3, 4], tensor([[3, 1, 0], [0, 4, 1]]), [3, 3, 4], tensor([[3, 1, 0], [0, 4, 1]])),
222300
],
223301
)
224302
def test_fix_ungrouped_dims(
225303
physical_shape: list[int],
226-
v_to_ps: list[list[int]],
304+
strides: Tensor,
227305
expected_physical_shape: list[int],
228-
expected_v_to_ps: list[list[int]],
306+
expected_strides: Tensor,
229307
):
230308
physical = randn_(physical_shape)
231-
fixed_physical, fixed_v_to_ps = fix_ungrouped_dims(physical, v_to_ps)
309+
fixed_physical, fixed_strides = fix_ungrouped_dims(physical, strides)
232310

233311
assert list(fixed_physical.shape) == expected_physical_shape
234-
assert fixed_v_to_ps == expected_v_to_ps
312+
assert torch.equal(fixed_strides, expected_strides)
235313

236314

237315
@mark.parametrize(
@@ -265,15 +343,15 @@ def test_unsquash_pdim(
265343
@mark.parametrize(
266344
["sst_args", "dim"],
267345
[
268-
([([3], [[0], [0]]), ([3], [[0], [0]])], 1),
269-
([([3, 2], [[0], [1, 0]]), ([3, 2], [[0], [1, 0]])], 1),
346+
([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1),
347+
([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1),
270348
],
271349
)
272350
def test_concatenate(
273-
sst_args: list[tuple[list[int], list[list[int]]]],
351+
sst_args: list[tuple[list[int], Tensor]],
274352
dim: int,
275353
):
276-
tensors = [StructuredSparseTensor(randn_(pshape), v_to_ps) for pshape, v_to_ps in sst_args]
354+
tensors = [StructuredSparseTensor(randn_(pshape), strides) for pshape, strides in sst_args]
277355
res = aten.cat.default(tensors, dim)
278356
expected = aten.cat.default([t.to_dense() for t in tensors], dim)
279357

0 commit comments

Comments
 (0)