@@ -212,6 +212,46 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]:
212212 return result
213213
214214
215+ def argmax (iterable ):
216+ return max (enumerate (iterable ), key = lambda x : x [1 ])[0 ]
217+
218+
219+ def strides_to_pdims (strides : list [int ], physical_shape : list [int ]) -> list [int ]:
220+ """
221+ Given a list of strides, find and return the used physical dimensions.
222+
223+ This algorithm runs in O(n * m) with n the number of physical dimensions (i.e.
224+ len(physical_shape) and len(strides)), and with m the number of pdims in the result.
225+
226+ I'm pretty sure it could be implemented in O((n+m)log(n)) by using a sorted linked list for the
227+ remaining_strides, and keeping it sorted each time we update it. Argmax would then always be 0,
228+ removing the need to go through the whole list at every iteration.
229+ """
230+
231+ # e.g. strides = [22111, 201000], physical_shape = [10, 2]
232+
233+ pdims = []
234+ remaining_strides = [s for s in strides ]
235+ remaining_numel = (
236+ sum (remaining_strides [i ] * (physical_shape [i ] - 1 ) for i in range (len (physical_shape ))) + 1
237+ )
238+ # e.g. 9 * 22111 + 1 * 201000 + 1 = 400000
239+
240+ while sum (remaining_strides ) > 0 :
241+ current_pdim = argmax (remaining_strides )
242+ # e.g. 1
243+
244+ pdims .append (current_pdim )
245+
246+ remaining_numel = remaining_numel // physical_shape [current_pdim ]
247+ # e.g. 400000 / 2 = 200000
248+
249+ remaining_strides [current_pdim ] -= remaining_numel
250+ # e.g. [22111, 1000]
251+
252+ return pdims
253+
254+
215255def merge_strides (strides : list [list [int ]]) -> list [int ]:
216256 return sorted ({s for stride in strides for s in stride }, reverse = True )
217257
0 commit comments