1- from abc import ABC , abstractmethod
1+ from abc import abstractmethod
22from contextlib import AbstractAsyncContextManager , AbstractContextManager
3- from typing import AsyncIterator , Callable , final , Generic , Iterator , TypeVar
3+ from typing import AsyncIterator , Callable , final , Iterator , Protocol , TypeVar
44from sqlalchemy .ext .asyncio import AsyncSession
55from sqlalchemy .future import select
66from sqlalchemy .orm import Session , Query
99from pymfdata .rdb .connection import Base
1010from pymfdata .common .errors import NotFoundException
1111
12- MT = TypeVar ("MT " , bound = Base )
12+ _MT = TypeVar ("_MT " , bound = Base )
1313
1414
15- class AsyncRepository (ABC , Generic [MT ]):
16- def __init__ (self , model : MT , session_factory : Callable [..., AbstractAsyncContextManager ]) -> None :
17- self .model = model
18- self ._session_factory = session_factory
15+ class AsyncRepository (Protocol [_MT ]):
16+ _model : _MT
17+ _session_factory : Callable [..., AbstractAsyncContextManager ]
1918
2019 @abstractmethod
21- async def find_by_pk (self , pk ) -> MT :
22- raise NotImplementedError ("Required implementation" )
20+ async def find_by_pk (self , pk ) -> _MT :
21+ raise NotImplementedError ("Required implementation {}" . format ( self . find_by_pk . __name__ ) )
2322
2423 @final
25- async def find_by_col (self , ** kwargs ) -> MT :
24+ async def find_by_col (self , ** kwargs ) -> _MT :
2625 if not await self .is_exists (** kwargs ):
2726 raise NotFoundException ()
2827
@@ -33,14 +32,14 @@ async def find_by_col(self, **kwargs) -> MT:
3332
3433 @final
3534 def _gen_stmt_for_param (self , ** kwargs ) -> Select :
36- stmt = select (self .model )
35+ stmt = select (self ._model )
3736 if kwargs :
3837 for key , value in kwargs .items ():
39- stmt = stmt .where (getattr (self .model , key ) == value )
38+ stmt = stmt .where (getattr (self ._model , key ) == value )
4039 return stmt
4140
4241 @final
43- async def find_all (self , ** kwargs ) -> AsyncIterator [MT ]:
42+ async def find_all (self , ** kwargs ) -> AsyncIterator [_MT ]:
4443 session : AsyncSession
4544 async with self ._session_factory () as session :
4645 stmt = self ._gen_stmt_for_param (** kwargs )
@@ -62,17 +61,16 @@ async def save(self, item: Base):
6261 await session .commit ()
6362
6463
65- class SyncRepository (ABC , Generic [MT ]):
66- def __init__ (self , model : MT , session_factory : Callable [..., AbstractContextManager ]) -> None :
67- self .model = model
68- self ._session_factory = session_factory
64+ class SyncRepository (Protocol [_MT ]):
65+ _model : _MT
66+ _session_factory : Callable [..., AbstractContextManager ]
6967
7068 @abstractmethod
71- def find_by_pk (self , pk ) -> MT :
69+ def find_by_pk (self , pk ) -> _MT :
7270 raise NotImplementedError ("Required implementation" )
7371
7472 @final
75- def find_by_col (self , ** kwargs ) -> MT :
73+ def find_by_col (self , ** kwargs ) -> _MT :
7674 if not self .is_exists (** kwargs ):
7775 raise NotFoundException ()
7876
@@ -82,14 +80,14 @@ def find_by_col(self, **kwargs) -> MT:
8280
8381 @final
8482 def _gen_query_for_param (self , session : Session , ** kwargs ) -> Query :
85- query = session .query (self .model )
83+ query = session .query (self ._model )
8684 if kwargs :
8785 for key , value in kwargs .items ():
88- query = query .filter (getattr (self .model .key ) == value )
86+ query = query .filter (getattr (self ._model .key ) == value )
8987 return query
9088
9189 @final
92- def find_all (self , ** kwargs ) -> Iterator [MT ]:
90+ def find_all (self , ** kwargs ) -> Iterator [_MT ]:
9391 session : Session
9492 with self ._session_factory () as session :
9593 query = self ._gen_query_for_param (session , ** kwargs )
0 commit comments