Skip to content

Commit edb2253

Browse files
committed
try tweaking parallelization on intel
1 parent f443406 commit edb2253

2 files changed

Lines changed: 27 additions & 8 deletions

File tree

arraycontext/impl/pytato/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,8 +990,14 @@ def _parallelize_across_device(
990990
parallelize_disjoint_loop_sets,
991991
)
992992

993+
dev = self.queue.device
994+
# The Intel CPU OpenCL runtime corrupts the host heap on some kernels
995+
# produced by the default parallelization; detect it here so the
996+
# parallelization can be tweaked to work around the issue.
997+
is_intel_cl = "intel" in dev.platform.name.lower()
998+
993999
t_unit = parallelize_disjoint_loop_sets(
994-
t_unit, self.queue.device.max_compute_units)
1000+
t_unit, dev.max_compute_units, is_intel_cl=is_intel_cl)
9951001

9961002
# FIXME: Is this something that this abstract-ish
9971003
# PytatoParallelPyOpenCLArrayContext class should be calling, or should it

arraycontext/impl/pytato/parallelize.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ def split_loop_set_across_work_items(
219219
callables: CallablesTable,
220220
loop_set: LoopSet,
221221
iname_to_approx_length: Mapping[str, float | int],
222-
max_device_compute_units: int,
222+
max_device_compute_units: int, *,
223+
is_intel_cl: bool = False,
223224
) -> lp.LoopKernel:
224225
# Could possibly do something fancier that also includes the individual inner
225226
# loops in the loop set, but for now just looking at the inames shared between
@@ -260,6 +261,14 @@ def split_loop_set_across_work_items(
260261
iname_to_approx_length[iname],
261262
-outer_iname_pos[iname])))
262263

264+
if is_intel_cl:
265+
# The Intel CPU OpenCL runtime corrupts the host heap on the 2D-tiled
266+
# kernels produced when parallelizing two inames (a work-group axis plus
267+
# two work-item axes). Keep only the largest loop (a non-reduction iname
268+
# whenever one is present) so we emit the 1D (g.0 + l.0) parallelization,
269+
# which the runtime handles correctly.
270+
inames_to_parallelize = inames_to_parallelize[-1:]
271+
263272
vng = kernel.get_var_name_generator()
264273

265274
if len(inames_to_parallelize) == 0:
@@ -454,7 +463,8 @@ def split_iteration_domain_across_work_items_for_single_kernel(
454463
kernel: lp.LoopKernel,
455464
callables: CallablesTable,
456465
max_device_compute_units: int, *,
457-
single_launch_config: bool = False) -> lp.LoopKernel:
466+
single_launch_config: bool = False,
467+
is_intel_cl: bool = False) -> lp.LoopKernel:
458468
if single_launch_config:
459469
raise NotImplementedError("single_launch_config==True isn't implemented yet.")
460470

@@ -467,15 +477,16 @@ def split_iteration_domain_across_work_items_for_single_kernel(
467477
for loop_set in loop_sets:
468478
kernel = split_loop_set_across_work_items(
469479
kernel, callables, loop_set, iname_to_approx_length,
470-
max_device_compute_units)
480+
max_device_compute_units, is_intel_cl=is_intel_cl)
471481

472482
return kernel
473483

474484

475485
def split_iteration_domain_across_work_items(
476486
t_unit: lp.TranslationUnit,
477487
max_device_compute_units: int, *,
478-
single_launch_config: bool = False) -> lp.TranslationUnit:
488+
single_launch_config: bool = False,
489+
is_intel_cl: bool = False) -> lp.TranslationUnit:
479490
"""
480491
Tag inames in *t_unit* with work-group/work-item axes so that each disjoint
481492
loop set is parallelized across the device. Loops are split based on their
@@ -486,7 +497,8 @@ def split_iteration_domain_across_work_items(
486497
return split_iteration_domain_across_work_items_for_single_kernel(
487498
t_unit, t_unit.callables_table,
488499
max_device_compute_units=max_device_compute_units,
489-
single_launch_config=single_launch_config)
500+
single_launch_config=single_launch_config,
501+
is_intel_cl=is_intel_cl)
490502

491503
# }}}
492504

@@ -604,14 +616,15 @@ def add_gbarrier_between_disjoint_loop_sets(
604616

605617
def parallelize_disjoint_loop_sets(
606618
t_unit: lp.TranslationUnit,
607-
max_device_compute_units: int) -> lp.TranslationUnit:
619+
max_device_compute_units: int, *,
620+
is_intel_cl: bool = False) -> lp.TranslationUnit:
608621
"""
609622
Parallelize *t_unit* by tagging the inames of each disjoint loop set with
610623
work-group and work-item axes and enforcing ordering between dependent
611624
loop sets.
612625
"""
613626
t_unit = split_iteration_domain_across_work_items(
614-
t_unit, max_device_compute_units)
627+
t_unit, max_device_compute_units, is_intel_cl=is_intel_cl)
615628
t_unit = add_gbarrier_between_disjoint_loop_sets(t_unit)
616629
return t_unit
617630

0 commit comments

Comments
 (0)