Skip to content

Commit 9770d6a

Browse files
authored
Merge pull request #10 from d-v-b/chore/numpy-correctness
fix: handle error in nearest-away rounding mode
2 parents ea80ed4 + f446386 commit 9770d6a

2 files changed

Lines changed: 20 additions & 3 deletions

File tree

src/cast_value/impl/_numpy.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,12 @@ def _cast_float(
174174
case "towards-negative":
175175
use_other = other_wide < ne_wide
176176
case "nearest-away":
177-
use_other = np.abs(other_wide) > np.abs(ne_wide)
177+
d_ne = np.abs(ne_wide - src)
178+
d_other = np.abs(other_wide - src)
179+
# Pick the closer candidate; break ties away from zero
180+
use_other = (d_other < d_ne) | (
181+
(d_other == d_ne) & (np.abs(other_wide) > np.abs(ne_wide))
182+
)
178183

179184
corrected = result.copy()
180185
indices = np.where(inexact)[0]

tests/test_core.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,12 +1201,16 @@ def _expected_for_rounding_mode(
12011201
hi: float,
12021202
mode: str,
12031203
ne_val: float,
1204+
original: float | None = None,
12041205
) -> float:
12051206
"""Return the expected result of rounding a value that lies between *lo* and *hi*.
12061207
12071208
*ne_val* is the nearest-even result from numpy (used only for the
12081209
``"nearest-even"`` case, since reproducing IEEE 754 tie-breaking from
12091210
scratch is error-prone).
1211+
1212+
*original* is the source value before narrowing (needed for ``nearest-away``
1213+
to determine which candidate is closer).
12101214
"""
12111215
match mode:
12121216
case "nearest-even":
@@ -1218,6 +1222,14 @@ def _expected_for_rounding_mode(
12181222
case "towards-negative":
12191223
return lo
12201224
case "nearest-away":
1225+
assert original is not None
1226+
d_lo = abs(lo - original)
1227+
d_hi = abs(hi - original)
1228+
if d_lo < d_hi:
1229+
return lo
1230+
if d_hi < d_lo:
1231+
return hi
1232+
# Tie: pick the one farther from zero
12211233
return hi if abs(hi) >= abs(lo) else lo
12221234
raise ValueError(mode) # pragma: no cover
12231235

@@ -1253,7 +1265,7 @@ def _make_float_rounding_cases() -> list[Expect]:
12531265

12541266
for mode in _ROUNDING_MODES:
12551267
expected_val = _expected_for_rounding_mode(
1256-
float(lo), float(hi), mode, float(ne)
1268+
float(lo), float(hi), mode, float(ne), original=float(val)
12571269
)
12581270
sign_label = "pos" if sign > 0 else "neg"
12591271
cases.append(
@@ -1306,7 +1318,7 @@ def _make_int_to_float_rounding_cases() -> list[Expect]:
13061318
hi = max(ne, other, key=float)
13071319

13081320
expected_val = _expected_for_rounding_mode(
1309-
float(lo), float(hi), mode, float(ne)
1321+
float(lo), float(hi), mode, float(ne), original=float(val)
13101322
)
13111323
sign_label = "pos" if sign > 0 else "neg"
13121324
cases.append(

0 commit comments

Comments
 (0)