1313from torchjd .sparse ._aten_function_overrides .shape import unsquash_pdim
1414from torchjd .sparse ._structured_sparse_tensor import (
1515 StructuredSparseTensor ,
16- encode_by_order ,
1716 fix_ungrouped_dims ,
1817 fix_zero_stride_columns ,
1918 get_groupings ,
@@ -24,7 +23,7 @@ def test_to_dense():
2423 n = 2
2524 m = 3
2625 a = randn_ ([n , m ])
27- b = StructuredSparseTensor (a , [[ 0 ], [1 ], [1 ], [0 ]])
26+ b = StructuredSparseTensor (a , tensor ([[ 1 , 0 ], [0 , 1 ], [0 , 1 ], [1 , 0 ]]) )
2827 c = b .to_dense ()
2928
3029 for i in range (n ):
@@ -34,32 +33,56 @@ def test_to_dense():
3433
3534def test_to_dense2 ():
3635 a = tensor_ ([1.0 , 2.0 , 3.0 ])
37- b = StructuredSparseTensor (a , [[ 0 , 0 ]] )
36+ b = StructuredSparseTensor (a , tensor ([[ 4 ]]) )
3837 c = b .to_dense ()
3938 expected = tensor_ ([1.0 , 0.0 , 0.0 , 0.0 , 2.0 , 0.0 , 0.0 , 0.0 , 3.0 ])
4039 assert torch .all (torch .eq (c , expected ))
4140
4241
4342@mark .parametrize (
44- ["a_pshape" , "a_v_to_ps " , "b_pshape" , "b_v_to_ps " , "a_indices" , "b_indices" , "output_indices" ],
43+ ["a_pshape" , "a_strides " , "b_pshape" , "b_strides " , "a_indices" , "b_indices" , "output_indices" ],
4544 [
46- ([4 , 5 ], [[0 ], [0 ], [1 ]], [4 , 5 ], [[0 ], [1 ], [1 ]], [0 , 1 , 2 ], [0 , 2 , 3 ], [0 , 1 , 3 ]),
47- ([2 , 3 , 5 ], [[0 , 1 ], [2 , 0 ]], [10 , 3 ], [[0 ], [1 ]], [0 , 1 ], [1 , 2 ], [0 , 2 ]),
48- ([2 , 3 ], [[0 , 1 ]], [6 ], [[0 ]], [0 ], [0 ], []),
49- ([6 , 2 , 3 ], [[0 ], [1 ], [2 ]], [2 , 3 ], [[0 , 1 ], [0 ], [1 ]], [0 , 1 , 2 ], [0 , 1 , 2 ], [0 , 1 , 2 ]),
45+ (
46+ [4 , 5 ],
47+ tensor ([[1 , 0 ], [1 , 0 ], [0 , 1 ]]),
48+ [4 , 5 ],
49+ tensor ([[1 , 0 ], [0 , 1 ], [0 , 1 ]]),
50+ [0 , 1 , 2 ],
51+ [0 , 2 , 3 ],
52+ [0 , 1 , 3 ],
53+ ),
54+ (
55+ [2 , 3 , 5 ],
56+ tensor ([[3 , 1 , 0 ], [1 , 0 , 2 ]]),
57+ [10 , 3 ],
58+ tensor ([[1 , 0 ], [0 , 1 ]]),
59+ [0 , 1 ],
60+ [1 , 2 ],
61+ [0 , 2 ],
62+ ),
63+ ([2 , 3 ], tensor ([[3 , 1 ]]), [6 ], tensor ([[1 ]]), [0 ], [0 ], []),
64+ (
65+ [6 , 2 , 3 ],
66+ tensor ([[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]]),
67+ [2 , 3 ],
68+ tensor ([[3 , 1 ], [1 , 0 ], [0 , 1 ]]),
69+ [0 , 1 , 2 ],
70+ [0 , 1 , 2 ],
71+ [0 , 1 , 2 ],
72+ ),
5073 ],
5174)
5275def test_einsum (
5376 a_pshape : list [int ],
54- a_v_to_ps : list [ list [ int ]] ,
77+ a_strides : Tensor ,
5578 b_pshape : list [int ],
56- b_v_to_ps : list [ list [ int ]] ,
79+ b_strides : Tensor ,
5780 a_indices : list [int ],
5881 b_indices : list [int ],
5982 output_indices : list [int ],
6083):
61- a = StructuredSparseTensor (randn_ (a_pshape ), a_v_to_ps )
62- b = StructuredSparseTensor (randn_ (b_pshape ), b_v_to_ps )
84+ a = StructuredSparseTensor (randn_ (a_pshape ), a_strides )
85+ b = StructuredSparseTensor (randn_ (b_pshape ), b_strides )
6386
6487 res = einsum ((a , a_indices ), (b , b_indices ), output = output_indices )
6588
@@ -80,15 +103,15 @@ def test_einsum(
80103)
81104def test_structured_sparse_tensor_scalar (shape : list [int ]):
82105 a = randn_ (shape )
83- b = StructuredSparseTensor (a , [[ dim ] for dim in range (len (shape ))] )
106+ b = StructuredSparseTensor (a , torch . eye (len (shape ), dtype = torch . int64 ) )
84107
85108 assert_close (a , b .to_dense ())
86109
87110
88111@mark .parametrize ("dim" , [2 , 3 , 4 , 5 , 10 ])
89112def test_diag_equivalence (dim : int ):
90113 a = randn_ ([dim ])
91- b = StructuredSparseTensor (a , [[ 0 ], [0 ]] )
114+ b = StructuredSparseTensor (a , tensor ([[ 1 ], [1 ]]) )
92115
93116 diag_a = torch .diag (a )
94117
@@ -98,7 +121,7 @@ def test_diag_equivalence(dim: int):
98121def test_three_virtual_single_physical ():
99122 dim = 10
100123 a = randn_ ([dim ])
101- b = StructuredSparseTensor (a , [[ 0 ], [0 ], [0 ]] )
124+ b = StructuredSparseTensor (a , tensor ([[ 1 ], [1 ], [1 ]]) )
102125
103126 expected = zeros_ ([dim , dim , dim ])
104127 for i in range (dim ):
@@ -111,7 +134,7 @@ def test_three_virtual_single_physical():
111134def test_pointwise (func ):
112135 dim = 10
113136 a = randn_ ([dim ])
114- b = StructuredSparseTensor (a , [[ 0 ], [0 ]] )
137+ b = StructuredSparseTensor (a , tensor ([[ 1 ], [1 ]]) )
115138 c = b .to_dense ()
116139 res = func (b )
117140 assert isinstance (res , StructuredSparseTensor )
@@ -123,7 +146,7 @@ def test_pointwise(func):
123146def test_inplace_pointwise (func ):
124147 dim = 10
125148 a = randn_ ([dim ])
126- b = StructuredSparseTensor (a , [[ 0 ], [0 ]] )
149+ b = StructuredSparseTensor (a , tensor ([[ 1 ], [1 ]]) )
127150 c = b .to_dense ()
128151 func (b )
129152 assert isinstance (b , StructuredSparseTensor )
@@ -135,69 +158,114 @@ def test_inplace_pointwise(func):
135158def test_unary (func ):
136159 dim = 10
137160 a = randn_ ([dim ])
138- b = StructuredSparseTensor (a , [[ 0 ], [0 ]] )
161+ b = StructuredSparseTensor (a , tensor ([[ 1 ], [1 ]]) )
139162 c = b .to_dense ()
140163
141164 res = func (b )
142165 assert_close (res .to_dense (), func (c ))
143166
144167
145168@mark .parametrize (
146- ["physical_shape" , "v_to_ps " , "target_shape" , "expected_physical_shape" , "expected_v_to_ps " ],
169+ ["physical_shape" , "strides " , "target_shape" , "expected_physical_shape" , "expected_strides " ],
147170 [
148- ([2 , 3 ], [[0 ], [0 ], [1 ]], [2 , 2 , 3 ], [2 , 3 ], [[0 ], [0 ], [1 ]]), # no change of shape
149- ([2 , 3 ], [[0 ], [0 , 1 ]], [2 , 6 ], [2 , 3 ], [[0 ], [0 , 1 ]]), # no change of shape
150- ([2 , 3 ], [[0 ], [0 ], [1 ]], [2 , 6 ], [2 , 3 ], [[0 ], [0 , 1 ]]), # squashing 2 dims
151- ([2 , 3 ], [[0 ], [0 , 1 ]], [2 , 2 , 3 ], [2 , 3 ], [[0 ], [0 ], [1 ]]), # unsquashing into 2 dims
152- ([2 , 3 ], [[0 , 0 , 1 ]], [2 , 6 ], [2 , 3 ], [[0 ], [0 , 1 ]]), # unsquashing into 2 dims
153- ([2 , 3 ], [[0 ], [0 ], [1 ]], [12 ], [2 , 3 ], [[0 , 0 , 1 ]]), # squashing 3 dims
154- ([2 , 3 ], [[0 , 0 , 1 ]], [2 , 2 , 3 ], [2 , 3 ], [[0 ], [0 ], [1 ]]), # unsquashing into 3 dims
155- ([4 ], [[0 ], [0 ]], [2 , 2 , 4 ], [2 , 2 ], [[0 ], [1 ], [0 , 1 ]]), # unsquashing physical dim
156- ([4 ], [[0 ], [0 ]], [4 , 2 , 2 ], [2 , 2 ], [[0 , 1 ], [0 ], [1 ]]), # unsquashing physical dim
157- ([2 , 3 , 4 ], [[0 ], [0 ], [1 ], [2 ]], [4 , 12 ], [2 , 12 ], [[0 , 0 ], [1 ]]), # world boss
158- ([2 , 12 ], [[0 , 0 ], [1 ]], [2 , 2 , 3 , 4 ], [2 , 3 , 4 ], [[0 ], [0 ], [1 ], [2 ]]), # world boss
171+ (
172+ [2 , 3 ],
173+ tensor ([[1 , 0 ], [1 , 0 ], [0 , 1 ]]),
174+ [2 , 2 , 3 ],
175+ [2 , 3 ],
176+ tensor ([[1 , 0 ], [1 , 0 ], [0 , 1 ]]),
177+ ), # no change of shape
178+ (
179+ [2 , 3 ],
180+ tensor ([[1 , 0 ], [3 , 1 ]]),
181+ [2 , 6 ],
182+ [2 , 3 ],
183+ tensor ([[1 , 0 ], [3 , 1 ]]),
184+ ), # no change of shape
185+ (
186+ [2 , 3 ],
187+ tensor ([[1 , 0 ], [1 , 0 ], [0 , 1 ]]),
188+ [2 , 6 ],
189+ [2 , 3 ],
190+ tensor ([[1 , 0 ], [3 , 1 ]]),
191+ ), # squashing 2 dims
192+ (
193+ [2 , 3 ],
194+ tensor ([[1 , 0 ], [3 , 1 ]]),
195+ [2 , 2 , 3 ],
196+ [2 , 3 ],
197+ tensor ([[1 , 0 ], [1 , 0 ], [0 , 1 ]]),
198+ ), # unsquashing into 2 dims
199+ (
200+ [2 , 3 ],
201+ tensor ([[9 , 1 ]]),
202+ [2 , 6 ],
203+ [2 , 3 ],
204+ tensor ([[1 , 0 ], [3 , 1 ]]),
205+ ), # unsquashing into 2 dims
206+ (
207+ [2 , 3 ],
208+ tensor ([[1 , 0 ], [1 , 0 ], [0 , 1 ]]),
209+ [12 ],
210+ [2 , 3 ],
211+ tensor ([[9 , 1 ]]),
212+ ), # squashing 3 dims
213+ (
214+ [2 , 3 ],
215+ tensor ([[9 , 1 ]]),
216+ [2 , 2 , 3 ],
217+ [2 , 3 ],
218+ tensor ([[1 , 0 ], [1 , 0 ], [0 , 1 ]]),
219+ ), # unsquashing into 3 dims
220+ (
221+ [4 ],
222+ tensor ([[1 ], [1 ]]),
223+ [2 , 2 , 4 ],
224+ [2 , 2 ],
225+ tensor ([[1 , 0 ], [0 , 1 ], [2 , 1 ]]),
226+ ), # unsquashing physical dim
227+ (
228+ [4 ],
229+ tensor ([[1 ], [1 ]]),
230+ [4 , 2 , 2 ],
231+ [2 , 2 ],
232+ tensor ([[2 , 1 ], [1 , 0 ], [0 , 1 ]]),
233+ ), # unsquashing physical dim
234+ (
235+ [2 , 3 , 4 ],
236+ tensor ([[1 , 0 , 0 ], [1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]]),
237+ [4 , 12 ],
238+ [2 , 12 ],
239+ tensor ([[3 , 0 ], [0 , 1 ]]),
240+ ), # world boss
241+ (
242+ [2 , 12 ],
243+ tensor ([[3 , 0 ], [0 , 1 ]]),
244+ [2 , 2 , 3 , 4 ],
245+ [2 , 3 , 4 ],
246+ tensor ([[1 , 0 , 0 ], [1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]]),
247+ ), # world boss
159248 ],
160249)
161250def test_view (
162251 physical_shape : list [int ],
163- v_to_ps : list [ list [ int ]] ,
252+ strides : Tensor ,
164253 target_shape : list [int ],
165254 expected_physical_shape : list [int ],
166- expected_v_to_ps : list [ list [ int ]] ,
255+ expected_strides : Tensor ,
167256):
168257 a = randn_ (tuple (physical_shape ))
169- t = StructuredSparseTensor (a , v_to_ps )
258+ t = StructuredSparseTensor (a , strides )
170259
171260 result = aten .view .default (t , target_shape )
172261 expected = t .to_dense ().reshape (target_shape )
173262
174263 assert isinstance (result , StructuredSparseTensor )
175264 assert list (result .physical .shape ) == expected_physical_shape
176- assert result .v_to_ps == expected_v_to_ps
265+ assert torch . equal ( result .strides , expected_strides )
177266 assert torch .all (torch .eq (result .to_dense (), expected ))
178267
179268
180- @mark .parametrize (
181- ["input" , "expected_output" , "expected_destination" ],
182- [
183- ([0 , 1 , 0 , 2 , 1 , 3 ], [0 , 1 , 0 , 2 , 1 , 3 ], [0 , 1 , 2 , 3 ]), # trivial
184- ([1 , 0 , 3 , 2 , 1 ], [0 , 1 , 2 , 3 , 0 ], [1 , 0 , 3 , 2 ]),
185- ([1 , 0 , 3 , 2 ], [0 , 1 , 2 , 3 ], [1 , 0 , 3 , 2 ]),
186- ([0 , 2 , 0 , 1 ], [0 , 1 , 0 , 2 ], [0 , 2 , 1 ]),
187- ([1 , 0 , 0 , 1 ], [0 , 1 , 1 , 0 ], [1 , 0 ]),
188- ],
189- )
190- def test_encode_by_order (
191- input : list [int ],
192- expected_output : list [int ],
193- expected_destination : list [int ],
194- ):
195- output , destination = encode_by_order (input )
196-
197- assert output == expected_output
198- assert destination == expected_destination
199-
200-
201269@mark .parametrize (
202270 ["pshape" , "strides" , "expected" ],
203271 [
@@ -214,24 +282,34 @@ def test_get_groupings(pshape: list[int], strides: torch.Tensor, expected: list[
214282
215283
216284@mark .parametrize (
217- ["physical_shape" , "v_to_ps " , "expected_physical_shape" , "expected_v_to_ps " ],
285+ ["physical_shape" , "strides " , "expected_physical_shape" , "expected_strides " ],
218286 [
219- ([3 , 4 , 5 ], [[0 , 1 , 2 ], [2 , 0 , 1 ], [2 ]], [12 , 5 ], [[0 , 1 ], [1 , 0 ], [1 ]]),
220- ([32 , 20 , 8 ], [[0 ], [1 , 0 ], [2 ]], [32 , 20 , 8 ], [[0 ], [1 , 0 ], [2 ]]),
221- ([3 , 3 , 4 ], [[0 , 1 ], [1 , 2 ]], [3 , 3 , 4 ], [[0 , 1 ], [1 , 2 ]]),
287+ (
288+ [3 , 4 , 5 ],
289+ tensor ([[20 , 5 , 1 ], [4 , 1 , 12 ], [0 , 0 , 1 ]]),
290+ [12 , 5 ],
291+ tensor ([[5 , 1 ], [1 , 12 ], [0 , 1 ]]),
292+ ),
293+ (
294+ [32 , 20 , 8 ],
295+ tensor ([[1 , 0 , 0 ], [1 , 32 , 0 ], [0 , 0 , 1 ]]),
296+ [32 , 20 , 8 ],
297+ tensor ([[1 , 0 , 0 ], [1 , 32 , 0 ], [0 , 0 , 1 ]]),
298+ ),
299+ ([3 , 3 , 4 ], tensor ([[3 , 1 , 0 ], [0 , 4 , 1 ]]), [3 , 3 , 4 ], tensor ([[3 , 1 , 0 ], [0 , 4 , 1 ]])),
222300 ],
223301)
224302def test_fix_ungrouped_dims (
225303 physical_shape : list [int ],
226- v_to_ps : list [ list [ int ]] ,
304+ strides : Tensor ,
227305 expected_physical_shape : list [int ],
228- expected_v_to_ps : list [ list [ int ]] ,
306+ expected_strides : Tensor ,
229307):
230308 physical = randn_ (physical_shape )
231- fixed_physical , fixed_v_to_ps = fix_ungrouped_dims (physical , v_to_ps )
309+ fixed_physical , fixed_strides = fix_ungrouped_dims (physical , strides )
232310
233311 assert list (fixed_physical .shape ) == expected_physical_shape
234- assert fixed_v_to_ps == expected_v_to_ps
312+ assert torch . equal ( fixed_strides , expected_strides )
235313
236314
237315@mark .parametrize (
@@ -265,15 +343,15 @@ def test_unsquash_pdim(
265343@mark .parametrize (
266344 ["sst_args" , "dim" ],
267345 [
268- ([([3 ], [[ 0 ], [0 ]]), ([3 ], [[ 0 ], [0 ]] )], 1 ),
269- ([([3 , 2 ], [[ 0 ], [1 , 0 ]]), ([3 , 2 ], [[ 0 ], [1 , 0 ]] )], 1 ),
346+ ([([3 ], tensor ([[ 1 ], [1 ]])) , ([3 ], tensor ([[ 1 ], [1 ]]) )], 1 ),
347+ ([([3 , 2 ], tensor ([[ 1 , 0 ], [1 , 3 ]])) , ([3 , 2 ], tensor ([[ 1 , 0 ], [1 , 3 ]]) )], 1 ),
270348 ],
271349)
272350def test_concatenate (
273- sst_args : list [tuple [list [int ], list [ list [ int ]] ]],
351+ sst_args : list [tuple [list [int ], Tensor ]],
274352 dim : int ,
275353):
276- tensors = [StructuredSparseTensor (randn_ (pshape ), v_to_ps ) for pshape , v_to_ps in sst_args ]
354+ tensors = [StructuredSparseTensor (randn_ (pshape ), strides ) for pshape , strides in sst_args ]
277355 res = aten .cat .default (tensors , dim )
278356 expected = aten .cat .default ([t .to_dense () for t in tensors ], dim )
279357
0 commit comments