@@ -41,6 +41,7 @@ def apply_where( # numpydoc ignore=GL08
4141 f2 : Callable [..., Array ],
4242 / ,
4343 * ,
44+ kwargs : dict [str , Array ] | None = None ,
4445 xp : ModuleType | None = None ,
4546) -> Array : ...
4647
@@ -53,6 +54,7 @@ def apply_where( # numpydoc ignore=GL08
5354 / ,
5455 * ,
5556 fill_value : Array | complex ,
57+ kwargs : dict [str , Array ] | None = None ,
5658 xp : ModuleType | None = None ,
5759) -> Array : ...
5860
@@ -65,6 +67,7 @@ def apply_where( # numpydoc ignore=PR01,PR02
6567 / ,
6668 * ,
6769 fill_value : Array | complex | None = None ,
70+ kwargs : dict [str , Array ] | None = None ,
6871 xp : ModuleType | None = None ,
6972) -> Array :
7073 """
@@ -91,6 +94,9 @@ def apply_where( # numpydoc ignore=PR01,PR02
9194 It does not need to be scalar; it needs however to be broadcastable with
9295 `cond` and `args`.
9396 Mutually exclusive with `f2`. You must provide one or the other.
97+ kwargs : dict of str : Array pairs
98+ Keyword argument(s) to `f1` (and `f2`). Values must be broadcastable with
99+ `cond`.
94100 xp : array_namespace, optional
95101 The standard-compatible namespace for `cond` and `args`. Default: infer.
96102
@@ -129,6 +135,12 @@ def apply_where( # numpydoc ignore=PR01,PR02
129135 args_ = list (args ) if isinstance (args , tuple ) else [args ]
130136 del args
131137
138+ kwargs_ = {} if kwargs is None else kwargs
139+ kwkeys = list (kwargs_ .keys ())
140+ nargs = len (args_ )
141+ args_ = [* args_ , * kwargs_ .values ()]
142+ del kwargs
143+
132144 xp = array_namespace (cond , fill_value , * args_ ) if xp is None else xp
133145
134146 if isinstance (fill_value , int | float | complex | NoneType ):
@@ -139,8 +151,19 @@ def apply_where( # numpydoc ignore=PR01,PR02
139151 if is_dask_namespace (xp ):
140152 meta_xp = meta_namespace (cond , fill_value , * args_ , xp = xp )
141153 # map_blocks doesn't descend into tuples of Arrays
142- return xp .map_blocks (_apply_where , cond , f1 , f2 , fill_value , * args_ , xp = meta_xp )
143- return _apply_where (cond , f1 , f2 , fill_value , * args_ , xp = xp )
154+ return xp .map_blocks (
155+ _apply_where , cond , f1 , f2 , fill_value , * args_ , kwkeys = kwkeys , xp = meta_xp
156+ )
157+
158+ if not capabilities (xp , device = _compat .device (cond ))["boolean indexing" ]:
159+ # jax.jit does not support assignment by boolean mask
160+ return xp .where (
161+ cond ,
162+ f1 (* args_ [:nargs ], ** kwargs_ ),
163+ f2 (* args_ [:nargs ], ** kwargs_ ) if f2 is not None else fill_value ,
164+ )
165+
166+ return _apply_where (cond , f1 , f2 , fill_value , * args_ , kwkeys = kwkeys , xp = xp )
144167
145168
146169def _apply_where ( # numpydoc ignore=PR01,RT01
@@ -149,15 +172,18 @@ def _apply_where( # numpydoc ignore=PR01,RT01
149172 f2 : Callable [..., Array ] | None ,
150173 fill_value : Array | int | float | complex | bool | None ,
151174 * args : Array ,
175+ kwkeys : list [str ],
152176 xp : ModuleType ,
153177) -> Array :
154178 """Helper of `apply_where`. On Dask, this runs on a single chunk."""
155179
156- if not capabilities ( xp , device = _compat . device ( cond ))[ "boolean indexing" ]:
157- # jax.jit does not support assignment by boolean mask
158- return xp . where ( cond , f1 ( * args ), f2 ( * args ) if f2 is not None else fill_value )
180+ nargs = len ( args ) - len ( kwkeys )
181+ kwargs = dict ( zip ( kwkeys , args [ nargs :], strict = True ))
182+ args = args [: nargs ]
159183
160- temp1 = f1 (* (arr [cond ] for arr in args ))
184+ temp1 = f1 (
185+ * (arr [cond ] for arr in args ), ** {key : val [cond ] for key , val in kwargs .items ()}
186+ )
161187
162188 if f2 is None :
163189 dtype = xp .result_type (temp1 , fill_value )
@@ -167,7 +193,10 @@ def _apply_where( # numpydoc ignore=PR01,RT01
167193 out = xp .astype (fill_value , dtype , copy = True )
168194 else :
169195 ncond = ~ cond
170- temp2 = f2 (* (arr [ncond ] for arr in args ))
196+ temp2 = f2 (
197+ * (arr [ncond ] for arr in args ),
198+ ** {key : val [ncond ] for key , val in kwargs .items ()},
199+ )
171200 dtype = xp .result_type (temp1 , temp2 )
172201 out = xp .empty_like (cond , dtype = dtype )
173202 out = at (out , ncond ).set (temp2 )
0 commit comments