Skip to content

Commit 6ac9b20

Browse files
committed
compiler: Call finalize_args once at the end of the lowering
1 parent 54c2b2a commit 6ac9b20

7 files changed

Lines changed: 67 additions & 24 deletions

File tree

devito/operator/operator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
from devito.operator.registry import operator_selector
3333
from devito.parameters import configuration
3434
from devito.passes import (
35-
Graph, error_mapper, generate_implicit, generate_macros, is_on_device, lower_dtypes,
36-
lower_index_derivatives, minimize_symbols, optimize_pows, unevaluate
35+
Graph, error_mapper, finalize_args, generate_implicit, generate_macros, is_on_device,
36+
lower_dtypes, lower_index_derivatives, minimize_symbols, optimize_pows, unevaluate
3737
)
3838
from devito.symbolics import estimate_cost, subs_op_args
3939
from devito.tools import (
@@ -522,6 +522,9 @@ def _lower_iet(cls, uiet, **kwargs):
522522
# Target-independent optimizations
523523
minimize_symbols(graph)
524524

525+
# Finalize helper signatures after all IET transformations have settled.
526+
finalize_args(graph)
527+
525528
return graph.root, graph
526529

527530
# Read-only properties exposed to the outside world

devito/passes/iet/definitions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def _inject_definitions(self, iet, storage):
444444

445445
return processed, flatten(efuncs)
446446

447-
@iet_pass
447+
@iet_pass(updates_args=True)
448448
def place_definitions(self, iet, globs=None, **kwargs):
449449
"""
450450
Create a new IET where all symbols have been declared, allocated, and
@@ -518,7 +518,7 @@ def place_definitions(self, iet, globs=None, **kwargs):
518518
'globals': as_tuple(globs),
519519
'includes': as_tuple(sorted(storage.includes))}
520520

521-
@iet_pass(updates_args=False)
521+
@iet_pass
522522
def place_casts(self, iet, **kwargs):
523523
"""
524524
Create a new IET with the necessary type casts.
@@ -669,7 +669,7 @@ def place_transfers(self, iet, data_movs=None, ctx=None, **kwargs):
669669

670670
return iet, {'efuncs': efuncs}
671671

672-
@iet_pass(updates_args=False)
672+
@iet_pass
673673
def place_devptr(self, iet, **kwargs):
674674
"""
675675
Transform `iet` such that device pointers are used in DeviceCalls.

devito/passes/iet/engine.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from devito.types.dense import DiscreteFunction
2929
from devito.types.dimension import AbstractIncrDimension, BlockDimension
3030

31-
__all__ = ['Graph', 'iet_pass', 'iet_visit']
31+
__all__ = ['Graph', 'finalize_args', 'iet_pass', 'iet_visit']
3232

3333

3434
class Byproduct:
@@ -128,10 +128,22 @@ def sync_mapper(self):
128128

129129
return found
130130

131-
def apply(self, func, *, updates_args=True, **kwargs):
131+
def apply(self, func, *, updates_args=False, **kwargs):
132132
"""
133-
Apply `func` to all nodes in the Graph. This changes the state of the Graph.
133+
Apply ``func`` to all nodes in the Graph.
134+
135+
Parameters
136+
----------
137+
updates_args : bool, optional
138+
If True, reconcile Callable parameters and Call arguments before
139+
the graph walk and after each changed node. This is only needed by
140+
passes whose transformation logic depends on already-updated
141+
signatures while the pass is still running. Otherwise, argument
142+
reconciliation is intentionally deferred to ``finalize_args``.
134143
"""
144+
if updates_args:
145+
_update_args(self)
146+
135147
dag = create_call_graph(self.root.name, as_hashable(self.efuncs))
136148

137149
# Apply `func`
@@ -161,10 +173,8 @@ def apply(self, func, *, updates_args=True, **kwargs):
161173
efuncs[i] = efunc
162174
efuncs.update(dict([(i.name, i) for i in new_efuncs]))
163175

164-
# Update the parameters / arguments lists if the pass may have
165-
# introduced or removed objects.
166176
if updates_args:
167-
efuncs = update_args(efunc, efuncs, dag)
177+
efuncs = _update_args_efunc(efunc, efuncs, dag)
168178

169179
# Minimize code size
170180
if len(efuncs) > len(self.efuncs):
@@ -209,7 +219,37 @@ def filter(self, key):
209219
)
210220

211221

212-
def iet_pass(func=None, *, updates_args=True):
222+
@timed_pass(name='finalize_args')
223+
def finalize_args(graph):
224+
"""
225+
Finalize Callable parameter lists and Call argument lists across ``graph``.
226+
227+
IET passes may temporarily leave helper signatures stale while introducing
228+
or eliminating symbols. This pass reconciles the whole call graph once,
229+
after lowering has settled.
230+
"""
231+
_update_args(graph)
232+
233+
234+
def _update_args(graph):
235+
dag = create_call_graph(graph.root.name, as_hashable(graph.efuncs))
236+
237+
efuncs = graph.efuncs
238+
for i in dag.topological_sort():
239+
efuncs = _update_args_efunc(efuncs[i], efuncs, dag)
240+
241+
graph.efuncs = efuncs
242+
243+
244+
def iet_pass(func=None, *, updates_args=False):
245+
"""
246+
Decorate an IET pass.
247+
248+
``updates_args=True`` is an opt-in for passes that must observe up-to-date
249+
Callable/Call signatures before and during their own graph walk. Most
250+
passes should leave it False and rely on ``finalize_args`` at the end of
251+
IET lowering.
252+
"""
213253
if func is None:
214254
return partial(iet_pass, updates_args=updates_args)
215255

@@ -702,7 +742,7 @@ def _(i, mapper, sregistry):
702742
mapper[i] = i._rebuild(name=sregistry.make_name(prefix='nthreads'))
703743

704744

705-
def update_args(root, efuncs, dag):
745+
def _update_args_efunc(root, efuncs, dag):
706746
"""
707747
Re-derive the parameters of `root` and apply the changes in cascade through
708748
the `efuncs`.
@@ -800,6 +840,6 @@ def _filter(v, efunc=None):
800840
continue
801841

802842
efuncs[n] = efunc
803-
efuncs = update_args(efunc, efuncs, dag)
843+
efuncs = _update_args_efunc(efunc, efuncs, dag)
804844

805845
return efuncs

devito/passes/iet/instrument.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def instrument(graph, **kwargs):
2727
sync_sections(graph, **kwargs)
2828

2929

30-
@iet_pass(updates_args=False)
30+
@iet_pass
3131
def track_subsections(iet, **kwargs):
3232
"""
3333
Add sub-Sections to the `profiler`. Sub-Sections include:
@@ -122,7 +122,7 @@ def instrument_sections(iet, **kwargs):
122122
return piet, {'headers': headers}
123123

124124

125-
@iet_pass(updates_args=False)
125+
@iet_pass
126126
def sync_sections(iet, langbb=None, profiler=None, **kwargs):
127127
"""
128128
Wrap sections within global barriers if deemed necessary by the profiler.

devito/passes/iet/linearization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def linearize(graph, **kwargs):
4646
linearization(graph, key=key, tracker=tracker, **kwargs)
4747

4848

49-
@iet_pass
49+
@iet_pass(updates_args=True)
5050
def linearization(iet, key=None, tracker=None, **kwargs):
5151
"""
5252
Carry out the actual work of `linearize`.

devito/passes/iet/misc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
]
3030

3131

32-
@iet_pass(updates_args=False)
32+
@iet_pass
3333
def avoid_denormals(iet, platform=None, **kwargs):
3434
"""
3535
Introduce nodes in the Iteration/Expression tree that will expand to C
@@ -60,7 +60,7 @@ def avoid_denormals(iet, platform=None, **kwargs):
6060
return iet, {'includes': ('xmmintrin.h', 'pmmintrin.h')}
6161

6262

63-
@iet_pass(updates_args=False)
63+
@iet_pass
6464
def hoist_prodders(iet):
6565
"""
6666
Move Prodders within the outer levels of an Iteration tree.
@@ -151,7 +151,7 @@ def generate_macros(graph, **kwargs):
151151
_generate_macros(graph, tracker={}, **kwargs)
152152

153153

154-
@iet_pass(updates_args=False)
154+
@iet_pass
155155
def _generate_macros(iet, tracker=None, langbb=None, printer=CPrinter, **kwargs):
156156
# Derive the Macros necessary for the FIndexeds
157157
iet = _generate_macros_findexeds(iet, tracker=tracker, **kwargs)

tests/test_iet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,20 +540,20 @@ def test_complex_array():
540540
"float _Complex **restrict a_vec __attribute__ ((aligned (64)));"
541541

542542

543-
def test_iet_pass_skip_update_args(monkeypatch):
543+
def test_iet_pass_does_not_update_args(monkeypatch):
544544
x = Symbol(name='x')
545545
y = Symbol(name='y')
546546

547547
foo = Callable('foo', DummyExpr(x, y), 'void', parameters=(x, y))
548548
graph = Graph(foo)
549549

550-
@iet_pass(updates_args=False)
550+
@iet_pass
551551
def inject_expr(iet):
552552
body = iet.body._rebuild(body=iet.body.body + (DummyExpr(x, x),))
553553
return iet._rebuild(body=body), {}
554554

555-
monkeypatch.setattr(iet_engine, 'update_args',
556-
lambda *args, **kwargs: pytest.fail("update_args called"))
555+
monkeypatch.setattr(iet_engine, '_update_args',
556+
lambda *args, **kwargs: pytest.fail("_update_args called"))
557557

558558
inject_expr(graph)
559559

0 commit comments

Comments
 (0)