33import random
44from . import types
55
6- T = types .TypeVar ('T' )
7- TC = types .TypeVar ('TC' , bound = types .Container [types .Any ])
8- P = types .ParamSpec ('P' )
6+ _T = types .TypeVar ('_T' )
7+ _TC = types .TypeVar ('_TC' , bound = types .Container [types .Any ])
8+ _P = types .ParamSpec ('_P' )
9+ _S = types .TypeVar ('_S' , covariant = True )
910
1011
1112def set_attributes (** kwargs : types .Any ) -> types .Callable [..., types .Any ]:
@@ -33,8 +34,8 @@ def set_attributes(**kwargs: types.Any) -> types.Callable[..., types.Any]:
3334 '''
3435
3536 def _set_attributes (
36- function : types .Callable [P , T ]
37- ) -> types .Callable [P , T ]:
37+ function : types .Callable [_P , _T ]
38+ ) -> types .Callable [_P , _T ]:
3839 for key , value in kwargs .items ():
3940 setattr (function , key , value )
4041 return function
@@ -43,11 +44,13 @@ def _set_attributes(
4344
4445
4546def listify (
46- collection : types .Callable [[types .Iterable [T ]], TC ] = list , # type: ignore
47+ collection : types .Callable [
48+ [types .Iterable [_T ]], _TC
49+ ] = list , # type: ignore
4750 allow_empty : bool = True ,
4851) -> types .Callable [
49- [types .Callable [..., types .Optional [types .Iterable [T ]]]],
50- types .Callable [..., TC ],
52+ [types .Callable [..., types .Optional [types .Iterable [_T ]]]],
53+ types .Callable [..., _TC ],
5154]:
5255 '''
5356 Convert any generator to a list or other type of collection.
@@ -96,10 +99,10 @@ def listify(
9699 '''
97100
98101 def _listify (
99- function : types .Callable [..., types .Optional [types .Iterable [T ]]]
100- ) -> types .Callable [..., TC ]:
101- def __listify (* args : types .Any , ** kwargs : types .Any ) -> TC :
102- result : types .Optional [types .Iterable [T ]] = function (
102+ function : types .Callable [..., types .Optional [types .Iterable [_T ]]]
103+ ) -> types .Callable [..., _TC ]:
104+ def __listify (* args : types .Any , ** kwargs : types .Any ) -> _TC :
105+ result : types .Optional [types .Iterable [_T ]] = function (
103106 * args , ** kwargs
104107 )
105108 if result is None :
@@ -134,10 +137,12 @@ def sample(sample_rate: float):
134137 '''
135138
136139 def _sample (
137- function : types .Callable [P , T ]
138- ) -> types .Callable [P , types .Optional [T ]]:
140+ function : types .Callable [_P , _T ]
141+ ) -> types .Callable [_P , types .Optional [_T ]]:
139142 @functools .wraps (function )
140- def __sample (* args : P .args , ** kwargs : P .kwargs ) -> types .Optional [T ]:
143+ def __sample (
144+ * args : _P .args , ** kwargs : _P .kwargs
145+ ) -> types .Optional [_T ]:
141146 if random .random () < sample_rate :
142147 return function (* args , ** kwargs )
143148 else :
@@ -152,3 +157,43 @@ def __sample(*args: P.args, **kwargs: P.kwargs) -> types.Optional[T]:
152157 return __sample
153158
154159 return _sample
160+
161+
162+ def wraps_classmethod (
163+ wrapped : types .Callable [types .Concatenate [_S , _P ], _T ],
164+ ) -> types .Callable [
165+ [
166+ types .Callable [types .Concatenate [types .Any , _P ], _T ],
167+ ],
168+ types .Callable [types .Concatenate [types .Type [_S ], _P ], _T ],
169+ ]:
170+ '''
171+ Like `functools.wraps`, but for wrapping classmethods with the type info
172+ from a regular method
173+ '''
174+
175+ def _wraps_classmethod (
176+ wrapper : types .Callable [types .Concatenate [types .Any , _P ], _T ],
177+ ) -> types .Callable [types .Concatenate [types .Type [_S ], _P ], _T ]:
178+ try : # pragma: no cover
179+ wrapper = functools .update_wrapper (
180+ wrapper ,
181+ wrapped ,
182+ assigned = tuple (
183+ a
184+ for a in functools .WRAPPER_ASSIGNMENTS
185+ if a != '__annotations__'
186+ ),
187+ )
188+ except AttributeError :
189+ # For some reason `functools.update_wrapper` fails on some test
190+ # runs but not while running actual code
191+ pass
192+
193+ if annotations := getattr (wrapped , '__annotations__' , {}):
194+ annotations .pop ('self' , None )
195+ wrapper .__annotations__ = annotations
196+
197+ return wrapper
198+
199+ return _wraps_classmethod
0 commit comments