22from torch import Tensor , tensor
33from torch .ops import aten # type: ignore
44
5- from torchjd .sparse import DiagonalSparseTensor
6- from torchjd . sparse . _diagonal_sparse_tensor import (
5+ from torchjd .sparse . _structured_sparse_tensor import (
6+ StructuredSparseTensor ,
77 p_to_vs_from_v_to_ps ,
8- to_diagonal_sparse_tensor ,
98 to_most_efficient_tensor ,
9+ to_structured_sparse_tensor ,
1010)
1111
1212
1313def prepare_for_elementwise_op (
1414 t1 : Tensor | int | float , t2 : Tensor | int | float
15- ) -> tuple [DiagonalSparseTensor , DiagonalSparseTensor ]:
15+ ) -> tuple [StructuredSparseTensor , StructuredSparseTensor ]:
1616 """
17- Prepares two DSTs of the same shape from two args, one of those being a DST , and the other being
18- a DST , Tensor, int or float.
17+ Prepares two SSTs of the same shape from two args, one of those being a SST , and the other being
18+ a SST , Tensor, int or float.
1919 """
2020
21- assert isinstance (t1 , DiagonalSparseTensor ) or isinstance (t2 , DiagonalSparseTensor )
21+ assert isinstance (t1 , StructuredSparseTensor ) or isinstance (t2 , StructuredSparseTensor )
2222
2323 if isinstance (t1 , int ) or isinstance (t1 , float ):
2424 t1_ = tensor (t1 , device = t2 .device )
@@ -31,52 +31,52 @@ def prepare_for_elementwise_op(
3131 t2_ = t2
3232
3333 t1_ , t2_ = aten .broadcast_tensors .default ([t1_ , t2_ ])
34- t1_ = to_diagonal_sparse_tensor (t1_ )
35- t2_ = to_diagonal_sparse_tensor (t2_ )
34+ t1_ = to_structured_sparse_tensor (t1_ )
35+ t2_ = to_structured_sparse_tensor (t2_ )
3636
3737 return t1_ , t2_
3838
3939
40- @DiagonalSparseTensor .implements (aten .mul .Tensor )
40+ @StructuredSparseTensor .implements (aten .mul .Tensor )
4141def mul_Tensor (t1 : Tensor | int | float , t2 : Tensor | int | float ) -> Tensor :
4242 # Element-wise multiplication with broadcasting
4343 t1_ , t2_ = prepare_for_elementwise_op (t1 , t2 )
4444 all_dims = list (range (t1_ .ndim ))
4545 return einsum ((t1_ , all_dims ), (t2_ , all_dims ), output = all_dims )
4646
4747
48- @DiagonalSparseTensor .implements (aten .div .Tensor )
48+ @StructuredSparseTensor .implements (aten .div .Tensor )
4949def div_Tensor (t1 : Tensor | int | float , t2 : Tensor | int | float ) -> Tensor :
5050 t1_ , t2_ = prepare_for_elementwise_op (t1 , t2 )
51- t2_ = DiagonalSparseTensor (1.0 / t2_ .physical , t2_ .v_to_ps )
51+ t2_ = StructuredSparseTensor (1.0 / t2_ .physical , t2_ .v_to_ps )
5252 all_dims = list (range (t1_ .ndim ))
5353 return einsum ((t1_ , all_dims ), (t2_ , all_dims ), output = all_dims )
5454
5555
56- @DiagonalSparseTensor .implements (aten .mul .Scalar )
57- def mul_Scalar (t : DiagonalSparseTensor , scalar ) -> DiagonalSparseTensor :
58- # TODO: maybe it could be that scalar is a scalar DST and t is a normal tensor. Need to check
56+ @StructuredSparseTensor .implements (aten .mul .Scalar )
57+ def mul_Scalar (t : StructuredSparseTensor , scalar ) -> StructuredSparseTensor :
58+ # TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check
5959 # that
6060
61- assert isinstance (t , DiagonalSparseTensor )
61+ assert isinstance (t , StructuredSparseTensor )
6262 new_physical = aten .mul .Scalar (t .physical , scalar )
63- return DiagonalSparseTensor (new_physical , t .v_to_ps )
63+ return StructuredSparseTensor (new_physical , t .v_to_ps )
6464
6565
66- @DiagonalSparseTensor .implements (aten .add .Tensor )
66+ @StructuredSparseTensor .implements (aten .add .Tensor )
6767def add_Tensor (
6868 t1 : Tensor | int | float , t2 : Tensor | int | float , alpha : Tensor | float = 1.0
69- ) -> DiagonalSparseTensor :
69+ ) -> StructuredSparseTensor :
7070 t1_ , t2_ = prepare_for_elementwise_op (t1 , t2 )
7171
7272 if t1_ .v_to_ps == t2_ .v_to_ps :
7373 new_physical = t1_ .physical + t2_ .physical * alpha
74- return DiagonalSparseTensor (new_physical , t1_ .v_to_ps )
74+ return StructuredSparseTensor (new_physical , t1_ .v_to_ps )
7575 else :
7676 raise NotImplementedError ()
7777
7878
79- def einsum (* args : tuple [DiagonalSparseTensor , list [int ]], output : list [int ]) -> Tensor :
79+ def einsum (* args : tuple [StructuredSparseTensor , list [int ]], output : list [int ]) -> Tensor :
8080
8181 # First part of the algorithm, determine how to cluster physical indices as well as the common
8282 # p_shapes corresponding to matching v_dims. Second part translates to physical einsum.
@@ -89,7 +89,7 @@ def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) ->
8989 # get unique indices
9090 # map output indices (there can be splits)
9191 # call physical einsum
92- # build resulting dst
92+ # build resulting sst
9393
9494 # OVER
9595
@@ -104,7 +104,7 @@ def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) ->
104104 # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension.
105105 # For this reason, an index is decomposed into sub-indices that are then independently
106106 # clustered.
107- # So if an index i in args for some DiagonalSparseTensor corresponds to a v_to_ps [j, k, l],
107+ # So if an index i in args for some StructuredSparseTensor corresponds to a v_to_ps [j, k, l],
108108 # We will consider three indices (i, 0), (i, 1) and (i, 2).
109109 # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then
110110 # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in
@@ -136,7 +136,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None:
136136 tensors = list [Tensor ]()
137137 indices_to_n_pdims = dict [int , int ]()
138138 for t , indices in args :
139- assert isinstance (t , DiagonalSparseTensor )
139+ assert isinstance (t , StructuredSparseTensor )
140140 tensors .append (t .physical )
141141 for ps , index in zip (t .v_to_ps , indices ):
142142 if index in indices_to_n_pdims :
@@ -150,7 +150,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None:
150150 group_indices ([(indices [i ], sub_i ) for i , sub_i in indices_ ])
151151 # record the physical dimensions, index[v] for v in vs will end-up mapping to the same
152152 # final dimension as they were just clustered, so we can take the first, which exists as
153- # t is a valid DST .
153+ # t is a valid SST .
154154 new_indices_pair .append ([(indices [vs [0 ][0 ]], vs [0 ][1 ]) for vs in p_to_vs ])
155155
156156 current = 0
@@ -186,52 +186,52 @@ def unique_int(pair: tuple[int, int]) -> int:
186186 return to_most_efficient_tensor (physical , v_to_ps )
187187
188188
189- @DiagonalSparseTensor .implements (aten .bmm .default )
189+ @StructuredSparseTensor .implements (aten .bmm .default )
190190def bmm_default (mat1 : Tensor , mat2 : Tensor ) -> Tensor :
191- assert isinstance (mat1 , DiagonalSparseTensor ) or isinstance (mat2 , DiagonalSparseTensor )
191+ assert isinstance (mat1 , StructuredSparseTensor ) or isinstance (mat2 , StructuredSparseTensor )
192192 assert (
193193 mat1 .ndim == 3
194194 and mat2 .ndim == 3
195195 and mat1 .shape [0 ] == mat2 .shape [0 ]
196196 and mat1 .shape [2 ] == mat2 .shape [1 ]
197197 )
198198
199- mat1_ = to_diagonal_sparse_tensor (mat1 )
200- mat2_ = to_diagonal_sparse_tensor (mat2 )
199+ mat1_ = to_structured_sparse_tensor (mat1 )
200+ mat2_ = to_structured_sparse_tensor (mat2 )
201201
202202 # TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes
203203 # decompositions. If not, can reshape to common decomposition?
204204 return einsum ((mat1_ , [0 , 1 , 2 ]), (mat2_ , [0 , 2 , 3 ]), output = [0 , 1 , 3 ])
205205
206206
207- @DiagonalSparseTensor .implements (aten .mm .default )
207+ @StructuredSparseTensor .implements (aten .mm .default )
208208def mm_default (mat1 : Tensor , mat2 : Tensor ) -> Tensor :
209- assert isinstance (mat1 , DiagonalSparseTensor ) or isinstance (mat2 , DiagonalSparseTensor )
209+ assert isinstance (mat1 , StructuredSparseTensor ) or isinstance (mat2 , StructuredSparseTensor )
210210 assert mat1 .ndim == 2 and mat2 .ndim == 2 and mat1 .shape [1 ] == mat2 .shape [0 ]
211211
212- mat1_ = to_diagonal_sparse_tensor (mat1 )
213- mat2_ = to_diagonal_sparse_tensor (mat2 )
212+ mat1_ = to_structured_sparse_tensor (mat1 )
213+ mat2_ = to_structured_sparse_tensor (mat2 )
214214
215215 return einsum ((mat1_ , [0 , 1 ]), (mat2_ , [1 , 2 ]), output = [0 , 2 ])
216216
217217
218- @DiagonalSparseTensor .implements (aten .mean .default )
219- def mean_default (t : DiagonalSparseTensor ) -> Tensor :
220- assert isinstance (t , DiagonalSparseTensor )
218+ @StructuredSparseTensor .implements (aten .mean .default )
219+ def mean_default (t : StructuredSparseTensor ) -> Tensor :
220+ assert isinstance (t , StructuredSparseTensor )
221221 return aten .sum .default (t .physical ) / t .numel ()
222222
223223
224- @DiagonalSparseTensor .implements (aten .sum .default )
225- def sum_default (t : DiagonalSparseTensor ) -> Tensor :
226- assert isinstance (t , DiagonalSparseTensor )
224+ @StructuredSparseTensor .implements (aten .sum .default )
225+ def sum_default (t : StructuredSparseTensor ) -> Tensor :
226+ assert isinstance (t , StructuredSparseTensor )
227227 return aten .sum .default (t .physical )
228228
229229
230- @DiagonalSparseTensor .implements (aten .sum .dim_IntList )
230+ @StructuredSparseTensor .implements (aten .sum .dim_IntList )
231231def sum_dim_IntList (
232- t : DiagonalSparseTensor , dim : list [int ], keepdim : bool = False , dtype = None
232+ t : StructuredSparseTensor , dim : list [int ], keepdim : bool = False , dtype = None
233233) -> Tensor :
234- assert isinstance (t , DiagonalSparseTensor )
234+ assert isinstance (t , StructuredSparseTensor )
235235
236236 if dtype :
237237 raise NotImplementedError ()
0 commit comments