diff --git a/CHANGELOG.md b/CHANGELOG.md index 70d7eff6a2..9bb6c04ccc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -521,6 +521,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2905](https://github.com/Pycord-Development/pycord/pull/2905)) - `view=None` in various methods causing an AttributeError. ([#2915](https://github.com/Pycord-Development/pycord/pull/2915)) +- Fixed Async I/O errors that could be raised when using `Client.run`. + ([#2645](https://github.com/Pycord-Development/pycord/pull/2645)) - `View.message` being `None` when it had not been interacted with yet. ([#2916](https://github.com/Pycord-Development/pycord/pull/2916)) - Fixed a crash when processing message edit events while message cache was disabled. diff --git a/discord/bot.py b/discord/bot.py index 5ba5f9b5df..74ee6e4172 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -1396,7 +1396,7 @@ def before_invoke(self, coro): TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The pre-invoke hook must be a coroutine.") self._before_invoke = coro @@ -1428,7 +1428,7 @@ def after_invoke(self, coro): The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The post-invoke hook must be a coroutine.") self._after_invoke = coro diff --git a/discord/client.py b/discord/client.py index 211532d68d..7c9704599a 100644 --- a/discord/client.py +++ b/discord/client.py @@ -26,8 +26,9 @@ from __future__ import annotations import asyncio +import contextlib +import inspect import logging -import signal import sys import traceback from types import TracebackType @@ -79,6 +80,7 @@ from .channel import ( DMChannel, ) + from .ext.tasks import Loop as TaskLoop from .interactions import Interaction from .member import Member from .message import Message @@ -130,12 +132,39 @@ def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: loop.close() +class LoopTaskSet: + def __init__(self) -> None: + self.tasks: set[TaskLoop] = set() + self.client: Client | None = None + + def add_loop(self, loop: TaskLoop) -> None: + if self.client is not None: + running = asyncio.get_running_loop() + loop.loop = running + loop.start() + else: + self.tasks.add(loop) + + def start(self, client: Client) -> None: + self.client = client + for task in self.tasks: + loop = client.loop + task.loop = loop + task.start() + + class Client: r"""Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. A number of options can be passed to the :class:`Client`. + .. container:: operations + + .. describe:: async with x + + Asynchronously initializes the client. + Parameters ----------- max_messages: Optional[:class:`int`] @@ -236,6 +265,8 @@ class Client: The event loop that the client uses for asynchronous operations. """ + _pending_loops = LoopTaskSet() + def __init__( self, *, @@ -244,9 +275,12 @@ def __init__( ): # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore - self.loop: asyncio.AbstractEventLoop = ( - asyncio.get_event_loop() if loop is None else loop - ) + + if loop is None: + with contextlib.suppress(RuntimeError): + loop = asyncio.get_running_loop() + + self._loop: asyncio.AbstractEventLoop | None = loop self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = ( {} ) @@ -262,7 +296,7 @@ def __init__( proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, - loop=self.loop, + loop=self._loop, ) self._handlers: dict[str, Callable] = {"ready": self._handle_ready} @@ -274,25 +308,46 @@ def __init__( self._enable_debug_events: bool = options.pop("enable_debug_events", False) self._connection: ConnectionState = self._get_state(**options) self._connection.shard_count = self.shard_count - self._closed: bool = False + self._closed: asyncio.Event = asyncio.Event() + self._closing_task: asyncio.Lock = asyncio.Lock() self._ready: asyncio.Event = asyncio.Event() self._connection._get_websocket = self._get_websocket self._connection._get_client = lambda: self self._event_handlers: dict[str, list[Coro]] = {} + self._setup_done: asyncio.Event = asyncio.Event() + self._setup_lock: asyncio.Lock = asyncio.Lock() warn_if_voice_dependencies_missing() # Used to hard-reference tasks so they don't get garbage collected (discarded with done_callbacks) self._tasks = set() - async def __aenter__(self) -> Client: - loop = asyncio.get_running_loop() - self.loop = loop - self.http.loop = loop - self._connection.loop = loop + async def _async_setup(self) -> None: + async with self._setup_lock: + if self._setup_done.is_set(): + return - self._ready = asyncio.Event() + if self._loop is None: + try: + l = asyncio.get_running_loop() + except RuntimeError: + # No event loop was found, this should not happen + # because entering on this context manager means a + # loop is already active, but we need to handle it + # anyways just to prevent future errors. + l = asyncio.new_event_loop() + + self._loop = l + self.http.loop = l + self._connection.loop = l + + self._ready = asyncio.Event() + self._closed = asyncio.Event() + self._setup_done.set() + + async def __aenter__(self) -> Client: + await self._async_setup() return self async def __aexit__( @@ -317,13 +372,28 @@ def _get_state(self, **options: Any) -> ConnectionState: handlers=self._handlers, hooks=self._hooks, http=self.http, - loop=self.loop, + loop=self._loop, **options, ) def _handle_ready(self) -> None: self._ready.set() + @property + def loop(self) -> asyncio.AbstractEventLoop: + """The event loop that the client uses for asynchronous operations.""" + if self._loop is None: + raise RuntimeError("loop is not set") + return self._loop + + @loop.setter + def loop(self, value: asyncio.AbstractEventLoop) -> None: + if not isinstance(value, asyncio.AbstractEventLoop): + raise TypeError( + f"expected a AbstractEventLoop object, got {value.__class__.__name__!r} instead" + ) + self._loop = value + @property def latency(self) -> float: """Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. If no websocket @@ -481,7 +551,6 @@ def _schedule_event( return task def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: - _log.debug("Dispatching event %s", event) method = f"on_{event}" listeners = self._listeners.get(event) @@ -668,6 +737,7 @@ async def login(self, token: str) -> None: f"token must be of type str, not {token.__class__.__name__}" ) + await self._async_setup() _log.info("logging in using static token") data = await self.http.static_login(token.strip()) @@ -698,6 +768,8 @@ async def connect(self, *, reconnect: bool = True) -> None: The WebSocket connection has been terminated. """ + await self._async_setup() + backoff = ExponentialBackoff() ws_params = { "initial": True, @@ -776,23 +848,25 @@ async def close(self) -> None: Closes the connection to Discord. """ - if self._closed: - return + async with self._closing_task: + if self.is_closed(): + return - await self.http.close() - self._closed = True + await self.http.close() - for voice in self.voice_clients: - try: - await voice.disconnect(force=True) - except Exception: - # if an error happens during disconnects, disregard it. - pass + for voice in self.voice_clients: + try: + await voice.disconnect(force=True) + except Exception: + # if an error happens during disconnects, disregard it. + pass - if self.ws is not None and self.ws.open: - await self.ws.close(code=1000) + if self.ws is not None and self.ws.open: + await self.ws.close(code=1000) - self._ready.clear() + self._ready.clear() + self._closed.set() + self._setup_done.clear() def clear(self) -> None: """Clears the internal state of the bot. @@ -801,8 +875,9 @@ def clear(self) -> None: and :meth:`is_ready` both return ``False`` along with the bot's internal cache cleared. """ - self._closed = False + self._closed.clear() self._ready.clear() + self._setup_done.clear() self._connection.clear() self.http.recreate() @@ -819,7 +894,12 @@ async def start(self, token: str, *, reconnect: bool = True) -> None: await self.login(token) await self.connect(reconnect=reconnect) - def run(self, *args: Any, **kwargs: Any) -> None: + def run( + self, + token: str, + *, + reconnect: bool = True, + ) -> None: """A blocking call that abstracts away the event loop initialisation from you. @@ -830,12 +910,20 @@ def run(self, *args: Any, **kwargs: Any) -> None: Roughly Equivalent to: :: try: - loop.run_until_complete(start(*args, **kwargs)) + asyncio.run(start(token)) except KeyboardInterrupt: - loop.run_until_complete(close()) - # cancel all tasks lingering - finally: - loop.close() + return + + Parameters + ---------- + token: :class:`str` + The authentication token. Do not prefix this token with + anything as the library will do it for you. + reconnect: :class:`bool` + If we should attempt reconnecting to the gateway, either due to internet + failure or a specific failure on Discord's part. Certain + disconnects that lead to bad state will not be handled (such as + invalid sharding payloads or bad tokens). .. warning:: @@ -843,47 +931,36 @@ def run(self, *args: Any, **kwargs: Any) -> None: is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. """ - loop = self.loop - - try: - loop.add_signal_handler(signal.SIGINT, loop.stop) - loop.add_signal_handler(signal.SIGTERM, loop.stop) - except (NotImplementedError, RuntimeError): - pass async def runner(): - try: - await self.start(*args, **kwargs) - finally: - if not self.is_closed(): - await self.close() + async with self: + await self.start(token=token, reconnect=reconnect) - def stop_loop_on_completion(f): - loop.stop() + try: + run = self.loop.run_until_complete + requires_cleanup = True + except RuntimeError: + run = asyncio.run + requires_cleanup = False - future = asyncio.ensure_future(runner(), loop=loop) - future.add_done_callback(stop_loop_on_completion) try: - loop.run_forever() - except KeyboardInterrupt: - _log.info("Received signal to terminate bot and event loop.") + run(runner()) finally: - future.remove_done_callback(stop_loop_on_completion) - _log.info("Cleaning up tasks.") - _cleanup_loop(loop) + # Ensure the bot is closed + if not self.is_closed(): + self.loop.run_until_complete(self.close()) - if not future.cancelled(): - try: - return future.result() - except KeyboardInterrupt: - # I am unsure why this gets raised here but suppress it anyway - return None + # asyncio.run automatically does the cleanup tasks, so if we use + # it we don't need to clean up the tasks. + if requires_cleanup: + _log.info("Cleaning up tasks.") + _cleanup_loop(self.loop) # properties def is_closed(self) -> bool: """Indicates if the WebSocket connection is closed.""" - return self._closed + return self._closed.is_set() @property def activity(self) -> ActivityTypes | None: @@ -1388,7 +1465,7 @@ async def my_message(message): pass if not name.startswith("on_"): raise ValueError("The 'name' parameter must start with 'on_'") - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Listeners must be coroutines") if name in self._event_handlers: @@ -1468,7 +1545,7 @@ def decorator(func: Coro) -> Coro: self.add_listener(func, name) return func - if asyncio.iscoroutinefunction(name): + if inspect.iscoroutinefunction(name): coro = name name = coro.__name__ return decorator(coro) @@ -1503,7 +1580,7 @@ async def on_ready(): print('Ready!') """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("event registered must be a coroutine function") setattr(self, coro.__name__, coro) diff --git a/discord/commands/core.py b/discord/commands/core.py index 0f721f15a0..18ac446a34 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -500,7 +500,7 @@ def error(self, coro): The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The error handler must be a coroutine.") self.on_error = coro @@ -529,7 +529,7 @@ def before_invoke(self, coro): TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The pre-invoke hook must be a coroutine.") self._before_invoke = coro @@ -554,7 +554,7 @@ def after_invoke(self, coro): TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The post-invoke hook must be a coroutine.") self._after_invoke = coro @@ -733,7 +733,7 @@ def __new__(cls, *args, **kwargs) -> SlashCommand: def __init__(self, func: Callable, *args, **kwargs) -> None: super().__init__(func, **kwargs) - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Callback must be a coroutine.") self.callback = func @@ -1666,7 +1666,7 @@ def __new__(cls, *args, **kwargs) -> ContextMenuCommand: def __init__(self, func: Callable, *args, **kwargs) -> None: super().__init__(func, **kwargs) - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Callback must be a coroutine.") self.callback = func diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 4a58ecfca3..c24b9b393d 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -325,7 +325,7 @@ def __init__( ), **kwargs: Any, ): - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Callback must be a coroutine.") name = kwargs.get("name") or func.__name__ @@ -993,7 +993,7 @@ def error(self, coro: ErrorT) -> ErrorT: The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The error handler must be a coroutine.") self.on_error: Error = coro @@ -1027,7 +1027,7 @@ def before_invoke(self, coro: HookT) -> HookT: TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The pre-invoke hook must be a coroutine.") self._before_invoke = coro @@ -1054,7 +1054,7 @@ def after_invoke(self, coro: HookT) -> HookT: TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The post-invoke hook must be a coroutine.") self._after_invoke = coro diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 9bdde87f23..a864538249 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -26,12 +26,14 @@ from __future__ import annotations import asyncio +import contextlib import contextvars import datetime import inspect +import logging import sys import traceback -from collections.abc import Sequence +from collections.abc import Coroutine, Sequence from typing import Any, Awaitable, Callable, Generic, TypeVar, cast import aiohttp @@ -42,12 +44,13 @@ __all__ = ("loop",) +_log = logging.getLogger(__name__) T = TypeVar("T") _func = Callable[..., Awaitable[Any]] LF = TypeVar("LF", bound=_func) FT = TypeVar("FT", bound=_func) ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]]) -_current_loop_ctx: contextvars.ContextVar[int] = contextvars.ContextVar( +_current_loop_ctx: contextvars.ContextVar[int | None] = contextvars.ContextVar( "_current_loop_ctx", default=None ) @@ -59,18 +62,21 @@ def __init__( self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop ) -> None: self.loop = loop - self.future = future = loop.create_future() + self.future = loop.create_future() relative_delta = discord.utils.compute_timedelta(dt) - self.handle = loop.call_later(relative_delta, future.set_result, True) + self.handle = loop.call_later(relative_delta, self._safe_result, self.future) - def _set_result_safe(self): - if not self.future.done(): - self.future.set_result(True) + @staticmethod + def _safe_result(future: asyncio.Future[Any]) -> None: + if not future.done(): + future.set_result(None) def recalculate(self, dt: datetime.datetime) -> None: self.handle.cancel() relative_delta = discord.utils.compute_timedelta(dt) - self.handle = self.loop.call_later(relative_delta, self._set_result_safe) + self.handle = self.loop.call_later( + relative_delta, self._safe_result, self.future + ) def wait(self) -> asyncio.Future[Any]: return self.future @@ -98,12 +104,28 @@ def __init__( time: datetime.time | Sequence[datetime.time], count: int | None, reconnect: bool, - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop | None, + create_loop: bool, + name: str | None, overlap: bool | int, ) -> None: self.coro: LF = coro self.reconnect: bool = reconnect - self.loop: asyncio.AbstractEventLoop = loop + + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + if create_loop: + loop = asyncio.new_event_loop() + + self.loop: asyncio.AbstractEventLoop | None = loop + + self.name: str = ( + f"pycord-ext-task ({id(self):#x}): {coro.__qualname__}" + if name in (None, MISSING) + else name + ) self.overlap: bool | int = overlap self.count: int | None = count self._current_loop = 0 @@ -117,6 +139,7 @@ def __init__( aiohttp.ClientError, asyncio.TimeoutError, ) + self._create_loop = create_loop self._before_loop = None self._after_loop = None @@ -139,6 +162,10 @@ def __init__( raise TypeError( f"Expected coroutine function, not {type(self.coro).__name__!r}." ) + + if loop is None and not create_loop: + discord.Client._pending_loops.add_loop(self) + if isinstance(overlap, bool): if overlap: self._run_with_semaphore = self._run_direct @@ -154,7 +181,7 @@ async def _run_direct(self, *args: Any, **kwargs: Any) -> None: """Run the coroutine directly.""" await self.coro(*args, **kwargs) - def _semaphore_runner_factory(self) -> Callable[..., Awaitable[None]]: + def _semaphore_runner_factory(self) -> Callable[..., Coroutine[Any, Any, None]]: """Return a function that runs the coroutine with a semaphore.""" async def runner(*args: Any, **kwargs: Any) -> None: @@ -179,8 +206,15 @@ async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> Non if name.endswith("_loop"): setattr(self, f"_{name}_running", False) + def _create_task(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: + if self.loop is None: + meth = asyncio.create_task + else: + meth = self.loop.create_task + return meth(self._loop(*args, **kwargs), name=self.name) + def _try_sleep_until(self, dt: datetime.datetime): - self._handle = SleepHandle(dt=dt, loop=self.loop) + self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop()) return self._handle.wait() async def _loop(self, *args: Any, **kwargs: Any) -> None: @@ -194,7 +228,9 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None: else: self._next_iteration = datetime.datetime.now(datetime.timezone.utc) try: - await self._try_sleep_until(self._next_iteration) + if self._stop_next_iteration: + return + while True: if not self._last_iteration_failed: self._last_iteration = self._next_iteration @@ -237,9 +273,10 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None: except asyncio.CancelledError: self._is_being_cancelled = True - for task in self._tasks: - task.cancel() - await asyncio.gather(*self._tasks, return_exceptions=True) + if self._tasks: + for task in self._tasks: + task.cancel() + await asyncio.gather(*self._tasks, return_exceptions=True) raise except Exception as exc: self._has_failed = True @@ -266,7 +303,9 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]: count=self.count, reconnect=self.reconnect, loop=self.loop, + name=self.name, overlap=self.overlap, + create_loop=self._create_loop, ) copy._injected = obj copy._before_loop = self._before_loop @@ -318,11 +357,7 @@ def time(self) -> list[datetime.time] | None: @property def current_loop(self) -> int: """The current iteration of the loop.""" - return ( - _current_loop_ctx.get() - if _current_loop_ctx.get() is not None - else self._current_loop - ) + return self._current_loop if (clc := _current_loop_ctx.get()) is None else clc @property def next_iteration(self) -> datetime.datetime | None: @@ -356,9 +391,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: return await self.coro(*args, **kwargs) - def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: + def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None] | None: r"""Starts the internal task in the event loop. + If this loop was created with the ``create_loop`` parameter set as ``False`` and + no running loop is found (eg this method is not called from an async context), + then this task will be started automatically when any kind of :class:`~discord.Client` + (subclasses included) starts. + Parameters ------------ \*args @@ -377,16 +417,31 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: The task that has been created. """ + loop = None + with contextlib.suppress(RuntimeError): + loop = asyncio.get_running_loop() + + if loop: + self.loop = loop + + if self.loop is None: + _log.warning( + f"The task {self.name} has been set to be bound to a discord.Client instance, and will start running automatically " + "when the client starts. If you want this task to be executed without it being bound to a discord.Client, " + "set the create_loop parameter in the decorator to True, and don't forget to set the client.loop to the loop.loop" + ) + return None + if self._task is not MISSING and not self._task.done(): raise RuntimeError("Task is already launched and is not completed.") if self._injected is not None: args = (self._injected, *args) - if self.loop is MISSING: - self.loop = asyncio.get_event_loop() - - self._task = self.loop.create_task(self._loop(*args, **kwargs)) + self._task = asyncio.ensure_future( + self.loop.create_task(self._loop(*args, **kwargs), name=self.name), + loop=self.loop, + ) return self._task def stop(self) -> None: @@ -412,13 +467,20 @@ def stop(self) -> None: def _can_be_cancelled(self) -> bool: return bool( - not self._is_being_cancelled and self._task and not self._task.done() + not self._is_being_cancelled + and ( + (self._task is not MISSING and (self._task and not self._task.done())) + or self._tasks + ) ) def cancel(self) -> None: """Cancels the internal task, if it is running.""" if self._can_be_cancelled(): - self._task.cancel() + if self._task is not MISSING: + self._task.cancel() + for task in self._tasks: + task.cancel() def restart(self, *args: Any, **kwargs: Any) -> None: r"""A convenience method to restart the internal task. @@ -769,15 +831,9 @@ def change_interval( self._time = self._get_time_parameter(time) self._sleep = self._seconds = self._minutes = self._hours = MISSING - if self.is_running() and not ( - self._before_loop_running or self._after_loop_running - ): - if self._time is not MISSING: - # prepare the next time index starting from after the last iteration - self._prepare_time_index(now=self._last_iteration) - + if self.is_running() and self._last_iteration is not MISSING: self._next_iteration = self._get_next_sleep_time() - if not self._handle.done(): + if self._handle and not self._handle.done(): # the loop is sleeping, recalculate based on new interval self._handle.recalculate(self._next_iteration) @@ -790,8 +846,10 @@ def loop( time: datetime.time | Sequence[datetime.time] = MISSING, count: int | None = None, reconnect: bool = True, - loop: asyncio.AbstractEventLoop = MISSING, + loop: asyncio.AbstractEventLoop | None = None, + name: str | None = MISSING, overlap: bool | int = False, + create_loop: bool = False, ) -> Callable[[LF], Loop[LF]]: """A decorator that schedules a task in the background for you with optional reconnect logic. The decorator returns a :class:`Loop`. @@ -824,9 +882,25 @@ def loop( Whether to handle errors and restart the task using an exponential back-off algorithm similar to the one used in :meth:`discord.Client.connect`. - loop: :class:`asyncio.AbstractEventLoop` - The loop to use to register the task, if not given - defaults to :func:`asyncio.get_event_loop`. + loop: Optional[:class:`asyncio.AbstractEventLoop`] + The loop to use to register the task, defaults to ``None``. + + .. versionchanged:: 2.7 + This can now be ``None`` + name: Optional[:class:`str`] + The name to create the task with, defaults to ``None``. + + .. versionadded:: 2.7 + create_loop: :class:`bool` + Whether this task should create its own :class:`asyncio.AbstractEventLoop` to run if + no already running one is found. + + Loops must be in an async context in order to run, this means :meth:`Loop.start` should be + called from an async context (e.g. coroutines). + + Defaults to ``False``. + + .. versionadded:: 2.7 overlap: Union[:class:`bool`, :class:`int`] Controls whether overlapping executions of the task loop are allowed. Set to False (default) to run iterations one at a time, True for unlimited overlap, or an int to cap the number of concurrent runs. @@ -851,8 +925,10 @@ def decorator(func: LF) -> Loop[LF]: count=count, time=time, reconnect=reconnect, + name=name, loop=loop, overlap=overlap, + create_loop=create_loop, ) return decorator diff --git a/discord/http.py b/discord/http.py index 0717feadf5..f6f17b6e0f 100644 --- a/discord/http.py +++ b/discord/http.py @@ -191,9 +191,7 @@ def __init__( loop: asyncio.AbstractEventLoop | None = None, unsync_clock: bool = True, ) -> None: - self.loop: asyncio.AbstractEventLoop = ( - asyncio.get_event_loop() if loop is None else loop - ) + self.loop: asyncio.AbstractEventLoop = loop or MISSING self.connector = connector self.__session: aiohttp.ClientSession = MISSING # filled in static_login self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() diff --git a/discord/state.py b/discord/state.py index 574c973c52..cf02aa8495 100644 --- a/discord/state.py +++ b/discord/state.py @@ -96,6 +96,8 @@ CS = TypeVar("CS", bound="ConnectionState") Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] +MISSING = utils.MISSING + class ChunkRequest: def __init__( @@ -168,16 +170,16 @@ def __init__( handlers: dict[str, Callable], hooks: dict[str, Callable], http: HTTPClient, - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop | None, **options: Any, ) -> None: - self.loop: asyncio.AbstractEventLoop = loop + self.loop: asyncio.AbstractEventLoop = loop or MISSING self.http: HTTPClient = http self.max_messages: int | None = options.get("max_messages", 1000) if self.max_messages is not None and self.max_messages <= 0: self.max_messages = 1000 - self.dispatch: Callable = dispatch + self._dispatch: Callable = dispatch self.handlers: dict[str, Callable] = handlers self.hooks: dict[str, Callable] = hooks self.shard_count: int | None = None @@ -266,6 +268,10 @@ def __init__( self.clear() + def dispatch(self, event: str, *args: Any, **kwargs: Any) -> Any: + _log.debug("Dispatching event %s", event) + return self._dispatch(event, *args, **kwargs) + def clear(self, *, views: bool = True) -> None: self.user: ClientUser | None = None # Originally, this code used WeakValueDictionary to maintain references to the diff --git a/discord/utils.py b/discord/utils.py index cc6d9d3b19..6505968215 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -31,6 +31,7 @@ import datetime import functools import importlib.resources +import inspect import io import itertools import json @@ -1569,7 +1570,7 @@ def _filter(ctx: AutocompleteContext, item: Any) -> bool: gen = (val for val in _values if _filter(ctx, val)) - elif asyncio.iscoroutinefunction(filter): + elif inspect.iscoroutinefunction(filter): gen = (val for val in _values if await filter(ctx, val)) elif callable(filter):