Skip to content

Commit f91e308

Browse files
committed
compiler: restrict nested parallelism to supported compilers (intel)
1 parent 7397940 commit f91e308

3 files changed

Lines changed: 57 additions & 2 deletions

File tree

devito/passes/iet/languages/openmp.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from sympy import And, Ne, Not
66

77
from devito.arch import AMDGPUX, INTELGPUX, NVIDIAX, PVC
8-
from devito.arch.compiler import CustomCompiler, GNUCompiler, NvidiaCompiler
8+
from devito.arch.compiler import (
9+
CustomCompiler, GNUCompiler, IntelCompiler, NvidiaCompiler, OneapiCompiler
10+
)
911
from devito.ir import (
1012
Call, Conditional, DeviceCall, FindSymbols, List, ParallelBlock, PointerCast, Pragma,
1113
Prodder, While
@@ -276,6 +278,16 @@ def _support_complex_reduction(cls, compiler):
276278
# Gcc doesn't supports complex reduction
277279
return not isinstance(compiler, GNUCompiler)
278280

281+
@classmethod
282+
def _support_nested_parallelism(cls, compiler):
283+
# In case we have a CustomCompiler
284+
if isinstance(compiler, CustomCompiler):
285+
compiler = compiler._base()
286+
if isinstance(compiler, (IntelCompiler, OneapiCompiler)): # noqa: SIM103
287+
return True
288+
else:
289+
return False
290+
279291

280292
class Ompizer(AbstractOmpizer):
281293
langbb = OmpBB

devito/passes/iet/parpragma.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def _support_array_reduction(cls, compiler):
5454
def _support_complex_reduction(cls, compiler):
5555
return False
5656

57+
@classmethod
58+
def _support_nested_parallelism(cls, compiler):
59+
return False
60+
5761
@property
5862
def simd_reg_nbytes(self):
5963
return self.platform.simd_reg_nbytes
@@ -344,7 +348,8 @@ def _make_guard(self, parregion):
344348

345349
def _make_nested_partree(self, partree):
346350
# Apply heuristic
347-
if self.nhyperthreads <= self.nested:
351+
if self.nhyperthreads <= self.nested or \
352+
not self._support_nested_parallelism(self.compiler):
348353
return partree
349354

350355
# Note: there might be multiple sub-trees amenable to nested parallelism,

tests/test_dle.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
PrecomputedSparseTimeFunction, ReduceMax, ReduceMin, ReduceMinMax, SpaceDimension,
1212
SparseTimeFunction, SubDimension, TimeFunction, configuration, cos, dimensions, info
1313
)
14+
from devito.arch.compiler import IntelCompiler, OneapiCompiler
1415
from devito.exceptions import InvalidArgument
1516
from devito.ir.iet import (
1617
Expression, FindNodes, IsPerfectIteration, Iteration, retrieve_iteration_tree
@@ -1461,3 +1462,40 @@ def test_collapsing_w_wo_halo(self, exprs, collapsed, scheduling):
14611462

14621463
assert iterations[1].pragmas[0].ccode.value ==\
14631464
"".join([ompfor_string, scheduling_string])
1465+
1466+
@skipif('device')
1467+
def test_nested_parallelism_support(self):
1468+
grid = Grid(shape=(10, 10, 10))
1469+
1470+
f = Function(name='f', grid=grid, space_order=4)
1471+
v = TimeFunction(name="v", grid=grid, space_order=4)
1472+
v1 = TimeFunction(name="v1", grid=grid, space_order=4)
1473+
1474+
f.data_with_halo[:] = 0.5
1475+
v.data_with_halo[:] = 1.
1476+
v1.data_with_halo[:] = 1.
1477+
1478+
eqn = Eq(v.forward, (v.dx * (1 + 2*f) * f).dx)
1479+
op = Operator(eqn, opt=('advanced', {'openmp': True, 'par-nested': 0}))
1480+
1481+
bns, _ = assert_blocking(op, {'x0_blk0'})
1482+
trees = retrieve_iteration_tree(bns['x0_blk0'])
1483+
assert len(trees) == 2
1484+
1485+
# Check omp pargams
1486+
assert trees[0][0].pragmas[0].ccode.value == \
1487+
'omp for collapse(2) schedule(dynamic,1)'
1488+
if isinstance(configuration['compiler'], (IntelCompiler, OneapiCompiler)):
1489+
# Supports nested parallelism
1490+
assert trees[0][2].pragmas[0].ccode.value == \
1491+
'#pragma omp parallel for collapse(2) schedule(dynamic,1)'\
1492+
' num_threads(nthreads_nested)'
1493+
assert trees[1][2].pragmas[0].ccode.value == \
1494+
trees[0][2].pragmas[0].ccode.value
1495+
else:
1496+
# Most compiler don't support nested parallelism
1497+
assert not trees[0][2].pragmas
1498+
assert not trees[1][2].pragmas
1499+
1500+
# Should compile properly
1501+
op.cfunction # noqa: B018

0 commit comments

Comments
 (0)