Skip to content

Commit 51da1de

Browse files
committed
Fix ClassVar self-references in class bodies
In CPython, the class body executes as a function where earlier assignments are available to later ones (e.g. C = A | B where A is a ClassVar defined earlier in the same class). mypyc previously resolved such names via load_global(), looking them up in the module globals dict where they don't exist — causing a KeyError at runtime. Fix by tracking ClassVar names as they're defined during class body processing, and redirecting lookups to the class being built: the type object (py_get_attr) for extension classes, or the class dict (dict_get_item_op) for non-extension classes. This enables patterns like: class Parser: TYPE_TOKENS: ClassVar = {"INT", "VARCHAR"} FUNC_TOKENS: ClassVar = TYPE_TOKENS | {"FUNCTION"}
1 parent 654cac5 commit 51da1de

5 files changed

Lines changed: 216 additions & 87 deletions

File tree

mypyc/irbuild/builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,14 @@ def __init__(
228228

229229
self.visitor = visitor
230230

231+
# Class body context: tracks ClassVar names defined so far when processing
232+
# a class body, so that intra-class references (e.g. C = A | B where A is
233+
# a ClassVar defined earlier in the same class) can be resolved correctly.
234+
# Without this, mypyc looks up such names in module globals, which fails.
235+
self.class_body_classvars: dict[str, None] = {}
236+
self.class_body_obj: Value | None = None
237+
self.class_body_is_ext: bool = False
238+
231239
# This list operates similarly to a function call stack for nested functions. Whenever a
232240
# function definition begins to be generated, a FuncInfo instance is added to the stack,
233241
# and information about that function (e.g. whether it is nested, its environment class to

mypyc/irbuild/classdef.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,16 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
136136
else:
137137
cls_builder = NonExtClassBuilder(builder, cdef)
138138

139+
# Set up class body context so that intra-class ClassVar references
140+
# (e.g. C = A | B where A is defined earlier in the same class) can be
141+
# resolved from the class being built instead of module globals.
142+
saved_classvars = builder.class_body_classvars
143+
saved_obj = builder.class_body_obj
144+
saved_is_ext = builder.class_body_is_ext
145+
builder.class_body_classvars = {}
146+
builder.class_body_obj = cls_builder.class_body_obj()
147+
builder.class_body_is_ext = ir.is_ext_class
148+
139149
for stmt in cdef.defs.body:
140150
if (
141151
isinstance(stmt, (FuncDef, Decorator, OverloadedFuncDef))
@@ -178,13 +188,21 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
178188
# We want to collect class variables in a dictionary for both real
179189
# non-extension classes and fake dataclass ones.
180190
cls_builder.add_attr(lvalue, stmt)
191+
# Track this ClassVar so subsequent class body statements can reference it.
192+
if is_class_var(lvalue) or stmt.is_final_def:
193+
builder.class_body_classvars[lvalue.name] = None
181194

182195
elif isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, StrExpr):
183196
# Docstring. Ignore
184197
pass
185198
else:
186199
builder.error("Unsupported statement in class body", stmt.line)
187200

201+
# Restore previous class body context (handles nested classes).
202+
builder.class_body_classvars = saved_classvars
203+
builder.class_body_obj = saved_obj
204+
builder.class_body_is_ext = saved_is_ext
205+
188206
# Generate implicit property setters/getters
189207
for name, decl in ir.method_decls.items():
190208
if decl.implicit and decl.is_prop_getter:
@@ -231,12 +249,23 @@ def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None:
231249
def finalize(self, ir: ClassIR) -> None:
232250
"""Perform any final operations to complete the class IR"""
233251

252+
def class_body_obj(self) -> Value | None:
253+
"""Return the object to use for loading class attributes during class body init.
254+
255+
For extension classes, this is the type object. For non-extension classes,
256+
this is the class dict. Returns None if not applicable.
257+
"""
258+
return None
259+
234260

235261
class NonExtClassBuilder(ClassBuilder):
236262
def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None:
237263
super().__init__(builder, cdef)
238264
self.non_ext = self.create_non_ext_info()
239265

266+
def class_body_obj(self) -> Value | None:
267+
return self.non_ext.dict
268+
240269
def create_non_ext_info(self) -> NonExtClassInfo:
241270
non_ext_bases = populate_non_ext_bases(self.builder, self.cdef)
242271
non_ext_metaclass = find_non_ext_metaclass(self.builder, self.cdef, non_ext_bases)
@@ -292,6 +321,9 @@ def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None:
292321
# If the class is not decorated, generate an extension class for it.
293322
self.type_obj: Value | None = allocate_class(builder, cdef)
294323

324+
def class_body_obj(self) -> Value | None:
325+
return self.type_obj
326+
295327
def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool:
296328
"""Controls whether to skip generating a default for an attribute."""
297329
return False

mypyc/irbuild/expression.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,19 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
198198
else:
199199
return builder.read(builder.get_assignment_target(expr, for_read=True), expr.line)
200200

201+
# If we're evaluating a class body and this name is a ClassVar defined earlier
202+
# in the same class, load it from the class being built (type object for ext classes,
203+
# class dict for non-ext classes) instead of module globals.
204+
if builder.class_body_obj is not None and expr.name in builder.class_body_classvars:
205+
if builder.class_body_is_ext:
206+
return builder.py_get_attr(builder.class_body_obj, expr.name, expr.line)
207+
else:
208+
return builder.primitive_op(
209+
dict_get_item_op,
210+
[builder.class_body_obj, builder.load_str(expr.name)],
211+
expr.line,
212+
)
213+
201214
return builder.load_global(expr)
202215

203216

mypyc/test-data/irbuild-classes.test

Lines changed: 86 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -220,48 +220,48 @@ def __top_level__():
220220
r24 :: object
221221
r25 :: str
222222
r26, r27 :: object
223-
r28, r29 :: bool
224-
r30 :: str
225-
r31 :: tuple
226-
r32 :: i32
227-
r33 :: bit
228-
r34 :: dict
229-
r35 :: str
230-
r36 :: i32
231-
r37 :: bit
232-
r38 :: bool
233-
r39 :: object
234-
r40 :: str
235-
r41, r42 :: object
236-
r43 :: str
237-
r44 :: tuple
238-
r45 :: i32
239-
r46 :: bit
240-
r47 :: dict
241-
r48 :: str
242-
r49 :: i32
243-
r50 :: bit
244-
r51 :: bool
245-
r52, r53 :: object
246-
r54 :: dict
247-
r55 :: str
248-
r56 :: object
249-
r57 :: dict
250-
r58 :: str
251-
r59, r60 :: object
252-
r61 :: tuple
253-
r62 :: str
254-
r63, r64 :: object
255-
r65, r66 :: bool
256-
r67, r68 :: str
257-
r69 :: tuple
258-
r70 :: i32
259-
r71 :: bit
260-
r72 :: dict
261-
r73 :: str
262-
r74 :: i32
263-
r75 :: bit
264-
r76 :: bool
223+
r28 :: bool
224+
r29 :: str
225+
r30 :: tuple
226+
r31 :: i32
227+
r32 :: bit
228+
r33 :: dict
229+
r34 :: str
230+
r35 :: i32
231+
r36 :: bit
232+
r37 :: bool
233+
r38 :: object
234+
r39 :: str
235+
r40, r41 :: object
236+
r42 :: str
237+
r43 :: tuple
238+
r44 :: i32
239+
r45 :: bit
240+
r46 :: dict
241+
r47 :: str
242+
r48 :: i32
243+
r49 :: bit
244+
r50 :: bool
245+
r51, r52 :: object
246+
r53 :: dict
247+
r54 :: str
248+
r55 :: object
249+
r56 :: dict
250+
r57 :: str
251+
r58, r59 :: object
252+
r60 :: tuple
253+
r61 :: str
254+
r62, r63 :: object
255+
r64 :: bool
256+
r65, r66 :: str
257+
r67 :: tuple
258+
r68 :: i32
259+
r69 :: bit
260+
r70 :: dict
261+
r71 :: str
262+
r72 :: i32
263+
r73 :: bit
264+
r74 :: bool
265265
L0:
266266
r0 = builtins :: module
267267
r1 = load_address _Py_NoneStruct
@@ -304,51 +304,50 @@ L2:
304304
r31 = PyObject_SetAttr(r27, r29, r30)
305305
r32 = r31 >= 0 :: signed
306306
__main__.C = r27 :: type
307-
r34 = __main__.globals :: static
308-
r35 = 'C'
309-
r36 = PyDict_SetItem(r34, r35, r27)
310-
r37 = r36 >= 0 :: signed
311-
r38 = CPy_InitSubclass(r27)
312-
r39 = <error> :: object
313-
r40 = '__main__'
314-
r41 = __main__.S_template :: type
315-
r42 = CPyType_FromTemplate(r41, r39, r40)
316-
r43 = '__mypyc_attrs__'
317-
r44 = CPyTuple_LoadEmptyTupleConstant()
318-
r45 = PyObject_SetAttr(r42, r43, r44)
319-
r46 = r45 >= 0 :: signed
320-
__main__.S = r42 :: type
321-
r47 = __main__.globals :: static
322-
r48 = 'S'
323-
r49 = PyDict_SetItem(r47, r48, r42)
324-
r50 = r49 >= 0 :: signed
325-
r51 = CPy_InitSubclass(r42)
326-
r52 = __main__.C :: type
327-
r53 = __main__.S :: type
328-
r54 = __main__.globals :: static
329-
r55 = 'Generic'
330-
r56 = CPyDict_GetItem(r54, r55)
331-
r57 = __main__.globals :: static
332-
r58 = 'T'
333-
r59 = CPyDict_GetItem(r57, r58)
334-
r60 = PyObject_GetItem(r56, r59)
335-
r61 = PyTuple_Pack(3, r52, r53, r60)
336-
r62 = '__main__'
337-
r63 = __main__.D_template :: type
338-
r64 = CPyType_FromTemplate(r63, r61, r62)
339-
r65 = D_trait_vtable_setup()
340-
r66 = D_coroutine_setup(r64)
341-
r67 = '__mypyc_attrs__'
342-
r68 = '__dict__'
343-
r69 = PyTuple_Pack(1, r68)
344-
r70 = PyObject_SetAttr(r64, r67, r69)
345-
r71 = r70 >= 0 :: signed
346-
__main__.D = r64 :: type
347-
r72 = __main__.globals :: static
348-
r73 = 'D'
349-
r74 = PyDict_SetItem(r72, r73, r64)
350-
r75 = r74 >= 0 :: signed
351-
r76 = CPy_InitSubclass(r64)
307+
r33 = __main__.globals :: static
308+
r34 = 'C'
309+
r35 = PyDict_SetItem(r33, r34, r27)
310+
r36 = r35 >= 0 :: signed
311+
r37 = CPy_InitSubclass(r27)
312+
r38 = <error> :: object
313+
r39 = '__main__'
314+
r40 = __main__.S_template :: type
315+
r41 = CPyType_FromTemplate(r40, r38, r39)
316+
r42 = '__mypyc_attrs__'
317+
r43 = CPyTuple_LoadEmptyTupleConstant()
318+
r44 = PyObject_SetAttr(r41, r42, r43)
319+
r45 = r44 >= 0 :: signed
320+
__main__.S = r41 :: type
321+
r46 = __main__.globals :: static
322+
r47 = 'S'
323+
r48 = PyDict_SetItem(r46, r47, r41)
324+
r49 = r48 >= 0 :: signed
325+
r50 = CPy_InitSubclass(r41)
326+
r51 = __main__.C :: type
327+
r52 = __main__.S :: type
328+
r53 = __main__.globals :: static
329+
r54 = 'Generic'
330+
r55 = CPyDict_GetItem(r53, r54)
331+
r56 = __main__.globals :: static
332+
r57 = 'T'
333+
r58 = CPyDict_GetItem(r56, r57)
334+
r59 = PyObject_GetItem(r55, r58)
335+
r60 = PyTuple_Pack(3, r51, r52, r59)
336+
r61 = '__main__'
337+
r62 = __main__.D_template :: type
338+
r63 = CPyType_FromTemplate(r62, r60, r61)
339+
r64 = D_trait_vtable_setup()
340+
r65 = '__mypyc_attrs__'
341+
r66 = '__dict__'
342+
r67 = PyTuple_Pack(1, r66)
343+
r68 = PyObject_SetAttr(r63, r65, r67)
344+
r69 = r68 >= 0 :: signed
345+
__main__.D = r63 :: type
346+
r70 = __main__.globals :: static
347+
r71 = 'D'
348+
r72 = PyDict_SetItem(r70, r71, r63)
349+
r73 = r72 >= 0 :: signed
350+
r74 = CPy_InitSubclass(r63)
352351
return 1
353352

354353
[case testIsInstance]

mypyc/test-data/run-classes.test

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5681,3 +5681,80 @@ from native import Concrete
56815681
c = Concrete()
56825682
assert c.value() == 42
56835683
assert c.derived() == 42
5684+
5685+
[case testClassVarSelfReference]
5686+
# ClassVar initializers that reference other ClassVars from the same class.
5687+
# In CPython, the class body executes as a function where earlier assignments
5688+
# are available to later ones. mypyc must replicate this by loading from the
5689+
# class being built (type object for ext classes, class dict for non-ext)
5690+
# instead of module globals.
5691+
from typing import ClassVar, Dict, Set
5692+
5693+
class Ext:
5694+
A: ClassVar[Set[int]] = {1, 2, 3}
5695+
B: ClassVar[Set[int]] = {4, 5, 6}
5696+
C: ClassVar[Set[int]] = A | B
5697+
5698+
class ExtChained:
5699+
X: ClassVar[Set[int]] = {1, 2}
5700+
Y: ClassVar[Set[int]] = X | {3}
5701+
Z: ClassVar[Set[int]] = Y | {4}
5702+
5703+
class ExtDict:
5704+
BASE: ClassVar[Dict[str, int]] = {"a": 1, "b": 2}
5705+
EXTENDED: ClassVar[Dict[str, int]] = {**BASE, "c": 3}
5706+
5707+
class ExtSub(Ext):
5708+
E: ClassVar[Set[int]] = {7, 8}
5709+
5710+
[file driver.py]
5711+
from native import Ext, ExtChained, ExtDict, ExtSub
5712+
5713+
assert Ext.A == {1, 2, 3}
5714+
assert Ext.B == {4, 5, 6}
5715+
assert Ext.C == {1, 2, 3, 4, 5, 6}
5716+
5717+
assert ExtChained.X == {1, 2}
5718+
assert ExtChained.Y == {1, 2, 3}
5719+
assert ExtChained.Z == {1, 2, 3, 4}
5720+
5721+
assert ExtDict.BASE == {"a": 1, "b": 2}
5722+
assert ExtDict.EXTENDED == {"a": 1, "b": 2, "c": 3}
5723+
5724+
assert ExtSub.C == {1, 2, 3, 4, 5, 6}
5725+
assert ExtSub.E == {7, 8}
5726+
5727+
[case testClassVarSelfReferenceNonExt]
5728+
# Same as testClassVarSelfReference but for non-extension classes
5729+
# (e.g. decorated classes or classes with allow_interpreted_subclasses).
5730+
from typing import ClassVar, Dict, Set
5731+
from mypy_extensions import mypyc_attr
5732+
5733+
@mypyc_attr(allow_interpreted_subclasses=True)
5734+
class NonExt:
5735+
A: ClassVar[Set[str]] = {"a", "b"}
5736+
B: ClassVar[Set[str]] = {"c"}
5737+
C: ClassVar[Set[str]] = A | B
5738+
5739+
@mypyc_attr(allow_interpreted_subclasses=True)
5740+
class NonExtDict:
5741+
BASE: ClassVar[Dict[str, int]] = {"x": 1}
5742+
EXTENDED: ClassVar[Dict[str, int]] = {**BASE, "y": 2}
5743+
5744+
@mypyc_attr(allow_interpreted_subclasses=True)
5745+
class NonExtChained:
5746+
X: ClassVar[Set[int]] = {10}
5747+
Y: ClassVar[Set[int]] = X | {20}
5748+
Z: ClassVar[Set[int]] = Y | {30}
5749+
5750+
[file driver.py]
5751+
from native import NonExt, NonExtDict, NonExtChained
5752+
5753+
assert NonExt.A == {"a", "b"}
5754+
assert NonExt.B == {"c"}
5755+
assert NonExt.C == {"a", "b", "c"}
5756+
5757+
assert NonExtDict.BASE == {"x": 1}
5758+
assert NonExtDict.EXTENDED == {"x": 1, "y": 2}
5759+
5760+
assert NonExtChained.Z == {10, 20, 30}

0 commit comments

Comments
 (0)