Skip to content

Commit 9b108ec

Browse files
committed
Update interp tests to consider when xp.sin(fx) uses float16 precision
1 parent dbb7199 commit 9b108ec

File tree

2 files changed

+84
-29
lines changed

2 files changed

+84
-29
lines changed

dpnp/tests/third_party/cupy/math_tests/test_misc.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,27 @@
99

1010

1111
class TestMisc:
12+
@staticmethod
13+
def _interp_atol(_result_dtype, dtype_x=None, **_kwargs):
14+
"""Compute absolute tolerance based on intermediate computation dtype.
15+
16+
Args:
17+
_result_dtype: Output dtype (unused - we check input dtype instead)
18+
dtype_x: Input dtype for fx coordinates
19+
_kwargs: Additional test parameters (unused)
20+
21+
When dtype_x is int8/uint8/float16, xp.sin(fx) uses float16 precision,
22+
so we need relaxed tolerance even if the final result is upcasted to float64.
23+
Float16 has ~3 decimal digits of precision, hence atol=1e-3.
24+
"""
25+
if dtype_x is not None:
26+
if numpy.dtype(dtype_x).type in (
27+
numpy.int8,
28+
numpy.uint8,
29+
numpy.float16,
30+
):
31+
return 1e-3
32+
return 1e-5
1233

1334
@testing.for_all_dtypes()
1435
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
@@ -401,17 +422,22 @@ def test_real_if_close_with_float_tol_false(self, xp, dtype):
401422

402423
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
403424
@testing.for_all_dtypes(name="dtype_y", no_bool=True)
404-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
425+
@testing.numpy_cupy_allclose(
426+
atol=_interp_atol, type_check=has_support_aspect64()
427+
)
405428
def test_interp(self, xp, dtype_y, dtype_x):
406429
# interpolate at points on and outside the boundaries
430+
# tolerance is automatically adjusted based on dtype_x via resolver
407431
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
408432
fx = xp.asarray([1, 3, 5, 7, 9], dtype=dtype_x)
409433
fy = xp.sin(fx).astype(dtype_y)
410434
return xp.interp(x, fx, fy)
411435

412436
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
413437
@testing.for_all_dtypes(name="dtype_y", no_bool=True)
414-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
438+
@testing.numpy_cupy_allclose(
439+
atol=_interp_atol, type_check=has_support_aspect64()
440+
)
415441
def test_interp_period(self, xp, dtype_y, dtype_x):
416442
# interpolate at points on and outside the boundaries
417443
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -421,7 +447,9 @@ def test_interp_period(self, xp, dtype_y, dtype_x):
421447

