Skip to content

Commit 5cbab08

Browse files
committed
update interface of unsquash_pdim
1 parent 66c2210 commit 5cbab08

2 files changed

Lines changed: 63 additions & 29 deletions

File tree

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,34 +69,59 @@ def infer_shape(shape: list[int], numel: int) -> list[int]:
6969
return [inferred if s == -1 else s for s in shape]
7070

7171

72-
def unsquash_pdim_from_strides(
73-
physical: Tensor, pdim: int, new_pdim_shape: list[int]
72+
def unsquash_pdim(
73+
physical: Tensor, strides: Tensor, pdim: int, new_pdim_shape: list[int]
7474
) -> tuple[Tensor, Tensor]:
75-
new_shape = list(physical.shape)
76-
new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :]
77-
new_physical = physical.reshape(new_shape)
78-
79-
stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))])
80-
return new_physical, stride_multipliers
75+
"""
76+
EXAMPLE:
77+
78+
physical = [
79+
[1, 2, 3, 4, 5, 6],
80+
[7, 8, 9, 10, 11, 12],
81+
[13, 14, 15, 16, 17, 18],
82+
]
83+
strides = [
84+
[1, 1],
85+
[0, 2],
86+
]
87+
88+
dim = 1
89+
shape = [2, 3]
90+
91+
new_physical = [[
92+
[1, 2, 3],
93+
[4, 5, 6],
94+
], [
95+
[7, 8, 9],
96+
[10, 11, 12],
97+
], [
98+
[13, 14, 15],
99+
[16, 17, 18],
100+
]]
101+
102+
new_strides = [
103+
[1, 3, 1],
104+
[0, 6, 2]
105+
"""
81106

107+
# TODO: handle working with multiple dimensions at once
82108

83-
def unsquash_pdim(
84-
physical: Tensor, pdim: int, new_pdim_shape: list[int]
85-
) -> tuple[Tensor, list[list[int]]]:
86-
new_shape = list(physical.shape)
87-
new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :]
109+
old_shape = list(physical.shape)
110+
new_shape = old_shape[:pdim] + new_pdim_shape + old_shape[pdim + 1 :]
88111
new_physical = physical.reshape(new_shape)
89112

90-
def new_encoding_fn(d: int) -> list[int]:
91-
if d < pdim:
92-
return [d]
93-
elif d > pdim:
94-
return [d + len(new_pdim_shape) - 1]
95-
else:
96-
return [pdim + i for i in range(len(new_pdim_shape))]
113+
stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))])
114+
115+
new_strides = torch.concat(
116+
[
117+
strides[:, :pdim],
118+
torch.outer(strides[:, pdim], stride_multipliers),
119+
strides[:, pdim + 1 :],
120+
],
121+
dim=1,
122+
)
97123

98-
new_encoding = [new_encoding_fn(d) for d in range(len(physical.shape))]
99-
return new_physical, new_encoding
124+
return new_physical, new_strides
100125

101126

102127
@impl(aten._unsafe_view.default)

tests/unit/sparse/test_structured_sparse_tensor.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -315,29 +315,38 @@ def test_fix_ungrouped_dims(
315315
@mark.parametrize(
316316
[
317317
"physical_shape",
318+
"strides",
318319
"pdim",
319320
"new_pdim_shape",
320321
"expected_physical_shape",
321-
"expected_new_encoding",
322+
"expected_strides",
322323
],
323324
[
324-
([4], 0, [4], [4], [[0]]), # trivial
325-
([4], 0, [2, 2], [2, 2], [[0, 1]]),
326-
([3, 4, 5], 1, [2, 1, 1, 2], [3, 2, 1, 1, 2, 5], [[0], [1, 2, 3, 4], [5]]),
325+
([4], tensor([[1], [2]]), 0, [4], [4], tensor([[1], [2]])), # trivial
326+
([4], tensor([[1], [2]]), 0, [2, 2], [2, 2], tensor([[2, 1], [4, 2]])),
327+
(
328+
[3, 4, 5],
329+
tensor([[1, 2, 0], [1, 0, 1], [0, 1, 1]]),
330+
1,
331+
[2, 1, 1, 2],
332+
[3, 2, 1, 1, 2, 5],
333+
tensor([[1, 4, 4, 4, 2, 0], [1, 0, 0, 0, 0, 1], [0, 2, 2, 2, 1, 1]]),
334+
),
327335
],
328336
)
329337
def test_unsquash_pdim(
330338
physical_shape: list[int],
339+
strides: Tensor,
331340
pdim: int,
332341
new_pdim_shape: list[int],
333342
expected_physical_shape: list[int],
334-
expected_new_encoding: list[list[int]],
343+
expected_strides: Tensor,
335344
):
336345
physical = randn_(physical_shape)
337-
new_physical, new_encoding = unsquash_pdim(physical, pdim, new_pdim_shape)
346+
new_physical, new_strides = unsquash_pdim(physical, strides, pdim, new_pdim_shape)
338347

339348
assert list(new_physical.shape) == expected_physical_shape
340-
assert new_encoding == expected_new_encoding
349+
assert torch.equal(new_strides, expected_strides)
341350

342351

343352
@mark.parametrize(

0 commit comments

Comments
 (0)