Skip to content

Commit ab15e1a

Browse files
committed
Remove get_groupings_generalized
* This function could be useful in the future if for example we want to merge two physical dimensions that have some overlap in the virtual dimension.
1 parent 3b9a9a8 commit ab15e1a

1 file changed

Lines changed: 1 addition & 42 deletions

File tree

src/torchjd/sparse/_diagonal_sparse_tensor.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]):
6969
# strides is of shape [v_ndim, p_ndim], such that v_index = strides @ p_index
7070
self.strides = get_strides(list(self.physical.shape), v_to_ps)
7171

72-
if any(len(group) != 1 for group in get_groupings_generalized(self.strides)):
72+
if any(len(group) != 1 for group in get_groupings(list(self.physical.shape), self.strides)):
7373
raise ValueError(f"Dimensions must be maximally grouped. Found {v_to_ps}.")
7474

7575
def to_dense(
@@ -260,47 +260,6 @@ def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]]
260260
return [res[i] for i in range(len(res))]
261261

262262

263-
def are_ratios_matching(v: Tensor) -> bool:
264-
# Returns a boolean indicating whether all non-nan values in a vector are integer and equal to
265-
# each other.
266-
# Returns a scalar boolean tensor indicating whether all values in v are the same or nan:
267-
# [3.0, nan, 3.0] => True
268-
# [nan, nan, nan] => True
269-
# [3.0, nan, 2.0] => False
270-
# [0.5, 0.5, 0.5] => False
271-
272-
non_nan_values = v[~v.isnan()]
273-
return (
274-
torch.eq(non_nan_values.int(), non_nan_values).all().item()
275-
and non_nan_values.eq(non_nan_values[0:1]).all().item()
276-
)
277-
278-
279-
def get_groupings_generalized(strides: Tensor) -> list[list[int]]:
280-
fstrides = strides.to(dtype=torch.float64)
281-
# Note that float64 has 53 bits of precision, meaning that every integer number up to 2^53 can
282-
# be represented on a float64 without any numerical error. Since strides are stored on int64,
283-
# ratios can be of up to 2^64. This function may thus fail for stride values between 2^53 and
284-
# 2^64.
285-
286-
ratios = torch.div(fstrides.unsqueeze(2), fstrides.unsqueeze(1))
287-
288-
# Mapping from column id to the set of columns with which it can be grouped
289-
groups = {i: {i} for i, column in enumerate(strides.T)}
290-
for i1, i2 in itertools.permutations(range(strides.shape[1]), 2):
291-
if are_ratios_matching(ratios[:, i1, i2]):
292-
groups[i1].update(groups[i2])
293-
groups[i2].update(groups[i1])
294-
295-
new_columns = []
296-
for i, group in groups.items():
297-
sorted_group = sorted(list(group))
298-
if i == sorted_group[0]: # This ensures that the same group is added only once
299-
new_columns.append(sorted_group)
300-
301-
return new_columns
302-
303-
304263
def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]:
305264
strides_time_pshape = strides * tensor(pshape)
306265
groups = {i: {i} for i, column in enumerate(strides.T)}

0 commit comments

Comments
 (0)