Skip to content

Commit f3e2a97

Browse files
authored
[mypyc] Fix __init_subclass__ running before ClassVar instantiations (#20916)
Fixes mypyc/mypyc#1086 We also ran into this in SQLGlot; To compile our AST we swapped any metaclasses for `__init_subclass__` in which we dynamically set some attributes in it like so: ```Python3 class Expression: arg_types: ClassVar[...] = {} def __init_subclass__(cls, **kwargs: object) -> None: cls.required_args = {k for k, v in cls.arg_types.items() if v} class Select(Expression): arg_types: ClassVar[...] = {...} ``` I introduced the `InitSubclass` primitive to allow for fine-grained control of when this gets called and moved it from `ExtClassBuilder::__init__` (i.e `allocate_class > CPyType_FromTemplate`) to `ExtClassBuilder::finalize`
1 parent 32f0502 commit f3e2a97

File tree

7 files changed

+133
-72
lines changed

7 files changed

+133
-72
lines changed

mypyc/irbuild/classdef.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
import_op,
8181
not_implemented_op,
8282
py_calc_meta_op,
83+
py_init_subclass_op,
8384
pytype_from_template_op,
8485
type_object_op,
8586
)
@@ -290,7 +291,7 @@ class ExtClassBuilder(ClassBuilder):
290291
def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None:
291292
super().__init__(builder, cdef)
292293
# If the class is not decorated, generate an extension class for it.
293-
self.type_obj: Value | None = allocate_class(builder, cdef)
294+
self.type_obj: Value = allocate_class(builder, cdef)
294295

295296
def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool:
296297
"""Controls whether to skip generating a default for an attribute."""
@@ -315,6 +316,9 @@ def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None:
315316
self.builder.init_final_static(lvalue, value, self.cdef.name)
316317

