Skip to content

Commit e711be3

Browse files
authored
[mypyc] Fix undefined attribute in nested coroutines (#20654)
When creating a callable class for a generator, mypyc creates an attribute in its env class that points at the compiled function so it can be called recursively https://github.com/python/mypy/blob/544b97ec296e16e4c96a42a1b81d5dcd841cf232/mypyc/irbuild/env_class.py#L263 The attribute is prefixed by `__mypyc_generator_attribute__` in case the function name clashes with an attribute generated by mypyc internally. However, in other places the attribute is currently not prefixed for async functions which leads to mypyc generating both the prefixed and unprefixed versions, eg. for a function `wrapper_async` the env class will have both `_wrapper_async` and `___mypyc_generator_attribute__wrapper_async` attributes with the second one potentially undefined. This ends up in a runtime crash when another nested function is called from a nested coroutine. The issue is fixed by prefixing the attribute for coroutines in addition to generators so that only the prefixed attribute is in the env class.
1 parent 544b97e commit e711be3

3 files changed

Lines changed: 53 additions & 3 deletions

File tree

mypyc/irbuild/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def get_assignment_target(
657657
# refers to the newly defined variable in that environment class. Add the
658658
# target to the table containing class environment variables, as well as the
659659
# current environment.
660-
if self.fn_info.is_generator:
660+
if self.fn_info.is_generator or self.fn_info.is_coroutine:
661661
return self.add_var_to_env_class(
662662
symbol,
663663
reg_type,

mypyc/irbuild/env_class.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,11 @@ def add_args_to_env(
200200
builder.add_local_reg(Var(bitmap_name(i)), bitmap_rprimitive, is_arg=True)
201201
else:
202202
for arg in args:
203-
if is_free_variable(builder, arg.variable) or fn_info.is_generator:
203+
if (
204+
is_free_variable(builder, arg.variable)
205+
or fn_info.is_generator
206+
or fn_info.is_coroutine
207+
):
204208
rtype = builder.type_to_rtype(arg.variable.type)
205209
assert base is not None, "base cannot be None for adding nonlocal args"
206210
builder.add_var_to_env_class(
@@ -240,7 +244,7 @@ def add_vars_to_env(builder: IRBuilder, prefix: str = "") -> None:
240244
# the same name and signature across conditional blocks
241245
# will generate different callable classes, so the callable
242246
# class that gets instantiated must be generic.
243-
if nested_fn.is_generator:
247+
if nested_fn.is_generator or nested_fn.is_coroutine:
244248
prefix = GENERATOR_ATTRIBUTE_PREFIX
245249
builder.add_var_to_env_class(
246250
nested_fn, object_rprimitive, env_for_func, reassign=False, prefix=prefix

mypyc/test-data/run-async.test

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,3 +1799,49 @@ for i in range(10):
17991799
from typing import Any, Generator
18001800

18011801
def run(x: object) -> object: ...
1802+
1803+
[case testNestedCoroutineCallsAnotherNestedFunction]
1804+
import asyncio
1805+
import functools
1806+
import inspect
1807+
from typing import Any, Callable, TypeVar, cast
1808+
1809+
F = TypeVar("F", bound=Callable[..., Any])
1810+
1811+
1812+
def mult(x: int) -> Callable[[F], F]:
1813+
def decorate(fn: F) -> F:
1814+
def get_multiplier() -> int:
1815+
return x
1816+
1817+
if inspect.iscoroutinefunction(fn):
1818+
@functools.wraps(fn)
1819+
async def wrapper_async(*args, **kwargs) -> Any:
1820+
return get_multiplier() * await fn(*args, **kwargs)
1821+
wrapper = wrapper_async
1822+
else:
1823+
@functools.wraps(fn)
1824+
def wrapper_non_async(*args, **kwargs) -> Any:
1825+
return get_multiplier() * fn(*args, **kwargs)
1826+
wrapper = wrapper_non_async
1827+
1828+
return cast(F, wrapper)
1829+
1830+
return decorate
1831+
1832+
@mult(3)
1833+
def identity(x: int):
1834+
return x
1835+
1836+
@mult(5)
1837+
async def async_identity(x: int):
1838+
return x
1839+
1840+
def test_nested_coroutine_calls_another_nested_function():
1841+
assert identity(1) == 3
1842+
assert asyncio.run(async_identity(2)) == 10
1843+
1844+
[file asyncio/__init__.pyi]
1845+
from typing import Any, Generator
1846+
1847+
def run(x: object) -> object: ...

0 commit comments

Comments
 (0)