@@ -69,7 +69,7 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]):
6969 # strides is of shape [v_ndim, p_ndim], such that v_index = strides @ p_index
7070 self .strides = get_strides (list (self .physical .shape ), v_to_ps )
7171
72- if any (len (group ) != 1 for group in get_groupings_generalized ( self .strides )):
72+ if any (len (group ) != 1 for group in get_groupings ( list ( self . physical . shape ), self .strides )):
7373 raise ValueError (f"Dimensions must be maximally grouped. Found { v_to_ps } ." )
7474
7575 def to_dense (
@@ -260,47 +260,6 @@ def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]]
260260 return [res [i ] for i in range (len (res ))]
261261
262262
263- def are_ratios_matching (v : Tensor ) -> bool :
264- # Returns a boolean indicating whether all non-nan values in a vector are integer and equal to
265- # each other.
266- # Returns a scalar boolean tensor indicating whether all values in v are the same or nan:
267- # [3.0, nan, 3.0] => True
268- # [nan, nan, nan] => True
269- # [3.0, nan, 2.0] => False
270- # [0.5, 0.5, 0.5] => False
271-
272- non_nan_values = v [~ v .isnan ()]
273- return (
274- torch .eq (non_nan_values .int (), non_nan_values ).all ().item ()
275- and non_nan_values .eq (non_nan_values [0 :1 ]).all ().item ()
276- )
277-
278-
279- def get_groupings_generalized (strides : Tensor ) -> list [list [int ]]:
280- fstrides = strides .to (dtype = torch .float64 )
281- # Note that float64 has 53 bits of precision, meaning that every integer number up to 2^53 can
282- # be represented on a float64 without any numerical error. Since strides are stored on int64,
283- # ratios can be of up to 2^64. This function may thus fail for stride values between 2^53 and
284- # 2^64.
285-
286- ratios = torch .div (fstrides .unsqueeze (2 ), fstrides .unsqueeze (1 ))
287-
288- # Mapping from column id to the set of columns with which it can be grouped
289- groups = {i : {i } for i , column in enumerate (strides .T )}
290- for i1 , i2 in itertools .permutations (range (strides .shape [1 ]), 2 ):
291- if are_ratios_matching (ratios [:, i1 , i2 ]):
292- groups [i1 ].update (groups [i2 ])
293- groups [i2 ].update (groups [i1 ])
294-
295- new_columns = []
296- for i , group in groups .items ():
297- sorted_group = sorted (list (group ))
298- if i == sorted_group [0 ]: # This ensures that the same group is added only once
299- new_columns .append (sorted_group )
300-
301- return new_columns
302-
303-
304263def get_groupings (pshape : list [int ], strides : Tensor ) -> list [list [int ]]:
305264 strides_time_pshape = strides * tensor (pshape )
306265 groups = {i : {i } for i , column in enumerate (strides .T )}
0 commit comments