Skip to content

Commit b2cb717

Browse files
committed
Tweak scheduling of scalar aliases in the presence of guards
1 parent 870a160 commit b2cb717

3 files changed

Lines changed: 74 additions & 41 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,13 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
125125
variants = []
126126
for mapper in self._generate(cgroup, exclude):
127127
# Clusters -> AliasList
128-
found = collect(mapper.extracted, meta.ispace, self.opt_minstorage)
128+
found = collect(mapper.extracted, meta, self.opt_minstorage)
129129
exprs, aliases = self._choose(found, cgroup, mapper)
130130

131131
# AliasList -> Schedule
132132
schedule = lower_aliases(aliases, meta, self.opt_maxpar)
133133

134-
variants.append(Variant(schedule, exprs))
134+
variants.append(make_variant(schedule, exprs, mapper))
135135

136136
if not variants:
137137
return []
@@ -282,8 +282,6 @@ def _do_generate(self, exprs, exclude, cbk_search, cbk_compose=None):
282282

283283
class CireInvariants(CireTransformerLegacy, Queue):
284284

285-
_q_guards_in_key = True
286-
287285
def __init__(self, sregistry, options, platform):
288286
super().__init__(sregistry, options, platform)
289287

@@ -511,7 +509,7 @@ def _cbk_search2(self, expr, rank):
511509
}
512510

513511

514-
def collect(extracted, ispace, minstorage):
512+
def collect(extracted, meta, minstorage):
515513
"""
516514
Find groups of aliasing expressions.
517515
@@ -575,11 +573,11 @@ def collect(extracted, ispace, minstorage):
575573

576574
group.append(u)
577575
unseen.remove(u)
578-
group = Group(group, ispace=ispace)
576+
group = Group(group, ispace=meta.ispace)
579577

580578
k = group.dimensions_translated if minstorage else group.dimensions
581-
582579
k = frozenset(d for d in k if not d.is_NonlinearDerived)
580+
583581
mapper.setdefault(k, []).append(group)
584582

585583
aliases = AliasList()
@@ -657,8 +655,9 @@ def collect(extracted, ispace, minstorage):
657655

658656
# Compute the alias score
659657
na = g.naliases
660-
nr = nredundants(ispace, pivot)
658+
nr = nredundants(meta.ispace, pivot)
661659
score = estimate_cost(pivot, True)*((na - 1) + nr)
660+
662661
aliases.add(pivot, aliaseds, list(mapper), distances, score)
663662

664663
return aliases
@@ -728,8 +727,9 @@ def lower_aliases(aliases, meta, maxpar):
728727
m = i.dim.symbolic_min - i.dim.parent.symbolic_min
729728
else:
730729
m = 0
731-
d = dmapper[i.dim] = IncrDimension(f"{i.dim.name}s", i.dim, m,
732-
dd.symbolic_size, 1, dd.step)
730+
d = dmapper[i.dim] = IncrDimension(
731+
f"{i.dim.name}s", i.dim, m, dd.symbolic_size, 1, dd.step
732+
)
733733
sub_iterators[i.dim] = d
734734
else:
735735
d = i.dim
@@ -745,6 +745,11 @@ def lower_aliases(aliases, meta, maxpar):
745745
# The alias write-to space
746746
writeto = IterationSpace(IntervalGroup(writeto), sub_iterators)
747747

748+
# Avoid scalar aliases in the presence of guards, since hoisting them
749+
# might cause scope issues (see `test_dse.py::TestAliases::test_split_cond`)
750+
if not writeto and meta.guards:
751+
continue
752+
748753
# The alias iteration space
749754
ispace = IterationSpace(IntervalGroup(intervals, meta.ispace.relations),
750755
meta.ispace.sub_iterators,
@@ -764,6 +769,34 @@ def lower_aliases(aliases, meta, maxpar):
764769
return Schedule(*processed, dmapper=dmapper, is_frame=aliases.is_frame)
765770

766771

772+
def make_variant(schedule, exprs, mapper):
773+
"""
774+
Create a Variant from a Schedule and the corresponding expressions.
775+
"""
776+
# Some aliases may have been discarded along the way, and for
777+
# them we reinstate the original sub-expressions
778+
retained = flatten(sa.aliaseds for sa in schedule)
779+
780+
subs = {}
781+
for k, v in mapper.items():
782+
if v in retained:
783+
continue
784+
elif isinstance(v, dict):
785+
# E.g., `mapper = {u[t0, x+3, y+3] + u[t0, x+3, y+4]:
786+
# {u[t0, x+3, y+4]: None, u[t0, x+3, y+3]: dummy0}}`
787+
try:
788+
v1, = [i for i in v.values() if i not in retained]
789+
except ValueError:
790+
continue
791+
subs[v1] = k
792+
else:
793+
subs[v] = k
794+
795+
exprs = [uxreplace(e, subs) for e in exprs]
796+
797+
return Variant(schedule, exprs)
798+
799+
767800
def optimize_schedule_rotations(schedule, sregistry):
768801
"""
769802
Transform the schedule such that the tensor temporaries "rotate" along
@@ -1493,7 +1526,7 @@ def nredundants(ispace, expr):
14931526
redundant if it defines an iteration space for `expr` while not appearing
14941527
among its free symbols. Note that the converse isn't generally true: there
14951528
could be a Dimension that does not appear in the free symbols while defining
1496-
a non-redundant iteration space (e.g., a BlockDimension).
1529+
a non-redundant iteration space (e.g., a BlockDimension or a reduction).
14971530
"""
14981531
iterated = {i.dim for i in ispace}
14991532
used = {i for i in expr.free_symbols if i.is_Dimension}

tests/test_dimension.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from devito.types import Array, StencilDimension, Symbol
2323
from devito.types.basic import Scalar
2424
from devito.types.dimension import AffineIndexAccessFunction, Thickness
25+
from devito.types.misc import Temp
2526

2627

2728
class TestIndexAccessFunction:
@@ -2130,9 +2131,10 @@ def test_topofusion_w_subdims_conddims(self):
21302131
assert exprs[0].write is h
21312132

21322133
exprs = FindNodes(Expression).visit(bns['x2_blk0'])
2133-
assert len(exprs) == 2
2134-
assert exprs[0].write is fsave
2135-
assert exprs[1].write is gsave
2134+
assert len(exprs) == 3
2135+
assert isinstance(exprs[0].expr.lhs, Temp)
2136+
assert exprs[1].write is fsave
2137+
assert exprs[2].write is gsave
21362138

21372139
def test_topofusion_w_subdims_conddims_v2(self):
21382140
"""
@@ -2163,9 +2165,10 @@ def test_topofusion_w_subdims_conddims_v2(self):
21632165
bns, _ = assert_blocking(op, {'x0_blk0', 'x1_blk0'})
21642166
assert len(FindNodes(Expression).visit(bns['x0_blk0'])) == 3
21652167
exprs = FindNodes(Expression).visit(bns['x1_blk0'])
2166-
assert len(exprs) == 2
2167-
assert exprs[0].write is fsave
2168-
assert exprs[1].write is gsave
2168+
assert len(exprs) == 3
2169+
assert isinstance(exprs[0].expr.lhs, Temp)
2170+
assert exprs[1].write is fsave
2171+
assert exprs[2].write is gsave
21692172

