11"""Decorator that makes sure the object is 'connected' according to it's connected predicate."""
22
33from collections .abc import Callable
4- from typing import Any
4+ from typing import (
5+ Concatenate ,
6+ ParamSpec ,
7+ Protocol ,
8+ TypeVar ,
9+ runtime_checkable ,
10+ )
511
12+ P = ParamSpec ("P" )
13+ R = TypeVar ("R" )
14+ S = TypeVar ("S" , bound = "Connectable" ) # the method's self type
615
7- def connection (f : Callable ) -> Callable :
16+
17+ @runtime_checkable
18+ class Connectable (Protocol ):
19+ """Any class that implements methods connected and connect."""
20+
21+ def connected (self ) -> bool :
22+ """Check if DB is connected."""
23+ return False
24+
25+ def connect (self ) -> None :
26+ """Connect or reconnect the database."""
27+
28+
29+ def connection (
30+ f : Callable [Concatenate [S , P ], R ],
31+ ) -> Callable [Concatenate [S , P ], R ]:
832 """
933 Ensure a connectable object is connected before invoking the wrapped function.
1034
@@ -31,7 +55,7 @@ def list_history(self) -> list[str]:
3155 ```
3256 """
3357
34- def wrapper (connectable : Any , * args : Any , ** kwargs : Any ) -> Callable :
58+ def wrapper (self : S , * args : P . args , ** kwargs : P . kwargs ) -> R :
3559 """
3660 Ensure the provided connectable is connected, then call the wrapped with the same arguments.
3761
@@ -46,8 +70,8 @@ def wrapper(connectable: Any, *args: Any, **kwargs: Any) -> Callable:
4670 -------
4771 Any: The value returned by the wrapped callable.
4872 """
49- if not connectable .connected ():
50- connectable .connect ()
51- return f (connectable , * args , ** kwargs )
73+ if not self .connected ():
74+ self .connect ()
75+ return f (self , * args , ** kwargs )
5276
5377 return wrapper
0 commit comments