Skip to content

Commit 61c1906

Browse files
committed
Ensure object in sys.modules is the native module
1 parent 641951c commit 61c1906

5 files changed

Lines changed: 94 additions & 30 deletions

File tree

mypyc/irbuild/builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
BITMAP_BITS,
6363
EXT_SUFFIX,
6464
GENERATOR_ATTRIBUTE_PREFIX,
65+
MODULE_PREFIX,
6566
SELF_NAME,
6667
TEMP_ATTR_NAME,
6768
shared_lib_name,
@@ -83,6 +84,7 @@
8384
InitStatic,
8485
Integer,
8586
IntOp,
87+
LoadAddress,
8688
LoadGlobal,
8789
LoadStatic,
8890
MethodCall,
@@ -115,6 +117,7 @@
115117
is_tagged,
116118
is_tuple_rprimitive,
117119
none_rprimitive,
120+
object_pointer_rprimitive,
118121
object_rprimitive,
119122
str_rprimitive,
120123
)
@@ -491,6 +494,12 @@ def gen_import(self, module: str, line: int) -> None:
491494
exec_func = self.add(
492495
LoadGlobal(c_pointer_rprimitive, f"CPyExec_{exported_name(module)}")
493496
)
497+
module_static = self.add(
498+
LoadAddress(
499+
object_pointer_rprimitive,
500+
f"{MODULE_PREFIX}{exported_name(module + '__internal')}",
501+
)
502+
)
494503
group_name = self.mapper.group_map.get(self.module_name)
495504
if group_name is not None:
496505
shared_lib_mod_name = shared_lib_name(group_name)
@@ -509,6 +518,7 @@ def gen_import(self, module: str, line: int) -> None:
509518
self.load_str(module, line),
510519
init_only_func,
511520
exec_func,
521+
module_static,
512522
shared_lib_file,
513523
ext_suffix,
514524
Integer(1 if is_pkg else 0, c_pyssize_t_rprimitive),

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,7 @@ PyObject *CPyImport_GetNativeAttrs(PyObject *mod_id, PyObject *names, PyObject *
964964
PyObject *CPyImport_ImportNative(PyObject *module_name,
965965
PyObject *(*init_only_fn)(void),
966966
int (*exec_fn)(PyObject *),
967+
CPyModule **module_static,
967968
PyObject *shared_lib_file, PyObject *ext_suffix,
968969
Py_ssize_t is_package);
969970

mypyc/lib-rt/misc_ops.c

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,10 +1410,12 @@ static int CPyImport_SetModuleSpec(PyObject *modobj, PyObject *module_name,
14101410
PyObject *CPyImport_ImportNative(PyObject *module_name,
14111411
PyObject *(*init_only_fn)(void),
14121412
int (*exec_fn)(PyObject *),
1413+
CPyModule **module_static,
14131414
PyObject *shared_lib_file, PyObject *ext_suffix,
14141415
Py_ssize_t is_package) {
14151416
PyObject *parent_module = NULL;
14161417
PyObject *child_name = NULL;
1418+
PyObject *exc_type, *exc_val, *exc_tb;
14171419
Py_ssize_t name_len = PyUnicode_GetLength(module_name);
14181420
if (name_len < 0) {
14191421
return NULL;
@@ -1449,27 +1451,45 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
14491451
return NULL;
14501452
}
14511453

1452-
PyObject *modobj = init_only_fn();
1453-
if (modobj == NULL) {
1454+
PyObject *existing = PyDict_GetItemWithError(module_dict, module_name);
1455+
if (existing != NULL) {
1456+
if (*module_static != NULL) {
1457+
if (existing == (PyObject *)*module_static) {
1458+
Py_INCREF(existing);
1459+
Py_XDECREF(parent_module);
1460+
Py_XDECREF(child_name);
1461+
return existing;
1462+
}
1463+
PyErr_Format(PyExc_ImportError,
1464+
"native module '%U' in sys.modules was replaced after initialization",
1465+
module_name);
1466+
Py_XDECREF(parent_module);
1467+
Py_XDECREF(child_name);
1468+
return NULL;
1469+
}
1470+
}
1471+
if (PyErr_Occurred()) {
14541472
Py_XDECREF(parent_module);
14551473
Py_XDECREF(child_name);
14561474
return NULL;
14571475
}
14581476

1459-
// Check if the module was already imported (e.g. via CPyInit_* from the
1460-
// standard import machinery, or by a previous CPyImport_ImportNative call).
1461-
int already_imported = PyDict_Contains(module_dict, module_name);
1462-
if (already_imported < 0) {
1477+
PyObject *modobj = init_only_fn();
1478+
if (modobj == NULL) {
14631479
Py_XDECREF(parent_module);
14641480
Py_XDECREF(child_name);
1465-
Py_DECREF(modobj);
14661481
return NULL;
14671482
}
14681483

1469-
if (!already_imported) {
1470-
if (PyObject_SetItem(module_dict, module_name, modobj) < 0) {
1471-
goto fail;
1472-
}
1484+
if (PyObject_SetItem(module_dict, module_name, modobj) < 0) {
1485+
goto fail;
1486+
}
1487+
1488+
if (*module_static != (CPyModule *)modobj) {
1489+
PyErr_Format(PyExc_ImportError,
1490+
"native module '%U' was initialized inconsistently",
1491+
module_name);
1492+
goto fail;
14731493
}
14741494

14751495
// Set __package__ before executing the module body so it is available
@@ -1512,15 +1532,6 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
15121532
goto fail;
15131533
}
15141534

1515-
// If the module was already imported, skip executing the module body
1516-
// (it was already executed). This handles circular imports and modules
1517-
// first imported via the standard Python import machinery (CPyInit_*).
1518-
if (already_imported) {
1519-
Py_XDECREF(parent_module);
1520-
Py_XDECREF(child_name);
1521-
return modobj;
1522-
}
1523-
15241535
// Now execute the module body, with __file__ and __package__ already set.
15251536
if (exec_fn(modobj) != 0) {
15261537
goto fail;
@@ -1537,15 +1548,12 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
15371548
return modobj;
15381549

15391550
fail:
1540-
// Clean up on failure: if we added the module to sys.modules, remove it
1541-
// so that a subsequent import attempt will retry initialization.
1542-
if (!already_imported) {
1543-
PyObject *exc_type, *exc_val, *exc_tb;
1544-
PyErr_Fetch(&exc_type, &exc_val, &exc_tb);
1545-
PyObject_DelItem(module_dict, module_name);
1546-
PyErr_Clear();
1547-
PyErr_Restore(exc_type, exc_val, exc_tb);
1548-
}
1551+
// Clean up on failure so that a subsequent import attempt will retry
1552+
// initialization.
1553+
PyErr_Fetch(&exc_type, &exc_val, &exc_tb);
1554+
PyObject_DelItem(module_dict, module_name);
1555+
PyErr_Clear();
1556+
PyErr_Restore(exc_type, exc_val, exc_tb);
15491557
Py_XDECREF(parent_module);
15501558
Py_XDECREF(child_name);
15511559
Py_DECREF(modobj);

mypyc/primitives/misc_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,13 @@
139139

140140
# Import a native same-group module directly via C-level init/exec functions.
141141
native_import_op = custom_op(
142-
# (module name, init-only function, exec function, shared lib __file__, ext suffix, is_package)
142+
# (module name, init-only function, exec function, module static,
143+
# shared lib __file__, ext suffix, is_package)
143144
arg_types=[
144145
str_rprimitive,
145146
c_pointer_rprimitive,
146147
c_pointer_rprimitive,
148+
object_pointer_rprimitive,
147149
object_rprimitive,
148150
str_rprimitive,
149151
c_pyssize_t_rprimitive,

mypyc/test-data/run-multimodule.test

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,49 @@ else:
369369
assert not hasattr(other_pkg, "other_sub")
370370
assert "other_pkg.other_sub" not in sys.modules
371371

372+
[case testNativeImportUsesExistingNativeSysModulesObject]
373+
# separate: [(["native.py", "other_top.py"], "testgroup")]
374+
import other_top
375+
376+
def f() -> int:
377+
return other_top.value
378+
[file other_top.py]
379+
value = 42
380+
[file driver.py]
381+
import native
382+
import other_top
383+
384+
assert native.other_top is other_top
385+
assert native.f() == 42
386+
387+
[case testNativeImportRejectsForeignSysModulesObject]
388+
# separate: [(["native.py", "other_top.py"], "testgroup")]
389+
import other_top
390+
391+
def f() -> int:
392+
return other_top.value
393+
[file other_top.py]
394+
value = 42
395+
[file driver.py]
396+
import sys
397+
import types
398+
399+
import other_top
400+
401+
fake = types.ModuleType("other_top")
402+
fake.value = -1
403+
sys.modules["other_top"] = fake
404+
405+
try:
406+
import native
407+
except ImportError as e:
408+
assert "other_top" in str(e)
409+
assert "replaced after initialization" in str(e)
410+
else:
411+
assert False, "import native should fail"
412+
413+
assert sys.modules["other_top"] is fake
414+
372415
[case testMultiModuleFastpaths]
373416
[file other_main.py]
374417

0 commit comments

Comments
 (0)