@@ -219,7 +219,8 @@ def split_loop_set_across_work_items(
219219 callables : CallablesTable ,
220220 loop_set : LoopSet ,
221221 iname_to_approx_length : Mapping [str , float | int ],
222- max_device_compute_units : int ,
222+ max_device_compute_units : int , * ,
223+ is_intel_cl : bool = False ,
223224) -> lp .LoopKernel :
224225 # Could possibly do something fancier that also includes the individual inner
225226 # loops in the loop set, but for now just looking at the inames shared between
@@ -260,6 +261,14 @@ def split_loop_set_across_work_items(
260261 iname_to_approx_length [iname ],
261262 - outer_iname_pos [iname ])))
262263
264+ if is_intel_cl :
265+ # The Intel CPU OpenCL runtime corrupts the host heap on the 2D-tiled
266+ # kernels produced when parallelizing two inames (a work-group axis plus
267+ # two work-item axes). Keep only the largest loop (a non-reduction iname
268+ # whenever one is present) so we emit the 1D (g.0 + l.0) parallelization,
269+ # which the runtime handles correctly.
270+ inames_to_parallelize = inames_to_parallelize [- 1 :]
271+
263272 vng = kernel .get_var_name_generator ()
264273
265274 if len (inames_to_parallelize ) == 0 :
@@ -454,7 +463,8 @@ def split_iteration_domain_across_work_items_for_single_kernel(
454463 kernel : lp .LoopKernel ,
455464 callables : CallablesTable ,
456465 max_device_compute_units : int , * ,
457- single_launch_config : bool = False ) -> lp .LoopKernel :
466+ single_launch_config : bool = False ,
467+ is_intel_cl : bool = False ) -> lp .LoopKernel :
458468 if single_launch_config :
459469 raise NotImplementedError ("single_launch_config==True isn't implemented yet." )
460470
@@ -467,15 +477,16 @@ def split_iteration_domain_across_work_items_for_single_kernel(
467477 for loop_set in loop_sets :
468478 kernel = split_loop_set_across_work_items (
469479 kernel , callables , loop_set , iname_to_approx_length ,
470- max_device_compute_units )
480+ max_device_compute_units , is_intel_cl = is_intel_cl )
471481
472482 return kernel
473483
474484
475485def split_iteration_domain_across_work_items (
476486 t_unit : lp .TranslationUnit ,
477487 max_device_compute_units : int , * ,
478- single_launch_config : bool = False ) -> lp .TranslationUnit :
488+ single_launch_config : bool = False ,
489+ is_intel_cl : bool = False ) -> lp .TranslationUnit :
479490 """
480491 Tag inames in *t_unit* with work-group/work-item axes so that each disjoint
481492 loop set is parallelized across the device. Loops are split based on their
@@ -486,7 +497,8 @@ def split_iteration_domain_across_work_items(
486497 return split_iteration_domain_across_work_items_for_single_kernel (
487498 t_unit , t_unit .callables_table ,
488499 max_device_compute_units = max_device_compute_units ,
489- single_launch_config = single_launch_config )
500+ single_launch_config = single_launch_config ,
501+ is_intel_cl = is_intel_cl )
490502
491503# }}}
492504
@@ -604,14 +616,15 @@ def add_gbarrier_between_disjoint_loop_sets(
604616
605617def parallelize_disjoint_loop_sets (
606618 t_unit : lp .TranslationUnit ,
607- max_device_compute_units : int ) -> lp .TranslationUnit :
619+ max_device_compute_units : int , * ,
620+ is_intel_cl : bool = False ) -> lp .TranslationUnit :
608621 """
609622 Parallelize *t_unit* by tagging the inames of each disjoint loop set with
610623 work-group and work-item axes and enforcing ordering between dependent
611624 loop sets.
612625 """
613626 t_unit = split_iteration_domain_across_work_items (
614- t_unit , max_device_compute_units )
627+ t_unit , max_device_compute_units , is_intel_cl = is_intel_cl )
615628 t_unit = add_gbarrier_between_disjoint_loop_sets (t_unit )
616629 return t_unit
617630
0 commit comments