diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index 2ff570c9..ef320bab 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -116,11 +116,37 @@ def transform(cls, f): log().warning("could not find rewritten function %s in code object", f.__name__) return f - f.__code__ = new_f_code_o - for name, val in cls.rewrite_globals.items(): f.__globals__[name] = val + # AST transformers may inject references to new names (e.g. + # scf_if_dispatch, scf_if_collect_results) inside the kernel or + # its rewriter-generated sub-functions (__then_N, __else_N). + # Because enclosing_mod creates a closure layer, these unresolved + # names become free vars (LOAD_DEREF) rather than globals. This + # causes new_f_code_o.co_freevars to have more entries than the + # original f.__closure__. Direct f.__code__ assignment would + # raise ValueError, so we build a new function with a matching + # closure instead. + if f.__closure__ and new_f_code_o.co_freevars != f.__code__.co_freevars: + old_cells = {name: cell for name, cell + in zip(f.__code__.co_freevars, f.__closure__)} + new_closure = [] + for var in new_f_code_o.co_freevars: + if var in old_cells: + new_closure.append(old_cells[var]) + else: + # Create a cell whose value comes from globals + cell = (lambda v: lambda: v)(f.__globals__.get(var)).__closure__[0] + new_closure.append(cell) + new_f = types.FunctionType( + new_f_code_o, f.__globals__, f.__name__, + f.__defaults__, tuple(new_closure), + ) + new_f.__kwdefaults__ = f.__kwdefaults__ + return new_f + + f.__code__ = new_f_code_o return f