Skip to content

Commit d1fd07a

Browse files
committed
Just return the value if we do not need to typecast.
1 parent 97fcd3c commit d1fd07a

1 file changed

Lines changed: 14 additions & 9 deletions

File tree

loopy/target/opencl.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
from loopy.codegen import CodeGenerationState
4848
from loopy.codegen.result import CodeGenerationResult
49+
from loopy.kernel import LoopKernel
4950

5051

5152
# {{{ dtype registry wrappers
@@ -456,7 +457,7 @@ def get_opencl_callables():
456457

457458
# {{{ symbol mangler
458459

459-
def opencl_symbol_mangler(kernel, name):
460+
def opencl_symbol_mangler(kernel: LoopKernel, name: str):
460461
# FIXME: should be more picky about exact names
461462
if name.startswith("FLT_"):
462463
return NumpyType(np.dtype(np.float32)), name
@@ -544,25 +545,29 @@ def wrap_in_typecast(self, actual_type, needed_dtype, s):
544545
# CL does not perform implicit conversion from float-type to a bool.
545546
from pymbolic.primitives import Comparison
546547
return Comparison(s, "!=", 0)
548+
549+
if needed_dtype == actual_type:
550+
return s
547551

548552
registry = self.codegen_state.ast_builder.target.get_dtype_registry()
549-
if self.codegen_state.target.is_vector_dtype(needed_dtype):
550-
# OpenCL does not let you do explicit vector type casts.
551-
# Instead you need to call their function which is of the form
552-
# convert_<desttype><n>(src) where desttype is the type you want and n
553+
if self.codegen_state.target.is_vector_dtype(needed_dtype) and \
554+
self.codegen_state.target.is_vector_dtype(actual_type):
555+
# OpenCL does not let you do explicit vector type casts between vector
556+
# types. Instead you need to call their function which is of the form
557+
# <desttype> convert_<desttype><n>(src) where n
553558
# is the number of elements in the vector which is the same as in src.
554559
cast = var("convert_%s" % registry.dtype_to_ctype(needed_dtype))
555560
return cast(s)
556561

557562
return super().wrap_in_typecast(actual_type, needed_dtype, s)
558563

559-
def map_group_hw_index(self, expr, type_context):
564+
def map_group_hw_index(self, expr, type_context: str):
560565
return var("gid")(expr.axis)
561566

562-
def map_local_hw_index(self, expr, type_context):
567+
def map_local_hw_index(self, expr, type_context: str):
563568
return var("lid")(expr.axis)
564569

565-
def map_variable(self, expr, type_context):
570+
def map_variable(self, expr, type_context: str):
566571

567572
if self.codegen_state.vectorization_info:
568573
if self.codegen_state.vectorization_info.iname == expr.name:
@@ -578,7 +583,7 @@ def map_variable(self, expr, type_context):
578583
return Literal(vector_literal)
579584
return super().map_variable(expr, type_context)
580585

581-
def map_if(self, expr, type_context):
586+
def map_if(self, expr, type_context: str):
582587
from loopy.types import to_loopy_type
583588
result_type = self.infer_type(expr)
584589
conditional_needed_loopy_type = to_loopy_type(np.bool_)

0 commit comments

Comments
 (0)