Skip to content

Commit 3b9a9a8

Browse files
committed
Revamp grouping detection
1 parent 5062b06 commit 3b9a9a8

2 files changed

Lines changed: 84 additions & 41 deletions

File tree

src/torchjd/sparse/_diagonal_sparse_tensor.py

Lines changed: 75 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import operator
23
from functools import wraps
34
from itertools import accumulate
@@ -62,15 +63,14 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]):
6263
f"v_to_ps elements are not encoded by first appearance. Found {v_to_ps}."
6364
)
6465

65-
if any(len(group) != 1 for group in get_groupings(v_to_ps)):
66-
raise ValueError(f"Dimensions must be maximally grouped. Found {v_to_ps}.")
67-
6866
self.physical = physical
6967
self.v_to_ps = v_to_ps
7068

7169
# strides is of shape [v_ndim, p_ndim], such that v_index = strides @ p_index
72-
pshape = list(self.physical.shape)
73-
self.strides = tensor([strides_v2(pdims, pshape) for pdims in self.v_to_ps])
70+
self.strides = get_strides(list(self.physical.shape), v_to_ps)
71+
72+
if any(len(group) != 1 for group in get_groupings_generalized(self.strides)):
73+
raise ValueError(f"Dimensions must be maximally grouped. Found {v_to_ps}.")
7474

7575
def to_dense(
7676
self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None
@@ -188,11 +188,18 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]:
188188
return result
189189

190190

191+
def get_strides(pshape: list[int], v_to_ps: list[list[int]]) -> Tensor:
192+
strides = torch.tensor([strides_v2(pdims, pshape) for pdims in v_to_ps], dtype=torch.int64)
193+
194+
# It's sometimes necessary to reshape: when v_to_ps contains 0 element for instance.
195+
return strides.reshape(len(v_to_ps), len(pshape))
196+
197+
191198
def argmax(iterable):
192199
return max(enumerate(iterable), key=lambda x: x[1])[0]
193200

194201

