4646
4747 from loopy .codegen import CodeGenerationState
4848 from loopy .codegen .result import CodeGenerationResult
49+ from loopy .kernel import LoopKernel
4950
5051
5152# {{{ dtype registry wrappers
@@ -456,7 +457,7 @@ def get_opencl_callables():
456457
457458# {{{ symbol mangler
458459
459- def opencl_symbol_mangler (kernel , name ):
460+ def opencl_symbol_mangler (kernel : LoopKernel , name : str ):
460461 # FIXME: should be more picky about exact names
461462 if name .startswith ("FLT_" ):
462463 return NumpyType (np .dtype (np .float32 )), name
@@ -544,25 +545,29 @@ def wrap_in_typecast(self, actual_type, needed_dtype, s):
544545 # CL does not perform implicit conversion from float-type to a bool.
545546 from pymbolic .primitives import Comparison
546547 return Comparison (s , "!=" , 0 )
548+
549+ if needed_dtype == actual_type :
550+ return s
547551
548552 registry = self .codegen_state .ast_builder .target .get_dtype_registry ()
549- if self .codegen_state .target .is_vector_dtype (needed_dtype ):
550- # OpenCL does not let you do explicit vector type casts.
551- # Instead you need to call their function which is of the form
552- # convert_<desttype><n>(src) where desttype is the type you want and n
553+ if self .codegen_state .target .is_vector_dtype (needed_dtype ) and \
554+ self .codegen_state .target .is_vector_dtype (actual_type ):
555+ # OpenCL does not let you do explicit vector type casts between vector
556+ # types. Instead you need to call their function which is of the form
557+ # <desttype> convert_<desttype><n>(src) where n
553558 # is the number of elements in the vector which is the same as in src.
554559 cast = var ("convert_%s" % registry .dtype_to_ctype (needed_dtype ))
555560 return cast (s )
556561
557562 return super ().wrap_in_typecast (actual_type , needed_dtype , s )
558563
559- def map_group_hw_index (self , expr , type_context ):
564+ def map_group_hw_index (self , expr , type_context : str ):
560565 return var ("gid" )(expr .axis )
561566
562- def map_local_hw_index (self , expr , type_context ):
567+ def map_local_hw_index (self , expr , type_context : str ):
563568 return var ("lid" )(expr .axis )
564569
565- def map_variable (self , expr , type_context ):
570+ def map_variable (self , expr , type_context : str ):
566571
567572 if self .codegen_state .vectorization_info :
568573 if self .codegen_state .vectorization_info .iname == expr .name :
@@ -578,7 +583,7 @@ def map_variable(self, expr, type_context):
578583 return Literal (vector_literal )
579584 return super ().map_variable (expr , type_context )
580585
581- def map_if (self , expr , type_context ):
586+ def map_if (self , expr , type_context : str ):
582587 from loopy .types import to_loopy_type
583588 result_type = self .infer_type (expr )
584589 conditional_needed_loopy_type = to_loopy_type (np .bool_ )
0 commit comments