Skip to content

Commit 97fcd3c

Browse files
committed
Modify the typecast for vector dtypes.
1 parent 2d430fc commit 97fcd3c

2 files changed

Lines changed: 49 additions & 34 deletions

File tree

loopy/target/c/codegen/expression.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def wrap_in_typecast(self, actual_type: LoopyType, needed_type: LoopyType, s):
130130
if actual_type != needed_type:
131131
registry = self.codegen_state.ast_builder.target.get_dtype_registry()
132132
cast = var("(%s) " % registry.dtype_to_ctype(needed_type))
133-
134133
return cast(s)
134+
135135
return s
136136

137137
def rec(self, expr, type_context=None, needed_type: LoopyType | None = None): # type: ignore[override]
@@ -283,10 +283,12 @@ def make_var(name):
283283
if (
284284
isinstance(ary, (ConstantArg, ArrayArg)) or
285285
(isinstance(ary, TemporaryVariable) and ary.base_storage)):
286-
# unsubscripted global args are pointers if they are inputs
287-
result = self.make_subscript(ary,
288-
make_var(access_info.array_name),
289-
0)
286+
# unsubscripted global args are pointers
287+
result = self.make_subscript(
288+
ary,
289+
make_var(access_info.array_name),
290+
(0,))
291+
290292
else:
291293
# unsubscripted temp vars are scalars
292294
# (unless they use base_storage)
@@ -412,37 +414,9 @@ def map_remainder(self, expr, type_context):
412414
def map_if(self, expr, type_context):
413415
from loopy.types import to_loopy_type
414416
result_type = self.infer_type(expr)
415-
conditional_needed_loopy_type = to_loopy_type(np.bool_)
416-
if self.codegen_state.vectorization_info:
417-
from loopy.codegen import UnvectorizableError
418-
from loopy.expression import VectorizabilityChecker
419-
checker = VectorizabilityChecker(self.codegen_state.kernel,
420-
self.codegen_state.vectorization_info.iname,
421-
self.codegen_state.vectorization_info.length)
422-
423-
try:
424-
is_vector = checker(expr)
425-
426-
if is_vector:
427-
"""
428-
We could have a vector literal here.
429-
So we may need to type cast the condition.
430-
OpenCL specification states that for ( c ? a : b)
431-
to be vectorized appropriately c must have the same
432-
number of elements in the vector as that of a and b.
433-
Also each element must have the same number of bits,
434-
and c must be an integral type.
435-
"""
436-
index_type = to_loopy_type(np.int64)
437-
if type_context == "f":
438-
index_type = to_loopy_type(np.int32)
439-
conditional_needed_loopy_type = to_loopy_type(self.codegen_state.target.vector_dtype(index_type,
440-
self.codegen_state.vectorization_info.length))
441-
except UnvectorizableError:
442-
pass
443417
return type(expr)(
444418
self.rec(expr.condition, type_context,
445-
conditional_needed_loopy_type),
419+
to_loopy_type(np.bool_)),
446420
self.rec(expr.then, type_context, result_type),
447421
self.rec(expr.else_, type_context, result_type),
448422
)

loopy/target/opencl.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,47 @@ def map_variable(self, expr, type_context):
577577
",".join([f"{i}" for i in range(vector_length)]) + "))"
578578
return Literal(vector_literal)
579579
return super().map_variable(expr, type_context)
580+
581+
def map_if(self, expr, type_context):
582+
from loopy.types import to_loopy_type
583+
result_type = self.infer_type(expr)
584+
conditional_needed_loopy_type = to_loopy_type(np.bool_)
585+
if self.codegen_state.vectorization_info:
586+
from loopy.codegen import UnvectorizableError
587+
from loopy.expression import VectorizabilityChecker
588+
checker = VectorizabilityChecker(self.codegen_state.kernel,
589+
self.codegen_state.vectorization_info.iname,
590+
self.codegen_state.vectorization_info.length)
591+
592+
try:
593+
is_vector = checker(expr)
594+
595+
if is_vector:
596+
"""
597+
We could have a vector literal here.
598+
So we may need to type cast the condition.
599+
OpenCL specification states that for ( c ? a : b)
600+
to be vectorized appropriately c must have the same
601+
number of elements in the vector as that of a and b.
602+
Also each element must have the same number of bits,
603+
and c must be an integral type.
604+
"""
605+
index_type = to_loopy_type(np.int64)
606+
if type_context == "f":
607+
index_type = to_loopy_type(np.int32)
608+
length = self.codegen_state.vectorization_info.length
609+
vector_type = self.codegen_state.target.vector_dtype(index_type,
610+
length)
611+
conditional_needed_loopy_type = to_loopy_type(vector_type)
612+
except UnvectorizableError:
613+
pass
614+
615+
return type(expr)(
616+
self.rec(expr.condition, type_context,
617+
conditional_needed_loopy_type),
618+
self.rec(expr.then, type_context, result_type),
619+
self.rec(expr.else_, type_context, result_type),
620+
)
580621
# }}}
581622

582623

0 commit comments

Comments
 (0)