1+ import itertools
12import operator
23from functools import wraps
34from itertools import accumulate
@@ -62,15 +63,14 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]):
6263 f"v_to_ps elements are not encoded by first appearance. Found { v_to_ps } ."
6364 )
6465
65- if any (len (group ) != 1 for group in get_groupings (v_to_ps )):
66- raise ValueError (f"Dimensions must be maximally grouped. Found { v_to_ps } ." )
67-
6866 self .physical = physical
6967 self .v_to_ps = v_to_ps
7068
7169 # strides is of shape [v_ndim, p_ndim], such that v_index = strides @ p_index
72- pshape = list (self .physical .shape )
73- self .strides = tensor ([strides_v2 (pdims , pshape ) for pdims in self .v_to_ps ])
70+ self .strides = get_strides (list (self .physical .shape ), v_to_ps )
71+
72+ if any (len (group ) != 1 for group in get_groupings_generalized (self .strides )):
73+ raise ValueError (f"Dimensions must be maximally grouped. Found { v_to_ps } ." )
7474
7575 def to_dense (
7676 self , dtype : torch .dtype | None = None , * , masked_grad : bool | None = None
@@ -188,11 +188,18 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]:
188188 return result
189189
190190
191+ def get_strides (pshape : list [int ], v_to_ps : list [list [int ]]) -> Tensor :
192+ strides = torch .tensor ([strides_v2 (pdims , pshape ) for pdims in v_to_ps ], dtype = torch .int64 )
193+
194+ # It's sometimes necessary to reshape: when v_to_ps contains 0 element for instance.
195+ return strides .reshape (len (v_to_ps ), len (pshape ))
196+
197+
191198def argmax (iterable ):
192199 return max (enumerate (iterable ), key = lambda x : x [1 ])[0 ]
193200
194201
195- def strides_to_pdims (strides : list [ int ] , physical_shape : list [int ]) -> list [int ]:
202+ def strides_to_pdims (strides : Tensor , physical_shape : list [int ]) -> list [int ]:
196203 """
197204 Given a list of strides, find and return the used physical dimensions.
198205
@@ -207,7 +214,7 @@ def strides_to_pdims(strides: list[int], physical_shape: list[int]) -> list[int]
207214 # e.g. strides = [22111, 201000], physical_shape = [10, 2]
208215
209216 pdims = []
210- remaining_strides = [ s for s in strides ]
217+ remaining_strides = strides . clone ()
211218 remaining_numel = (
212219 sum (remaining_strides [i ] * (physical_shape [i ] - 1 ) for i in range (len (physical_shape ))) + 1
213220 )
@@ -253,29 +260,62 @@ def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]]
253260 return [res [i ] for i in range (len (res ))]
254261
255262
256- def get_groupings (v_to_ps : list [list [int ]]) -> list [list [int ]]:
257- """Example: [[0, 1, 2], [2, 0, 1], [2]] => [[0, 1], [2]]"""
263+ def are_ratios_matching (v : Tensor ) -> bool :
264+ # Returns a boolean indicating whether all non-nan values in a vector are integer and equal to
265+ # each other.
266+ # Returns a scalar boolean tensor indicating whether all values in v are the same or nan:
267+ # [3.0, nan, 3.0] => True
268+ # [nan, nan, nan] => True
269+ # [3.0, nan, 2.0] => False
270+ # [0.5, 0.5, 0.5] => False
258271
259- mapping = dict [int , list [int ]]()
272+ non_nan_values = v [~ v .isnan ()]
273+ return (
274+ torch .eq (non_nan_values .int (), non_nan_values ).all ().item ()
275+ and non_nan_values .eq (non_nan_values [0 :1 ]).all ().item ()
276+ )
260277
261- for p_dims in v_to_ps :
262- for i , p_dim in enumerate (p_dims ):
263- if p_dim not in mapping :
264- mapping [p_dim ] = p_dims [i :]
265- else :
266- mapping [p_dim ] = longest_common_prefix (mapping [p_dim ], p_dims [i :])
267278
268- groups = []
269- visited_is = set ()
270- for i , group in mapping .items ():
271- if i in visited_is :
272- continue
279+ def get_groupings_generalized (strides : Tensor ) -> list [list [int ]]:
280+ fstrides = strides .to (dtype = torch .float64 )
281+ # Note that float64 has 53 bits of precision, meaning that every integer number up to 2^53 can
282+ # be represented on a float64 without any numerical error. Since strides are stored on int64,
283+ # ratios can be of up to 2^64. This function may thus fail for stride values between 2^53 and
284+ # 2^64.
285+
286+ ratios = torch .div (fstrides .unsqueeze (2 ), fstrides .unsqueeze (1 ))
287+
288+ # Mapping from column id to the set of columns with which it can be grouped
289+ groups = {i : {i } for i , column in enumerate (strides .T )}
290+ for i1 , i2 in itertools .permutations (range (strides .shape [1 ]), 2 ):
291+ if are_ratios_matching (ratios [:, i1 , i2 ]):
292+ groups [i1 ].update (groups [i2 ])
293+ groups [i2 ].update (groups [i1 ])
294+
295+ new_columns = []
296+ for i , group in groups .items ():
297+ sorted_group = sorted (list (group ))
298+ if i == sorted_group [0 ]: # This ensures that the same group is added only once
299+ new_columns .append (sorted_group )
300+
301+ return new_columns
273302
274- available_dims = set (group ) - visited_is
275- groups .append (list (available_dims ))
276- visited_is .update (set (group ))
277303
278- return groups
304+ def get_groupings (pshape : list [int ], strides : Tensor ) -> list [list [int ]]:
305+ strides_time_pshape = strides * tensor (pshape )
306+ groups = {i : {i } for i , column in enumerate (strides .T )}
307+ group_ids = [i for i in range (len (strides .T ))]
308+ for i1 , i2 in itertools .combinations (range (strides .shape [1 ]), 2 ):
309+ if torch .equal (strides [:, i1 ], strides_time_pshape [:, i2 ]):
310+ groups [group_ids [i1 ]].update (groups [group_ids [i2 ]])
311+ group_ids [i2 ] = group_ids [i1 ]
312+
313+ new_columns = [sorted (groups [group_id ]) for group_id in sorted (set (group_ids ))]
314+
315+ if len (new_columns ) != len (pshape ):
316+ print (f"Combined pshape with the following new columns: { new_columns } ." )
317+
318+ return new_columns
279319
280320
281321def longest_common_prefix (l1 : list [int ], l2 : list [int ]) -> list [int ]:
@@ -413,12 +453,16 @@ def new_encoding(d: int) -> int:
413453def fix_ungrouped_dims (
414454 physical : Tensor , v_to_ps : list [list [int ]]
415455) -> tuple [Tensor , list [list [int ]]]:
416- groups = get_groupings (v_to_ps )
417- physical = physical .reshape ([prod ([physical .shape [dim ] for dim in group ]) for group in groups ])
418- mapping = {group [0 ]: i for i , group in enumerate (groups )}
419- new_v_to_ps = [[mapping [i ] for i in dims if i in mapping ] for dims in v_to_ps ]
420-
421- return physical , new_v_to_ps
456+ strides = get_strides (list (physical .shape ), v_to_ps )
457+ groups = get_groupings (list (physical .shape ), strides )
458+ nphysical = physical .reshape ([prod ([physical .shape [dim ] for dim in group ]) for group in groups ])
459+ stride_mapping = torch .zeros (physical .ndim , nphysical .ndim , dtype = torch .int64 )
460+ for j , group in enumerate (groups ):
461+ stride_mapping [group [- 1 ], j ] = 1
462+
463+ new_strides = strides @ stride_mapping
464+ new_v_to_ps = [strides_to_pdims (stride , list (nphysical .shape )) for stride in new_strides ]
465+ return nphysical , new_v_to_ps
422466
423467
424468def make_dst (physical : Tensor , v_to_ps : list [list [int ]]) -> DiagonalSparseTensor :
0 commit comments