Skip to content

Commit 744e57f

Browse files
authored
Fix test tolerances for float16 precision in math tests (#2828)
This PR fixes test failures when testing with `DPNP_TEST_ALL_INT_TYPES=1` against conda-forge's NumPy, where float16 precision is used in various scenarios requiring relaxed tolerances.
1 parent 5e5dc24 commit 744e57f

File tree

5 files changed

+117
-43
lines changed

5 files changed

+117
-43
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
8181
* Fixed an issue causing an exception in `dpnp.geomspace` and `dpnp.logspace` when called with explicit `device` keyword but any input array is allocated on another device [#2723](https://github.com/IntelPython/dpnp/pull/2723)
8282
* Fixed `.data.ptr` property on array views to correctly return the pointer to the view's data location instead of the base allocation pointer [#2812](https://github.com/IntelPython/dpnp/pull/2812)
8383
* Resolved an issue with strides calculation in `dpnp.diagonal` to return correct values for empty diagonals [#2814](https://github.com/IntelPython/dpnp/pull/2814)
84+
* Fixed test tolerance issues for float16 intermediate precision that became visible when testing against conda-forge's NumPy [#2828](https://github.com/IntelPython/dpnp/pull/2828)
8485

8586
### Security
8687

dpnp/tests/third_party/cupy/math_tests/test_explog.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,24 @@
66

77

88
class TestExplog:
9+
# rtol=1e-3 is used to pass the test when dtype is int8/uint8
10+
# for such a case, output dtype is float16
11+
_rtol_dict = {numpy.float16: 1e-3, "default": 1e-7}
912

1013
@testing.for_all_dtypes()
11-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
14+
@testing.numpy_cupy_allclose(
15+
rtol=_rtol_dict, atol=1e-5, type_check=has_support_aspect64()
16+
)
1217
def check_unary(self, name, xp, dtype, no_complex=False):
1318
if no_complex:
1419
if numpy.dtype(dtype).kind == "c":
1520
return xp.array(True)
1621
a = testing.shaped_arange((2, 3), xp, dtype)
1722
return getattr(xp, name)(a)
1823

19-
# rtol=1e-3 is added for dpnp to pass the test when dtype is int8/unint8
20-
# for such a case, output dtype is float16
2124
@testing.for_all_dtypes()
2225
@testing.numpy_cupy_allclose(
23-
rtol=1e-3, atol=1e-5, type_check=has_support_aspect64()
26+
rtol=_rtol_dict, atol=1e-5, type_check=has_support_aspect64()
2427
)
2528
def check_binary(self, name, xp, dtype, no_complex=False):
2629
if no_complex:

dpnp/tests/third_party/cupy/math_tests/test_hyperbolic.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,22 @@
77

88

99
class TestHyperbolic(unittest.TestCase):
10+
# rtol=1e-2 is used to pass the test when dtype is int8/uint8
11+
# for such a case, output dtype is float16
12+
_rtol_dict = {numpy.float16: 1e-2, "default": 1e-7}
1013

1114
@testing.for_all_dtypes()
1215
@testing.numpy_cupy_allclose(
13-
atol={numpy.float16: 1e-3, "default": 1e-5},
16+
rtol=_rtol_dict,
17+
atol=1e-5,
1418
type_check=has_support_aspect64(),
1519
)
1620
def check_unary(self, name, xp, dtype):
1721
a = testing.shaped_arange((2, 3), xp, dtype)
1822
return getattr(xp, name)(a)
1923

2024
@testing.for_dtypes(["e", "f", "d"])
21-
@testing.numpy_cupy_allclose(atol={numpy.float16: 1e-3, "default": 1e-5})
25+
@testing.numpy_cupy_allclose(rtol=_rtol_dict, atol=1e-5)
2226
def check_unary_unit(self, name, xp, dtype):
2327
a = xp.array([0.2, 0.4, 0.6, 0.8], dtype=dtype)
2428
return getattr(xp, name)(a)
@@ -36,7 +40,7 @@ def test_arcsinh(self):
3640
self.check_unary("arcsinh")
3741

3842
@testing.for_dtypes(["e", "f", "d"])
39-
@testing.numpy_cupy_allclose(atol={numpy.float16: 1e-3, "default": 1e-5})
43+
@testing.numpy_cupy_allclose(rtol=_rtol_dict, atol=1e-5)
4044
def test_arccosh(self, xp, dtype):
4145
a = xp.array([1, 2, 3], dtype=dtype)
4246
return xp.arccosh(a)

dpnp/tests/third_party/cupy/math_tests/test_misc.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,27 @@
99

1010

1111
class TestMisc:
12+
@staticmethod
13+
def _interp_atol(_result_dtype, dtype_x=None, **_kwargs):
14+
"""Compute absolute tolerance based on intermediate computation dtype.
15+
16+
Args:
17+
_result_dtype: Output dtype (unused - we check input dtype instead)
18+
dtype_x: Input dtype for fx coordinates
19+
_kwargs: Additional test parameters (unused)
20+
21+
When dtype_x is int8/uint8/float16, xp.sin(fx) uses float16 precision,
22+
so we need relaxed tolerance even if the final result is upcasted to float64.
23+
Float16 has ~3 decimal digits of precision, hence atol=1e-3.
24+
"""
25+
if dtype_x is not None:
26+
if numpy.dtype(dtype_x).type in (
27+
numpy.int8,
28+
numpy.uint8,
29+
numpy.float16,
30+
):
31+
return 1e-3
32+
return 1e-5
1233

1334
@testing.for_all_dtypes()
1435
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
@@ -401,17 +422,22 @@ def test_real_if_close_with_float_tol_false(self, xp, dtype):
401422

402423
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
403424
@testing.for_all_dtypes(name="dtype_y", no_bool=True)
404-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
425+
@testing.numpy_cupy_allclose(
426+
atol=_interp_atol, type_check=has_support_aspect64()
427+
)
405428
def test_interp(self, xp, dtype_y, dtype_x):
406429
# interpolate at points on and outside the boundaries
430+
# tolerance is automatically adjusted based on dtype_x via resolver
407431
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
408432
fx = xp.asarray([1, 3, 5, 7, 9], dtype=dtype_x)
409433
fy = xp.sin(fx).astype(dtype_y)
410434
return xp.interp(x, fx, fy)
411435

412436
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
413437
@testing.for_all_dtypes(name="dtype_y", no_bool=True)
414-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
438+
@testing.numpy_cupy_allclose(
439+
atol=_interp_atol, type_check=has_support_aspect64()
440+
)
415441
def test_interp_period(self, xp, dtype_y, dtype_x):
416442
# interpolate at points on and outside the boundaries
417443
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -421,7 +447,9 @@ def test_interp_period(self, xp, dtype_y, dtype_x):
421447

422448
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
423449
@testing.for_all_dtypes(name="dtype_y", no_bool=True)
424-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
450+
@testing.numpy_cupy_allclose(
451+
atol=_interp_atol, type_check=has_support_aspect64()
452+
)
425453
def test_interp_left_right(self, xp, dtype_y, dtype_x):
426454
# interpolate at points on and outside the boundaries
427455
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -434,7 +462,9 @@ def test_interp_left_right(self, xp, dtype_y, dtype_x):
434462
@testing.with_requires("numpy>=1.17.0")
435463
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
436464
@testing.for_dtypes("efdFD", name="dtype_y")
437-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
465+
@testing.numpy_cupy_allclose(
466+
atol=_interp_atol, type_check=has_support_aspect64()
467+
)
438468
def test_interp_nan_fy(self, xp, dtype_y, dtype_x):
439469
# interpolate at points on and outside the boundaries
440470
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -446,7 +476,9 @@ def test_interp_nan_fy(self, xp, dtype_y, dtype_x):
446476
@testing.with_requires("numpy>=1.17.0")
447477
@testing.for_float_dtypes(name="dtype_x")
448478
@testing.for_dtypes("efdFD", name="dtype_y")
449-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
479+
@testing.numpy_cupy_allclose(
480+
atol=_interp_atol, type_check=has_support_aspect64()
481+
)
450482
def test_interp_nan_fx(self, xp, dtype_y, dtype_x):
451483
# interpolate at points on and outside the boundaries
452484
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -458,7 +490,9 @@ def test_interp_nan_fx(self, xp, dtype_y, dtype_x):
458490
@testing.with_requires("numpy>=1.17.0")
459491
@testing.for_float_dtypes(name="dtype_x")
460492
@testing.for_dtypes("efdFD", name="dtype_y")
461-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
493+
@testing.numpy_cupy_allclose(
494+
atol=_interp_atol, type_check=has_support_aspect64()
495+
)
462496
def test_interp_nan_x(self, xp, dtype_y, dtype_x):
463497
# interpolate at points on and outside the boundaries
464498
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -470,7 +504,9 @@ def test_interp_nan_x(self, xp, dtype_y, dtype_x):
470504
@testing.with_requires("numpy>=1.17.0")
471505
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
472506
@testing.for_dtypes("efdFD", name="dtype_y")
473-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
507+
@testing.numpy_cupy_allclose(
508+
atol=_interp_atol, type_check=has_support_aspect64()
509+
)
474510
def test_interp_inf_fy(self, xp, dtype_y, dtype_x):
475511
# interpolate at points on and outside the boundaries
476512
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -482,7 +518,9 @@ def test_interp_inf_fy(self, xp, dtype_y, dtype_x):
482518
@testing.with_requires("numpy>=1.17.0")
483519
@testing.for_float_dtypes(name="dtype_x")
484520
@testing.for_dtypes("efdFD", name="dtype_y")
485-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
521+
@testing.numpy_cupy_allclose(
522+
atol=_interp_atol, type_check=has_support_aspect64()
523+
)
486524
def test_interp_inf_fx(self, xp, dtype_y, dtype_x):
487525
# interpolate at points on and outside the boundaries
488526
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -494,7 +532,9 @@ def test_interp_inf_fx(self, xp, dtype_y, dtype_x):
494532
@testing.with_requires("numpy>=1.17.0")
495533
@testing.for_float_dtypes(name="dtype_x")
496534
@testing.for_dtypes("efdFD", name="dtype_y")
497-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
535+
@testing.numpy_cupy_allclose(
536+
atol=_interp_atol, type_check=has_support_aspect64()
537+
)
498538
def test_interp_inf_x(self, xp, dtype_y, dtype_x):
499539
# interpolate at points on and outside the boundaries
500540
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -505,7 +545,9 @@ def test_interp_inf_x(self, xp, dtype_y, dtype_x):
505545

506546
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
507547
@testing.for_all_dtypes(name="dtype_y", no_bool=True)
508-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
548+
@testing.numpy_cupy_allclose(
549+
atol=_interp_atol, type_check=has_support_aspect64()
550+
)
509551
def test_interp_size1(self, xp, dtype_y, dtype_x):
510552
# interpolate at points on and outside the boundaries
511553
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -518,7 +560,9 @@ def test_interp_size1(self, xp, dtype_y, dtype_x):
518560
@testing.with_requires("numpy>=1.17.0")
519561
@testing.for_float_dtypes(name="dtype_x")
520562
@testing.for_dtypes("efdFD", name="dtype_y")
521-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
563+
@testing.numpy_cupy_allclose(
564+
atol=_interp_atol, type_check=has_support_aspect64()
565+
)
522566
def test_interp_inf_to_nan(self, xp, dtype_y, dtype_x):
523567
# from NumPy's test_non_finite_inf
524568
x = xp.asarray([0.5], dtype=dtype_x)

dpnp/tests/third_party/cupy/testing/_loops.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def test_func(*args, **kw):
410410
numpy_r = numpy_r[mask]
411411

412412
if not skip:
413-
check_func(cupy_r, numpy_r)
413+
check_func(cupy_r, numpy_r, **kw)
414414

415415
return test_func
416416

@@ -469,6 +469,9 @@ def _convert_output_to_ndarray(c_out, n_out, sp_name, check_sparse_format):
469469

470470
def _check_tolerance_keys(rtol, atol):
471471
def _check(tol):
472+
if callable(tol):
473+
# Callable tolerance is allowed
474+
return
472475
if isinstance(tol, dict):
473476
for k in tol.keys():
474477
if type(k) is type:
@@ -486,9 +489,13 @@ def _check(tol):
486489
_check(atol)
487490

488491

489-
def _resolve_tolerance(type_check, result, rtol, atol):
492+
def _resolve_tolerance(type_check, result, rtol, atol, **test_kwargs):
490493
def _resolve(dtype, tol):
491-
if isinstance(tol, dict):
494+
if callable(tol):
495+
# Support callable tolerance that can inspect test kwargs
496+
return tol(dtype, **test_kwargs)
497+
elif isinstance(tol, dict):
498+
# Original dict lookup logic
492499
tol1 = tol.get(dtype.type)
493500
if tol1 is None:
494501
tol1 = tol.get("default")
@@ -523,13 +530,15 @@ def numpy_cupy_allclose(
523530
"""Decorator that checks NumPy results and CuPy ones are close.
524531
525532
Args:
526-
rtol(float or dict): Relative tolerance. Besides a float value, a
527-
dictionary that maps a dtypes to a float value can be supplied to
528-
adjust tolerance per dtype. If the dictionary has ``'default'``
529-
string as its key, its value is used as the default tolerance in
530-
case any dtype keys do not match.
531-
atol(float or dict): Absolute tolerance. Besides a float value, a
532-
dictionary can be supplied as ``rtol``.
533+
rtol(float, dict, or callable): Relative tolerance. Can be:
534+
- A float value
535+
- A dictionary that maps dtypes to float values. If the dictionary
536+
has ``'default'`` string as its key, its value is used as the
537+
default tolerance in case any dtype keys do not match.
538+
- A callable with signature ``(dtype, **test_kwargs)`` that returns
539+
a float. This allows dynamic tolerance based on test parameters
540+
like input dtypes.
541+
atol(float, dict, or callable): Absolute tolerance. Same options as ``rtol``.
533542
err_msg(str): The error message to be printed in case of failure.
534543
verbose(bool): If ``True``, the conflicting values are
535544
appended to the error message.
@@ -583,10 +592,17 @@ def numpy_cupy_allclose(
583592
# "must be supplied as float."
584593
# )
585594

586-
def check_func(c, n):
587-
rtol1, atol1 = _resolve_tolerance(type_check, c, rtol, atol)
595+
def check_func(cupy_result, numpy_result, **test_kwargs):
596+
rtol1, atol1 = _resolve_tolerance(
597+
type_check, cupy_result, rtol, atol, **test_kwargs
598+
)
588599
_array.assert_allclose(
589-
c, n, rtol1, atol1, err_msg=err_msg, verbose=verbose
600+
cupy_result,
601+
numpy_result,
602+
rtol1,
603+
atol1,
604+
err_msg=err_msg,
605+
verbose=verbose,
590606
)
591607

592608
return _make_decorator(
@@ -641,8 +657,10 @@ def numpy_cupy_array_almost_equal(
641657
.. seealso:: :func:`cupy.testing.assert_array_almost_equal`
642658
"""
643659

644-
def check_func(x, y):
645-
_array.assert_array_almost_equal(x, y, decimal, err_msg, verbose)
660+
def check_func(cupy_result, numpy_result, **test_kwargs):
661+
_array.assert_array_almost_equal(
662+
cupy_result, numpy_result, decimal, err_msg, verbose
663+
)
646664

647665
return _make_decorator(
648666
check_func, name, type_check, False, accept_error, sp_name, scipy_name
@@ -684,8 +702,8 @@ def numpy_cupy_array_almost_equal_nulp(
684702
.. seealso:: :func:`cupy.testing.assert_array_almost_equal_nulp`
685703
"""
686704

687-
def check_func(x, y):
688-
_array.assert_array_almost_equal_nulp(x, y, nulp)
705+
def check_func(cupy_result, numpy_result, **test_kwargs):
706+
_array.assert_array_almost_equal_nulp(cupy_result, numpy_result, nulp)
689707

690708
return _make_decorator(
691709
check_func,
@@ -738,8 +756,8 @@ def numpy_cupy_array_max_ulp(
738756
739757
"""
740758

741-
def check_func(x, y):
742-
_array.assert_array_max_ulp(x, y, maxulp, dtype)
759+
def check_func(cupy_result, numpy_result, **test_kwargs):
760+
_array.assert_array_max_ulp(cupy_result, numpy_result, maxulp, dtype)
743761

744762
return _make_decorator(
745763
check_func, name, type_check, False, accept_error, sp_name, scipy_name
@@ -787,9 +805,13 @@ def numpy_cupy_array_equal(
787805
.. seealso:: :func:`cupy.testing.assert_array_equal`
788806
"""
789807

790-
def check_func(x, y):
808+
def check_func(cupy_result, numpy_result, **test_kwargs):
791809
_array.assert_array_equal(
792-
x, y, err_msg, verbose, strides_check=strides_check
810+
cupy_result,
811+
numpy_result,
812+
err_msg,
813+
verbose,
814+
strides_check=strides_check,
793815
)
794816

795817
return _make_decorator(
@@ -826,8 +848,8 @@ def numpy_cupy_array_list_equal(
826848
DeprecationWarning,
827849
)
828850

829-
def check_func(x, y):
830-
_array.assert_array_equal(x, y, err_msg, verbose)
851+
def check_func(cupy_result, numpy_result, **test_kwargs):
852+
_array.assert_array_equal(cupy_result, numpy_result, err_msg, verbose)
831853

832854
return _make_decorator(
833855
check_func, name, False, False, False, sp_name, scipy_name
@@ -871,8 +893,8 @@ def numpy_cupy_array_less(
871893
.. seealso:: :func:`cupy.testing.assert_array_less`
872894
"""
873895

874-
def check_func(x, y):
875-
_array.assert_array_less(x, y, err_msg, verbose)
896+
def check_func(cupy_result, numpy_result, **test_kwargs):
897+
_array.assert_array_less(cupy_result, numpy_result, err_msg, verbose)
876898

877899
return _make_decorator(
878900
check_func, name, type_check, False, accept_error, sp_name, scipy_name

0 commit comments

Comments
 (0)