Skip to content

Commit 3565340

Browse files
committed
pyopencl fake np: Cast then/else to a common dtype.
pyopencl.array does not allow array branches with unequal dtypes.
1 parent 6e8f55d commit 3565340

1 file changed

Lines changed: 31 additions & 1 deletion

File tree

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,39 @@ def absolute(self, a):
460460
# {{{ sorting, searching, and counting
461461

462462
def where(self, criterion, then, else_):
463-
def where_inner(inner_crit, inner_then, inner_else):
463+
464+
def where_inner(
465+
inner_crit: ArrayOrScalar,
466+
inner_then: ArrayOrScalar,
467+
inner_else: ArrayOrScalar,
468+
) -> ArrayOrScalar:
464469
if isinstance(inner_crit, bool | np.bool_):
465470
return inner_then if inner_crit else inner_else
471+
472+
# pyopencl's if_positive does not support then, else branches with
473+
# unequal dtypes -> cast them to a common dtype.
474+
inner_then_dtype = (
475+
inner_then.dtype
476+
if isinstance(inner_then, cl_array.Array)
477+
else np.dtype(type(inner_then))
478+
)
479+
inner_else_dtype = (
480+
inner_else.dtype
481+
if isinstance(inner_else, cl_array.Array)
482+
else np.dtype(type(inner_else))
483+
)
484+
dtype = np.promote_types(inner_then_dtype, inner_else_dtype)
485+
inner_then = (
486+
inner_then.astype(dtype)
487+
if isinstance(inner_then, cl_array.Array)
488+
else dtype.type(inner_then)
489+
)
490+
inner_else = (
491+
inner_else.astype(dtype)
492+
if isinstance(inner_else, cl_array.Array)
493+
else dtype.type(inner_else)
494+
)
495+
466496
return cl_array.if_positive(inner_crit != 0, inner_then, inner_else,
467497
queue=self._array_context.queue)
468498

0 commit comments

Comments
 (0)