@@ -184,6 +184,34 @@ def strides_from_p_dims_and_p_shape(p_dims: list[int], physical_shape: list[int]
184184 ]
185185
186186
187+ def strides_v2 (p_dims : list [int ], physical_shape : list [int ]) -> list [int ]:
188+ """
189+ From a list of physical dimensions corresponding to a virtual dimension, and from the physical
190+ shape, get the stride indicating how moving on each physical dimension makes you move on the
191+ virtual dimension.
192+
193+ Example:
194+ Imagine a vector of size 3, and of value [1, 2, 3].
195+ Imagine a DST t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps.
196+ t.to_dense() is [1, 0, 0, 0, 2, 0, 0, 0, 3] (it's the flattening of the diagonal matrix
197+ [[1, 0, 0], [0, 2, 0], [0, 0, 3]]).
198+ When you move by 1 on physical dimension 0, you move by 4 on virtual dimension 0, i.e.
199+ strides_v2([0, 0], [3]) = 4
200+ In the 2D view, you'd move by 1 row (3 indices) and 1 column (1 index).
201+
202+ Example:
203+ strides_v2([0, 0, 1], [3,4]) # [16, 1]
204+ Moving by 1 on physical dimension 0 makes you move by 16 on the virtual dimension. Moving by
205+ 1 on physical dimension 1 makes you move by 1 on the virtual dimension.
206+ """
207+
208+ strides_v1 = strides_from_p_dims_and_p_shape (p_dims , physical_shape )
209+ result = [0 for _ in range (len (physical_shape ))]
210+ for i , d in enumerate (p_dims ):
211+ result [d ] += strides_v1 [i ]
212+ return result
213+
214+
187215def merge_strides (strides : list [list [int ]]) -> list [int ]:
188216 return sorted ({s for stride in strides for s in stride }, reverse = True )
189217
0 commit comments