21702173
def test_topofusion_w_subdims_conddims_v3(self):
21712174
"""
@@ -2200,9 +2203,10 @@ def test_topofusion_w_subdims_conddims_v3(self):
22002203
assert exprs[1].write is g
22012204

22022205
exprs = FindNodes(Expression).visit(bns['x2_blk0'])
2203-
assert len(exprs) == 2
2204-
assert exprs[0].write is fsave
2205-
assert exprs[1].write is gsave
2206+
assert len(exprs) == 3
2207+
assert isinstance(exprs[0].expr.lhs, Temp)
2208+
assert exprs[1].write is fsave
2209+
assert exprs[2].write is gsave
22062210

22072211
# Additional nest due to anti-dependence
22082212
exprs = FindNodes(Expression).visit(bns['x1_blk0'])

tests/test_dse.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
Conditional, DummyEq, Expression, FindNodes, FindSymbols, Iteration,
2020
ParallelIteration, retrieve_iteration_tree
2121
)
22-
from devito.passes.clusters.aliases import collect
22+
from devito.passes.clusters.aliases import AliasKey, collect
2323
from devito.passes.clusters.factorization import collect_nested
2424
from devito.passes.iet.parpragma import VExpanded
2525
from devito.symbolics import ( # noqa
@@ -423,8 +423,9 @@ def test_collection(self, exprs, expected):
423423

424424
extracted = {i.rhs: i.lhs for i in exprs}
425425
ispace = exprs[0].ispace
426+
meta = AliasKey(ispace, None, None, None, None)
426427

427-
aliases = collect(extracted, ispace, False)
428+
aliases = collect(extracted, meta, False)
428429
aliases.filter(lambda a: a.score > 0)
429430

430431
assert len(aliases) == len(expected)
@@ -2553,15 +2554,15 @@ def test_invariants_with_conditional(self):
25532554

25542555
op = Operator(eqn, opt='advanced')
25552556

2556-
assert_structure(op, ['t', 't,fd', 't,fd,x,y'], 't,fd,x,y')
2557+
assert_structure(op, ['t', 't,fd,x,y'], 't,fd,x,y')
25572558
# Make sure it compiles
25582559
_ = op.cfunction
25592560

25602561
# Check hoisting for time invariant
25612562
eqn = Eq(u, u - (cos(time_sub * factor * f) * sin(g) * uf))
25622563

25632564
op = Operator(eqn, opt='advanced')
2564-
assert_structure(op, ['x,y', 't', 't,fd', 't,fd,x,y'], 'x,y,t,fd,x,y')
2565+
assert_structure(op, ['x,y', 't', 't,fd,x,y'], 'x,y,t,fd,x,y')
25652566
# Make sure it compiles
25662567
_ = op.cfunction
25672568

@@ -2705,10 +2706,9 @@ def test_split_cond(self):
27052706

27062707
cond = FindNodes(Conditional).visit(op)
27072708
assert len(cond) == 3
2708-
# Each guard should have its own alias for cos(time)
2709-
assert 'float r0 = cos(time);' in str(body0(op))
2709+
# No aliases in this case due to guards
27102710
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2711-
assert len(scalars) == 2
2711+
assert len(scalars) == 0
27122712

27132713
def test_split_cond_multi_alias(self):
27142714
grid = Grid((11, 11))
@@ -2728,11 +2728,9 @@ def test_split_cond_multi_alias(self):
27282728

27292729
cond = FindNodes(Conditional).visit(op)
27302730
assert len(cond) == 3
2731-
# Each guard should have its own aliases for cos(time) and sin(time)
2732-
assert 'const float r0 = sin(time) + cos(time)' in str(body0(op))
2733-
assert 'const float r1 = cos(time);' in str(body0(op))
2731+
# No aliases in this case due to guards
27342732
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2735-
assert len(scalars) == 3
2733+
assert len(scalars) == 0
27362734

27372735
def test_multi_cond_no_split(self):
27382736
grid = Grid((11, 11))
@@ -2758,7 +2756,7 @@ def test_multi_cond_no_split(self):
27582756
)
27592757

27602758
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2761-
assert len(scalars) == 3
2759+
assert len(scalars) == 0
27622760

27632761
def test_alias_with_conditional(self):
27642762
grid = Grid((11, 11))
@@ -2779,9 +2777,9 @@ def test_alias_with_conditional(self):
27792777
cond = FindNodes(Conditional).visit(op)
27802778
assert len(cond) == 3
27812779

2782-
# Each guard should have its own alias for cos(time/ctf)
2780+
# No aliases in this case due to guards
27832781
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2784-
assert len(scalars) == 2
2782+
assert len(scalars) == 0
27852783

27862784
def test_scalar_alias_interp(self):
27872785
grid = Grid(shape=(11, 11))
@@ -2825,9 +2823,9 @@ def test_scalar_with_cond_access(self):
28252823
cond = FindNodes(Conditional).visit(op)
28262824
assert len(cond) == 3
28272825

2828-
# # Each guard should have its own alias for cos/sin(f1[time-2])
2826+
# The guards prevent some aliases from being hoisted out
28292827
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2830-
assert len(scalars) == 3
2828+
assert len(scalars) == 0
28312829

28322830
assert_structure(
28332831
op,
@@ -2855,21 +2853,19 @@ def test_scalar_with_cond_tinvariant(self):
28552853

28562854
cond = FindNodes(Conditional).visit(op)
28572855
assert len(cond) == 1
2858-
# One for each 1/dt 1/dt**2
2856+
# One for 1/dt, while 1/dt**2 ain't hoisted out due to the guard
28592857
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2860-
assert len(scalars) == 2
2858+
assert len(scalars) == 1
28612859

28622860
assert_structure(
28632861
op,
28642862
['t,x,y', 't', 't,x,y'],
28652863
'txyxy'
28662864
)
28672865

2868-
# Both aliases should be hoisted outside the time loop
2866+
# The 1/dt alias should be hoisted outside the time loop
28692867
assert str(body0(op).body[0]) == 'const float r0 = 1.0F/dt;'
28702868
assert not body0(op).body[0].ispace
2871-
assert str(body0(op).body[1]) == 'const float r1 = 1.0F/(dt*dt);'
2872-
assert not body0(op).body[1].ispace
28732869

28742870

28752871
class TestIsoAcoustic:

0 commit comments

Comments
 (0)