Skip to content

Commit d2cc887

Browse files
committed
compiler: Remove dangerous updates_args feature
1 parent ffc1dc9 commit d2cc887

5 files changed

Lines changed: 13 additions & 46 deletions

File tree

devito/passes/iet/asynchrony.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def pthreadify(graph, **kwargs):
3838
AsyncMeta = namedtuple('AsyncMeta', 'sdata threads init shutdown')
3939

4040

41-
@iet_pass(updates_args=True)
41+
@iet_pass
4242
def lower_async_objs(iet, **kwargs):
4343
# Different actions depending on the Callable type
4444
iet, efuncs = _lower_async_objs(iet, **kwargs)

devito/passes/iet/definitions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def _inject_definitions(self, iet, storage):
444444

445445
return processed, flatten(efuncs)
446446

447-
@iet_pass(updates_args=True)
447+
@iet_pass
448448
def place_definitions(self, iet, globs=None, **kwargs):
449449
"""
450450
Create a new IET where all symbols have been declared, allocated, and

devito/passes/iet/engine.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,14 @@ def sync_mapper(self):
127127

128128
return found
129129

130-
def apply(self, func, *, updates_args=False, **kwargs):
130+
def apply(self, func, **kwargs):
131131
"""
132132
Apply ``func`` to all nodes in the Graph.
133133
134-
Parameters
135-
----------
136-
updates_args : bool, optional
137-
If True, reconcile Callable parameters and Call arguments before
138-
the graph walk and after each changed node. This is only needed by
139-
passes whose transformation logic depends on already-updated
140-
signatures while the pass is still running. Otherwise, argument
141-
reconciliation is intentionally deferred to ``finalize_args``.
134+
Callable parameters and Call arguments are reconciled before the graph
135+
walk, after each changed node, and after the pass has completed.
142136
"""
143-
if updates_args:
144-
_update_args(self)
137+
_update_args(self)
145138

146139
dag = create_call_graph(self.root.name, as_hashable(self.efuncs))
147140

@@ -172,8 +165,7 @@ def apply(self, func, *, updates_args=False, **kwargs):
172165
efuncs[i] = efunc
173166
efuncs.update(dict([(i.name, i) for i in new_efuncs]))
174167

175-
if updates_args:
176-
efuncs = _update_args_efunc(efunc, efuncs, dag)
168+
efuncs = _update_args_efunc(efunc, efuncs, dag)
177169

178170
# Minimize code size
179171
if len(efuncs) > len(self.efuncs):
@@ -182,6 +174,7 @@ def apply(self, func, *, updates_args=False, **kwargs):
182174
efuncs = reuse_efuncs(self.root, efuncs, self.sregistry)
183175

184176
self.efuncs = efuncs
177+
_update_args(self)
185178

186179
# Uniqueness
187180
self.includes = filter_ordered(self.includes)
@@ -240,24 +233,19 @@ def _update_args(graph):
240233
graph.efuncs = efuncs
241234

242235

243-
def iet_pass(func=None, *, updates_args=False):
236+
def iet_pass(func=None):
244237
"""
245238
Decorate an IET pass.
246-
247-
``updates_args=True`` is an opt-in for passes that must observe up-to-date
248-
Callable/Call signatures before and during their own graph walk. Most
249-
passes should leave it False and rely on ``finalize_args`` at the end of
250-
IET lowering.
251239
"""
252240
if func is None:
253-
return partial(iet_pass, updates_args=updates_args)
241+
return iet_pass
254242

255243
if isinstance(func, tuple):
256244
assert len(func) == 2 and func[0] is iet_visit
257245
call = lambda graph: graph.visit
258246
func = func[1]
259247
else:
260-
call = lambda graph: partial(graph.apply, updates_args=updates_args)
248+
call = lambda graph: graph.apply
261249

262250
@wraps(func)
263251
def wrapper(*args, **kwargs):

devito/passes/iet/linearization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def linearize(graph, **kwargs):
4646
linearization(graph, key=key, tracker=tracker, **kwargs)
4747

4848

49-
@iet_pass(updates_args=True)
49+
@iet_pass
5050
def linearization(iet, key=None, tracker=None, **kwargs):
5151
"""
5252
Carry out the actual work of `linearize`.

tests/test_iet.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
ElementalFunction, FindSymbols, Iteration, KernelLaunch, Lambda, List, Switch,
1515
Transformer, filter_iterations, make_efunc, retrieve_iteration_tree
1616
)
17-
from devito.passes.iet import engine as iet_engine
18-
from devito.passes.iet.engine import Graph, iet_pass
17+
from devito.passes.iet.engine import Graph
1918
from devito.passes.iet.languages.C import CDataManager
2019
from devito.symbolics import (
2120
FLOAT, Byref, Class, FieldFromComposite, InlineIf, ListInitializer, Macro, SizeOf,
@@ -540,26 +539,6 @@ def test_complex_array():
540539
"float _Complex **restrict a_vec __attribute__ ((aligned (64)));"
541540

542541

543-
def test_iet_pass_does_not_update_args(monkeypatch):
544-
x = Symbol(name='x')
545-
y = Symbol(name='y')
546-
547-
foo = Callable('foo', DummyExpr(x, y), 'void', parameters=(x, y))
548-
graph = Graph(foo)
549-
550-
@iet_pass
551-
def inject_expr(iet):
552-
body = iet.body._rebuild(body=iet.body.body + (DummyExpr(x, x),))
553-
return iet._rebuild(body=body), {}
554-
555-
monkeypatch.setattr(iet_engine, '_update_args',
556-
lambda *args, **kwargs: pytest.fail("_update_args called"))
557-
558-
inject_expr(graph)
559-
560-
assert graph.root.parameters is foo.parameters
561-
562-
563542
def test_special_array_definition():
564543

565544
class MyArray(Array):

0 commit comments

Comments
 (0)