|
28 | 28 | from devito.types.dense import DiscreteFunction |
29 | 29 | from devito.types.dimension import AbstractIncrDimension, BlockDimension |
30 | 30 |
|
31 | | -__all__ = ['Graph', 'iet_pass', 'iet_visit'] |
| 31 | +__all__ = ['Graph', 'finalize_args', 'iet_pass', 'iet_visit'] |
32 | 32 |
|
33 | 33 |
|
34 | 34 | class Byproduct: |
@@ -128,10 +128,22 @@ def sync_mapper(self): |
128 | 128 |
|
129 | 129 | return found |
130 | 130 |
|
131 | | - def apply(self, func, *, updates_args=True, **kwargs): |
| 131 | + def apply(self, func, *, updates_args=False, **kwargs): |
132 | 132 | """ |
133 | | - Apply `func` to all nodes in the Graph. This changes the state of the Graph. |
| 133 | + Apply ``func`` to all nodes in the Graph. |
| 134 | +
|
| 135 | + Parameters |
| 136 | + ---------- |
| 137 | + updates_args : bool, optional |
| 138 | + If True, reconcile Callable parameters and Call arguments before |
| 139 | + the graph walk and after each changed node. This is only needed by |
| 140 | + passes whose transformation logic depends on already-updated |
| 141 | + signatures while the pass is still running. Otherwise, argument |
| 142 | + reconciliation is intentionally deferred to ``finalize_args``. |
134 | 143 | """ |
| 144 | + if updates_args: |
| 145 | + _update_args(self) |
| 146 | + |
135 | 147 | dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) |
136 | 148 |
|
137 | 149 | # Apply `func` |
@@ -161,10 +173,8 @@ def apply(self, func, *, updates_args=True, **kwargs): |
161 | 173 | efuncs[i] = efunc |
162 | 174 | efuncs.update(dict([(i.name, i) for i in new_efuncs])) |
163 | 175 |
|
164 | | - # Update the parameters / arguments lists if the pass may have |
165 | | - # introduced or removed objects. |
166 | 176 | if updates_args: |
167 | | - efuncs = update_args(efunc, efuncs, dag) |
| 177 | + efuncs = _update_args_efunc(efunc, efuncs, dag) |
168 | 178 |
|
169 | 179 | # Minimize code size |
170 | 180 | if len(efuncs) > len(self.efuncs): |
@@ -209,7 +219,37 @@ def filter(self, key): |
209 | 219 | ) |
210 | 220 |
|
211 | 221 |
|
212 | | -def iet_pass(func=None, *, updates_args=True): |
| 222 | +@timed_pass(name='finalize_args') |
| 223 | +def finalize_args(graph): |
| 224 | + """ |
| 225 | + Finalize Callable parameter lists and Call argument lists across ``graph``. |
| 226 | +
|
| 227 | + IET passes may temporarily leave helper signatures stale while introducing |
| 228 | + or eliminating symbols. This pass reconciles the whole call graph once, |
| 229 | + after lowering has settled. |
| 230 | + """ |
| 231 | + _update_args(graph) |
| 232 | + |
| 233 | + |
| 234 | +def _update_args(graph): |
| 235 | + dag = create_call_graph(graph.root.name, as_hashable(graph.efuncs)) |
| 236 | + |
| 237 | + efuncs = graph.efuncs |
| 238 | + for i in dag.topological_sort(): |
| 239 | + efuncs = _update_args_efunc(efuncs[i], efuncs, dag) |
| 240 | + |
| 241 | + graph.efuncs = efuncs |
| 242 | + |
| 243 | + |
| 244 | +def iet_pass(func=None, *, updates_args=False): |
| 245 | + """ |
| 246 | + Decorate an IET pass. |
| 247 | +
|
| 248 | + ``updates_args=True`` is an opt-in for passes that must observe up-to-date |
| 249 | + Callable/Call signatures before and during their own graph walk. Most |
| 250 | + passes should leave it False and rely on ``finalize_args`` at the end of |
| 251 | + IET lowering. |
| 252 | + """ |
213 | 253 | if func is None: |
214 | 254 | return partial(iet_pass, updates_args=updates_args) |
215 | 255 |
|
@@ -702,7 +742,7 @@ def _(i, mapper, sregistry): |
702 | 742 | mapper[i] = i._rebuild(name=sregistry.make_name(prefix='nthreads')) |
703 | 743 |
|
704 | 744 |
|
705 | | -def update_args(root, efuncs, dag): |
| 745 | +def _update_args_efunc(root, efuncs, dag): |
706 | 746 | """ |
707 | 747 | Re-derive the parameters of `root` and apply the changes in cascade through |
708 | 748 | the `efuncs`. |
@@ -800,6 +840,6 @@ def _filter(v, efunc=None): |
800 | 840 | continue |
801 | 841 |
|
802 | 842 | efuncs[n] = efunc |
803 | | - efuncs = update_args(efunc, efuncs, dag) |
| 843 | + efuncs = _update_args_efunc(efunc, efuncs, dag) |
804 | 844 |
|
805 | 845 | return efuncs |
0 commit comments