422448
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
423449
@testing.for_all_dtypes(name="dtype_y", no_bool=True)
424-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
450+
@testing.numpy_cupy_allclose(
451+
atol=_interp_atol, type_check=has_support_aspect64()
452+
)
425453
def test_interp_left_right(self, xp, dtype_y, dtype_x):
426454
# interpolate at points on and outside the boundaries
427455
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -434,7 +462,9 @@ def test_interp_left_right(self, xp, dtype_y, dtype_x):
434462
@testing.with_requires("numpy>=1.17.0")
435463
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
436464
@testing.for_dtypes("efdFD", name="dtype_y")
437-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
465+
@testing.numpy_cupy_allclose(
466+
atol=_interp_atol, type_check=has_support_aspect64()
467+
)
438468
def test_interp_nan_fy(self, xp, dtype_y, dtype_x):
439469
# interpolate at points on and outside the boundaries
440470
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -446,7 +476,9 @@ def test_interp_nan_fy(self, xp, dtype_y, dtype_x):
446476
@testing.with_requires("numpy>=1.17.0")
447477
@testing.for_float_dtypes(name="dtype_x")
448478
@testing.for_dtypes("efdFD", name="dtype_y")
449-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
479+
@testing.numpy_cupy_allclose(
480+
atol=_interp_atol, type_check=has_support_aspect64()
481+
)
450482
def test_interp_nan_fx(self, xp, dtype_y, dtype_x):
451483
# interpolate at points on and outside the boundaries
452484
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -458,7 +490,9 @@ def test_interp_nan_fx(self, xp, dtype_y, dtype_x):
458490
@testing.with_requires("numpy>=1.17.0")
459491
@testing.for_float_dtypes(name="dtype_x")
460492
@testing.for_dtypes("efdFD", name="dtype_y")
461-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
493+
@testing.numpy_cupy_allclose(
494+
atol=_interp_atol, type_check=has_support_aspect64()
495+
)
462496
def test_interp_nan_x(self, xp, dtype_y, dtype_x):
463497
# interpolate at points on and outside the boundaries
464498
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -470,7 +504,9 @@ def test_interp_nan_x(self, xp, dtype_y, dtype_x):
470504
@testing.with_requires("numpy>=1.17.0")
471505
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
472506
@testing.for_dtypes("efdFD", name="dtype_y")
473-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
507+
@testing.numpy_cupy_allclose(
508+
atol=_interp_atol, type_check=has_support_aspect64()
509+
)
474510
def test_interp_inf_fy(self, xp, dtype_y, dtype_x):
475511
# interpolate at points on and outside the boundaries
476512
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -482,7 +518,9 @@ def test_interp_inf_fy(self, xp, dtype_y, dtype_x):
482518
@testing.with_requires("numpy>=1.17.0")
483519
@testing.for_float_dtypes(name="dtype_x")
484520
@testing.for_dtypes("efdFD", name="dtype_y")
485-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
521+
@testing.numpy_cupy_allclose(
522+
atol=_interp_atol, type_check=has_support_aspect64()
523+
)
486524
def test_interp_inf_fx(self, xp, dtype_y, dtype_x):
487525
# interpolate at points on and outside the boundaries
488526
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -494,7 +532,9 @@ def test_interp_inf_fx(self, xp, dtype_y, dtype_x):
494532
@testing.with_requires("numpy>=1.17.0")
495533
@testing.for_float_dtypes(name="dtype_x")
496534
@testing.for_dtypes("efdFD", name="dtype_y")
497-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
535+
@testing.numpy_cupy_allclose(
536+
atol=_interp_atol, type_check=has_support_aspect64()
537+
)
498538
def test_interp_inf_x(self, xp, dtype_y, dtype_x):
499539
# interpolate at points on and outside the boundaries
500540
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -505,7 +545,9 @@ def test_interp_inf_x(self, xp, dtype_y, dtype_x):
505545

506546
@testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True)
507547
@testing.for_all_dtypes(name="dtype_y", no_bool=True)
508-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
548+
@testing.numpy_cupy_allclose(
549+
atol=_interp_atol, type_check=has_support_aspect64()
550+
)
509551
def test_interp_size1(self, xp, dtype_y, dtype_x):
510552
# interpolate at points on and outside the boundaries
511553
x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x)
@@ -518,7 +560,9 @@ def test_interp_size1(self, xp, dtype_y, dtype_x):
518560
@testing.with_requires("numpy>=1.17.0")
519561
@testing.for_float_dtypes(name="dtype_x")
520562
@testing.for_dtypes("efdFD", name="dtype_y")
521-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
563+
@testing.numpy_cupy_allclose(
564+
atol=_interp_atol, type_check=has_support_aspect64()
565+
)
522566
def test_interp_inf_to_nan(self, xp, dtype_y, dtype_x):
523567
# from NumPy's test_non_finite_inf
524568
x = xp.asarray([0.5], dtype=dtype_x)

dpnp/tests/third_party/cupy/testing/_loops.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def test_func(*args, **kw):
410410
numpy_r = numpy_r[mask]
411411

412412
if not skip:
413-
check_func(cupy_r, numpy_r)
413+
check_func(cupy_r, numpy_r, **kw)
414414

415415
return test_func
416416

@@ -469,6 +469,9 @@ def _convert_output_to_ndarray(c_out, n_out, sp_name, check_sparse_format):
469469

470470
def _check_tolerance_keys(rtol, atol):
471471
def _check(tol):
472+
if callable(tol):
473+
# Callable tolerance is allowed
474+
return
472475
if isinstance(tol, dict):
473476
for k in tol.keys():
474477
if type(k) is type:
@@ -486,9 +489,13 @@ def _check(tol):
486489
_check(atol)
487490

488491

