Skip to content

Commit 92876d0

Browse files
committed
Revert "try tweaking parallelization on intel"
This reverts commit edb2253.
1 parent edb2253 commit 92876d0

2 files changed

Lines changed: 8 additions & 27 deletions

File tree

arraycontext/impl/pytato/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -990,14 +990,8 @@ def _parallelize_across_device(
990990
parallelize_disjoint_loop_sets,
991991
)
992992

993-
dev = self.queue.device
994-
# The Intel CPU OpenCL runtime corrupts the host heap on some kernels
995-
# produced by the default parallelization; detect it here so the
996-
# parallelization can be tweaked to work around the issue.
997-
is_intel_cl = "intel" in dev.platform.name.lower()
998-
999993
t_unit = parallelize_disjoint_loop_sets(
1000-
t_unit, dev.max_compute_units, is_intel_cl=is_intel_cl)
994+
t_unit, self.queue.device.max_compute_units)
1001995

1002996
# FIXME: Is this something that this abstract-ish
1003997
# PytatoParallelPyOpenCLArrayContext class should be calling, or should it

arraycontext/impl/pytato/parallelize.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,7 @@ def split_loop_set_across_work_items(
219219
callables: CallablesTable,
220220
loop_set: LoopSet,
221221
iname_to_approx_length: Mapping[str, float | int],
222-
max_device_compute_units: int, *,
223-
is_intel_cl: bool = False,
222+
max_device_compute_units: int,
224223
) -> lp.LoopKernel:
225224
# Could possibly do something fancier that also includes the individual inner
226225
# loops in the loop set, but for now just looking at the inames shared between
@@ -261,14 +260,6 @@ def split_loop_set_across_work_items(
261260
iname_to_approx_length[iname],
262261
-outer_iname_pos[iname])))
263262

264-
if is_intel_cl:
265-
# The Intel CPU OpenCL runtime corrupts the host heap on the 2D-tiled
266-
# kernels produced when parallelizing two inames (a work-group axis plus
267-
# two work-item axes). Keep only the largest loop (a non-reduction iname
268-
# whenever one is present) so we emit the 1D (g.0 + l.0) parallelization,
269-
# which the runtime handles correctly.
270-
inames_to_parallelize = inames_to_parallelize[-1:]
271-
272263
vng = kernel.get_var_name_generator()
273264

274265
if len(inames_to_parallelize) == 0:
@@ -463,8 +454,7 @@ def split_iteration_domain_across_work_items_for_single_kernel(
463454
kernel: lp.LoopKernel,
464455
callables: CallablesTable,
465456
max_device_compute_units: int, *,
466-
single_launch_config: bool = False,
467-
is_intel_cl: bool = False) -> lp.LoopKernel:
457+
single_launch_config: bool = False) -> lp.LoopKernel:
468458
if single_launch_config:
469459
raise NotImplementedError("single_launch_config==True isn't implemented yet.")
470460

@@ -477,16 +467,15 @@ def split_iteration_domain_across_work_items_for_single_kernel(
477467
for loop_set in loop_sets:
478468
kernel = split_loop_set_across_work_items(
479469
kernel, callables, loop_set, iname_to_approx_length,
480-
max_device_compute_units, is_intel_cl=is_intel_cl)
470+
max_device_compute_units)
481471

482472
return kernel
483473

484474

485475
def split_iteration_domain_across_work_items(
486476
t_unit: lp.TranslationUnit,
487477
max_device_compute_units: int, *,
488-
single_launch_config: bool = False,
489-
is_intel_cl: bool = False) -> lp.TranslationUnit:
478+
single_launch_config: bool = False) -> lp.TranslationUnit:
490479
"""
491480
Tag inames in *t_unit* with work-group/work-item axes so that each disjoint
492481
loop set is parallelized across the device. Loops are split based on their
@@ -497,8 +486,7 @@ def split_iteration_domain_across_work_items(
497486
return split_iteration_domain_across_work_items_for_single_kernel(
498487
t_unit, t_unit.callables_table,
499488
max_device_compute_units=max_device_compute_units,
500-
single_launch_config=single_launch_config,
501-
is_intel_cl=is_intel_cl)
489+
single_launch_config=single_launch_config)
502490

503491
# }}}
504492

@@ -616,15 +604,14 @@ def add_gbarrier_between_disjoint_loop_sets(
616604

617605
def parallelize_disjoint_loop_sets(
618606
t_unit: lp.TranslationUnit,
619-
max_device_compute_units: int, *,
620-
is_intel_cl: bool = False) -> lp.TranslationUnit:
607+
max_device_compute_units: int) -> lp.TranslationUnit:
621608
"""
622609
Parallelize *t_unit* by tagging the inames of each disjoint loop set with
623610
work-group and work-item axes and enforcing ordering between dependent
624611
loop sets.
625612
"""
626613
t_unit = split_iteration_domain_across_work_items(
627-
t_unit, max_device_compute_units, is_intel_cl=is_intel_cl)
614+
t_unit, max_device_compute_units)
628615
t_unit = add_gbarrier_between_disjoint_loop_sets(t_unit)
629616
return t_unit
630617

0 commit comments

Comments
 (0)