11from abc import ABC , abstractmethod
22from contextlib import AbstractAsyncContextManager , AbstractContextManager
3- from typing import AsyncIterator , Callable , final , Iterator , Type , TypeVar
3+ from typing import AsyncIterator , Callable , final , Generic , Iterator , TypeVar
44from sqlalchemy .ext .asyncio import AsyncSession
55from sqlalchemy .future import select
66from sqlalchemy .orm import Session , Query
99from pymfdata .rdb .connection import Base
1010from pymfdata .common .errors import NotFoundException
1111
12- ModelType = TypeVar ("ModelType " , bound = Base )
12+ MT = TypeVar ("MT " , bound = Base )
1313
1414
15- class AsyncRepository (ABC ):
16- def __init__ (self , model : Type [ ModelType ] , session_factory : Callable [..., AbstractAsyncContextManager ]) -> None :
15+ class AsyncRepository (ABC , Generic [ MT ] ):
16+ def __init__ (self , model : MT , session_factory : Callable [..., AbstractAsyncContextManager ]) -> None :
1717 self .model = model
1818 self ._session_factory = session_factory
1919
2020 @abstractmethod
21- async def find_by_pk (self , pk ) -> Type [ ModelType ] :
21+ async def find_by_pk (self , pk ) -> MT :
2222 raise NotImplementedError ("Required implementation" )
2323
2424 @final
25- async def find_by_col (self , ** kwargs ) -> Type [ ModelType ] :
25+ async def find_by_col (self , ** kwargs ) -> MT :
2626 if not await self .is_exists (** kwargs ):
2727 raise NotFoundException ()
2828
@@ -40,7 +40,7 @@ def _gen_stmt_for_param(self, **kwargs) -> Select:
4040 return stmt
4141
4242 @final
43- async def find_all (self , ** kwargs ) -> AsyncIterator [Type [ ModelType ] ]:
43+ async def find_all (self , ** kwargs ) -> AsyncIterator [MT ]:
4444 session : AsyncSession
4545 async with self ._session_factory () as session :
4646 stmt = self ._gen_stmt_for_param (** kwargs )
@@ -62,17 +62,17 @@ async def save(self, item: Base):
6262 await session .commit ()
6363
6464
65- class SyncRepository (ABC ):
66- def __init__ (self , model : Type [ ModelType ] , session_factory : Callable [..., AbstractContextManager ]) -> None :
65+ class SyncRepository (ABC , Generic [ MT ] ):
66+ def __init__ (self , model : MT , session_factory : Callable [..., AbstractContextManager ]) -> None :
6767 self .model = model
6868 self ._session_factory = session_factory
6969
7070 @abstractmethod
71- def find_by_pk (self , pk ) -> Type [ ModelType ] :
71+ def find_by_pk (self , pk ) -> MT :
7272 raise NotImplementedError ("Required implementation" )
7373
7474 @final
75- def find_by_col (self , ** kwargs ) -> Type [ ModelType ] :
75+ def find_by_col (self , ** kwargs ) -> MT :
7676 if not self .is_exists (** kwargs ):
7777 raise NotFoundException ()
7878
@@ -89,7 +89,7 @@ def _gen_query_for_param(self, session: Session, **kwargs) -> Query:
8989 return query
9090
9191 @final
92- def find_all (self , ** kwargs ) -> Iterator [Type [ ModelType ] ]:
92+ def find_all (self , ** kwargs ) -> Iterator [MT ]:
9393 session : Session
9494 with self ._session_factory () as session :
9595 query = self ._gen_query_for_param (session , ** kwargs )
0 commit comments