Skip to content

Commit 0580097

Browse files
authored
Merge pull request #410 from ev-br/scalars_stragegy
ENH: test binops with `float_array, int_scalar` combinations
2 parents b00ac41 + 05e89f3 commit 0580097

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,12 @@ def scalars(draw, dtypes, finite=False, **kwds):
457457
dtypes should be one of the shared_* dtypes strategies.
458458
"""
459459
dtype = draw(dtypes)
460-
mM = kwds.pop('mM', None)
461460
if dh.is_int_dtype(dtype):
462-
if mM is None:
463-
m, M = dh.dtype_ranges[dtype]
464-
else:
465-
m, M = mM
466-
return draw(integers(m, M))
461+
m, M = dh.dtype_ranges[dtype]
462+
min_value = kwds.get('min_value', m)
463+
max_value = kwds.get('max_value', M)
464+
465+
return draw(integers(min_value, max_value))
467466
elif dtype == bool_dtype:
468467
return draw(booleans())
469468
elif dtype == float64:
@@ -593,20 +592,32 @@ def two_mutual_arrays(
593592

594593

595594
@composite
596-
def array_and_py_scalar(draw, dtypes, mM=None, positive=False):
595+
def array_and_py_scalar(draw, dtypes, **kwds):
597596
"""Draw a pair: (array, scalar) or (scalar, array)."""
598597
dtype = draw(sampled_from(dtypes))
599598

600-
scalar_var = draw(scalars(just(dtype), finite=True, mM=mM))
601-
if positive:
602-
assume (scalar_var > 0)
599+
# draw the scalar: for float arrays, draw a float or an int
600+
if dtype in dh.real_float_dtypes:
601+
scalar_strategy = sampled_from([xp.int32, dtype])
602+
else:
603+
scalar_strategy = just(dtype)
604+
scalar_var = draw(scalars(scalar_strategy, finite=True, **kwds))
603605

606+
# draw the array.
607+
# XXX artificially limit the range of values for floats, otherwise value testing is flaky
604608
elements={}
605609
if dtype in dh.real_float_dtypes:
606-
elements = {'allow_nan': False, 'allow_infinity': False,
607-
'min_value': 1.0 / (2<<5), 'max_value': 2<<5}
608-
if positive:
609-
elements = {'min_value': 0}
610+
elements = {
611+
'allow_nan': False,
612+
'allow_infinity': False,
613+
'min_value': kwds.get('min_value', 1.0 / (2<<5)),
614+
'max_value': kwds.get('max_value', 2<<5)
615+
}
616+
elif dtype in dh.int_dtypes:
617+
elements = {
618+
'min_value': kwds.get('min_value', None),
619+
'max_value': kwds.get('max_value', None)
620+
}
610621
array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements))
611622

612623
if draw(booleans()):

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2246,7 +2246,7 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
22462246
],
22472247
ids=lambda func_data: func_data[0] # use names for test IDs
22482248
)
2249-
@given(x1x2=hh.array_and_py_scalar([xp.int32], positive=True, mM=(1, 3)))
2249+
@given(x1x2=hh.array_and_py_scalar([xp.int32], min_value=1, max_value=3))
22502250
def test_binary_with_scalars_bitwise_shifts(func_data, x1x2):
22512251
func_name, refimpl, kwargs, expected = func_data
22522252
# repack the refimpl

0 commit comments

Comments
 (0)