@@ -690,6 +690,31 @@ def binary_param_assert_against_refimpl(
690690 )
691691
692692
693+ def _convert_scalars_helper (x1 , x2 ):
694+ """Convert python scalar to arrays, record the shapes/dtypes of arrays.
695+
696+ For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
697+ and all arguments converted to arrays.
698+
699+ dtypes are separate to help distinguishing between
700+ `py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
701+ """
702+ if dh .is_scalar (x1 ):
703+ in_dtypes = [x2 .dtype ]
704+ in_shapes = [x2 .shape ]
705+ x1a , x2a = xp .asarray (x1 ), x2
706+ elif dh .is_scalar (x2 ):
707+ in_dtypes = [x1 .dtype ]
708+ in_shapes = [x1 .shape ]
709+ x1a , x2a = x1 , xp .asarray (x2 )
710+ else :
711+ in_dtypes = [x1 .dtype , x2 .dtype ]
712+ in_shapes = [x1 .shape , x2 .shape ]
713+ x1a , x2a = x1 , x2
714+
715+ return in_dtypes , in_shapes , (x1a , x2a )
716+
717+
693718@pytest .mark .parametrize ("ctx" , make_unary_params ("abs" , dh .numeric_dtypes ))
694719@given (data = st .data ())
695720def test_abs (ctx , data ):
@@ -1468,13 +1493,27 @@ def test_maximum(x1, x2):
14681493 binary_assert_against_refimpl ("maximum" , x1 , x2 , out , max , strict_check = True )
14691494
14701495
1496+ def _assert_correctness_binary (name , in_dtypes , in_shapes , in_arrs , out ):
1497+ x1a , x2a = in_arrs
1498+ ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype )
1499+ ph .assert_result_shape (name , in_shapes = in_shapes , out_shape = out .shape )
1500+ binary_assert_against_refimpl (name , x1a , x2a , out , min , strict_check = True )
1501+
1502+
14711503@pytest .mark .min_version ("2023.12" )
14721504@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
14731505def test_minimum (x1 , x2 ):
14741506 out = xp .minimum (x1 , x2 )
1475- ph .assert_dtype ("minimum" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1476- ph .assert_result_shape ("minimum" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1477- binary_assert_against_refimpl ("minimum" , x1 , x2 , out , min , strict_check = True )
1507+ _assert_correctness_binary ("minimum" , [x1 .dtype , x2 .dtype ], [x1 .shape , x2 .shape ], (x1 , x2 ), out )
1508+
1509+
1510+ @pytest .mark .min_version ("2024.12" )
1511+ @given (hh .array_and_py_scalar (dh .real_float_dtypes ))
1512+ def test_minimum_with_scalars (x1x2 ):
1513+ x1 , x2 = x1x2
1514+ out = xp .minimum (x1 , x2 )
1515+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1516+ _assert_correctness_binary ("minimum" , in_dtypes , in_shapes , (x1a , x2a ), out )
14781517
14791518
14801519@pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , dh .numeric_dtypes ))
0 commit comments