@@ -821,10 +821,30 @@ def test_atan(x):
821821@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
822822def test_atan2 (x1 , x2 ):
823823 out = xp .atan2 (x1 , x2 )
824- ph .assert_dtype ("atan2" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
825- ph .assert_result_shape ("atan2" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
826- refimpl = cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2
827- binary_assert_against_refimpl ("atan2" , x1 , x2 , out , refimpl )
824+ _assert_correctness_binary (
825+ "atan" ,
826+ cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2 ,
827+ in_dtypes = [x1 .dtype , x2 .dtype ],
828+ in_shapes = [x1 .shape , x2 .shape ],
829+ in_arrs = [x1 , x2 ],
830+ out = out ,
831+ )
832+
833+
834+ @pytest .mark .min_version ("2024.12" )
835+ @given (hh .array_and_py_scalar (dh .real_float_dtypes ))
836+ def test_atan2_with_scalars (x1x2 ):
837+ x1 , x2 = x1x2
838+ out = xp .atan2 (x1 , x2 )
839+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
840+ _assert_correctness_binary (
841+ "atan2" ,
842+ cmath .atan2 if x1a .dtype in dh .complex_dtypes else math .atan2 ,
843+ in_dtypes = in_dtypes ,
844+ in_shapes = in_shapes ,
845+ in_arrs = [x1a , x2a ],
846+ out = out ,
847+ )
828848
829849
830850@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
@@ -1290,11 +1310,31 @@ def test_greater_equal(ctx, data):
12901310@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
12911311def test_hypot (x1 , x2 ):
12921312 out = xp .hypot (x1 , x2 )
1293- ph .assert_dtype ("hypot" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1294- ph .assert_result_shape ("hypot" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1295- binary_assert_against_refimpl ("hypot" , x1 , x2 , out , math .hypot )
1313+ _assert_correctness_binary (
1314+ "hypot" ,
1315+ math .hypot ,
1316+ in_dtypes = [x1 .dtype , x2 .dtype ],
1317+ in_shapes = [x1 .shape , x2 .shape ],
1318+ in_arrs = [x1 , x2 ],
1319+ out = out
1320+ )
12961321
12971322
1323+ @pytest .mark .min_version ("2024.12" )
1324+ @given (hh .array_and_py_scalar (dh .real_float_dtypes ))
1325+ def test_hypot_with_scalars (x1x2 ):
1326+ x1 , x2 = x1x2
1327+ out = xp .hypot (x1 , x2 )
1328+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1329+ _assert_correctness_binary (
1330+ "hypot" ,
1331+ math .hypot ,
1332+ in_dtypes = in_dtypes ,
1333+ in_shapes = in_shapes ,
1334+ in_arrs = (x1a , x2a ),
1335+ out = out
1336+ )
1337+
12981338
12991339@pytest .mark .min_version ("2022.12" )
13001340@pytest .mark .skipif (hh .complex_dtypes .is_empty , reason = "no complex data types to draw from" )
@@ -1443,12 +1483,34 @@ def logaddexp_refimpl(l: float, r: float) -> float:
14431483 raise OverflowError
14441484
14451485
1486+ @pytest .mark .min_version ("2023.12" )
14461487@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
14471488def test_logaddexp (x1 , x2 ):
14481489 out = xp .logaddexp (x1 , x2 )
1449- ph .assert_dtype ("logaddexp" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1450- ph .assert_result_shape ("logaddexp" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1451- binary_assert_against_refimpl ("logaddexp" , x1 , x2 , out , logaddexp_refimpl )
1490+ _assert_correctness_binary (
1491+ "logaddexp" ,
1492+ logaddexp_refimpl ,
1493+ in_dtypes = [x1 .dtype , x2 .dtype ],
1494+ in_shapes = [x1 .shape , x2 .shape ],
1495+ in_arrs = [x1 , x2 ],
1496+ out = out
1497+ )
1498+
1499+
1500+ @pytest .mark .min_version ("2024.12" )
1501+ @given (hh .array_and_py_scalar (dh .real_float_dtypes ))
1502+ def test_logaddexp_with_scalars (x1x2 ):
1503+ x1 , x2 = x1x2
1504+ out = xp .logaddexp (x1 , x2 )
1505+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1506+ _assert_correctness_binary (
1507+ "logaddexp" ,
1508+ logaddexp_refimpl ,
1509+ in_dtypes = in_dtypes ,
1510+ in_shapes = in_shapes ,
1511+ in_arrs = (x1a , x2a ),
1512+ out = out
1513+ )
14521514
14531515
14541516@given (hh .arrays (dtype = xp .bool , shape = hh .shapes ()))
0 commit comments