@@ -460,9 +460,39 @@ def absolute(self, a):
460460 # {{{ sorting, searching, and counting
461461
462462 def where (self , criterion , then , else_ ):
463- def where_inner (inner_crit , inner_then , inner_else ):
463+
464+ def where_inner (
465+ inner_crit : ArrayOrScalar ,
466+ inner_then : ArrayOrScalar ,
467+ inner_else : ArrayOrScalar ,
468+ ) -> ArrayOrScalar :
464469 if isinstance (inner_crit , bool | np .bool_ ):
465470 return inner_then if inner_crit else inner_else
471+
472+ # pyopencl's if_positive does not support then, else branches with
473+ # unequal dtypes -> cast them to a common dtype.
474+ inner_then_dtype = (
475+ inner_then .dtype
476+ if isinstance (inner_then , cl_array .Array )
477+ else np .dtype (type (inner_then ))
478+ )
479+ inner_else_dtype = (
480+ inner_else .dtype
481+ if isinstance (inner_else , cl_array .Array )
482+ else np .dtype (type (inner_else ))
483+ )
484+ dtype = np .promote_types (inner_then_dtype , inner_else_dtype )
485+ inner_then = (
486+ inner_then .astype (dtype )
487+ if isinstance (inner_then , cl_array .Array )
488+ else dtype .type (inner_then )
489+ )
490+ inner_else = (
491+ inner_else .astype (dtype )
492+ if isinstance (inner_else , cl_array .Array )
493+ else dtype .type (inner_else )
494+ )
495+
466496 return cl_array .if_positive (inner_crit != 0 , inner_then , inner_else ,
467497 queue = self ._array_context .queue )
468498
0 commit comments