Skip to content

Commit b340037

Browse files
committed
compiler: Fix lower_aliases' sub-iterators construction
1 parent fa07652 commit b340037

2 files changed

Lines changed: 25 additions & 25 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def lower_aliases(aliases, meta, opt_maxpar, opt_block_temps):
678678
"""
679679
Create a Schedule from an AliasList.
680680
"""
681-
dmapper = make_blockdim_subiter(meta.ispace.itdims)
681+
dmapper = SubIterMapper()
682682

683683
processed = []
684684
for a in aliases:
@@ -727,7 +727,7 @@ def lower_aliases(aliases, meta, opt_maxpar, opt_block_temps):
727727
# populated
728728
if d.is_Block:
729729
if opt_block_temps:
730-
sub_iterators.update(dmapper)
730+
sub_iterators[d] = dmapper.make(d)
731731
writeto.append(interval)
732732
intervals[d] = interval
733733
else:
@@ -778,7 +778,7 @@ def lower_aliases(aliases, meta, opt_maxpar, opt_block_temps):
778778
# deterministic code generation
779779
processed = sorted(processed, key=lambda i: cit(meta.ispace, i.ispace))
780780

781-
return Schedule(*processed, dmapper=dmapper, is_frame=aliases.is_frame)
781+
return Schedule(*processed, dmapper=dict(dmapper), is_frame=aliases.is_frame)
782782

783783

784784
def make_variant(schedule, exprs, mapper):
@@ -1687,24 +1687,25 @@ def _(expr):
16871687
return expr.function, {}
16881688

16891689

1690-
def make_blockdim_subiter(dimensions):
1691-
"""
1692-
Create sub-iterators for the BlockDimension in `dimensions`.
1690+
class SubIterMapper(dict):
16931691

1694-
For example, in `r[xs][ys][z]` both `xs` and `ys` must be initialized such
1695-
that all accesses are within bounds. This requires traversing the hierarchy
1696-
of BlockDimensions to set `xs` (`ys`) in a way that consecutive blocks access
1697-
consecutive regions in `r` (e.g., trivially `xs=0` with `blocklevels=1`;
1698-
`xs=x0_blk1-x0_blk0` with `blocklevels=2`; and so on).
1699-
"""
1700-
depth = max([d._depth for d in dimensions if d.is_Block], default=0)
1701-
if depth <= 1:
1702-
return {}
1692+
def make(self, d):
1693+
"""
1694+
Create a sub-iterator for the BlockDimension `d`.
17031695
1704-
mapper = {}
1705-
for d in dimensions:
1706-
if not d.is_Block or d._depth < depth:
1707-
continue
1696+
For example, in `r[xs][ys][z]` both `xs` and `ys` must be initialized such
1697+
that all accesses are within bounds. This requires traversing the hierarchy
1698+
of BlockDimensions to set `xs` (`ys`) in a way that consecutive blocks access
1699+
consecutive regions in `r` (e.g., trivially `xs=0` with `blocklevels=1`;
1700+
`xs=x0_blk1-x0_blk0` with `blocklevels=2`; and so on).
1701+
"""
1702+
if not d.is_Block:
1703+
raise ValueError(f"Expected BlockDimension, got `{type(d)}`")
1704+
1705+
try:
1706+
return self[d]
1707+
except KeyError:
1708+
pass
17081709

17091710
pd = d.parent
17101711

@@ -1714,8 +1715,7 @@ def make_blockdim_subiter(dimensions):
17141715
else:
17151716
m = 0
17161717

1717-
mapper[d] = IncrDimension(
1718-
f"{d.name}s", d, m, pd.symbolic_size, 1, pd.step
1719-
)
1718+
name = f"{d.name}s"
1719+
di = self[d] = IncrDimension(name, d, m, pd.symbolic_size, 1, pd.step)
17201720

1721-
return mapper
1721+
return di

examples/performance/00_overview.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,7 @@
10501050
"Sometimes it's possible to trade storage for parallelism (i.e., for more parallel dimensions). For this, Devito provides the `cire-maxpar` option which is by default set to:\n",
10511051
"\n",
10521052
"* False on CPU backends\n",
1053-
"* True on GPU backends\n",
1053+
"* 'basic' on GPU backends\n",
10541054
"\n",
10551055
"Let's see what happens when we switch it on"
10561056
]
@@ -1101,7 +1101,7 @@
11011101
}
11021102
],
11031103
"source": [
1104-
"op13_omp = Operator(eq, opt=('cire-sops', {'openmp': True, 'cire-maxpar': True}))\n",
1104+
"op13_omp = Operator(eq, opt=('cire-sops', {'openmp': True, 'cire-maxpar': 'basic'}))\n",
11051105
"print_kernel(op13_omp)"
11061106
]
11071107
},

0 commit comments

Comments
 (0)