Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/trio/_core/_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import sys
import threading
import time
import types
import weakref
from contextlib import ExitStack, contextmanager, suppress
from math import inf, nan
from typing import TYPE_CHECKING, NoReturn, TypeVar
from typing import TYPE_CHECKING, Generic, NoReturn, TypeVar
from unittest import mock

import outcome
Expand Down Expand Up @@ -2301,9 +2300,14 @@ async def test_Task_custom_sleep_data() -> None:
assert task.custom_sleep_data is None


@types.coroutine
def async_yield(value: T) -> Generator[T, None, None]:
yield value
class AsyncYield(Generic[T]):
"""Yields a value when awaited."""

def __init__(self, value: T) -> None:
self.value = value

def __await__(self) -> Generator[T, None, None]:
yield self.value


async def test_permanently_detach_coroutine_object() -> None:
Expand All @@ -2322,7 +2326,7 @@ async def detachable_coroutine(
_core.permanently_detach_coroutine_object,
task_outcome,
)
await async_yield(yield_value)
await AsyncYield(yield_value)

async with _core.open_nursery() as nursery:
nursery.start_soon(detachable_coroutine, outcome.Value(None), "I'm free!")
Expand Down Expand Up @@ -2380,8 +2384,8 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: # pragma: no cover
got = await _core.temporarily_detach_coroutine_object(abort_fn)
assert got == "not trio!"

await async_yield(1)
await async_yield(2)
await AsyncYield(1)
await AsyncYield(2)

with pytest.raises(RuntimeError) as excinfo:
await _core.reattach_detached_coroutine_object(
Expand Down
54 changes: 27 additions & 27 deletions src/trio/_core/_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
from __future__ import annotations

import enum
import types

# Jedi gets mad in test_static_tool_sees_class_members if we use collections Callable
from typing import TYPE_CHECKING, Any, Callable, NoReturn, Union, cast
from typing import TYPE_CHECKING, Any, Callable, NoReturn, Union

import attrs
import outcome

from . import _run

if TYPE_CHECKING:
from collections.abc import Awaitable, Generator
from collections.abc import Generator

from typing_extensions import TypeAlias

Expand Down Expand Up @@ -45,30 +44,31 @@ class PermanentlyDetachCoroutineObject:
type[CancelShieldedCheckpoint],
WaitTaskRescheduled,
PermanentlyDetachCoroutineObject,
object,
object, # For reattach_detached_coroutine_object(), a foreign loop's value.
]


# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
# function, but you can inside a generator, and if you decorate your generator
# with @types.coroutine, then it's even awaitable. However, it's still not a
# real async function: in particular, it isn't recognized by
# inspect.iscoroutinefunction, and it doesn't trigger the unawaited coroutine
# tracking machinery. Since our traps are public APIs, we make them real async
# functions, and then this helper takes care of the actual yield:
@types.coroutine
def _real_async_yield(
obj: MessageType,
) -> Generator[MessageType, None, None]:
return (yield obj)
class _AsyncYield:
"""Helper for the bottommost 'yield'.

You can't use 'yield' inside an async function, so implement an awaitable object to do so.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use yield inside an async function! It just doesn't do what we want.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite true.

Since this isn't a real async function, it isn't recognized by inspect.iscoroutinefunction,
and it doesn't trigger the unawaited coroutine tracking machinery. Since our traps are public
APIs, we make them real async functions, and then this helper takes care of the actual yield.
"""

def __init__(self, message: MessageType) -> None:
self.message = message

def __await__(
self,
) -> Generator[MessageType, outcome.Outcome[object], outcome.Outcome[object]]:
"""To suspend we yield one of several messages.

# Real yield value is from trio's main loop, but type checkers can't
# understand that, so we cast it to make type checkers understand.
_async_yield = cast(
"Callable[[MessageType], Awaitable[outcome.Outcome[object]]]",
_real_async_yield,
)
The event loop sends back an outcome, which we return to our awaiter (a trap function)
to handle.
"""
return (yield self.message)


async def cancel_shielded_checkpoint() -> None:
Expand All @@ -84,7 +84,7 @@ async def cancel_shielded_checkpoint() -> None:
await trio.lowlevel.checkpoint()

"""
(await _async_yield(CancelShieldedCheckpoint)).unwrap()
(await _AsyncYield(CancelShieldedCheckpoint)).unwrap()


# Return values for abort functions
Expand Down Expand Up @@ -205,7 +205,7 @@ def abort(inner_raise_cancel):
above about how you should use a higher-level API if at all possible?

"""
return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap()
return (await _AsyncYield(WaitTaskRescheduled(abort_func))).unwrap()


async def permanently_detach_coroutine_object(
Expand Down Expand Up @@ -238,7 +238,7 @@ async def permanently_detach_coroutine_object(
raise RuntimeError(
"can't permanently detach a coroutine object with open nurseries",
)
return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome))
return await _AsyncYield(PermanentlyDetachCoroutineObject(final_outcome))


async def temporarily_detach_coroutine_object(
Expand Down Expand Up @@ -276,7 +276,7 @@ async def temporarily_detach_coroutine_object(
uses to resume the coroutine.

"""
return await _async_yield(WaitTaskRescheduled(abort_func))
return await _AsyncYield(WaitTaskRescheduled(abort_func))


async def reattach_detached_coroutine_object(task: Task, yield_value: object) -> None:
Expand Down Expand Up @@ -306,5 +306,5 @@ async def reattach_detached_coroutine_object(task: Task, yield_value: object) ->
if not task.coro.cr_running:
raise RuntimeError("given task does not match calling coroutine")
_run.reschedule(task, outcome.Value("reattaching"))
value = await _async_yield(yield_value)
value = await _AsyncYield(yield_value)
assert value == outcome.Value("reattaching")
Loading