317318
def finalize(self, ir: ClassIR) -> None:
319+
# Call __init_subclass__ after class attributes have been set
320+
self.builder.call_c(py_init_subclass_op, [self.type_obj], self.cdef.line)
321+
318322
attrs_with_defaults, default_assignments = find_attr_initializers(
319323
self.builder, self.cdef, self.skip_attr_default
320324
)

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,7 @@ PyObject *CPyType_FromTemplate(PyObject *template_,
920920
PyObject *CPyType_FromTemplateWrapper(PyObject *template_,
921921
PyObject *orig_bases,
922922
PyObject *modname);
923+
bool CPy_InitSubclass(PyObject *type);
923924
int CPyDataclass_SleightOfHand(PyObject *dataclass_dec, PyObject *tp,
924925
PyObject *dict, PyObject *annotations,
925926
PyObject *dataclass_type);

mypyc/lib-rt/misc_ops.c

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,6 @@ PyObject *CPyType_FromTemplate(PyObject *template,
303303
if (PyObject_SetAttr((PyObject *)t, mypyc_interned_str.__module__, modname) < 0)
304304
goto error;
305305

306-
if (init_subclass((PyTypeObject *)t, NULL))
307-
goto error;
308-
309306
Py_XDECREF(dummy_class);
310307

311308
// Unlike the tp_doc slots of most other object, a heap type's tp_doc
@@ -338,6 +335,16 @@ PyObject *CPyType_FromTemplate(PyObject *template,
338335
return NULL;
339336
}
340337

338+
// Call __init_subclass__ on the appropriate base class of type.
339+
// This is separated from CPyType_FromTemplate so that class attributes
340+
// can be set before __init_subclass__ is called.
341+
bool CPy_InitSubclass(PyObject *type) {
342+
if (init_subclass((PyTypeObject *)type, NULL)) {
343+
return false;
344+
}
345+
return true;
346+
}
347+
341348
static int _CPy_UpdateObjFromDict(PyObject *obj, PyObject *dict)
342349
{
343350
Py_ssize_t pos = 0;

mypyc/primitives/misc_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,15 @@
239239
error_kind=ERR_MAGIC,
240240
)
241241

242+
# Call __init_subclass__ on a type. Separated from CPyType_FromTemplate
243+
# so that class attributes can be set before __init_subclass__ is called.
244+
py_init_subclass_op = custom_op(
245+
arg_types=[object_rprimitive],
246+
return_type=bool_rprimitive,
247+
c_function_name="CPy_InitSubclass",
248+
error_kind=ERR_FALSE,
249+
)
250+
242251
# Create a dataclass from an extension class. See
243252
# CPyDataclass_SleightOfHand for more docs.
244253
dataclass_sleight_of_hand = custom_op(

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class object:
4242
__class__: type
4343
def __new__(cls) -> Self: pass
4444
def __init__(self) -> None: pass
45+
def __init_subclass__(cls, **kwargs: object) -> None: pass
4546
def __eq__(self, x: object) -> bool: pass
4647
def __ne__(self, x: object) -> bool: pass
4748
def __str__(self) -> str: pass

mypyc/test-data/irbuild-classes.test

Lines changed: 74 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -229,36 +229,39 @@ def __top_level__():
229229
r35 :: str
230230
r36 :: i32
231231
r37 :: bit
232-
r38 :: object
233-
r39 :: str
234-
r40, r41 :: object
235-
r42 :: str
236-
r43 :: tuple
237-
r44 :: i32
238-
r45 :: bit
239-
r46 :: dict
240-
r47 :: str
241-
r48 :: i32
242-
r49 :: bit
243-
r50, r51 :: object
244-
r52 :: dict
245-
r53 :: str
246-
r54 :: object
247-
r55 :: dict
248-
r56 :: str
249-
r57, r58 :: object
250-
r59 :: tuple
251-
r60 :: str
252-
r61, r62 :: object
253-
r63, r64 :: bool
254-
r65, r66 :: str
255-
r67 :: tuple
256-
r68 :: i32
257-
r69 :: bit
258-
r70 :: dict
259-
r71 :: str
260-
r72 :: i32
261-
r73 :: bit
232+
r38 :: bool
233+
r39 :: object
234+
r40 :: str
235+
r41, r42 :: object
236+
r43 :: str
237+
r44 :: tuple
238+
r45 :: i32
239+
r46 :: bit
240+
r47 :: dict
241+
r48 :: str
242+
r49 :: i32
243+
r50 :: bit
244+
r51 :: bool
245+
r52, r53 :: object
246+
r54 :: dict
247+
r55 :: str
248+
r56 :: object
249+
r57 :: dict
250+
r58 :: str
251+
r59, r60 :: object
252+
r61 :: tuple
253+
r62 :: str
254+
r63, r64 :: object
255+
r65, r66 :: bool
256+
r67, r68 :: str
257+
r69 :: tuple
258+
r70 :: i32
259+
r71 :: bit
260+
r72 :: dict
261+
r73 :: str
262+
r74 :: i32
263+
r75 :: bit
264+
r76 :: bool
262265
L0:
263266
r0 = builtins :: module
264267
r1 = load_address _Py_NoneStruct
@@ -306,44 +309,47 @@ L2:
306309
r35 = 'C'
307310
r36 = PyDict_SetItem(r34, r35, r27)
308311
r37 = r36 >= 0 :: signed
309-
r38 = <error> :: object
310-
r39 = '__main__'
311-
r40 = __main__.S_template :: type
312-
r41 = CPyType_FromTemplate(r40, r38, r39)
313-
r42 = '__mypyc_attrs__'
314-
r43 = CPyTuple_LoadEmptyTupleConstant()
315-
r44 = PyObject_SetAttr(r41, r42, r43)
316-
r45 = r44 >= 0 :: signed
317-
__main__.S = r41 :: type
318-
r46 = __main__.globals :: static
319-
r47 = 'S'
320-
r48 = PyDict_SetItem(r46, r47, r41)
321-
r49 = r48 >= 0 :: signed
322-
r50 = __main__.C :: type
323-
r51 = __main__.S :: type
324-
r52 = __main__.globals :: static
325-
r53 = 'Generic'
326-
r54 = CPyDict_GetItem(r52, r53)
327-
r55 = __main__.globals :: static
328-
r56 = 'T'
329-
r57 = CPyDict_GetItem(r55, r56)
330-
r58 = PyObject_GetItem(r54, r57)
331-
r59 = PyTuple_Pack(3, r50, r51, r58)
332-
r60 = '__main__'
333-
r61 = __main__.D_template :: type
334-
r62 = CPyType_FromTemplate(r61, r59, r60)
335-
r63 = D_trait_vtable_setup()
336-
r64 = D_coroutine_setup(r62)
337-
r65 = '__mypyc_attrs__'
338-
r66 = '__dict__'
339-
r67 = PyTuple_Pack(1, r66)
340-
r68 = PyObject_SetAttr(r62, r65, r67)
341-
r69 = r68 >= 0 :: signed
342-
__main__.D = r62 :: type
343-
r70 = __main__.globals :: static
344-
r71 = 'D'
345-
r72 = PyDict_SetItem(r70, r71, r62)
346-
r73 = r72 >= 0 :: signed
312+
r38 = CPy_InitSubclass(r27)
313+
r39 = <error> :: object
314+
r40 = '__main__'
315+
r41 = __main__.S_template :: type
316+
r42 = CPyType_FromTemplate(r41, r39, r40)
317+
r43 = '__mypyc_attrs__'
318+
r44 = CPyTuple_LoadEmptyTupleConstant()
319+
r45 = PyObject_SetAttr(r42, r43, r44)
320+
r46 = r45 >= 0 :: signed
321+
__main__.S = r42 :: type
322+
r47 = __main__.globals :: static
323+
r48 = 'S'
324+
r49 = PyDict_SetItem(r47, r48, r42)
325+
r50 = r49 >= 0 :: signed
326+
r51 = CPy_InitSubclass(r42)
327+
r52 = __main__.C :: type
328+
r53 = __main__.S :: type
329+
r54 = __main__.globals :: static
330+
r55 = 'Generic'
331+
r56 = CPyDict_GetItem(r54, r55)
332+
r57 = __main__.globals :: static
333+
r58 = 'T'
334+
r59 = CPyDict_GetItem(r57, r58)
335+
r60 = PyObject_GetItem(r56, r59)
336+
r61 = PyTuple_Pack(3, r52, r53, r60)
337+
r62 = '__main__'
338+
r63 = __main__.D_template :: type
339+
r64 = CPyType_FromTemplate(r63, r61, r62)
340+
r65 = D_trait_vtable_setup()
341+
r66 = D_coroutine_setup(r64)
342+
r67 = '__mypyc_attrs__'
343+
r68 = '__dict__'
344+
r69 = PyTuple_Pack(1, r68)
345+
r70 = PyObject_SetAttr(r64, r67, r69)
346+
r71 = r70 >= 0 :: signed
347+
__main__.D = r64 :: type
348+
r72 = __main__.globals :: static
349+
r73 = 'D'
350+
r74 = PyDict_SetItem(r72, r73, r64)
351+
r75 = r74 >= 0 :: signed
352+
r76 = CPy_InitSubclass(r64)
347353
return 1
348354

349355
[case testIsInstance]

mypyc/test-data/run-classes.test

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,39 @@ assert f() == 10
11131113
A.x = 200
11141114
assert f() == 200
11151115

1116+
[case testInitSubclassWithClassVar]
1117+
from typing import ClassVar
1118+
1119+
class Base:
1120+
name: ClassVar[str] = "base"
1121+
required: ClassVar[int] = -1
1122+
1123+
def __init_subclass__(cls, **kwargs: object) -> None:
1124+
cls.required = len(cls.name)
1125+
1126+
class Child(Base):
1127+
name: ClassVar[str] = "child"
1128+
1129+
class GrandChild(Child):
1130+
name: ClassVar[str] = "grandchild"
1131+
1132+
class NoOverride(Base):
1133+
pass
1134+
1135+
[file driver.py]
1136+
from native import Base, Child, GrandChild, NoOverride
1137+
1138+
# __init_subclass__ should see the subclass's own ClassVar values
1139+
assert Child.name == "child"
1140+
assert Child.required == 5, f"expected 5, got {Child.required}"
1141+
1142+
assert GrandChild.name == "grandchild"
1143+
assert GrandChild.required == 10, f"expected 10, got {GrandChild.required}"
1144+
1145+
# No override should use inherited value
1146+
assert NoOverride.name == "base"
1147+
assert NoOverride.required == 4, f"expected 4, got {NoOverride.required}"
1148+
11161149
[case testDefaultVars]
11171150
from typing import Optional
11181151
class A:

0 commit comments

Comments
 (0)