Skip to content

Commit 0c27438

Browse files
committed
compiler: Shed reduction-only halos when lowering injections
After lifting lower_sparse_ops out of mpiize, an injection nest is turned into a Call before optimize_halospots runs, so _drop_reduction_halospots can no longer detect that the wrapping HaloSpot's entry for the injected field is reduction-only. The stale entry was left in place, and on save=True (no modulo buffering) the hoist pass propagated the loop iteration variable out of the time loop, producing an undeclared 'time' reference. Drop those entries at lowering time so the resulting HaloSpot only carries entries with a genuine read at the IET level.
1 parent 7e07f72 commit 0c27438

2 files changed

Lines changed: 42 additions & 19 deletions

File tree

devito/passes/iet/sparse.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,16 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
6969
continue
7070
groups.setdefault(nest, []).append(expr)
7171

72-
# If a sparse-op nest sits inside a HaloSpot whose halo scheme is
73-
# void (the reduction-only halo got dropped by
74-
# ``_drop_reduction_halospots``), replace the HaloSpot rather than
75-
# just the nest so we don't leave behind an empty HaloSpot — the
76-
# MPI overlap machinery would otherwise try to wrap our Call with
77-
# its own dynamic-args plumbing.
78-
parents = {nest: _enclosing_void_halospot(iet, nest) for nest in groups}
72+
# ``lower_sparse_ops`` runs before ``optimize_halospots``, so the
73+
# halo-exchange optimiser hasn't yet had a chance to drop the
74+
# reduction-only halo entries that the IR scheduler put around an
75+
# injection nest (e.g. an entry for ``u`` at ``loc_indices={time:
76+
# time+1}`` wrapping ``u[time+1] += ...``). Once the nest becomes a
77+
# Call those expressions are no longer visible to
78+
# ``_drop_reduction_halospots``, so we shed those entries here -- and
79+
# if that empties the HaloSpot, replace it whole so the MPI overlap
80+
# machinery doesn't wrap our Call with stale dynamic-args plumbing.
81+
parents = {nest: _enclosing_halospot(iet, nest) for nest in groups}
7982

8083
mapper = {}
8184
efuncs = []
@@ -87,7 +90,26 @@ def lower_sparse_ops(iet, sregistry=None, **kwargs):
8790
efunc = make_callable(sregistry.make_name(prefix=prefix), new_nest)
8891
efuncs.append(efunc)
8992

90-
mapper[parents[nest] or nest] = Call(efunc.name, list(efunc.parameters))
93+
call = Call(efunc.name, list(efunc.parameters))
94+
parent = parents[nest]
95+
if parent is None:
96+
mapper[nest] = call
97+
continue
98+
99+
# Drop fields that the (now-opaque) Call only writes/increments,
100+
# since the wrapping HaloSpot's purpose was to ensure read-side
101+
# coherency for them and the read no longer exists at the IET
102+
# level. Interpolation reads its target field, so its entries
103+
# stay.
104+
reduced = {e.expr.lhs.function for e in exprs
105+
if isinstance(e.expr, InjectionMixin)}
106+
hs = parent.halo_scheme.drop(reduced) if reduced else parent.halo_scheme
107+
if hs.is_void:
108+
mapper[parent] = call
109+
elif hs is parent.halo_scheme:
110+
mapper[nest] = call
111+
else:
112+
mapper[parent] = parent._rebuild(halo_scheme=hs, body=call)
91113

92114
if not mapper:
93115
return iet, {}
@@ -107,14 +129,12 @@ def _find_outer_iteration(iet, expr):
107129
return None
108130

109131

110-
def _enclosing_void_halospot(iet, nest):
132+
def _enclosing_halospot(iet, nest):
111133
"""
112-
Return the HaloSpot directly wrapping ``nest`` if it carries an
113-
empty (void) HaloScheme, otherwise None. Such HaloSpots are leftover
114-
after ``_drop_reduction_halospots`` cleared all entries.
134+
Return the HaloSpot directly wrapping ``nest``, if any.
115135
"""
116136
for hs in FindNodes(HaloSpot).visit(iet):
117-
if hs.is_void and nest in FindNodes(Iteration).visit(hs):
137+
if nest in FindNodes(Iteration).visit(hs):
118138
return hs
119139
return None
120140

tests/test_dse.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2660,7 +2660,9 @@ def test_sparse_const(self):
26602660

26612661
op = Operator(src.interpolate(u))
26622662

2663-
cond = FindNodes(Conditional).visit(op)
2663+
assert len(FindNodes(Conditional).visit(op)) == 0
2664+
print(op._func_table)
2665+
cond = FindNodes(Conditional).visit(op._func_table['interpolate_src0'])
26642666
assert len(cond) == 1
26652667
assert len(cond[0].args['then_body'][0].exprs) == 1
26662668
assert all(e.is_scalar for e in cond[0].args['then_body'][0].exprs)
@@ -2914,12 +2916,12 @@ def test_fullopt(self):
29142916
bns, _ = assert_blocking(op1, {'x0_blk0'}) # due to loop blocking
29152917

29162918
assert summary0[('section0', None)].ops == 55
2917-
assert summary0[('section1', None)].ops == 44
2919+
assert summary0[('section1', None)].ops == 17
29182920
assert np.isclose(summary0[('section0', None)].oi, 3.136, atol=0.001)
29192921

29202922
assert summary1[('section0', None)].ops == 31
2921-
assert summary1[('section1', None)].ops == 88
2922-
assert summary1[('section2', None)].ops == 25
2923+
assert summary1[('section1', None)].ops == 17
2924+
assert summary1[('section2', None)].ops == 0
29232925
assert np.isclose(summary1[('section0', None)].oi, 1.767, atol=0.001)
29242926

29252927
assert np.allclose(u0.data, u1.data, atol=10e-5)
@@ -2966,7 +2968,8 @@ def tti_noopt(self):
29662968

29672969
# Make sure no opts were applied
29682970
op = wavesolver.op_fwd(False)
2969-
assert len(op._func_table) == 0
2971+
# Two funcs, one for src, one for rec
2972+
assert len(op._func_table) == 2
29702973
assert summary[('section0', None)].ops == 753
29712974

29722975
return v, rec
@@ -3024,7 +3027,7 @@ def test_fullopt_w_mpi(self, mode):
30243027

30253028
# Run a quick check to be sure MPI-full-mode code was actually generated
30263029
op = tti_agg.op_fwd(False)
3027-
assert len(op._func_table) == 7
3030+
assert len(op._func_table) == 9
30283031
assert 'pokempi0' in op._func_table
30293032

30303033
@switchconfig(profiling='advanced')

0 commit comments

Comments
 (0)