Skip to content

Commit c461e4f

Browse files
authored
Merge pull request #1678 from tisnik/lcore-2106-better-connection-decorator
LCORE-2106: better connection decorator with full type hints
2 parents 41abeaf + ce7255f commit c461e4f

1 file changed

Lines changed: 30 additions & 6 deletions

File tree

src/utils/connection_decorator.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,34 @@
11
"""Decorator that makes sure the object is 'connected' according to it's connected predicate."""
22

33
from collections.abc import Callable
4-
from typing import Any
4+
from typing import (
5+
Concatenate,
6+
ParamSpec,
7+
Protocol,
8+
TypeVar,
9+
runtime_checkable,
10+
)
511

12+
P = ParamSpec("P")
13+
R = TypeVar("R")
14+
S = TypeVar("S", bound="Connectable") # the method's self type
615

7-
def connection(f: Callable) -> Callable:
16+
17+
@runtime_checkable
18+
class Connectable(Protocol):
19+
"""Any class that implements methods connected and connect."""
20+
21+
def connected(self) -> bool:
22+
"""Check if DB is connected."""
23+
return False
24+
25+
def connect(self) -> None:
26+
"""Connect or reconnect the database."""
27+
28+
29+
def connection(
30+
f: Callable[Concatenate[S, P], R],
31+
) -> Callable[Concatenate[S, P], R]:
832
"""
933
Ensure a connectable object is connected before invoking the wrapped function.
1034
@@ -31,7 +55,7 @@ def list_history(self) -> list[str]:
3155
```
3256
"""
3357

34-
def wrapper(connectable: Any, *args: Any, **kwargs: Any) -> Callable:
58+
def wrapper(self: S, *args: P.args, **kwargs: P.kwargs) -> R:
3559
"""
3660
Ensure the provided connectable is connected, then call the wrapped with the same arguments.
3761
@@ -46,8 +70,8 @@ def wrapper(connectable: Any, *args: Any, **kwargs: Any) -> Callable:
4670
-------
4771
Any: The value returned by the wrapped callable.
4872
"""
49-
if not connectable.connected():
50-
connectable.connect()
51-
return f(connectable, *args, **kwargs)
73+
if not self.connected():
74+
self.connect()
75+
return f(self, *args, **kwargs)
5276

5377
return wrapper

0 commit comments

Comments
 (0)