Skip to content

Commit 97cceae

Browse files
committed
ENH: apply_where: add kwargs support
1 parent f2541dc commit 97cceae

2 files changed

Lines changed: 70 additions & 20 deletions

File tree

src/array_api_extra/_lib/_funcs.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def apply_where( # numpydoc ignore=GL08
4141
f2: Callable[..., Array],
4242
/,
4343
*,
44+
kwargs: dict[str, Array] | None = None,
4445
xp: ModuleType | None = None,
4546
) -> Array: ...
4647

@@ -53,6 +54,7 @@ def apply_where( # numpydoc ignore=GL08
5354
/,
5455
*,
5556
fill_value: Array | complex,
57+
kwargs: dict[str, Array] | None = None,
5658
xp: ModuleType | None = None,
5759
) -> Array: ...
5860

@@ -65,6 +67,7 @@ def apply_where( # numpydoc ignore=PR01,PR02
6567
/,
6668
*,
6769
fill_value: Array | complex | None = None,
70+
kwargs: dict[str, Array] | None = None,
6871
xp: ModuleType | None = None,
6972
) -> Array:
7073
"""
@@ -91,6 +94,9 @@ def apply_where( # numpydoc ignore=PR01,PR02
9194
It does not need to be scalar; it needs however to be broadcastable with
9295
`cond` and `args`.
9396
Mutually exclusive with `f2`. You must provide one or the other.
97+
kwargs : dict of str : Array pairs
98+
Keyword argument(s) to `f1` (and `f2`). Values must be broadcastable with
99+
`cond`.
94100
xp : array_namespace, optional
95101
The standard-compatible namespace for `cond` and `args`. Default: infer.
96102
@@ -129,6 +135,12 @@ def apply_where( # numpydoc ignore=PR01,PR02
129135
args_ = list(args) if isinstance(args, tuple) else [args]
130136
del args
131137

138+
kwargs_ = {} if kwargs is None else kwargs
139+
kwkeys = list(kwargs_.keys())
140+
nargs = len(args_)
141+
args_ = [*args_, *kwargs_.values()]
142+
del kwargs
143+
132144
xp = array_namespace(cond, fill_value, *args_) if xp is None else xp
133145

134146
if isinstance(fill_value, int | float | complex | NoneType):
@@ -139,8 +151,19 @@ def apply_where( # numpydoc ignore=PR01,PR02
139151
if is_dask_namespace(xp):
140152
meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp)
141153
# map_blocks doesn't descend into tuples of Arrays
142-
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp)
143-
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
154+
return xp.map_blocks(
155+
_apply_where, cond, f1, f2, fill_value, *args_, kwkeys=kwkeys, xp=meta_xp
156+
)
157+
158+
if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
159+
# jax.jit does not support assignment by boolean mask
160+
return xp.where(
161+
cond,
162+
f1(*args_[:nargs], **kwargs_),
163+
f2(*args_[:nargs], **kwargs_) if f2 is not None else fill_value,
164+
)
165+
166+
return _apply_where(cond, f1, f2, fill_value, *args_, kwkeys=kwkeys, xp=xp)
144167

145168

146169
def _apply_where( # numpydoc ignore=PR01,RT01
@@ -149,15 +172,18 @@ def _apply_where( # numpydoc ignore=PR01,RT01
149172
f2: Callable[..., Array] | None,
150173
fill_value: Array | int | float | complex | bool | None,
151174
*args: Array,
175+
kwkeys: list[str],
152176
xp: ModuleType,
153177
) -> Array:
154178
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
155179

156-
if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
157-
# jax.jit does not support assignment by boolean mask
158-
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
180+
nargs = len(args) - len(kwkeys)
181+
kwargs = dict(zip(kwkeys, args[nargs:], strict=True))
182+
args = args[:nargs]
159183

160-
temp1 = f1(*(arr[cond] for arr in args))
184+
temp1 = f1(
185+
*(arr[cond] for arr in args), **{key: val[cond] for key, val in kwargs.items()}
186+
)
161187

162188
if f2 is None:
163189
dtype = xp.result_type(temp1, fill_value)
@@ -167,7 +193,10 @@ def _apply_where( # numpydoc ignore=PR01,RT01
167193
out = xp.astype(fill_value, dtype, copy=True)
168194
else:
169195
ncond = ~cond
170-
temp2 = f2(*(arr[ncond] for arr in args))
196+
temp2 = f2(
197+
*(arr[ncond] for arr in args),
198+
**{key: val[ncond] for key, val in kwargs.items()},
199+
)
171200
dtype = xp.result_type(temp1, temp2)
172201
out = xp.empty_like(cond, dtype=dtype)
173202
out = at(out, ncond).set(temp2)

tests/test_funcs.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,14 @@ def test_device(self, xp: ModuleType, device: Device):
207207
# The xp and library fixtures are not regenerated between hypothesis iterations
208208
suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture],
209209
# JAX can take a long time to initialize on the first call
210-
deadline=None,
210+
# deadline=None,
211+
phases=[hypothesis.Phase.generate], # disables shrinking/simplification
212+
max_examples=50, # small number of tests
213+
deadline=50, # ms per example (optional speed limit)
211214
)
212215
@given(
213216
n_arrays=st.integers(min_value=1, max_value=3),
217+
n_kwarrays=st.integers(min_value=1, max_value=3),
214218
rng_seed=st.integers(min_value=1000000000, max_value=9999999999),
215219
dtype=npst.floating_dtypes(sizes=(32, 64)),
216220
p=st.floats(min_value=0, max_value=1),
@@ -219,6 +223,7 @@ def test_device(self, xp: ModuleType, device: Device):
219223
def test_hypothesis(
220224
self,
221225
n_arrays: int,
226+
n_kwarrays: int,
222227
rng_seed: int,
223228
dtype: np.dtype[Any],
224229
p: float,
@@ -233,9 +238,13 @@ def test_hypothesis(
233238
):
234239
pytest.xfail(reason="NumPy 1.x dtype promotion for scalars")
235240

236-
mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays + 1, min_side=0)
241+
mbs = npst.mutually_broadcastable_shapes(
242+
num_shapes=1 + n_arrays + n_kwarrays, min_side=0
243+
)
237244
input_shapes, _ = data.draw(mbs)
238-
cond_shape, *shapes = input_shapes
245+
cond_shape = input_shapes[0]
246+
shapes = input_shapes[1 : 1 + n_arrays]
247+
kwshapes = input_shapes[1 + n_arrays :]
239248

240249
# cupy/cupy#8382
241250
# https://github.com/jax-ml/jax/issues/26658
@@ -257,22 +266,34 @@ def test_hypothesis(
257266
for shape in shapes
258267
)
259268

260-
def f1(*args: Array) -> Array:
261-
return cast(Array, sum(args))
269+
kwargs = {
270+
str(n): xp.asarray(
271+
data.draw(npst.arrays(dtype=dtype.type, shape=shape, elements=elements))
272+
)
273+
for n, shape in enumerate(kwshapes)
274+
}
275+
kwkeys = kwargs.keys()
276+
277+
def f1(*args: Array, **kwargs: dict[str, Array]) -> Array:
278+
assert set(kwargs.keys()) == set(kwkeys)
279+
args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values()))
280+
return cast(Array, sum(args_kwargs))
262281

263-
def f2(*args: Array) -> Array:
264-
return cast(Array, sum(args) / 2)
282+
def f2(*args: Array, **kwargs: dict[str, Array]) -> Array:
283+
assert set(kwargs.keys()) == set(kwkeys)
284+
args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values()))
285+
return cast(Array, sum(args_kwargs) / 2)
265286

266287
rng = np.random.default_rng(rng_seed)
267288
cond = xp.asarray(rng.random(size=cond_shape) > p)
268289

269-
res1 = apply_where(cond, arrays, f1, fill_value=fill_value)
270-
res2 = apply_where(cond, arrays, f1, f2)
271-
res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value)
290+
res1 = apply_where(cond, arrays, f1, fill_value=fill_value, kwargs=kwargs)
291+
res2 = apply_where(cond, arrays, f1, f2, kwargs=kwargs)
292+
res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value, kwargs=kwargs)
272293

273-
ref1 = xp.where(cond, f1(*arrays), fill_value)
274-
ref2 = xp.where(cond, f1(*arrays), f2(*arrays))
275-
ref3 = xp.where(cond, f1(*arrays), float_fill_value)
294+
ref1 = xp.where(cond, f1(*arrays, **kwargs), fill_value)
295+
ref2 = xp.where(cond, f1(*arrays, **kwargs), f2(*arrays, **kwargs))
296+
ref3 = xp.where(cond, f1(*arrays, **kwargs), float_fill_value)
276297

277298
xp_assert_close(res1, ref1, rtol=2e-16)
278299
xp_assert_equal(res2, ref2)

0 commit comments

Comments
 (0)