195-
def strides_to_pdims(strides: list[int], physical_shape: list[int]) -> list[int]:
202+
def strides_to_pdims(strides: Tensor, physical_shape: list[int]) -> list[int]:
196203
"""
197204
Given a list of strides, find and return the used physical dimensions.
198205
@@ -207,7 +214,7 @@ def strides_to_pdims(strides: list[int], physical_shape: list[int]) -> list[int]
207214
# e.g. strides = [22111, 201000], physical_shape = [10, 2]
208215

209216
pdims = []
210-
remaining_strides = [s for s in strides]
217+
remaining_strides = strides.clone()
211218
remaining_numel = (
212219
sum(remaining_strides[i] * (physical_shape[i] - 1) for i in range(len(physical_shape))) + 1
213220
)
@@ -253,29 +260,62 @@ def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]]
253260
return [res[i] for i in range(len(res))]
254261

255262

256-
def get_groupings(v_to_ps: list[list[int]]) -> list[list[int]]:
257-
"""Example: [[0, 1, 2], [2, 0, 1], [2]] => [[0, 1], [2]]"""
263+
def are_ratios_matching(v: Tensor) -> bool:
264+
# Returns a boolean indicating whether all non-nan values in a vector are integer and equal to
265+
# each other.
266+
# Returns a scalar boolean tensor indicating whether all values in v are the same or nan:
267+
# [3.0, nan, 3.0] => True
268+
# [nan, nan, nan] => True
269+
# [3.0, nan, 2.0] => False
270+
# [0.5, 0.5, 0.5] => False
258271

259-
mapping = dict[int, list[int]]()
272+
non_nan_values = v[~v.isnan()]
273+
return (
274+
torch.eq(non_nan_values.int(), non_nan_values).all().item()
275+
and non_nan_values.eq(non_nan_values[0:1]).all().item()
276+
)
260277

261-
for p_dims in v_to_ps:
262-
for i, p_dim in enumerate(p_dims):
263-
if p_dim not in mapping:
264-
mapping[p_dim] = p_dims[i:]
265-
else:
266-
mapping[p_dim] = longest_common_prefix(mapping[p_dim], p_dims[i:])
267278

268-
groups = []
269-
visited_is = set()
270-
for i, group in mapping.items():
271-
if i in visited_is:
272-
continue
279+
def get_groupings_generalized(strides: Tensor) -> list[list[int]]:
280+
fstrides = strides.to(dtype=torch.float64)
281+
# Note that float64 has 53 bits of precision, meaning that every integer number up to 2^53 can
282+
# be represented on a float64 without any numerical error. Since strides are stored on int64,
283+
# ratios can be of up to 2^64. This function may thus fail for stride values between 2^53 and
284+
# 2^64.
285+
286+
ratios = torch.div(fstrides.unsqueeze(2), fstrides.unsqueeze(1))
287+
288+
# Mapping from column id to the set of columns with which it can be grouped
289+
groups = {i: {i} for i, column in enumerate(strides.T)}
290+
for i1, i2 in itertools.permutations(range(strides.shape[1]), 2):
291+
if are_ratios_matching(ratios[:, i1, i2]):
292+
groups[i1].update(groups[i2])
293+
groups[i2].update(groups[i1])
294+
295+
new_columns = []
296+
for i, group in groups.items():
297+
sorted_group = sorted(list(group))
298+
if i == sorted_group[0]: # This ensures that the same group is added only once
299+
new_columns.append(sorted_group)
300+
301+
return new_columns
273302

274-
available_dims = set(group) - visited_is
275-
groups.append(list(available_dims))
276-
visited_is.update(set(group))
277303

278-
return groups
304+
def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]:
305+
strides_time_pshape = strides * tensor(pshape)
306+
groups = {i: {i} for i, column in enumerate(strides.T)}
307+
group_ids = [i for i in range(len(strides.T))]
308+
for i1, i2 in itertools.combinations(range(strides.shape[1]), 2):
309+
if torch.equal(strides[:, i1], strides_time_pshape[:, i2]):
310+
groups[group_ids[i1]].update(groups[group_ids[i2]])
311+
group_ids[i2] = group_ids[i1]
312+
313+
new_columns = [sorted(groups[group_id]) for group_id in sorted(set(group_ids))]
314+
315+
if len(new_columns) != len(pshape):
316+
print(f"Combined pshape with the following new columns: {new_columns}.")
317+
318+
return new_columns
279319

280320

281321
def longest_common_prefix(l1: list[int], l2: list[int]) -> list[int]:
@@ -413,12 +453,16 @@ def new_encoding(d: int) -> int:
413453
def fix_ungrouped_dims(
414454
physical: Tensor, v_to_ps: list[list[int]]
415455
) -> tuple[Tensor, list[list[int]]]:
416-
groups = get_groupings(v_to_ps)
417-
physical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups])
418-
mapping = {group[0]: i for i, group in enumerate(groups)}
419-
new_v_to_ps = [[mapping[i] for i in dims if i in mapping] for dims in v_to_ps]
420-
421-
return physical, new_v_to_ps
456+
strides = get_strides(list(physical.shape), v_to_ps)
457+
groups = get_groupings(list(physical.shape), strides)
458+
nphysical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups])
459+
stride_mapping = torch.zeros(physical.ndim, nphysical.ndim, dtype=torch.int64)
460+
for j, group in enumerate(groups):
461+
stride_mapping[group[-1], j] = 1
462+
463+
new_strides = strides @ stride_mapping
464+
new_v_to_ps = [strides_to_pdims(stride, list(nphysical.shape)) for stride in new_strides]
465+
return nphysical, new_v_to_ps
422466

423467

424468
def make_dst(physical: Tensor, v_to_ps: list[list[int]]) -> DiagonalSparseTensor:

tests/unit/sparse/test_diagonal_sparse_tensor.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -194,19 +194,18 @@ def test_encode_by_order(
194194

195195

196196
@mark.parametrize(
197-
["v_to_ps", "expected_groupings"],
197+
["pshape", "strides", "expected"],
198198
[
199-
([[0, 1, 2], [2, 0, 1], [2]], [[0, 1], [2]]),
200-
([[0, 1, 0, 1]], [[0, 1]]),
201-
([[0, 1, 0, 1, 2]], [[0, 1], [2]]),
202-
([[0, 0]], [[0, 0]]),
203-
([[0, 1], [1, 2]], [[0], [1], [2]]),
199+
(
200+
[[32, 2, 3, 4, 5]],
201+
torch.tensor([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 60, 20, 5, 1]]),
202+
[[0], [1, 2, 3, 4]],
203+
)
204204
],
205205
)
206-
def test_get_groupings(v_to_ps: list[list[int]], expected_groupings: list[list[int]]):
207-
groupings = get_groupings(v_to_ps)
208-
209-
assert groupings == expected_groupings
206+
def test_get_groupings(pshape: list[int], strides: torch.Tensor, expected: list[list[int]]):
207+
result = get_groupings(pshape, strides)
208+
assert result == expected
210209

211210

212211
@mark.parametrize(

0 commit comments

Comments
 (0)