@@ -127,21 +127,14 @@ def sync_mapper(self):
127127
128128 return found
129129
130- def apply (self , func , * , updates_args = False , * *kwargs ):
130+ def apply (self , func , ** kwargs ):
131131 """
132132 Apply ``func`` to all nodes in the Graph.
133133
134- Parameters
135- ----------
136- updates_args : bool, optional
137- If True, reconcile Callable parameters and Call arguments before
138- the graph walk and after each changed node. This is only needed by
139- passes whose transformation logic depends on already-updated
140- signatures while the pass is still running. Otherwise, argument
141- reconciliation is intentionally deferred to ``finalize_args``.
134+ Callable parameters and Call arguments are reconciled before the graph
135+ walk, after each changed node, and after the pass has completed.
142136 """
143- if updates_args :
144- _update_args (self )
137+ _update_args (self )
145138
146139 dag = create_call_graph (self .root .name , as_hashable (self .efuncs ))
147140
@@ -172,8 +165,7 @@ def apply(self, func, *, updates_args=False, **kwargs):
172165 efuncs [i ] = efunc
173166 efuncs .update (dict ([(i .name , i ) for i in new_efuncs ]))
174167
175- if updates_args :
176- efuncs = _update_args_efunc (efunc , efuncs , dag )
168+ efuncs = _update_args_efunc (efunc , efuncs , dag )
177169
178170 # Minimize code size
179171 if len (efuncs ) > len (self .efuncs ):
@@ -182,6 +174,7 @@ def apply(self, func, *, updates_args=False, **kwargs):
182174 efuncs = reuse_efuncs (self .root , efuncs , self .sregistry )
183175
184176 self .efuncs = efuncs
177+ _update_args (self )
185178
186179 # Uniqueness
187180 self .includes = filter_ordered (self .includes )
@@ -240,24 +233,19 @@ def _update_args(graph):
240233 graph .efuncs = efuncs
241234
242235
243- def iet_pass (func = None , * , updates_args = False ):
236+ def iet_pass (func = None ):
244237 """
245238 Decorate an IET pass.
246-
247- ``updates_args=True`` is an opt-in for passes that must observe up-to-date
248- Callable/Call signatures before and during their own graph walk. Most
249- passes should leave it False and rely on ``finalize_args`` at the end of
250- IET lowering.
251239 """
252240 if func is None :
253- return partial ( iet_pass , updates_args = updates_args )
241+ return iet_pass
254242
255243 if isinstance (func , tuple ):
256244 assert len (func ) == 2 and func [0 ] is iet_visit
257245 call = lambda graph : graph .visit
258246 func = func [1 ]
259247 else :
260- call = lambda graph : partial ( graph .apply , updates_args = updates_args )
248+ call = lambda graph : graph .apply
261249
262250 @wraps (func )
263251 def wrapper (* args , ** kwargs ):
0 commit comments