Skip to content

Commit b30d74e

Browse files
committed
pyopencl fake np: Cast then/else to a common dtype.
pyopencl.array does not allow array branches with unequal dtypes.
1 parent 6e8f55d commit b30d74e

1 file changed

Lines changed: 25 additions & 0 deletions

File tree

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,31 @@ def where(self, criterion, then, else_):
463463
def where_inner(inner_crit, inner_then, inner_else):
464464
if isinstance(inner_crit, bool | np.bool_):
465465
return inner_then if inner_crit else inner_else
466+
467+
# pyopencl's if_positive does not support then, else branches with
468+
# unequal dtypes -> cast them to a common dtype.
469+
inner_then_dtype = (
470+
inner_then.dtype
471+
if isinstance(inner_then, cl_array.Array)
472+
else np.dtype(type(inner_then))
473+
)
474+
inner_else_dtype = (
475+
inner_else.dtype
476+
if isinstance(inner_else, cl_array.Array)
477+
else np.dtype(type(inner_else))
478+
)
479+
dtype = np.promote_types(inner_then_dtype, inner_else_dtype)
480+
inner_then = (
481+
inner_then.astype(dtype)
482+
if isinstance(inner_then, cl_array.Array)
483+
else dtype.type(inner_then)
484+
)
485+
inner_else = (
486+
inner_else.astype(dtype)
487+
if isinstance(inner_else, cl_array.Array)
488+
else dtype.type(inner_else)
489+
)
490+
466491
return cl_array.if_positive(inner_crit != 0, inner_then, inner_else,
467492
queue=self._array_context.queue)
468493

0 commit comments

Comments
 (0)