Skip to content

Commit 16e6165

Browse files
committed
Add strides_v2
1 parent 2c23dbe commit 16e6165

1 file changed

Lines changed: 28 additions & 0 deletions

File tree

src/torchjd/sparse/_diagonal_sparse_tensor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,34 @@ def strides_from_p_dims_and_p_shape(p_dims: list[int], physical_shape: list[int]
184184
]
185185

186186

187+
def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]:
188+
"""
189+
From a list of physical dimensions corresponding to a virtual dimension, and from the physical
190+
shape, get the stride indicating how moving on each physical dimension makes you move on the
191+
virtual dimension.
192+
193+
Example:
194+
Imagine a vector of size 3, and of value [1, 2, 3].
195+
Imagine a DST t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps.
196+
t.to_dense() is [1, 0, 0, 0, 2, 0, 0, 0, 3] (it's the flattening of the diagonal matrix
197+
[[1, 0, 0], [0, 2, 0], [0, 0, 3]]).
198+
When you move by 1 on physical dimension 0, you move by 4 on virtual dimension 0, i.e.
199+
strides_v2([0, 0], [3]) = 4
200+
In the 2D view, you'd move by 1 row (3 indices) and 1 column (1 index).
201+
202+
Example:
203+
strides_v2([0, 0, 1], [3,4]) # [16, 1]
204+
Moving by 1 on physical dimension 0 makes you move by 16 on the virtual dimension. Moving by
205+
1 on physical dimension 1 makes you move by 1 on the virtual dimension.
206+
"""
207+
208+
strides_v1 = strides_from_p_dims_and_p_shape(p_dims, physical_shape)
209+
result = [0 for _ in range(len(physical_shape))]
210+
for i, d in enumerate(p_dims):
211+
result[d] += strides_v1[i]
212+
return result
213+
214+
187215
def merge_strides(strides: list[list[int]]) -> list[int]:
188216
return sorted({s for stride in strides for s in stride}, reverse=True)
189217

0 commit comments

Comments
 (0)