@@ -63,7 +63,7 @@ class BucketType(Enum):
6363 category = 5
6464 role = 6
6565
66- def get_key (self , ctx : Context | ApplicationContext ) -> Any :
66+ def get_key (self , ctx : Context | ApplicationContext | Message ) -> Any :
6767 if self is BucketType .user :
6868 return ctx .author .id
6969 elif self is BucketType .guild :
@@ -90,7 +90,7 @@ def get_key(self, ctx: Context | ApplicationContext) -> Any:
9090 else ctx .author .top_role
9191 ).id # type: ignore
9292
93- def __call__ (self , ctx : Context | ApplicationContext ) -> Any :
93+ def __call__ (self , ctx : Context | ApplicationContext | Message ) -> Any :
9494 return self .get_key (ctx )
9595
9696
@@ -215,14 +215,14 @@ class CooldownMapping:
215215 def __init__ (
216216 self ,
217217 original : Cooldown | None ,
218- type : Callable [[Context | ApplicationContext ], Any ],
218+ type : Callable [[Context | ApplicationContext | Message ], Any ],
219219 ) -> None :
220220 if not callable (type ):
221221 raise TypeError ("Cooldown type must be a BucketType or callable" )
222222
223223 self ._cache : dict [Any , Cooldown ] = {}
224224 self ._cooldown : Cooldown | None = original
225- self ._type : Callable [[Context | ApplicationContext ], Any ] = type
225+ self ._type : Callable [[Context | ApplicationContext | Message ], Any ] = type
226226
227227 def copy (self ) -> CooldownMapping :
228228 ret = CooldownMapping (self ._cooldown , self ._type )
@@ -234,14 +234,14 @@ def valid(self) -> bool:
234234 return self ._cooldown is not None
235235
236236 @property
237- def type (self ) -> Callable [[Context | ApplicationContext ], Any ]:
237+ def type (self ) -> Callable [[Context | ApplicationContext | Message ], Any ]:
238238 return self ._type
239239
240240 @classmethod
241241 def from_cooldown (cls : type [C ], rate , per , type ) -> C :
242242 return cls (Cooldown (rate , per ), type )
243243
244- def _bucket_key (self , ctx : Context | ApplicationContext ) -> Any :
244+ def _bucket_key (self , ctx : Context | ApplicationContext | Message ) -> Any :
245245 return self ._type (ctx )
246246
247247 def _verify_cache_integrity (self , current : float | None = None ) -> None :
@@ -253,11 +253,11 @@ def _verify_cache_integrity(self, current: float | None = None) -> None:
253253 for k in dead_keys :
254254 del self ._cache [k ]
255255
256- async def create_bucket (self , ctx : Context | ApplicationContext ) -> Cooldown :
256+ async def create_bucket (self , ctx : Context | ApplicationContext | Message ) -> Cooldown :
257257 return self ._cooldown .copy () # type: ignore
258258
259259 async def get_bucket (
260- self , ctx : Context | ApplicationContext , current : float | None = None
260+ self , ctx : Context | ApplicationContext | Message , current : float | None = None
261261 ) -> Cooldown :
262262 if self ._type is BucketType .default :
263263 return self ._cooldown # type: ignore
@@ -274,7 +274,7 @@ async def get_bucket(
274274 return bucket
275275
276276 async def update_rate_limit (
277- self , ctx : Context | ApplicationContext , current : float | None = None
277+ self , ctx : Context | ApplicationContext | Message , current : float | None = None
278278 ) -> float | None :
279279 bucket = await self .get_bucket (ctx , current )
280280 return bucket .update_rate_limit (current )
@@ -284,13 +284,13 @@ class DynamicCooldownMapping(CooldownMapping):
284284 def __init__ (
285285 self ,
286286 factory : Callable [
287- [Context | ApplicationContext ], Cooldown | Awaitable [Cooldown ]
287+ [Context | ApplicationContext | Message ], Cooldown | Awaitable [Cooldown ]
288288 ],
289- type : Callable [[Context | ApplicationContext ], Any ],
289+ type : Callable [[Context | ApplicationContext | Message ], Any ],
290290 ) -> None :
291291 super ().__init__ (None , type )
292292 self ._factory : Callable [
293- [Context | ApplicationContext ], Cooldown | Awaitable [Cooldown ]
293+ [Context | ApplicationContext | Message ], Cooldown | Awaitable [Cooldown ]
294294 ] = factory
295295
296296 def copy (self ) -> DynamicCooldownMapping :
@@ -302,7 +302,7 @@ def copy(self) -> DynamicCooldownMapping:
302302 def valid (self ) -> bool :
303303 return True
304304
305- async def create_bucket (self , ctx : Context | ApplicationContext ) -> Cooldown :
305+ async def create_bucket (self , ctx : Context | ApplicationContext | Message ) -> Cooldown :
306306 from ...ext .commands import Context
307307
308308 if isinstance (ctx , Context ):
@@ -399,10 +399,10 @@ def __repr__(self) -> str:
399399 f"<MaxConcurrency per={ self .per !r} number={ self .number } wait={ self .wait } >"
400400 )
401401
402- def get_key (self , ctx : Context | ApplicationContext ) -> Any :
402+ def get_key (self , ctx : Context | ApplicationContext | Message ) -> Any :
403403 return self .per .get_key (ctx )
404404
405- async def acquire (self , ctx : Context | ApplicationContext ) -> None :
405+ async def acquire (self , ctx : Context | ApplicationContext | Message ) -> None :
406406 key = self .get_key (ctx )
407407
408408 try :
@@ -414,7 +414,7 @@ async def acquire(self, ctx: Context | ApplicationContext) -> None:
414414 if not acquired :
415415 raise MaxConcurrencyReached (self .number , self .per )
416416
417- async def release (self , ctx : Context | ApplicationContext ) -> None :
417+ async def release (self , ctx : Context | ApplicationContext | Message ) -> None :
418418 # Technically there's no reason for this function to be async
419419 # But it might be more useful in the future
420420 key = self .get_key (ctx )
0 commit comments