1- from abc import abstractmethod
21from contextlib import AbstractAsyncContextManager , AbstractContextManager
32from typing import AsyncIterator , Callable , final , Iterator , Protocol , TypeVar
43from sqlalchemy .ext .asyncio import AsyncSession
1514class AsyncRepository (Protocol [_MT ]):
1615 _model : _MT
1716 _session_factory : Callable [..., AbstractAsyncContextManager ]
17+ _pk_column : str
18+
19+ async def delete_by_pk (self , pk ):
20+ if not await self .is_exists (** {self ._pk_column : pk }):
21+ raise NotFoundException ()
22+
23+ item = await self .find_by_pk (pk )
24+
25+ session : AsyncSession
26+ async with self ._session_factory () as session :
27+ await session .delete (item )
28+ await session .commit ()
1829
19- @abstractmethod
2030 async def find_by_pk (self , pk ) -> _MT :
21- raise NotImplementedError ( "Required implementation {}" . format ( self .find_by_pk . __name__ ) )
31+ return await self . find_by_col ( ** { self ._pk_column : pk } )
2232
2333 @final
2434 async def find_by_col (self , ** kwargs ) -> _MT :
@@ -59,15 +69,31 @@ async def save(self, item: Base):
5969 async with self ._session_factory () as session :
6070 session .add (item )
6171 await session .commit ()
72+ await session .refresh (item )
73+
74+ async def update_by_pk (self , pk , req : dict ):
75+ if not await self .is_exists (** {self ._pk_column : pk }):
76+ raise NotFoundException ()
77+
78+ item = await self .find_by_pk (pk )
79+
80+ session : AsyncSession
81+ async with self ._session_factory () as session :
82+ for k , v in req .items ():
83+ if v is not None :
84+ setattr (item , k , v )
85+
86+ await session .commit ()
87+ await session .refresh (item )
6288
6389
6490class SyncRepository (Protocol [_MT ]):
6591 _model : _MT
6692 _session_factory : Callable [..., AbstractContextManager ]
93+ _pk_column : str
6794
68- @abstractmethod
6995 def find_by_pk (self , pk ) -> _MT :
70- raise NotImplementedError ( "Required implementation" )
96+ return self . find_by_col ( ** { self . _pk_column : pk } )
7197
7298 @final
7399 def find_by_col (self , ** kwargs ) -> _MT :
@@ -83,7 +109,7 @@ def _gen_query_for_param(self, session: Session, **kwargs) -> Query:
83109 query = session .query (self ._model )
84110 if kwargs :
85111 for key , value in kwargs .items ():
86- query = query .filter (getattr (self ._model . key ) == value )
112+ query = query .filter (getattr (self ._model , key ) == value )
87113 return query
88114
89115 @final
0 commit comments