Skip to content

Commit 65416bf

Browse files
committed
Add mod_c
1 parent 96c54e4 commit 65416bf

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

src/torchjd/sparse/_linalg.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,29 @@ def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None:
2222

2323
# TODO: Verify that the round operation cannot fail
2424
return X_rounded.to(torch.int64)
25+
26+
27+
def mod_c(t1: Tensor, t2: Tensor) -> Tensor:
28+
"""
29+
Computes the combined modulo r = t1 %c t2, such that
30+
t1 = k * t2 + r with k any integer and
31+
0 <= r[i] <= t1[i] for all i.
32+
33+
:param t1: Non-negative integer vector.
34+
:param t2: Non-negative integer vector.
35+
36+
Examples:
37+
[8, 12]^T %c [2, 3]^T = [0, 0]^T
38+
[8, 12]^T %c [2, 4]^T = [2, 0]^T
39+
[8, 12]^T %c [3, 3]^T = [2, 6]^T
40+
[8, 12]^T %c [0, 0]^T = [8, 12]^T
41+
[8, 12]^T %c [2, 0]^T = [0, 12]^T
42+
[8, 12]^T %c [0, 2]^T = [8, 0]^T
43+
"""
44+
45+
non_zero_indices = torch.nonzero(t2)
46+
if len(non_zero_indices) == 0:
47+
return t1
48+
else:
49+
min_divider = (t1[non_zero_indices] // t2[non_zero_indices]).min()
50+
return t1 - min_divider * t2

tests/unit/sparse/test_structured_sparse_tensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim
1414
from torchjd.sparse._coalesce import fix_zero_stride_columns
15+
from torchjd.sparse._linalg import mod_c
1516
from torchjd.sparse._structured_sparse_tensor import (
1617
StructuredSparseTensor,
1718
fix_ungrouped_dims,
@@ -420,3 +421,22 @@ def test_fix_zero_stride_columns(
420421
physical, strides = fix_zero_stride_columns(physical, strides)
421422
assert torch.equal(physical, expected_physical)
422423
assert torch.equal(strides, expected_strides)
424+
425+
426+
@mark.parametrize(
427+
["t1", "t2", "expected"],
428+
[
429+
(tensor([8, 12]), tensor([2, 3]), tensor([0, 0])),
430+
(tensor([8, 12]), tensor([2, 4]), tensor([2, 0])),
431+
(tensor([8, 12]), tensor([3, 3]), tensor([2, 6])),
432+
(tensor([8, 12]), tensor([0, 0]), tensor([8, 12])),
433+
(tensor([8, 12]), tensor([2, 0]), tensor([0, 12])),
434+
(tensor([8, 12]), tensor([0, 2]), tensor([8, 0])),
435+
],
436+
)
437+
def test_mod_c(
438+
t1: Tensor,
439+
t2: Tensor,
440+
expected: Tensor,
441+
):
442+
assert torch.equal(mod_c(t1, t2), expected)

0 commit comments

Comments
 (0)