Skip to content

Commit 3dffd1e

Browse files
committed
Add strides_to_pdims
1 parent e0dc1a7 commit 3dffd1e

1 file changed

Lines changed: 40 additions & 0 deletions

File tree

src/torchjd/sparse/_diagonal_sparse_tensor.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,46 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]:
212212
return result
213213

214214

215+
def argmax(iterable):
216+
return max(enumerate(iterable), key=lambda x: x[1])[0]
217+
218+
219+
def strides_to_pdims(strides: list[int], physical_shape: list[int]) -> list[int]:
220+
"""
221+
Given a list of strides, find and return the used physical dimensions.
222+
223+
This algorithm runs in O(n * m) with n the number of physical dimensions (i.e.
224+
len(physical_shape) and len(strides)), and with m the number of pdims in the result.
225+
226+
I'm pretty sure it could be implemented in O((n+m)log(n)) by using a sorted linked list for the
227+
remaining_strides, and keeping it sorted each time we update it. Argmax would then always be 0,
228+
removing the need to go through the whole list at every iteration.
229+
"""
230+
231+
# e.g. strides = [22111, 201000], physical_shape = [10, 2]
232+
233+
pdims = []
234+
remaining_strides = [s for s in strides]
235+
remaining_numel = (
236+
sum(remaining_strides[i] * (physical_shape[i] - 1) for i in range(len(physical_shape))) + 1
237+
)
238+
# e.g. 9 * 22111 + 1 * 201000 + 1 = 400000
239+
240+
while sum(remaining_strides) > 0:
241+
current_pdim = argmax(remaining_strides)
242+
# e.g. 1
243+
244+
pdims.append(current_pdim)
245+
246+
remaining_numel = remaining_numel // physical_shape[current_pdim]
247+
# e.g. 400000 / 2 = 200000
248+
249+
remaining_strides[current_pdim] -= remaining_numel
250+
# e.g. [22111, 1000]
251+
252+
return pdims
253+
254+
215255
def merge_strides(strides: list[list[int]]) -> list[int]:
216256
return sorted({s for stride in strides for s in stride}, reverse=True)
217257

0 commit comments

Comments
 (0)