Skip to content

Commit d27e3fa

Browse files
authored
Fix modulo operation between floatingpoints and integers (#207)
1 parent bfed1db commit d27e3fa

4 files changed

Lines changed: 46 additions & 24 deletions

File tree

CHANGELOG.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
Changelog
88
=========
99

10+
0.18.1 (2026-04-22)
11+
-------------------
12+
13+
**Bug fix**
14+
15+
- Fixed a regression that forbade using the modulo operator between integers and floats.
16+
17+
1018
0.18.0 (2026-04-07)
1119
-------------------
1220

ndonnx/_typed_array/onnx.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,12 @@ def __mul__(self, other: TyArrayBase | PyScalar) -> TyArrayBase:
15251525
def __rmul__(self, other: TyArrayBase | PyScalar) -> TyArrayBase:
15261526
return self._apply(other, op.mul, forward=False, result_type=TyArrayNumber)
15271527

1528+
def __rmod__(self, other) -> TyArrayBase:
1529+
if isinstance(other, int | float):
1530+
b, a = promote(self, other)
1531+
return a.__mod__(b)
1532+
return NotImplemented
1533+
15281534
def __pow__(self, other: TyArrayBase | PyScalar) -> TyArrayBase:
15291535
return self._apply(other, op.pow, forward=True, result_type=TyArrayNumber)
15301536

@@ -1697,15 +1703,16 @@ def __or__(self, other) -> TyArrayBase:
16971703
def __ror__(self, other) -> TyArrayBase:
16981704
return self._apply_int_only(other, op.bitwise_or, forward=False)
16991705

1700-
def __mod__(self, other) -> TyArrayInteger:
1701-
return self._apply_int_only(
1702-
other, lambda a, b: op.mod(a, b, fmod=0), forward=True
1703-
)
1706+
def __mod__(self, other) -> TyArrayBase:
1707+
if isinstance(other, type(self)):
1708+
var = op.mod(self._var, other._var, fmod=0)
1709+
return safe_cast(type(self), _var_to_tyarray(var))
17041710

1705-
def __rmod__(self, other) -> TyArrayInteger:
1706-
return self._apply_int_only(
1707-
other, lambda a, b: op.mod(a, b, fmod=0), forward=False
1708-
)
1711+
if isinstance(other, TyArrayNumber | int | float):
1712+
a, b = promote(self, other)
1713+
return a.__mod__(b)
1714+
1715+
return NotImplemented
17091716

17101717
def __xor__(self, other) -> TyArrayBase:
17111718
return self._apply_int_only(other, op.bitwise_xor, forward=True)
@@ -1884,33 +1891,30 @@ def __floordiv__(self, other: TyArrayBase | PyScalar) -> TyArrayBase:
18841891
return NotImplemented
18851892

18861893
def __mod__(self, other) -> TyArrayBase:
1887-
if isinstance(other, TyArrayFloating | float):
1894+
if isinstance(other, type(self)):
18881895
# This function is complicated for two reasons:
18891896
# 1. The ONNX standard is undefined if dividend is 0, but the array-api is not.
18901897
# 2. The array-api follows the Python semantics, which are rather odd.
1891-
a, b = promote(self, other)
1892-
var = op.mod(a._var, b._var, fmod=1)
1898+
var = op.mod(self._var, other._var, fmod=1)
18931899
mod = safe_cast(TyArrayFloating, _var_to_tyarray(var))
18941900
# NOTE: onnxruntime appears to have a bug where the sign
18951901
# of zeros is only preserved if they are on the
18961902
# false-branch!
18971903
# TODO: File a bug!
1898-
fixed_mod = where((b < 0) == (mod < 0), mod, mod + b)
1899-
fixed_zeros = where(b > 0, const(0.0, mod.dtype), const(-0.0, mod.dtype))
1904+
fixed_mod = where((other < 0) == (mod < 0), mod, mod + other)
1905+
fixed_zeros = where(
1906+
other > 0, const(0.0, mod.dtype), const(-0.0, mod.dtype)
1907+
)
19001908
return where(
1901-
safe_cast(TyArrayBool, ~((mod == 0.0) & (b != 0.0))),
1909+
safe_cast(TyArrayBool, ~((mod == 0.0) & (other != 0.0))),
19021910
fixed_mod,
19031911
fixed_zeros,
19041912
)
1913+
if isinstance(other, TyArrayNumber | int | float):
1914+
a, b = promote(self, other)
1915+
return a.__mod__(b)
19051916

1906-
return super().__mod__(other)
1907-
1908-
def __rmod__(self, other) -> TyArrayBase:
1909-
if isinstance(other, TyArrayFloating | float):
1910-
b, a = promote(self, other)
1911-
var = op.mod(a._var, b._var, fmod=1)
1912-
return safe_cast(TyArrayFloating, _var_to_tyarray(var))
1913-
return super().__mod__(other)
1917+
return NotImplemented
19141918

19151919
def __ndx_logaddexp__(self, x2: TyArrayBase | int | float, /) -> TyArrayFloating:
19161920
if isinstance(x2, TyArrayNumber | int | float):

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ ignore = [
4949
"N803", # https://docs.astral.sh/ruff/rules/invalid-argument-name
5050
"N806", # https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function
5151
"E501", # https://docs.astral.sh/ruff/faq/#is-the-ruff-linter-compatible-with-black
52-
"UP038", # https://github.com/astral-sh/ruff/issues/7871
5352
"N807", # Free functions may start/end with dunders __array_namespace_info__
5453
"UP007",
5554
]

tests/test_elementwise.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,15 @@ def do(npx):
1717
return x_arr % y_arr
1818

1919
np.testing.assert_array_equal(do(ndx).unwrap_numpy(), do(np))
20-
assert do(ndx).unwrap_numpy() == x % y
20+
assert do(ndx).unwrap_numpy() == x % y # Compare to Python result
21+
22+
23+
@pytest.mark.parametrize("a,b", [(12.3, 3), (12, 3.3)])
24+
def test_mod_float_int(a, b):
25+
a1 = np.asarray(a)
26+
a2 = np.asarray(b)
27+
28+
def do(npx):
29+
return npx.asarray(a1) % npx.asarray(a2)
30+
31+
np.testing.assert_array_equal(do(ndx).unwrap_numpy(), do(np))

0 commit comments

Comments
 (0)