489-
def _resolve_tolerance(type_check, result, rtol, atol):
492+
def _resolve_tolerance(type_check, result, rtol, atol, **test_kwargs):
490493
def _resolve(dtype, tol):
491-
if isinstance(tol, dict):
494+
if callable(tol):
495+
# Support callable tolerance that can inspect test kwargs
496+
return tol(dtype, **test_kwargs)
497+
elif isinstance(tol, dict):
498+
# Original dict lookup logic
492499
tol1 = tol.get(dtype.type)
493500
if tol1 is None:
494501
tol1 = tol.get("default")
@@ -523,13 +530,15 @@ def numpy_cupy_allclose(
523530
"""Decorator that checks NumPy results and CuPy ones are close.
524531
525532
Args:
526-
rtol(float or dict): Relative tolerance. Besides a float value, a
527-
dictionary that maps a dtypes to a float value can be supplied to
528-
adjust tolerance per dtype. If the dictionary has ``'default'``
529-
string as its key, its value is used as the default tolerance in
530-
case any dtype keys do not match.
531-
atol(float or dict): Absolute tolerance. Besides a float value, a
532-
dictionary can be supplied as ``rtol``.
533+
rtol(float, dict, or callable): Relative tolerance. Can be:
534+
- A float value
535+
- A dictionary that maps dtypes to float values. If the dictionary
536+
has ``'default'`` string as its key, its value is used as the
537+
default tolerance in case any dtype keys do not match.
538+
- A callable with signature ``(dtype, **test_kwargs)`` that returns
539+
a float. This allows dynamic tolerance based on test parameters
540+
like input dtypes.
541+
atol(float, dict, or callable): Absolute tolerance. Same options as ``rtol``.
533542
err_msg(str): The error message to be printed in case of failure.
534543
verbose(bool): If ``True``, the conflicting values are
535544
appended to the error message.
@@ -583,8 +592,10 @@ def numpy_cupy_allclose(
583592
# "must be supplied as float."
584593
# )
585594

586-
def check_func(c, n):
587-
rtol1, atol1 = _resolve_tolerance(type_check, c, rtol, atol)
595+
def check_func(c, n, **test_kwargs):
596+
rtol1, atol1 = _resolve_tolerance(
597+
type_check, c, rtol, atol, **test_kwargs
598+
)
588599
_array.assert_allclose(
589600
c, n, rtol1, atol1, err_msg=err_msg, verbose=verbose
590601
)
@@ -641,7 +652,7 @@ def numpy_cupy_array_almost_equal(
641652
.. seealso:: :func:`cupy.testing.assert_array_almost_equal`
642653
"""
643654

644-
def check_func(x, y):
655+
def check_func(x, y, **test_kwargs):
645656
_array.assert_array_almost_equal(x, y, decimal, err_msg, verbose)
646657

647658
return _make_decorator(
@@ -684,7 +695,7 @@ def numpy_cupy_array_almost_equal_nulp(
684695
.. seealso:: :func:`cupy.testing.assert_array_almost_equal_nulp`
685696
"""
686697

687-
def check_func(x, y):
698+
def check_func(x, y, **test_kwargs):
688699
_array.assert_array_almost_equal_nulp(x, y, nulp)
689700

690701
return _make_decorator(
@@ -738,7 +749,7 @@ def numpy_cupy_array_max_ulp(
738749
739750
"""
740751

741-
def check_func(x, y):
752+
def check_func(x, y, **test_kwargs):
742753
_array.assert_array_max_ulp(x, y, maxulp, dtype)
743754

744755
return _make_decorator(
@@ -787,7 +798,7 @@ def numpy_cupy_array_equal(
787798
.. seealso:: :func:`cupy.testing.assert_array_equal`
788799
"""
789800

790-
def check_func(x, y):
801+
def check_func(x, y, **test_kwargs):
791802
_array.assert_array_equal(
792803
x, y, err_msg, verbose, strides_check=strides_check
793804
)
@@ -826,7 +837,7 @@ def numpy_cupy_array_list_equal(
826837
DeprecationWarning,
827838
)
828839

829-
def check_func(x, y):
840+
def check_func(x, y, **test_kwargs):
830841
_array.assert_array_equal(x, y, err_msg, verbose)
831842

832843
return _make_decorator(
@@ -871,7 +882,7 @@ def numpy_cupy_array_less(
871882
.. seealso:: :func:`cupy.testing.assert_array_less`
872883
"""
873884

874-
def check_func(x, y):
885+
def check_func(x, y, **test_kwargs):
875886
_array.assert_array_less(x, y, err_msg, verbose)
876887

877888
return _make_decorator(

0 commit comments

Comments
 (0)