diff --git a/asyncer/_main.py b/asyncer/_main.py index 1fa778ea..c00788a0 100644 --- a/asyncer/_main.py +++ b/asyncer/_main.py @@ -2,9 +2,11 @@ import sys from collections.abc import Awaitable, Callable, Coroutine from importlib import import_module +from types import TracebackType from typing import ( Any, Generic, + Literal, ParamSpec, TypeVar, ) @@ -169,6 +171,16 @@ async def __aenter__(self) -> "TaskGroup": # pragma: nocover """Enter the task group context and allow starting new tasks.""" return await super().__aenter__() # type: ignore + # This is only for the return type annotation, but it won't really be called + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None, + ) -> Literal[False]: # pragma: nocover + """Exit the task group context once all tasks are completed.""" + return await super().__aexit__(exc_type, exc_value, exc_tb) # type: ignore + def create_task_group() -> "TaskGroup": """