Skip to content

Commit 3271b3a

Browse files
committed
Improve handling of zero division in mod_c and intdiv_c
* In python, when computing x % 0, we get ZeroDivisionError("division by zero"). mod_c should thus probably behave the same when t2 is the zero vector. * Change ValueError to ZeroDivisionError in intdiv_c * Make mod_c reuse intdiv_c. This makes it raise ZeroDivisionError when called with t2=0.
1 parent 322717a commit 3271b3a

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

src/torchjd/sparse/_linalg.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,12 @@ def mod_c(t1: Tensor, t2: Tensor) -> Tensor:
3737
[8, 12]^T %c [2, 3]^T = [0, 0]^T
3838
[8, 12]^T %c [2, 4]^T = [2, 0]^T
3939
[8, 12]^T %c [3, 3]^T = [2, 6]^T
40-
[8, 12]^T %c [0, 0]^T = [8, 12]^T
4140
[8, 12]^T %c [2, 0]^T = [0, 12]^T
4241
[8, 12]^T %c [0, 2]^T = [8, 0]^T
42+
[8, 12]^T %c [0, 0]^T => ZeroDivisionError
4343
"""
4444

45-
non_zero_indices = torch.nonzero(t2)
46-
if len(non_zero_indices) == 0:
47-
return t1
48-
else:
49-
min_divider = (t1[non_zero_indices] // t2[non_zero_indices]).min()
50-
return t1 - min_divider * t2
45+
return t1 - intdiv_c(t1, t2) * t2
5146

5247

5348
def intdiv_c(t1: Tensor, t2: Tensor) -> Tensor:
@@ -63,14 +58,14 @@ def intdiv_c(t1: Tensor, t2: Tensor) -> Tensor:
6358
[8, 12]^T //c [2, 3]^T = 4
6459
[8, 12]^T //c [2, 4]^T = 3
6560
[8, 12]^T //c [3, 3]^T = 2
66-
[8, 12]^T //c [0, 0]^T => ValueError
6761
[8, 12]^T //c [2, 0]^T = 4
6862
[8, 12]^T //c [0, 2]^T = 6
63+
[8, 12]^T //c [0, 0]^T => ZeroDivisionError
6964
"""
7065

7166
non_zero_indices = torch.nonzero(t2)
7267
if len(non_zero_indices) == 0:
73-
raise ValueError("Cannot divide by the zero vector.")
68+
raise ZeroDivisionError("division by zero")
7469
else:
7570
min_divider = (t1[non_zero_indices] // t2[non_zero_indices]).min()
7671
return min_divider

tests/unit/sparse/test_structured_sparse_tensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,6 @@ def test_fix_zero_stride_columns(
429429
(tensor([8, 12]), tensor([2, 3]), tensor([0, 0])),
430430
(tensor([8, 12]), tensor([2, 4]), tensor([2, 0])),
431431
(tensor([8, 12]), tensor([3, 3]), tensor([2, 6])),
432-
(tensor([8, 12]), tensor([0, 0]), tensor([8, 12])),
433432
(tensor([8, 12]), tensor([2, 0]), tensor([0, 12])),
434433
(tensor([8, 12]), tensor([0, 2]), tensor([8, 0])),
435434
],
@@ -442,6 +441,11 @@ def test_mod_c(
442441
assert torch.equal(mod_c(t1, t2), expected)
443442

444443

444+
def test_mod_c_by_0_raises():
445+
with raises(ZeroDivisionError):
446+
mod_c(tensor([3, 4]), tensor([0, 0]))
447+
448+
445449
@mark.parametrize(
446450
["t1", "t2", "expected"],
447451
[
@@ -461,5 +465,5 @@ def test_intdiv_c(
461465

462466

463467
def test_intdiv_c_by_0_raises():
464-
with raises(ValueError):
468+
with raises(ZeroDivisionError):
465469
intdiv_c(tensor([3, 4]), tensor([0, 0]))

0 commit comments

Comments
 (0)