Skip to content

Commit 72a2431

Browse files
[mypyc] Fix lambda inside comprehension (#21009)
Fixes mypyc/mypyc#1190 There's a subtle difference between mypyc and CPython when it comes to evaluating comprehensions: - CPython creates an implicit function scope for every comprehension (visible as `MAKE_FUNCTION` etc) - Mypyc inlines comprehensions directly into the enclosing/outer scope This leads to the following bug: When a lambda inside a comprehension tries to capture the loop variable, the closure/env-class machinery fails because there's no scope boundary to chain through. Consider this example with a module level comprehension which currently fails with `UnboundLocalError`: ```Python3 # bug.py d = {name: (lambda: name) for name in ("a", "b")} d["a"]() ``` <br /> Schematically: ```Bash Before (broken): module scope (no env class) └── lambda (needs env class to find 'name') → crash After (fixed): module scope └── comprehension scope (has env class with 'name' attribute) └── lambda (loads 'name' from comprehensions env class) → works ``` <br /> Three failure modes depending on where the comprehension lives: - Module level: UnboundLocalError at runtime -> the lambda can't find the variable - Class level: KeyError compiler crash -> env class setup fails entirely - Function level: Already worked, because the enclosing function provides the env class The fix creates a _lightweight_ synthetic scope (new `FuncInfo` + `env` class) only when a comprehension body contains a lambda, while still inlining the comprehension into the same basic blocks otherwise. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4d6c417 commit 72a2431

File tree

7 files changed

+344
-15
lines changed

7 files changed

+344
-15
lines changed

mypyc/irbuild/builder.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def __init__(
227227
self.nested_fitems = pbv.nested_funcs.keys()
228228
self.fdefs_to_decorators = pbv.funcs_to_decorators
229229
self.module_import_groups = pbv.module_import_groups
230+
self.comprehension_to_fitem = pbv.comprehension_to_fitem
230231

231232
self.singledispatch_impls = singledispatch_impls
232233

@@ -1263,6 +1264,37 @@ def leave(self) -> tuple[list[Register], list[RuntimeArg], list[BasicBlock], RTy
12631264
self.fn_info = self.fn_infos[-1]
12641265
return builder.args, runtime_args, builder.blocks, ret_type, fn_info
12651266

1267+
@contextmanager
1268+
def enter_scope(self, fn_info: FuncInfo) -> Iterator[None]:
1269+
"""Push a lightweight scope for comprehensions.
1270+
1271+
Unlike enter(), this reuses the same LowLevelIRBuilder (same basic
1272+
blocks and registers) but pushes new symtable and fn_info entries
1273+
so that the closure machinery sees a scope boundary.
1274+
"""
1275+
self.builders.append(self.builder)
1276+
# Copy the parent symtable so variables from the enclosing scope
1277+
# (e.g. function parameters used as the comprehension iterable)
1278+
# remain accessible. The comprehension is inlined (same basic blocks
1279+
# and registers), so the parent's register references are still valid.
1280+
self.symtables.append(dict(self.symtables[-1]))
1281+
self.runtime_args.append([])
1282+
self.fn_info = fn_info
1283+
self.fn_infos.append(self.fn_info)
1284+
self.ret_types.append(none_rprimitive)
1285+
self.nonlocal_control.append(BaseNonlocalControl())
1286+
try:
1287+
yield
1288+
finally:
1289+
self.builders.pop()
1290+
self.symtables.pop()
1291+
self.runtime_args.pop()
1292+
self.ret_types.pop()
1293+
self.fn_infos.pop()
1294+
self.nonlocal_control.pop()
1295+
self.builder = self.builders[-1]
1296+
self.fn_info = self.fn_infos[-1]
1297+
12661298
@contextmanager
12671299
def enter_method(
12681300
self,

mypyc/irbuild/callable_class.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,13 @@ def instantiate_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> Value:
224224
# - A generator function: the callable class is instantiated
225225
# from the '__next__' method of the generator class, and hence the
226226
# environment of the generator class is used.
227-
# - Regular function: we use the environment of the original function.
227+
# - Regular function or comprehension scope: we use the environment
228+
# of the original function. Comprehension scopes are inlined (no
229+
# callable class), so they fall into this case despite is_nested.
228230
curr_env_reg = None
229231
if builder.fn_info.is_generator:
230232
curr_env_reg = builder.fn_info.generator_class.curr_env_reg
231-
elif builder.fn_info.is_nested:
233+
elif builder.fn_info.is_nested and not builder.fn_info.is_comprehension_scope:
232234
curr_env_reg = builder.fn_info.callable_class.curr_env_reg
233235
elif builder.fn_info.contains_nested:
234236
curr_env_reg = builder.fn_info.curr_env_reg

mypyc/irbuild/context.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
is_decorated: bool = False,
2424
in_non_ext: bool = False,
2525
add_nested_funcs_to_env: bool = False,
26+
is_comprehension_scope: bool = False,
2627
) -> None:
2728
self.fitem = fitem
2829
self.name = name
@@ -49,6 +50,11 @@ def __init__(
4950
self.is_decorated = is_decorated
5051
self.in_non_ext = in_non_ext
5152
self.add_nested_funcs_to_env = add_nested_funcs_to_env
53+
# Comprehension scopes are lightweight scope boundaries created when
54+
# a comprehension body contains a lambda. The comprehension is still
55+
# inlined (same basic blocks), but we push a new FuncInfo so the
56+
# closure machinery can capture loop variables through env classes.
57+
self.is_comprehension_scope = is_comprehension_scope
5258

5359
# TODO: add field for ret_type: RType = none_rprimitive
5460

mypyc/irbuild/env_class.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class is generated, the function environment has not yet been
5656
)
5757
env_class.reuse_freed_instance = True
5858
env_class.attributes[SELF_NAME] = RInstance(env_class)
59-
if builder.fn_info.is_nested:
59+
if builder.fn_info.is_nested and builder.fn_infos[-2]._env_class is not None:
6060
# If the function is nested, its environment class must contain an environment
6161
# attribute pointing to its encapsulating functions' environment class.
6262
env_class.attributes[ENV_ATTR_NAME] = RInstance(builder.fn_infos[-2].env_class)
@@ -73,11 +73,14 @@ def finalize_env_class(builder: IRBuilder, prefix: str = "") -> None:
7373

7474
# Iterate through the function arguments and replace local definitions (using registers)
7575
# that were previously added to the environment with references to the function's
76-
# environment class.
77-
if builder.fn_info.is_nested:
78-
add_args_to_env(builder, local=False, base=builder.fn_info.callable_class, prefix=prefix)
79-
else:
80-
add_args_to_env(builder, local=False, base=builder.fn_info, prefix=prefix)
76+
# environment class. Comprehension scopes have no arguments to add.
77+
if not builder.fn_info.is_comprehension_scope:
78+
if builder.fn_info.is_nested:
79+
add_args_to_env(
80+
builder, local=False, base=builder.fn_info.callable_class, prefix=prefix
81+
)
82+
else:
83+
add_args_to_env(builder, local=False, base=builder.fn_info, prefix=prefix)
8184

8285

8386
def instantiate_env_class(builder: IRBuilder) -> Value:
@@ -86,7 +89,7 @@ def instantiate_env_class(builder: IRBuilder) -> Value:
8689
Call(builder.fn_info.env_class.ctor, [], builder.fn_info.fitem.line)
8790
)
8891

89-
if builder.fn_info.is_nested:
92+
if builder.fn_info.is_nested and not builder.fn_info.is_comprehension_scope:
9093
builder.fn_info.callable_class._curr_env_reg = curr_env_reg
9194
builder.add(
9295
SetAttr(
@@ -97,7 +100,22 @@ def instantiate_env_class(builder: IRBuilder) -> Value:
97100
)
98101
)
99102
else:
103+
# Top-level functions and comprehension scopes store env reg directly.
100104
builder.fn_info._curr_env_reg = curr_env_reg
105+
# Comprehension scopes link to parent env if it exists.
106+
if (
107+
builder.fn_info.is_nested
108+
and builder.fn_infos[-2]._env_class is not None
109+
and builder.fn_infos[-2]._curr_env_reg is not None
110+
):
111+
builder.add(
112+
SetAttr(
113+
curr_env_reg,
114+
ENV_ATTR_NAME,
115+
builder.fn_infos[-2].curr_env_reg,
116+
builder.fn_info.fitem.line,
117+
)
118+
)
101119

102120
return curr_env_reg
103121

@@ -114,7 +132,7 @@ def load_env_registers(builder: IRBuilder, prefix: str = "") -> None:
114132

115133
fn_info = builder.fn_info
116134
fitem = fn_info.fitem
117-
if fn_info.is_nested:
135+
if fn_info.is_nested and builder.fn_infos[-2]._env_class is not None:
118136
load_outer_envs(builder, fn_info.callable_class)
119137
# If this is a FuncDef, then make sure to load the FuncDef into its own environment
120138
# class so that the function can be called recursively.
@@ -155,7 +173,8 @@ def load_outer_envs(builder: IRBuilder, base: ImplicitClass) -> None:
155173

156174
# Load the first outer environment. This one is special because it gets saved in the
157175
# FuncInfo instance's prev_env_reg field.
158-
if index > 1:
176+
has_outer = index > 1 or (index == 1 and builder.fn_infos[1].contains_nested)
177+
if has_outer and builder.fn_infos[index]._env_class is not None:
159178
# outer_env = builder.fn_infos[index].environment
160179
outer_env = builder.symtables[index]
161180
if isinstance(base, GeneratorClass):
@@ -167,6 +186,8 @@ def load_outer_envs(builder: IRBuilder, base: ImplicitClass) -> None:
167186

168187
# Load the remaining outer environments into registers.
169188
while index > 1:
189+
if builder.fn_infos[index]._env_class is None:
190+
break
170191
# outer_env = builder.fn_infos[index].environment
171192
outer_env = builder.symtables[index]
172193
env_reg = load_outer_env(builder, env_reg, outer_env)
@@ -224,7 +245,9 @@ def add_vars_to_env(builder: IRBuilder, prefix: str = "") -> None:
224245
env_for_func: FuncInfo | ImplicitClass = builder.fn_info
225246
if builder.fn_info.is_generator:
226247
env_for_func = builder.fn_info.generator_class
227-
elif builder.fn_info.is_nested or builder.fn_info.in_non_ext:
248+
elif (
249+
builder.fn_info.is_nested or builder.fn_info.in_non_ext
250+
) and not builder.fn_info.is_comprehension_scope:
228251
env_for_func = builder.fn_info.callable_class
229252

230253
if builder.fn_info.fitem in builder.free_variables:

mypyc/irbuild/expression.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,20 +1164,45 @@ def _visit_display(
11641164

11651165

11661166
# Comprehensions
1167+
#
1168+
# mypyc always inlines comprehensions (the loop body is emitted directly into
1169+
# the enclosing function's IR, no implicit function call like CPython).
1170+
#
1171+
# However, when a comprehension body contains a lambda, we need a lightweight
1172+
# scope boundary so the closure/env-class machinery can see the comprehension
1173+
# as a separate scope. The comprehension is still inlined (same basic blocks
1174+
# and registers), but we push a new FuncInfo and set up an env class so the
1175+
# lambda can capture loop variables through the standard env-class chain.
11671176

11681177

11691178
def transform_list_comprehension(builder: IRBuilder, o: ListComprehension) -> Value:
1170-
return translate_list_comprehension(builder, o.generator)
1179+
gen = o.generator
1180+
if gen in builder.comprehension_to_fitem:
1181+
return _translate_comprehension_with_scope(
1182+
builder, gen, lambda: translate_list_comprehension(builder, gen)
1183+
)
1184+
return translate_list_comprehension(builder, gen)
11711185

11721186

11731187
def transform_set_comprehension(builder: IRBuilder, o: SetComprehension) -> Value:
1174-
return translate_set_comprehension(builder, o.generator)
1188+
gen = o.generator
1189+
if gen in builder.comprehension_to_fitem:
1190+
return _translate_comprehension_with_scope(
1191+
builder, gen, lambda: translate_set_comprehension(builder, gen)
1192+
)
1193+
return translate_set_comprehension(builder, gen)
11751194

11761195

11771196
def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehension) -> Value:
11781197
if raise_error_if_contains_unreachable_names(builder, o):
11791198
return builder.none()
11801199

1200+
if o in builder.comprehension_to_fitem:
1201+
return _translate_comprehension_with_scope(builder, o, lambda: _dict_comp_body(builder, o))
1202+
return _dict_comp_body(builder, o)
1203+
1204+
1205+
def _dict_comp_body(builder: IRBuilder, o: DictionaryComprehension) -> Value:
11811206
d = builder.maybe_spill(builder.call_c(dict_new_op, [], o.line))
11821207
loop_params = list(zip(o.indices, o.sequences, o.condlists, o.is_async))
11831208

@@ -1190,6 +1215,31 @@ def gen_inner_stmts() -> None:
11901215
return builder.read(d, o.line)
11911216

11921217

1218+
def _translate_comprehension_with_scope(
1219+
builder: IRBuilder,
1220+
node: GeneratorExpr | DictionaryComprehension,
1221+
gen_body: Callable[[], Value],
1222+
) -> Value:
1223+
"""Wrap a comprehension body with a lightweight scope for closure capture."""
1224+
from mypyc.irbuild.context import FuncInfo
1225+
from mypyc.irbuild.env_class import add_vars_to_env, finalize_env_class, setup_env_class
1226+
1227+
comprehension_fdef = builder.comprehension_to_fitem[node]
1228+
fn_info = FuncInfo(
1229+
fitem=comprehension_fdef,
1230+
name=comprehension_fdef.name,
1231+
is_nested=True,
1232+
contains_nested=True,
1233+
is_comprehension_scope=True,
1234+
)
1235+
1236+
with builder.enter_scope(fn_info):
1237+
setup_env_class(builder)
1238+
finalize_env_class(builder)
1239+
add_vars_to_env(builder)
1240+
return gen_body()
1241+
1242+
11931243
# Misc
11941244

11951245

@@ -1206,6 +1256,16 @@ def get_arg(arg: Expression | None) -> Value:
12061256

12071257
def transform_generator_expr(builder: IRBuilder, o: GeneratorExpr) -> Value:
12081258
builder.warning("Treating generator comprehension as list", o.line)
1259+
if o in builder.comprehension_to_fitem:
1260+
return builder.primitive_op(
1261+
iter_op,
1262+
[
1263+
_translate_comprehension_with_scope(
1264+
builder, o, lambda: translate_list_comprehension(builder, o)
1265+
)
1266+
],
1267+
o.line,
1268+
)
12091269
return builder.primitive_op(iter_op, [translate_list_comprehension(builder, o)], o.line)
12101270

12111271

0 commit comments

Comments
 (0)