11from contextlib import AbstractAsyncContextManager , AbstractContextManager
2- from typing import Callable , final , Iterator , List , Protocol , TypeVar , Optional
2+ from typing import Callable , final , Iterator , get_args , List , Protocol , Optional , Type , TypeVar
33from sqlalchemy .ext .asyncio import AsyncSession
44from sqlalchemy .future import select
5+ from sqlalchemy .inspection import inspect
56from sqlalchemy .orm import Session , Query
67from sqlalchemy .sql .selectable import Select
78
1213
1314
1415class AsyncRepository (Protocol [_MT , _T ]):
15- _model : _MT
1616 _session_factory : Callable [..., AbstractAsyncContextManager ]
17- _pk_column : str
17+
18+ @property
19+ def _model (self ):
20+ return get_args (self .__orig_bases__ [0 ])[0 ]
21+
22+ @property
23+ def _pk_column (self ) -> str :
24+ return inspect (self ._model ).primary_key [0 ].name
1825
1926 async def delete_by_pk (self , pk : _T ) -> bool :
20- item = await self . find_by_pk ( pk )
21- if item is not None :
22- session : AsyncSession
23- async with self . _session_factory () as session :
27+ session : AsyncSession
28+ async with self . _session_factory () as session :
29+ item = await self . find_by_pk ( session , pk )
30+ if item is not None :
2431 await session .delete (item )
2532 await session .commit ()
2633
27- return True
28- return False
34+ return True
2935
30- async def find_by_pk (self , pk : _T ) -> Optional [_MT ]:
31- return await self .find_by_col (** {self ._pk_column : pk })
36+ return False
3237
33- @final
34- async def find_by_col (self , ** kwargs ) -> Optional [_MT ]:
35- if not await self .is_exists (** kwargs ):
36- return None
38+ async def find_by_pk (self , session : AsyncSession , pk : _T ) -> Optional [_MT ]:
39+ return await self .find_by_col (session , ** {self ._pk_column : pk })
3740
38- session : AsyncSession
39- async with self . _session_factory () as session :
40- item = await session .execute (self ._gen_stmt_for_param (** kwargs ))
41- return item .unique ().scalars ().one ()
41+ @ final
42+ async def find_by_col ( self , session : AsyncSession , ** kwargs ) -> Optional [ _MT ] :
43+ item = await session .execute (self ._gen_stmt_for_param (** kwargs ))
44+ return item .unique ().scalars ().one_or_none ()
4245
4346 @final
4447 def _gen_stmt_for_param (self , ** kwargs ) -> Select :
@@ -61,59 +64,66 @@ async def find_all(self, **kwargs) -> List[_MT]:
6164 async def is_exists (self , ** kwargs ) -> bool :
6265 session : AsyncSession
6366 async with self ._session_factory () as session :
64- return await session .execute (self ._gen_stmt_for_param (** kwargs ).exists ().select ())
67+ result = await session .execute (self ._gen_stmt_for_param (** kwargs ).exists ().select ())
68+ return result .scalar ()
6569
6670 @final
67- async def save (self , item : Base ):
71+ async def save (self , item : _MT ):
6872 session : AsyncSession
6973 async with self ._session_factory () as session :
7074 session .add (item )
7175 await session .commit ()
7276 await session .refresh (item )
7377
7478 async def update_by_pk (self , pk : _T , req : dict ) -> bool :
75- item = await self . find_by_pk ( pk )
76- if item is not None :
77- session : AsyncSession
78- async with self . _session_factory () as session :
79+ session : AsyncSession
80+ async with self . _session_factory () as session :
81+ item = await self . find_by_pk ( session , pk )
82+ if item is not None :
7983 for k , v in req .items ():
8084 if v is not None :
8185 setattr (item , k , v )
8286
8387 await session .commit ()
8488 await session .refresh (item )
8589
86- return True
87- return False
90+ return True
91+ return False
8892
8993
9094class SyncRepository (Protocol [_MT , _T ]):
91- _model : _MT
9295 _session_factory : Callable [..., AbstractContextManager ]
93- _pk_column : str
96+
97+ @property
98+ def _model (self ):
99+ return get_args (self .__orig_bases__ [0 ])[0 ]
100+
101+ @property
102+ def _pk_column (self ) -> str :
103+ return inspect (self ._model ).primary_key [0 ].name
104+
105+ @final
106+ def count (self , ** kwargs ) -> int :
107+ return self ._gen_query_for_param (** kwargs ).count ()
94108
95109 def delete_by_pk (self , pk : _T ) -> bool :
96- item = self . find_by_pk ( pk )
97- if item is not None :
98- session : Session
99- with self . _session_factory () as session :
110+ session : Session
111+ with self . _session_factory () as session :
112+ item = self . find_by_pk ( session , pk )
113+ if item is not None :
100114 session .delete (item )
101115 session .commit ()
102116
103- return True
104- return False
117+ return True
118+ return False
105119
106- def find_by_pk (self , pk : _T ) -> Optional [_MT ]:
107- return self .find_by_col (** {self ._pk_column : pk })
120+ def find_by_pk (self , session : Session , pk : _T ) -> Optional [_MT ]:
121+ return self .find_by_col (session , ** {self ._pk_column : pk })
108122
109123 @final
110- def find_by_col (self , ** kwargs ) -> Optional [_MT ]:
111- if not self .is_exists (** kwargs ):
112- return None
113-
114- with self ._session_factory () as session :
115- query = self ._gen_query_for_param (session , ** kwargs )
116- return query .one ()
124+ def find_by_col (self , session : Session , ** kwargs ) -> Optional [_MT ]:
125+ query = self ._gen_query_for_param (session , ** kwargs )
126+ return query .one_or_none ()
117127
118128 @final
119129 def _gen_query_for_param (self , session : Session , ** kwargs ) -> Query :
@@ -144,16 +154,16 @@ def save(self, item: Base):
144154 session .commit ()
145155
146156 def update_by_pk (self , pk : _T , req : dict ) -> bool :
147- item = self . find_by_pk ( pk )
148- if item is not None :
149- session : Session
150- with self . _session_factory () as session :
157+ session : Session
158+ with self . _session_factory () as session :
159+ item = self . find_by_pk ( session , pk )
160+ if item is not None :
151161 for k , v in req .items ():
152162 if v is not None :
153163 setattr (item , k , v )
154164
155- await session .commit ()
156- await session .refresh (item )
165+ session .commit ()
166+ session .refresh (item )
157167
158- return True
159- return False
168+ return True
169+ return False
0 commit comments