Skip to content

Commit 130b1b8

Browse files
LumabotsPaillat-dev
authored andcommitted
feat(loop): add optional overlap support to allow concurrent loop executions (Pycord-Development#2771)
Co-authored-by: plun1331 <plun1331@gmail.com> Co-authored-by: JustaSqu1d <89910983+JustaSqu1d@users.noreply.github.com> Co-authored-by: DA344 <108473820+DA-344@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 65b8b61)
1 parent b3916c7 commit 130b1b8

3 files changed

Lines changed: 94 additions & 4 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ These changes are available on the `master` branch, but have not yet been releas
117117
([#2714](https://github.com/Pycord-Development/pycord/pull/2714))
118118
- Added the ability to pass a `datetime.time` object to `format_dt`.
119119
([#2747](https://github.com/Pycord-Development/pycord/pull/2747))
120+
- Added the ability to pass an `overlap` parameter to the `loop` decorator and `Loop`
121+
class, allowing concurrent iterations if enabled.
122+
([#2765](https://github.com/Pycord-Development/pycord/pull/2765))
120123
- Added various missing channel parameters and allow `default_reaction_emoji` to be
121124
`None`. ([#2772](https://github.com/Pycord-Development/pycord/pull/2772))
122125
- Added support for type hinting slash command options with `typing.Annotated`.

discord/ext/tasks/__init__.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from __future__ import annotations
2727

2828
import asyncio
29+
import contextvars
2930
import datetime
3031
import inspect
3132
import sys
@@ -47,6 +48,9 @@
4748
LF = TypeVar("LF", bound=_func)
4849
FT = TypeVar("FT", bound=_func)
4950
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
51+
_current_loop_ctx: contextvars.ContextVar[int] = contextvars.ContextVar(
52+
"_current_loop_ctx", default=None
53+
)
5054

5155

5256
def compute_timedelta(dt: datetime.datetime):
@@ -65,10 +69,14 @@ def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) ->
6569
relative_delta = discord.utils.compute_timedelta(dt)
6670
self.handle = loop.call_later(relative_delta, future.set_result, True)
6771

72+
def _set_result_safe(self):
73+
if not self.future.done():
74+
self.future.set_result(True)
75+
6876
def recalculate(self, dt: datetime.datetime) -> None:
6977
self.handle.cancel()
7078
relative_delta = discord.utils.compute_timedelta(dt)
71-
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
79+
self.handle = self.loop.call_later(relative_delta, self._set_result_safe)
7280

7381
def wait(self) -> asyncio.Future[Any]:
7482
return self.future
@@ -97,10 +105,12 @@ def __init__(
97105
count: int | None,
98106
reconnect: bool,
99107
loop: asyncio.AbstractEventLoop,
108+
overlap: bool | int,
100109
) -> None:
101110
self.coro: LF = coro
102111
self.reconnect: bool = reconnect
103112
self.loop: asyncio.AbstractEventLoop = loop
113+
self.overlap: bool | int = overlap
104114
self.count: int | None = count
105115
self._current_loop = 0
106116
self._handle: SleepHandle | utils.Undefined = MISSING
@@ -121,6 +131,7 @@ def __init__(
121131
self._is_being_cancelled = False
122132
self._has_failed = False
123133
self._stop_next_iteration = False
134+
self._tasks: set[asyncio.Task[Any]] = set()
124135

125136
if self.count is not None and self.count <= 0:
126137
raise ValueError("count must be greater than 0 or None.")
@@ -132,6 +143,29 @@ def __init__(
132143

133144
if not inspect.iscoroutinefunction(self.coro):
134145
raise TypeError(f"Expected coroutine function, not {type(self.coro).__name__!r}.")
146+
if isinstance(overlap, bool):
147+
if overlap:
148+
self._run_with_semaphore = self._run_direct
149+
elif isinstance(overlap, int):
150+
if overlap <= 1:
151+
raise ValueError("overlap as an integer must be greater than 1.")
152+
self._semaphore = asyncio.Semaphore(overlap)
153+
self._run_with_semaphore = self._semaphore_runner_factory()
154+
else:
155+
raise TypeError("overlap must be a bool or a positive integer.")
156+
157+
async def _run_direct(self, *args: Any, **kwargs: Any) -> None:
158+
"""Run the coroutine directly."""
159+
await self.coro(*args, **kwargs)
160+
161+
def _semaphore_runner_factory(self) -> Callable[..., Awaitable[None]]:
162+
"""Return a function that runs the coroutine with a semaphore."""
163+
164+
async def runner(*args: Any, **kwargs: Any) -> None:
165+
async with self._semaphore:
166+
await self.coro(*args, **kwargs)
167+
168+
return runner
135169

136170
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
137171
coro = getattr(self, f"_{name}")
@@ -170,7 +204,18 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
170204
self._last_iteration = self._next_iteration
171205
self._next_iteration = self._get_next_sleep_time()
172206
try:
173-
await self.coro(*args, **kwargs)
207+
token = _current_loop_ctx.set(self._current_loop)
208+
if not self.overlap:
209+
await self.coro(*args, **kwargs)
210+
else:
211+
task = asyncio.create_task(
212+
self._run_with_semaphore(*args, **kwargs),
213+
name=f"pycord-loop-{self.coro.__name__}-{self._current_loop}",
214+
)
215+
task.add_done_callback(self._tasks.discard)
216+
self._tasks.add(task)
217+
218+
_current_loop_ctx.reset(token)
174219
self._last_iteration_failed = False
175220
backoff = ExponentialBackoff()
176221
except self._valid_exception:
@@ -196,6 +241,9 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
196241

197242
except asyncio.CancelledError:
198243
self._is_being_cancelled = True
244+
for task in self._tasks:
245+
task.cancel()
246+
await asyncio.gather(*self._tasks, return_exceptions=True)
199247
raise
200248
except Exception as exc:
201249
self._has_failed = True
@@ -222,6 +270,7 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]:
222270
count=self.count,
223271
reconnect=self.reconnect,
224272
loop=self.loop,
273+
overlap=self.overlap,
225274
)
226275
copy._injected = obj
227276
copy._before_loop = self._before_loop
@@ -273,7 +322,11 @@ def time(self) -> list[datetime.time] | None:
273322
@property
274323
def current_loop(self) -> int:
275324
"""The current iteration of the loop."""
276-
return self._current_loop
325+
return (
326+
_current_loop_ctx.get()
327+
if _current_loop_ctx.get() is not None
328+
else self._current_loop
329+
)
277330

278331
@property
279332
def next_iteration(self) -> datetime.datetime | None:
@@ -712,6 +765,7 @@ def loop(
712765
count: int | None = None,
713766
reconnect: bool = True,
714767
loop: asyncio.AbstractEventLoop | utils.Undefined = MISSING,
768+
overlap: bool | int = False,
715769
) -> Callable[[LF], Loop[LF]]:
716770
"""A decorator that schedules a task in the background for you with
717771
optional reconnect logic. The decorator returns a :class:`Loop`.
@@ -747,6 +801,11 @@ def loop(
747801
loop: :class:`asyncio.AbstractEventLoop`
748802
The loop to use to register the task, if not given
749803
defaults to :func:`asyncio.get_event_loop`.
804+
overlap: Union[:class:`bool`, :class:`int`]
805+
Controls whether overlapping executions of the task loop are allowed.
806+
Set to False (default) to run iterations one at a time, True for unlimited overlap, or an int to cap the number of concurrent runs.
807+
808+
.. versionadded:: 2.7
750809
751810
Raises
752811
------
@@ -767,6 +826,7 @@ def decorator(func: LF) -> Loop[LF]:
767826
time=time,
768827
reconnect=reconnect,
769828
loop=loop,
829+
overlap=overlap,
770830
)
771831

772832
return decorator

examples/background_task.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import random
13
from datetime import time, timezone
24

35
import discord
@@ -10,7 +12,6 @@ def __init__(self, *args, **kwargs):
1012

1113
# An attribute we can access from our task
1214
self.counter = 0
13-
1415
# Start the tasks to run in the background
1516
self.my_background_task.start()
1617
self.time_task.start()
@@ -35,6 +36,32 @@ async def time_task(self):
3536
async def before_my_task(self):
3637
await self.wait_until_ready() # Wait until the bot logs in
3738

39+
# Schedule every 10s; each run takes between 5 to 20s. With overlap=2, at most 2 runs
40+
# execute concurrently so we don't build an ever-growing backlog.
41+
@tasks.loop(seconds=10, overlap=2)
42+
async def fetch_status_task(self):
43+
"""
44+
Practical overlap use-case:
45+
46+
Poll an external service and post a short summary. Each poll may take
47+
between 5 to 20s due to network latency or rate limits, but we want fresh data
48+
every 10s. Allowing a small amount of overlap avoids drifting schedules
49+
without opening the floodgates to unlimited concurrency.
50+
"""
51+
print(f"[status] start run #{self.fetch_status_task.current_loop}")
52+
53+
# Simulate slow I/O (e.g., HTTP requests, DB queries, file I/O)
54+
await asyncio.sleep(random.randint(5, 20))
55+
56+
channel = self.get_channel(1234567) # Replace with your channel ID
57+
msg = f"[status] run #{self.fetch_status_task.current_loop} complete"
58+
if channel:
59+
await channel.send(msg)
60+
else:
61+
print(msg)
62+
63+
print(f"[status] end run #{self.fetch_status_task.current_loop}")
64+
3865

3966
client = MyClient()
4067
client.run("TOKEN")

0 commit comments

Comments
 (0)