Skip to content

Commit 4debd8a

Browse files
committed
TST: apply_where: improve tests per review suggestions
1 parent 1acee7a commit 4debd8a

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

tests/test_funcs.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def test_device(self, xp: ModuleType, device: Device):
210210
deadline=None,
211211
)
212212
@given(
213-
n_arrays=st.integers(min_value=1, max_value=3),
214-
n_kwarrays=st.integers(min_value=1, max_value=3),
213+
n_arrays=st.integers(min_value=0, max_value=3),
214+
n_kwarrays=st.integers(min_value=0, max_value=3),
215215
rng_seed=st.integers(min_value=1000000000, max_value=9999999999),
216216
dtype=npst.floating_dtypes(sizes=(32, 64)),
217217
p=st.floats(min_value=0, max_value=1),
@@ -235,6 +235,7 @@ def test_hypothesis(
235235
):
236236
pytest.xfail(reason="NumPy 1.x dtype promotion for scalars")
237237

238+
_ = hypothesis.assume(n_arrays + n_kwarrays > 0)
238239
mbs = npst.mutually_broadcastable_shapes(
239240
num_shapes=1 + n_arrays + n_kwarrays, min_side=0
240241
)
@@ -264,20 +265,20 @@ def test_hypothesis(
264265
)
265266

266267
kwargs = {
267-
str(n): xp.asarray(
268+
f"kw{n}": xp.asarray(
268269
data.draw(npst.arrays(dtype=dtype.type, shape=shape, elements=elements))
269270
)
270271
for n, shape in enumerate(kwshapes)
271272
}
272273
kwkeys = kwargs.keys()
273274

274275
def f1(*args: Array, **kwargs: dict[str, Array]) -> Array:
275-
assert set(kwargs.keys()) == set(kwkeys)
276+
assert kwargs.keys() == kwkeys
276277
args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values()))
277278
return cast(Array, sum(args_kwargs))
278279

279280
def f2(*args: Array, **kwargs: dict[str, Array]) -> Array:
280-
assert set(kwargs.keys()) == set(kwkeys)
281+
assert kwargs.keys() == kwkeys
281282
args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values()))
282283
return cast(Array, sum(args_kwargs) / 2)
283284

0 commit comments

Comments
 (0)