Skip to content

Commit 322717a

Browse files
committed
Add intdiv_c
1 parent 65416bf commit 322717a

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

src/torchjd/sparse/_linalg.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None:
2727
def mod_c(t1: Tensor, t2: Tensor) -> Tensor:
2828
"""
2929
Computes the combined modulo r = t1 %c t2, such that
30-
t1 = k * t2 + r with k any integer and
30+
t1 = d * t2 + r with d = t1 //c t2 and
3131
0 <= r[i] <= t1[i] for all i.
3232
3333
:param t1: Non-negative integer vector.
@@ -48,3 +48,29 @@ def mod_c(t1: Tensor, t2: Tensor) -> Tensor:
4848
else:
4949
min_divider = (t1[non_zero_indices] // t2[non_zero_indices]).min()
5050
return t1 - min_divider * t2
51+
52+
53+
def intdiv_c(t1: Tensor, t2: Tensor) -> Tensor:
54+
"""
55+
Computes the combined integer division d = t1 // t2, such that
56+
t1 = d * t2 + r with r = t1 %c t2
57+
0 <= r[i] <= t1[i] for all i.
58+
59+
:param t1: Non-negative integer vector.
60+
:param t2: Non-negative integer vector.
61+
62+
Examples:
63+
[8, 12]^T //c [2, 3]^T = 4
64+
[8, 12]^T //c [2, 4]^T = 3
65+
[8, 12]^T //c [3, 3]^T = 2
66+
[8, 12]^T //c [0, 0]^T => ValueError
67+
[8, 12]^T //c [2, 0]^T = 4
68+
[8, 12]^T //c [0, 2]^T = 6
69+
"""
70+
71+
non_zero_indices = torch.nonzero(t2)
72+
if len(non_zero_indices) == 0:
73+
raise ValueError("Cannot divide by the zero vector.")
74+
else:
75+
min_divider = (t1[non_zero_indices] // t2[non_zero_indices]).min()
76+
return min_divider

tests/unit/sparse/test_structured_sparse_tensor.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from pytest import mark
2+
from pytest import mark, raises
33
from torch import Tensor, tensor
44
from torch.ops import aten # type: ignore
55
from torch.testing import assert_close
@@ -12,7 +12,7 @@
1212
)
1313
from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim
1414
from torchjd.sparse._coalesce import fix_zero_stride_columns
15-
from torchjd.sparse._linalg import mod_c
15+
from torchjd.sparse._linalg import intdiv_c, mod_c
1616
from torchjd.sparse._structured_sparse_tensor import (
1717
StructuredSparseTensor,
1818
fix_ungrouped_dims,
@@ -440,3 +440,26 @@ def test_mod_c(
440440
expected: Tensor,
441441
):
442442
assert torch.equal(mod_c(t1, t2), expected)
443+
444+
445+
@mark.parametrize(
446+
["t1", "t2", "expected"],
447+
[
448+
(tensor([8, 12]), tensor([2, 3]), 4),
449+
(tensor([8, 12]), tensor([2, 4]), 3),
450+
(tensor([8, 12]), tensor([3, 3]), 2),
451+
(tensor([8, 12]), tensor([2, 0]), 4),
452+
(tensor([8, 12]), tensor([0, 2]), 6),
453+
],
454+
)
455+
def test_intdiv_c(
456+
t1: Tensor,
457+
t2: Tensor,
458+
expected: Tensor,
459+
):
460+
assert intdiv_c(t1, t2) == expected
461+
462+
463+
def test_intdiv_c_by_0_raises():
464+
with raises(ValueError):
465+
intdiv_c(tensor([3, 4]), tensor([0, 0]))

0 commit comments

Comments
 (0)