Skip to content

Commit bfe5024

Browse files
authored
[mypyc] Fix crash on multiple nested decorated functions with same name (#20666)
This happens when a nested function has different definitions in different condition branches. mypyc currently crashes in this case when calling `get_func_target` in `transform_decorator` https://github.com/python/mypy/blob/master/mypyc/irbuild/function.py#L118. The `FuncDef` objects for the definitions after the first one will have the `original_def` attribute set. This makes `get_func_target` lookup this original def but it doesn't find it when the original def is a decorator because registers are only created for `FuncDef`s. Extracting a `FuncDef` from the original decorator and looking it up instead fixes the crash.
1 parent 94a3cf6 commit bfe5024

3 files changed

Lines changed: 107 additions & 7 deletions

File tree

mypyc/irbuild/function.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,9 +852,11 @@ def get_func_target(builder: IRBuilder, fdef: FuncDef) -> AssignmentTarget:
852852
If the function was not already defined somewhere, then define it
853853
and add it to the current environment.
854854
"""
855-
if fdef.original_def:
855+
if orig := fdef.original_def:
856+
if isinstance(orig, Decorator):
857+
orig = orig.func
856858
# Get the target associated with the previously defined FuncDef.
857-
return builder.lookup(fdef.original_def)
859+
return builder.lookup(orig)
858860

859861
if builder.fn_info.is_generator or builder.fn_info.add_nested_funcs_to_env:
860862
return builder.lookup(fdef)

mypyc/test-data/run-async.test

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,7 +1809,7 @@ from typing import Any, Callable, TypeVar, cast
18091809
F = TypeVar("F", bound=Callable[..., Any])
18101810

18111811

1812-
def mult(x: int) -> Callable[[F], F]:
1812+
def mult_different_wrapper_names(x: int) -> Callable[[F], F]:
18131813
def decorate(fn: F) -> F:
18141814
def get_multiplier() -> int:
18151815
return x
@@ -1829,18 +1829,47 @@ def mult(x: int) -> Callable[[F], F]:
18291829

18301830
return decorate
18311831

1832-
@mult(3)
1832+
def mult_same_wrapper_names(x: int) -> Callable[[F], F]:
1833+
def decorate(fn: F) -> F:
1834+
def get_multiplier() -> int:
1835+
return x
1836+
1837+
if inspect.iscoroutinefunction(fn):
1838+
@functools.wraps(fn)
1839+
async def wrapper(*args, **kwargs) -> Any:
1840+
return get_multiplier() * await fn(*args, **kwargs)
1841+
else:
1842+
@functools.wraps(fn)
1843+
def wrapper(*args, **kwargs) -> Any:
1844+
return get_multiplier() * fn(*args, **kwargs)
1845+
1846+
return cast(F, wrapper)
1847+
1848+
return decorate
1849+
1850+
@mult_different_wrapper_names(3)
18331851
def identity(x: int):
18341852
return x
18351853

1836-
@mult(5)
1854+
@mult_different_wrapper_names(5)
18371855
async def async_identity(x: int):
18381856
return x
18391857

1858+
@mult_same_wrapper_names(2)
1859+
def times_two(x: int):
1860+
return x * 2
1861+
1862+
@mult_same_wrapper_names(4)
1863+
async def async_times_two(x: int):
1864+
return x * 2
1865+
18401866
def test_nested_coroutine_calls_another_nested_function():
18411867
assert identity(1) == 3
18421868
assert asyncio.run(async_identity(2)) == 10
18431869

1870+
assert times_two(3) == 12
1871+
assert asyncio.run(async_times_two(4)) == 32
1872+
18441873
[file asyncio/__init__.pyi]
18451874
from typing import Any, Generator
18461875

mypyc/test-data/run-functions.test

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,62 @@ def if_else(flag: int) -> str:
154154
return 'if_else.inner: third definition'
155155
return inner()
156156

157+
def wrap(f: Callable[[], str]):
158+
def inner():
159+
return 'wrapped ' + f()
160+
return inner
161+
162+
def if_else_all_decorated(flag: int) -> str:
163+
def dummy_function() -> str:
164+
return 'if_else.dummy_function'
165+
166+
if flag < 0:
167+
@wrap
168+
def inner() -> str:
169+
return 'if_else.inner: first definition'
170+
elif flag > 0:
171+
@wrap
172+
def inner() -> str:
173+
return 'if_else.inner: second definition'
174+
else:
175+
@wrap
176+
def inner() -> str:
177+
return 'if_else.inner: third definition'
178+
return inner()
179+
180+
def if_else_first_decorated(flag: int) -> str:
181+
def dummy_function() -> str:
182+
return 'if_else.dummy_function'
183+
184+
if flag < 0:
185+
@wrap
186+
def inner() -> str:
187+
return 'if_else.inner: first definition'
188+
elif flag > 0:
189+
def inner() -> str:
190+
return 'if_else.inner: second definition'
191+
else:
192+
def inner() -> str:
193+
return 'if_else.inner: third definition'
194+
return inner()
195+
196+
def if_else_all_but_first_decorated(flag: int) -> str:
197+
def dummy_function() -> str:
198+
return 'if_else.dummy_function'
199+
200+
if flag < 0:
201+
def inner() -> str:
202+
return 'if_else.inner: first definition'
203+
elif flag > 0:
204+
@wrap
205+
def inner() -> str:
206+
return 'if_else.inner: second definition'
207+
else:
208+
@wrap
209+
def inner() -> str:
210+
return 'if_else.inner: third definition'
211+
return inner()
212+
157213
def for_loop() -> int:
158214
def dummy_function() -> str:
159215
return 'for_loop.dummy_function'
@@ -235,8 +291,9 @@ toplevel_lambda = lambda x: 10 + global_upvar + x
235291

236292
[file driver.py]
237293
from native import (
238-
a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, triple, if_else, for_loop, while_loop,
239-
free_vars, lambdas, outer, inner, A, toplevel_lambda
294+
a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, triple, if_else, if_else_all_decorated,
295+
if_else_first_decorated, if_else_all_but_first_decorated,
296+
for_loop, while_loop, free_vars, lambdas, outer, inner, A, toplevel_lambda
240297
)
241298

242299
assert a()() == None
@@ -266,6 +323,18 @@ assert if_else(-1) == 'if_else.inner: first definition'
266323
assert if_else(1) == 'if_else.inner: second definition'
267324
assert if_else(0) == 'if_else.inner: third definition'
268325

326+
assert if_else_all_decorated(-1) == 'wrapped if_else.inner: first definition'
327+
assert if_else_all_decorated(1) == 'wrapped if_else.inner: second definition'
328+
assert if_else_all_decorated(0) == 'wrapped if_else.inner: third definition'
329+
330+
assert if_else_first_decorated(-1) == 'wrapped if_else.inner: first definition'
331+
assert if_else_first_decorated(1) == 'if_else.inner: second definition'
332+
assert if_else_first_decorated(0) == 'if_else.inner: third definition'
333+
334+
assert if_else_all_but_first_decorated(-1) == 'if_else.inner: first definition'
335+
assert if_else_all_but_first_decorated(1) == 'wrapped if_else.inner: second definition'
336+
assert if_else_all_but_first_decorated(0) == 'wrapped if_else.inner: third definition'
337+
269338
assert for_loop() == 3
270339
assert while_loop() == 3
271340

0 commit comments

Comments
 (0)