1- import asyncio
21from contextlib import asynccontextmanager
32from typing import Any , Type
43
54from pydantic import BaseModel
6- from sqlalchemy .exc import IntegrityError , OperationalError
5+ from sqlmodel import and_ , update
6+ from sqlalchemy .exc import IntegrityError
77from sqlalchemy .ext .asyncio import async_sessionmaker , create_async_engine
8- from sqlalchemy .orm import selectinload
9- from sqlmodel import select
108from sqlmodel .ext .asyncio .session import AsyncSession
119
1210from common .config import conf
11+ from common .db_model .query_db import QueryDB
1312from common .enums import DBOperator , ModelType
1413from common .serializers import BaseTable , DBQuery , FilterQuery , RowLike
1514from common .utils import BaseUtils
1615
17- opts = {
18- DBOperator .eq : "__eq__" ,
19- DBOperator .ne : "__ne__" ,
20- DBOperator .lt : "__lt__" ,
21- DBOperator .le : "__le__" ,
22- DBOperator .gt : "__gt__" ,
23- DBOperator .ge : "__ge__" ,
24- DBOperator .like : "like" ,
25- DBOperator .ilike : "ilike" , # PostgreSQL only
26- DBOperator .in_ : "in_" ,
27- DBOperator .not_in : "notin_" ,
28- DBOperator .is_null : "is_" ,
29- }
30-
31-
32- class DBModel (BaseUtils ):
16+
17+ class DBModel (BaseUtils , QueryDB ):
3318 table : Type [BaseTable ] | None = None
3419 payload : Type [BaseModel ] | None = None
3520 model_type : ModelType | None = None
@@ -66,6 +51,29 @@ async def get_session(cls):
6651 finally :
6752 await session .close ()
6853
54+ @classmethod
55+ def _raise_db_error (cls , ex , row = None ):
56+ dump = (
57+ row
58+ if isinstance (row , dict )
59+ else (row .model_dump () if hasattr (row , "model_dump" ) else None )
60+ )
61+ if dump :
62+ cls .logger .error (f"{ dump } : { ex } " )
63+ else :
64+ cls .logger .error (str (ex ))
65+ if isinstance (ex , IntegrityError ):
66+ orig = str (getattr (ex , "orig" , "" ))
67+ if "duplicate key value" in orig :
68+ field = "unique"
69+ if "Key (" in orig :
70+ try :
71+ field = orig .split ("Key (" )[1 ].split (")=" )[0 ]
72+ except Exception :
73+ pass
74+ cls .error_400 (details = f"unique:{ field } " )
75+ cls .error_400 (details = ex )
76+
6977 @classmethod
7078 async def add_update (cls , row : BaseTable | list [BaseTable ]):
7179 async with cls .get_session () as session :
@@ -85,100 +93,78 @@ async def add_update(cls, row: BaseTable | list[BaseTable]):
8593 await session .refresh (row )
8694 return row
8795 except Exception as ex :
88- dump = "" if isinstance (row , list ) else row .model_dump ()
89- cls .logger .error (f"{ dump } : { ex } " )
90- if isinstance (ex , IntegrityError ):
91- orig = str (getattr (ex , "orig" , "" ))
92- if "duplicate key value" in orig :
93- field = "unique"
94- if "Key (" in orig :
95- try :
96- field = orig .split ("Key (" )[1 ].split (")=" )[0 ]
97- except Exception :
98- pass
99- cls .error_400 (details = f"unique:{ field } " )
100- cls .error_400 (details = ex )
96+ cls ._raise_db_error (ex , row )
97+
98+ @classmethod
99+ async def update_exist_bulk (cls , filter_query : FilterQuery , values : dict ) -> int :
100+ statement = update (cls .table ).values (** values )
101+
102+ conditions = cls .generate_where_query (filter_query .query )
103+
104+ if conditions :
105+ statement = statement .where (and_ (* conditions ))
106+
107+ try :
108+ async with cls .get_session () as session :
109+ results = await session .exec (statement )
110+ return results .rowcount
111+ except Exception as ex :
112+ cls ._raise_db_error (ex , values )
101113
102114 @classmethod
103115 async def add_or_find_update_single (
104- cls ,
105- add_or_id : str | int ,
106- body : BaseModel | dict [ str , Any ] ,
107- ** kwargs : Any ,
116+ cls ,
117+ add_or_id : str | int ,
118+ body : BaseModel ,
119+ ** kwargs : Any ,
108120 ) -> BaseTable :
109121 user_auth = kwargs .get ("user_auth" )
110122
111123 if add_or_id != "add" :
112- db_obj = await cls .get_by_id (_id = add_or_id , ** kwargs )
124+ _id = int (add_or_id )
125+ query = [DBQuery (key = cls .table .id .key , opt = DBOperator .eq , value = _id )]
126+
127+ if user_auth :
128+ query .append (
129+ DBQuery (
130+ key = conf .AUTH_PARENT_FIELD ,
131+ opt = DBOperator .eq ,
132+ value = user_auth .id ,
133+ )
134+ )
135+
136+ body_dict = body .model_dump ()
137+ values = {
138+ key : value for key , value in body_dict .items () if key not in {"id" }
139+ }
113140
114- if db_obj and (
115- not user_auth or user_auth .id == getattr (db_obj , conf .AUTH_PARENT_FIELD )
116- ):
117- cls .set_elements_by_dict (db_obj , body , exclude_items = ["id" ])
118- else :
141+ updated = await cls .update_exist_bulk (
142+ filter_query = FilterQuery (query = query ),
143+ values = values ,
144+ )
145+
146+ if not updated :
119147 cls .error_400 (details = "not found" )
120148
121- else :
122- db_obj = cls .table ()
123- cls .set_elements_by_dict (db_obj , body )
149+ body_dict ["id" ] = _id
150+ new_obj = cls .table (** body_dict )
124151
152+ else :
153+ db_obj = cls .table (** body .model_dump ())
125154 if user_auth and hasattr (db_obj , conf .AUTH_PARENT_FIELD ):
126155 setattr (db_obj , conf .AUTH_PARENT_FIELD , user_auth .id )
127156
128- new_obj = await cls .add_update (row = db_obj )
129- return new_obj
130-
131- # Generate SQLAlchemy filter conditions with AND logic from a list of DBQuery objects
132- @classmethod
133- def generate_where_query (cls , query : list [DBQuery ]):
134- # Build a tuple of SQLAlchemy expressions by applying the operator (e.g., ==, >=) on each field
135- ans = tuple (
136- getattr (getattr (cls .table , q .key ), opts .get (q .opt , q .opt ))(q .value )
137- for q in query
138- if hasattr (cls .table , q .key )
139- and hasattr (getattr (cls .table , q .key ), opts .get (q .opt , q .opt ))
140- )
141- return ans
142-
143- @classmethod
144- def build_query (cls , filter_query : FilterQuery , offset : int = 0 , limit : int = 1000 ):
145- statement = select (cls .table )
146-
147- if filter_query .query :
148- statement = statement .where (* cls .generate_where_query (filter_query .query ))
149-
150- if filter_query .relation_model :
151- statement = statement .options (selectinload ("*" ))
157+ new_obj = await cls .add_update (row = db_obj )
152158
153- if filter_query .sort :
154- try :
155- field , direction = filter_query .sort .split (":" )
156- col = getattr (cls .table , field , None )
157- if col :
158- direction = direction .lower ()
159- statement = statement .order_by (
160- col .asc () if direction == "asc" else col .desc ()
161- )
162- except Exception :
163- raise ValueError (
164- f"Invalid sort format: '{ filter_query .sort } '. Expected 'field:asc' or 'field:desc'"
165- )
166-
167- if offset :
168- statement = statement .offset (offset )
169-
170- if limit :
171- statement = statement .limit (limit )
172-
173- return statement
159+ return new_obj
174160
175161 @classmethod
176162 async def fetch_rows (
177- cls ,
178- filter_query : FilterQuery = FilterQuery (),
179- offset : int = 0 ,
180- limit : int = 1000 ,
181- as_dict : bool = False ,
163+ cls ,
164+ filter_query : FilterQuery = FilterQuery (),
165+ offset : int = 0 ,
166+ limit : int = 1000 ,
167+ as_dict : bool = False ,
182168 ) -> RowLike | list [RowLike ]:
183169 statement = cls .build_query (filter_query , offset , limit )
184170
@@ -204,19 +190,22 @@ async def fetch_rows(
204190
205191 @classmethod
206192 async def delete_rows (
207- cls , filter_query : FilterQuery = FilterQuery (), offset : int = 0
193+ cls , filter_query : FilterQuery = FilterQuery (), offset : int = 0
208194 ):
209195 try :
210196 async with cls .get_session () as session :
197+ filter_query .columns = [cls .table .id ]
211198 select_stmt = cls .build_query (filter_query , offset , limit = 0 )
212199 results = (await session .exec (select_stmt )).all ()
213200 if not results :
214201 cls .error_400 (details = "not delete nothing" )
202+
215203 for obj in results :
216204 await session .delete (obj )
205+ return results
217206
218207 except Exception as ex :
219- cls .error_400 ( details = ex )
208+ cls ._raise_db_error ( ex )
220209
221210 @classmethod
222211 async def get_by_id (cls , _id : str | int , ** kwargs ):
0 commit comments