11from contextlib import AbstractAsyncContextManager , AbstractContextManager
2- from typing import AsyncIterator , Callable , final , Iterator , Protocol , TypeVar
2+ from typing import Callable , final , Iterator , List , Protocol , TypeVar , Optional
33from sqlalchemy .ext .asyncio import AsyncSession
44from sqlalchemy .future import select
55from sqlalchemy .orm import Session , Query
66from sqlalchemy .sql .selectable import Select
77
88from pymfdata .rdb .connection import Base
9- from pymfdata .common .errors import NotFoundException
109
11- _MT = TypeVar ("_MT" , bound = Base )
10+ _MT = TypeVar ("_MT" , bound = Base ) # Model Type
11+ _T = TypeVar ("_T" ) # Primary key Type
1212
1313
14- class AsyncRepository (Protocol [_MT ]):
14+ class AsyncRepository (Protocol [_MT , _T ]):
1515 _model : _MT
1616 _session_factory : Callable [..., AbstractAsyncContextManager ]
1717 _pk_column : str
1818
19- async def delete_by_pk (self , pk ):
20- if not await self .is_exists (** {self ._pk_column : pk }):
21- raise NotFoundException ()
22-
19+ async def delete_by_pk (self , pk : _T ) -> bool :
2320 item = await self .find_by_pk (pk )
21+ if item is not None :
22+ session : AsyncSession
23+ async with self ._session_factory () as session :
24+ await session .delete (item )
25+ await session .commit ()
2426
25- session : AsyncSession
26- async with self ._session_factory () as session :
27- await session .delete (item )
28- await session .commit ()
27+ return True
28+ return False
2929
30- async def find_by_pk (self , pk ) -> _MT :
30+ async def find_by_pk (self , pk : _T ) -> Optional [ _MT ] :
3131 return await self .find_by_col (** {self ._pk_column : pk })
3232
3333 @final
34- async def find_by_col (self , ** kwargs ) -> _MT :
34+ async def find_by_col (self , ** kwargs ) -> Optional [ _MT ] :
3535 if not await self .is_exists (** kwargs ):
36- raise NotFoundException ()
36+ return None
3737
3838 session : AsyncSession
3939 async with self ._session_factory () as session :
@@ -49,7 +49,7 @@ def _gen_stmt_for_param(self, **kwargs) -> Select:
4949 return stmt
5050
5151 @final
52- async def find_all (self , ** kwargs ) -> AsyncIterator [_MT ]:
52+ async def find_all (self , ** kwargs ) -> List [_MT ]:
5353 session : AsyncSession
5454 async with self ._session_factory () as session :
5555 stmt = self ._gen_stmt_for_param (** kwargs )
@@ -71,34 +71,45 @@ async def save(self, item: Base):
7171 await session .commit ()
7272 await session .refresh (item )
7373
74- async def update_by_pk (self , pk , req : dict ):
75- if not await self .is_exists (** {self ._pk_column : pk }):
76- raise NotFoundException ()
77-
74+ async def update_by_pk (self , pk : _T , req : dict ) -> bool :
7875 item = await self .find_by_pk (pk )
76+ if item is not None :
77+ session : AsyncSession
78+ async with self ._session_factory () as session :
79+ for k , v in req .items ():
80+ if v is not None :
81+ setattr (item , k , v )
7982
80- session : AsyncSession
81- async with self ._session_factory () as session :
82- for k , v in req .items ():
83- if v is not None :
84- setattr (item , k , v )
83+ await session .commit ()
84+ await session .refresh (item )
8585
86- await session . commit ()
87- await session . refresh ( item )
86+ return True
87+ return False
8888
8989
90- class SyncRepository (Protocol [_MT ]):
90+ class SyncRepository (Protocol [_MT , _T ]):
9191 _model : _MT
9292 _session_factory : Callable [..., AbstractContextManager ]
9393 _pk_column : str
9494
95- def find_by_pk (self , pk ) -> _MT :
95+ def delete_by_pk (self , pk : _T ) -> bool :
96+ item = self .find_by_pk (pk )
97+ if item is not None :
98+ session : Session
99+ with self ._session_factory () as session :
100+ session .delete (item )
101+ session .commit ()
102+
103+ return True
104+ return False
105+
106+ def find_by_pk (self , pk : _T ) -> Optional [_MT ]:
96107 return self .find_by_col (** {self ._pk_column : pk })
97108
98109 @final
99- def find_by_col (self , ** kwargs ) -> _MT :
110+ def find_by_col (self , ** kwargs ) -> Optional [ _MT ] :
100111 if not self .is_exists (** kwargs ):
101- raise NotFoundException ()
112+ return None
102113
103114 with self ._session_factory () as session :
104115 query = self ._gen_query_for_param (session , ** kwargs )
@@ -121,8 +132,8 @@ def find_all(self, **kwargs) -> Iterator[_MT]:
121132
122133 @final
123134 def is_exists (self , ** kwargs ) -> bool :
135+ session : Session
124136 with self ._session_factory () as session :
125- session : Session
126137 return session .query (self ._gen_query_for_param (session , ** kwargs ).exists ()).scalar ()
127138
128139 @final
@@ -131,3 +142,18 @@ def save(self, item: Base):
131142 with self ._session_factory () as session :
132143 session .add (item )
133144 session .commit ()
145+
146+ def update_by_pk (self , pk : _T , req : dict ) -> bool :
147+ item = self .find_by_pk (pk )
148+ if item is not None :
149+ session : Session
150+ with self ._session_factory () as session :
151+ for k , v in req .items ():
152+ if v is not None :
153+ setattr (item , k , v )
154+
155+ await session .commit ()
156+ await session .refresh (item )
157+
158+ return True
159+ return False
0 commit comments