@@ -457,13 +457,12 @@ def scalars(draw, dtypes, finite=False, **kwds):
457457 dtypes should be one of the shared_* dtypes strategies.
458458 """
459459 dtype = draw (dtypes )
460- mM = kwds .pop ('mM' , None )
461460 if dh .is_int_dtype (dtype ):
462- if mM is None :
463- m , M = dh . dtype_ranges [ dtype ]
464- else :
465- m , M = mM
466- return draw (integers (m , M ))
461+ m , M = dh . dtype_ranges [ dtype ]
462+ min_value = kwds . get ( 'min_value' , m )
463+ max_value = kwds . get ( 'max_value' , M )
464+
465+ return draw (integers (min_value , max_value ))
467466 elif dtype == bool_dtype :
468467 return draw (booleans ())
469468 elif dtype == float64 :
@@ -593,20 +592,32 @@ def two_mutual_arrays(
593592
594593
595594@composite
596- def array_and_py_scalar (draw , dtypes , mM = None , positive = False ):
595+ def array_and_py_scalar (draw , dtypes , ** kwds ):
597596 """Draw a pair: (array, scalar) or (scalar, array)."""
598597 dtype = draw (sampled_from (dtypes ))
599598
600- scalar_var = draw (scalars (just (dtype ), finite = True , mM = mM ))
601- if positive :
602- assume (scalar_var > 0 )
599+ # draw the scalar: for float arrays, draw a float or an int
600+ if dtype in dh .real_float_dtypes :
601+ scalar_strategy = sampled_from ([xp .int32 , dtype ])
602+ else :
603+ scalar_strategy = just (dtype )
604+ scalar_var = draw (scalars (scalar_strategy , finite = True , ** kwds ))
603605
606+ # draw the array.
607+ # XXX artificially limit the range of values for floats, otherwise value testing is flaky
604608 elements = {}
605609 if dtype in dh .real_float_dtypes :
606- elements = {'allow_nan' : False , 'allow_infinity' : False ,
607- 'min_value' : 1.0 / (2 << 5 ), 'max_value' : 2 << 5 }
608- if positive :
609- elements = {'min_value' : 0 }
610+ elements = {
611+ 'allow_nan' : False ,
612+ 'allow_infinity' : False ,
613+ 'min_value' : kwds .get ('min_value' , 1.0 / (2 << 5 )),
614+ 'max_value' : kwds .get ('max_value' , 2 << 5 )
615+ }
616+ elif dtype in dh .int_dtypes :
617+ elements = {
618+ 'min_value' : kwds .get ('min_value' , None ),
619+ 'max_value' : kwds .get ('max_value' , None )
620+ }
610621 array_var = draw (arrays (dtype , shape = shapes (min_dims = 1 ), elements = elements ))
611622
612623 if draw (booleans ()):
0 commit comments