2626from __future__ import annotations
2727
2828import asyncio
29+ import contextvars
2930import datetime
3031import inspect
3132import sys
4748LF = TypeVar ("LF" , bound = _func )
4849FT = TypeVar ("FT" , bound = _func )
4950ET = TypeVar ("ET" , bound = Callable [[Any , BaseException ], Awaitable [Any ]])
51+ _current_loop_ctx : contextvars .ContextVar [int ] = contextvars .ContextVar (
52+ "_current_loop_ctx" , default = None
53+ )
5054
5155
5256def compute_timedelta (dt : datetime .datetime ):
@@ -65,10 +69,14 @@ def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) ->
6569 relative_delta = discord .utils .compute_timedelta (dt )
6670 self .handle = loop .call_later (relative_delta , future .set_result , True )
6771
72+ def _set_result_safe (self ):
73+ if not self .future .done ():
74+ self .future .set_result (True )
75+
6876 def recalculate (self , dt : datetime .datetime ) -> None :
6977 self .handle .cancel ()
7078 relative_delta = discord .utils .compute_timedelta (dt )
71- self .handle = self .loop .call_later (relative_delta , self .future . set_result , True )
79+ self .handle = self .loop .call_later (relative_delta , self ._set_result_safe )
7280
7381 def wait (self ) -> asyncio .Future [Any ]:
7482 return self .future
@@ -97,10 +105,12 @@ def __init__(
97105 count : int | None ,
98106 reconnect : bool ,
99107 loop : asyncio .AbstractEventLoop ,
108+ overlap : bool | int ,
100109 ) -> None :
101110 self .coro : LF = coro
102111 self .reconnect : bool = reconnect
103112 self .loop : asyncio .AbstractEventLoop = loop
113+ self .overlap : bool | int = overlap
104114 self .count : int | None = count
105115 self ._current_loop = 0
106116 self ._handle : SleepHandle | utils .Undefined = MISSING
@@ -121,6 +131,7 @@ def __init__(
121131 self ._is_being_cancelled = False
122132 self ._has_failed = False
123133 self ._stop_next_iteration = False
134+ self ._tasks : set [asyncio .Task [Any ]] = set ()
124135
125136 if self .count is not None and self .count <= 0 :
126137 raise ValueError ("count must be greater than 0 or None." )
@@ -132,6 +143,29 @@ def __init__(
132143
133144 if not inspect .iscoroutinefunction (self .coro ):
134145 raise TypeError (f"Expected coroutine function, not { type (self .coro ).__name__ !r} ." )
146+ if isinstance (overlap , bool ):
147+ if overlap :
148+ self ._run_with_semaphore = self ._run_direct
149+ elif isinstance (overlap , int ):
150+ if overlap <= 1 :
151+ raise ValueError ("overlap as an integer must be greater than 1." )
152+ self ._semaphore = asyncio .Semaphore (overlap )
153+ self ._run_with_semaphore = self ._semaphore_runner_factory ()
154+ else :
155+ raise TypeError ("overlap must be a bool or a positive integer." )
156+
157+ async def _run_direct (self , * args : Any , ** kwargs : Any ) -> None :
158+ """Run the coroutine directly."""
159+ await self .coro (* args , ** kwargs )
160+
161+ def _semaphore_runner_factory (self ) -> Callable [..., Awaitable [None ]]:
162+ """Return a function that runs the coroutine with a semaphore."""
163+
164+ async def runner (* args : Any , ** kwargs : Any ) -> None :
165+ async with self ._semaphore :
166+ await self .coro (* args , ** kwargs )
167+
168+ return runner
135169
136170 async def _call_loop_function (self , name : str , * args : Any , ** kwargs : Any ) -> None :
137171 coro = getattr (self , f"_{ name } " )
@@ -170,7 +204,18 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
170204 self ._last_iteration = self ._next_iteration
171205 self ._next_iteration = self ._get_next_sleep_time ()
172206 try :
173- await self .coro (* args , ** kwargs )
207+ token = _current_loop_ctx .set (self ._current_loop )
208+ if not self .overlap :
209+ await self .coro (* args , ** kwargs )
210+ else :
211+ task = asyncio .create_task (
212+ self ._run_with_semaphore (* args , ** kwargs ),
213+ name = f"pycord-loop-{ self .coro .__name__ } -{ self ._current_loop } " ,
214+ )
215+ task .add_done_callback (self ._tasks .discard )
216+ self ._tasks .add (task )
217+
218+ _current_loop_ctx .reset (token )
174219 self ._last_iteration_failed = False
175220 backoff = ExponentialBackoff ()
176221 except self ._valid_exception :
@@ -196,6 +241,9 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
196241
197242 except asyncio .CancelledError :
198243 self ._is_being_cancelled = True
244+ for task in self ._tasks :
245+ task .cancel ()
246+ await asyncio .gather (* self ._tasks , return_exceptions = True )
199247 raise
200248 except Exception as exc :
201249 self ._has_failed = True
@@ -222,6 +270,7 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]:
222270 count = self .count ,
223271 reconnect = self .reconnect ,
224272 loop = self .loop ,
273+ overlap = self .overlap ,
225274 )
226275 copy ._injected = obj
227276 copy ._before_loop = self ._before_loop
@@ -273,7 +322,11 @@ def time(self) -> list[datetime.time] | None:
273322 @property
274323 def current_loop (self ) -> int :
275324 """The current iteration of the loop."""
276- return self ._current_loop
325+ return (
326+ _current_loop_ctx .get ()
327+ if _current_loop_ctx .get () is not None
328+ else self ._current_loop
329+ )
277330
278331 @property
279332 def next_iteration (self ) -> datetime .datetime | None :
@@ -712,6 +765,7 @@ def loop(
712765 count : int | None = None ,
713766 reconnect : bool = True ,
714767 loop : asyncio .AbstractEventLoop | utils .Undefined = MISSING ,
768+ overlap : bool | int = False ,
715769) -> Callable [[LF ], Loop [LF ]]:
716770 """A decorator that schedules a task in the background for you with
717771 optional reconnect logic. The decorator returns a :class:`Loop`.
@@ -747,6 +801,11 @@ def loop(
747801 loop: :class:`asyncio.AbstractEventLoop`
748802 The loop to use to register the task, if not given
749803 defaults to :func:`asyncio.get_event_loop`.
804+ overlap: Union[:class:`bool`, :class:`int`]
805+ Controls whether overlapping executions of the task loop are allowed.
806+ Set to False (default) to run iterations one at a time, True for unlimited overlap, or an int to cap the number of concurrent runs.
807+
808+ .. versionadded:: 2.7
750809
751810 Raises
752811 ------
@@ -767,6 +826,7 @@ def decorator(func: LF) -> Loop[LF]:
767826 time = time ,
768827 reconnect = reconnect ,
769828 loop = loop ,
829+ overlap = overlap ,
770830 )
771831
772832 return decorator
0 commit comments