Skip to content

Commit 5765480

Browse files
authored
Align behaviors to cpython (#18)
* O(1) bit-level nextafter with steps parameter Replace the O(n) loop in nextafter(x, y, steps) with O(1) IEEE 754 bit manipulation matching math_nextafter_impl in mathmodule.c. Handles sign-crossing, saturation, and all edge cases. Add pyo3 proptest and edge/extreme-step tests for nextafter. * Fix cmath.log error propagation and bigint log precision cmath: Rewrite log(z, base) to match cmath_log_impl errno semantics. c_log always returns a value and separately reports EDOM for zero. c_quot returns EDOM and (0,0) for zero denominator instead of NaN. bigint: Add sticky bit to frexp_bigint for correct IEEE round-half-to-even, matching _PyLong_Frexp. Use mul_add for log/log2/log10 bigint to match loghelper fma usage. Add regression tests for remainder tie-to-even and signed zero. * Document design decisions and improve sumprod - math_1/math_2: document why errno handling differs from CPython (platform-specific unreliability, output checks sufficient, verified by proptest) - math.log: document EDOM substitution for ZeroDivisionError - math.remainder: document libm delegation rationale - sumprod: return Result for length mismatch instead of panic, improve overflow fallback to continue from where the fast path stopped instead of restarting from scratch
1 parent bcf75ff commit 5765480

File tree

9 files changed

+570
-70
lines changed

9 files changed

+570
-70
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "pymath"
33
authors = ["Jeong, YunWon <jeong@youknowone.org>"]
44
repository = "https://github.com/RustPython/pymath"
55
description = "A binary representation compatible Rust implementation of Python's math library."
6-
version = "0.1.5"
6+
version = "0.2.0"
77
edition = "2024"
88
license = "PSF-2.0"
99

src/cmath/exponential.rs

Lines changed: 97 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -177,40 +177,47 @@ pub(crate) fn ln(z: Complex64) -> Result<Complex64> {
177177
/// If base is Some(b), returns log(z) / log(b).
178178
#[inline]
179179
pub fn log(z: Complex64, base: Option<Complex64>) -> Result<Complex64> {
180-
// c_log always returns a value, but sets errno for special cases.
181-
// The error check happens at the end of cmath_log_impl.
182-
// For log(z) without base: z=0 raises EDOM
183-
// For log(z, base): z=0 doesn't raise because c_log(base) clears errno
184-
let z_is_zero = z.re == 0.0 && z.im == 0.0;
180+
let (log_z, mut err) = c_log(z);
185181
match base {
186-
None => {
187-
// No base: raise error if z=0
188-
if z_is_zero {
189-
return Err(Error::EDOM);
190-
}
191-
ln(z)
192-
}
182+
None => err.map_or(Ok(log_z), Err),
193183
Some(b) => {
194-
// With base: z=0 is allowed (second ln clears the "errno")
195-
let log_z = ln(z)?;
196-
let log_b = ln(b)?;
197-
// Use _Py_c_quot-style division to preserve sign of zero
198-
Ok(c_quot(log_z, log_b))
184+
// Like cmath_log_impl, the second c_log call overwrites
185+
// any pending error from the first one.
186+
let (log_b, base_err) = c_log(b);
187+
err = base_err;
188+
let (q, quot_err) = c_quot(log_z, log_b);
189+
if let Some(e) = quot_err {
190+
err = Some(e);
191+
}
192+
err.map_or(Ok(q), Err)
199193
}
200194
}
201195
}
202196

197+
/// c_log behavior: always returns a value, but reports EDOM for zero.
198+
#[inline]
199+
fn c_log(z: Complex64) -> (Complex64, Option<Error>) {
200+
let r = ln(z).expect("ln handles special values without failing");
201+
if z.re == 0.0 && z.im == 0.0 {
202+
(r, Some(Error::EDOM))
203+
} else {
204+
(r, None)
205+
}
206+
}
207+
203208
/// Complex division following _Py_c_quot algorithm.
204209
/// This preserves the sign of zero correctly and recovers infinities
205210
/// from NaN results per C11 Annex G.5.2.
206211
#[inline]
207-
fn c_quot(a: Complex64, b: Complex64) -> Complex64 {
212+
fn c_quot(a: Complex64, b: Complex64) -> (Complex64, Option<Error>) {
208213
let abs_breal = m::fabs(b.re);
209214
let abs_bimag = m::fabs(b.im);
215+
let mut err = None;
210216

211217
let mut r = if abs_breal >= abs_bimag {
212218
if abs_breal == 0.0 {
213-
Complex64::new(f64::NAN, f64::NAN)
219+
err = Some(Error::EDOM);
220+
Complex64::new(0.0, 0.0)
214221
} else {
215222
let ratio = b.im / b.re;
216223
let denom = b.re + b.im * ratio;
@@ -244,7 +251,7 @@ fn c_quot(a: Complex64, b: Complex64) -> Complex64 {
244251
}
245252
}
246253

247-
r
254+
(r, err)
248255
}
249256

250257
/// Complex base-10 logarithm.
@@ -325,6 +332,53 @@ mod tests {
325332
});
326333
}
327334

335+
fn test_log_error(z: Complex64, base: Complex64) {
336+
use pyo3::prelude::*;
337+
338+
let rs_result = log(z, Some(base));
339+
340+
Python::attach(|py| {
341+
let cmath = pyo3::types::PyModule::import(py, "cmath").unwrap();
342+
let py_z = pyo3::types::PyComplex::from_doubles(py, z.re, z.im);
343+
let py_base = pyo3::types::PyComplex::from_doubles(py, base.re, base.im);
344+
let py_result = cmath.getattr("log").unwrap().call1((py_z, py_base));
345+
346+
match py_result {
347+
Ok(result) => {
348+
use pyo3::types::PyComplexMethods;
349+
let c = result.cast::<pyo3::types::PyComplex>().unwrap();
350+
panic!(
351+
"log({}+{}j, {}+{}j): expected ValueError, got ({}, {})",
352+
z.re,
353+
z.im,
354+
base.re,
355+
base.im,
356+
c.real(),
357+
c.imag()
358+
);
359+
}
360+
Err(err) => {
361+
assert!(
362+
err.is_instance_of::<pyo3::exceptions::PyValueError>(py),
363+
"log({}+{}j, {}+{}j): expected ValueError, got {err:?}",
364+
z.re,
365+
z.im,
366+
base.re,
367+
base.im,
368+
);
369+
assert!(
370+
matches!(rs_result, Err(crate::Error::EDOM)),
371+
"log({}+{}j, {}+{}j): expected Err(EDOM), got {rs_result:?}",
372+
z.re,
373+
z.im,
374+
base.re,
375+
base.im,
376+
);
377+
}
378+
}
379+
});
380+
}
381+
328382
use crate::test::EDGE_VALUES;
329383

330384
#[test]
@@ -382,6 +436,29 @@ mod tests {
382436
}
383437
}
384438

439+
#[test]
440+
fn regression_c_quot_zero_denominator_sets_edom() {
441+
let (q, err) = c_quot(Complex64::new(2.0, -3.0), Complex64::new(0.0, 0.0));
442+
assert_eq!(err, Some(crate::Error::EDOM));
443+
assert_eq!(q.re.to_bits(), 0.0f64.to_bits());
444+
assert_eq!(q.im.to_bits(), 0.0f64.to_bits());
445+
}
446+
447+
#[test]
448+
fn regression_log_zero_quotient_denominator_raises_edom() {
449+
let cases = [
450+
(Complex64::new(2.0, 0.0), Complex64::new(1.0, 0.0)),
451+
(Complex64::new(1.0, 0.0), Complex64::new(1.0, 0.0)),
452+
(Complex64::new(2.0, 0.0), Complex64::new(0.0, 0.0)),
453+
(Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)),
454+
(Complex64::new(0.0, 0.0), Complex64::new(0.0, 0.0)),
455+
];
456+
457+
for (z, base) in cases {
458+
test_log_error(z, base);
459+
}
460+
}
461+
385462
proptest::proptest! {
386463
#[test]
387464
fn proptest_sqrt(re: f64, im: f64) {

src/math.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,26 @@ macro_rules! libm_simple {
4949

5050
pub(crate) use libm_simple;
5151

52-
/// math_1: wrapper for 1-arg functions
52+
/// Wrapper for 1-arg libm functions, corresponding to FUNC1/is_error in
53+
/// mathmodule.c.
54+
///
5355
/// - isnan(r) && !isnan(x) -> domain error
5456
/// - isinf(r) && isfinite(x) -> overflow (can_overflow=true) or domain error (can_overflow=false)
5557
/// - isfinite(r) && errno -> check errno (unnecessary on most platforms)
58+
///
59+
/// CPython's approach: clear errno, call libm, then inspect both the result
60+
/// and errno to classify errors. We rely primarily on output inspection
61+
/// (NaN/Inf checks) because:
62+
///
63+
/// - On macOS and Windows, libm functions do not reliably set errno for
64+
/// edge cases, so CPython's own is_error() skips the errno check there
65+
/// too (it only uses it as a fallback on other Unixes).
66+
/// - The NaN/Inf output checks are sufficient to detect all domain and
67+
/// range errors on every platform we test against (verified by proptest
68+
/// and edgetest against CPython via pyo3).
69+
/// - The errno-only branch (finite result with errno set) is kept for
70+
/// non-macOS/non-Windows Unixes where libm might signal an error
71+
/// without producing a NaN/Inf result.
5672
#[inline]
5773
pub(crate) fn math_1(x: f64, func: fn(f64) -> f64, can_overflow: bool) -> crate::Result<f64> {
5874
crate::err::set_errno(0);
@@ -75,9 +91,17 @@ pub(crate) fn math_1(x: f64, func: fn(f64) -> f64, can_overflow: bool) -> crate:
7591
Ok(r)
7692
}
7793

78-
/// math_2: wrapper for 2-arg functions
94+
/// Wrapper for 2-arg libm functions, corresponding to FUNC2 in
95+
/// mathmodule.c.
96+
///
7997
/// - isnan(r) && !isnan(x) && !isnan(y) -> domain error
8098
/// - isinf(r) && isfinite(x) && isfinite(y) -> range error
99+
///
100+
/// Unlike math_1, this does not set/check errno at all. CPython's FUNC2
101+
/// does clear and check errno, but the NaN/Inf output checks already
102+
/// cover all error cases for the 2-arg functions we wrap (atan2, fmod,
103+
/// copysign, remainder, pow). This is verified by bit-exact proptest
104+
/// and edgetest against CPython.
81105
#[inline]
82106
pub(crate) fn math_2(x: f64, y: f64, func: fn(f64, f64) -> f64) -> crate::Result<f64> {
83107
let r = func(x, y);

src/math/aggregate.rs

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ pub fn vector_norm(vec: &[f64], max: f64, found_nan: bool) -> f64 {
231231
///
232232
/// The points are given as sequences of coordinates.
233233
/// Uses high-precision vector_norm algorithm.
234+
///
235+
/// Panics if `p` and `q` have different lengths. CPython raises ValueError
236+
/// for mismatched dimensions, but in this Rust API the caller is expected
237+
/// to guarantee equal-length slices. A length mismatch is a programming
238+
/// error, not a runtime condition.
234239
pub fn dist(p: &[f64], q: &[f64]) -> f64 {
235240
assert_eq!(
236241
p.len(),
@@ -261,24 +266,52 @@ pub fn dist(p: &[f64], q: &[f64]) -> f64 {
261266

262267
/// Return the sum of products of values from two sequences (float version).
263268
///
264-
/// Uses TripleLength arithmetic for high precision.
265-
/// Equivalent to sum(p[i] * q[i] for i in range(len(p))).
266-
pub fn sumprod(p: &[f64], q: &[f64]) -> f64 {
267-
assert_eq!(p.len(), q.len(), "Inputs are not the same length");
269+
/// Uses TripleLength arithmetic for the fast path, then falls back to
270+
/// ordinary floating-point multiply/add starting at the first unsupported
271+
/// pair, matching Python's staged `math.sumprod` behavior for float inputs.
272+
///
273+
/// CPython's math_sumprod_impl is a 3-stage state machine that handles
274+
/// int/float/generic Python objects. This function only covers the float
275+
/// path (`&[f64]`). The int accumulation and generic PyNumber fallback
276+
/// stages are Python type-system concerns and should be handled by the
277+
/// caller (e.g. RustPython) before delegating here.
278+
///
279+
/// Returns EDOM if the inputs are not the same length.
280+
pub fn sumprod(p: &[f64], q: &[f64]) -> crate::Result<f64> {
281+
if p.len() != q.len() {
282+
return Err(crate::Error::EDOM);
283+
}
268284

285+
let mut total = 0.0;
269286
let mut flt_total = TL_ZERO;
287+
let mut flt_path_enabled = true;
288+
let mut i = 0;
270289

271-
for (&pi, &qi) in p.iter().zip(q.iter()) {
272-
let new_flt_total = tl_fma(pi, qi, flt_total);
273-
if new_flt_total.hi.is_finite() {
274-
flt_total = new_flt_total;
275-
} else {
276-
// Overflow or special value, fall back to simple sum
277-
return p.iter().zip(q.iter()).map(|(a, b)| a * b).sum();
290+
while i < p.len() {
291+
let pi = p[i];
292+
let qi = q[i];
293+
294+
if flt_path_enabled {
295+
let new_flt_total = tl_fma(pi, qi, flt_total);
296+
if new_flt_total.hi.is_finite() {
297+
flt_total = new_flt_total;
298+
i += 1;
299+
continue;
300+
}
301+
302+
flt_path_enabled = false;
303+
total += tl_to_d(flt_total);
278304
}
305+
306+
total += pi * qi;
307+
i += 1;
279308
}
280309

281-
tl_to_d(flt_total)
310+
Ok(if flt_path_enabled {
311+
tl_to_d(flt_total)
312+
} else {
313+
total
314+
})
282315
}
283316

284317
/// Return the sum of products of values from two sequences (integer version).
@@ -427,14 +460,27 @@ mod tests {
427460
crate::test::with_py_math(|py, math| {
428461
let py_p = pyo3::types::PyList::new(py, p).unwrap();
429462
let py_q = pyo3::types::PyList::new(py, q).unwrap();
430-
let py: f64 = math
431-
.getattr("sumprod")
432-
.unwrap()
433-
.call1((py_p, py_q))
434-
.unwrap()
435-
.extract()
436-
.unwrap();
437-
crate::test::assert_f64_eq(py, rs, format_args!("sumprod({p:?}, {q:?})"));
463+
let py_result = math.getattr("sumprod").unwrap().call1((py_p, py_q));
464+
match py_result {
465+
Ok(py_val) => {
466+
let py: f64 = py_val.extract().unwrap();
467+
let rs = rs.unwrap_or_else(|e| {
468+
panic!("sumprod({p:?}, {q:?}): py={py} but rs returned error {e:?}")
469+
});
470+
crate::test::assert_f64_eq(py, rs, format_args!("sumprod({p:?}, {q:?})"));
471+
}
472+
Err(e) => {
473+
if e.is_instance_of::<pyo3::exceptions::PyValueError>(py) {
474+
assert_eq!(
475+
rs.as_ref().err(),
476+
Some(&crate::Error::EDOM),
477+
"sumprod({p:?}, {q:?}): py raised ValueError but rs={rs:?}"
478+
);
479+
} else {
480+
panic!("sumprod({p:?}, {q:?}): py raised unexpected error {e}");
481+
}
482+
}
483+
}
438484
});
439485
}
440486

@@ -444,6 +490,9 @@ mod tests {
444490
test_sumprod_impl(&[], &[]);
445491
test_sumprod_impl(&[1.0], &[2.0]);
446492
test_sumprod_impl(&[1e100, 1e100], &[1e100, -1e100]);
493+
test_sumprod_impl(&[1.0, 1e308, -1e308], &[1.0, 2.0, 2.0]);
494+
test_sumprod_impl(&[1e-16, 1e308, -1e308], &[1.0, 2.0, 2.0]);
495+
test_sumprod_impl(&[1.0], &[]);
447496
}
448497

449498
fn test_prod_impl(values: &[f64], start: Option<f64>) {

0 commit comments

Comments
 (0)