Skip to content

Commit 0d4dfc6

Browse files
committed
make for query columns
1 parent 2f3a293 commit 0d4dfc6

File tree

20 files changed

+368
-223
lines changed

20 files changed

+368
-223
lines changed

fastapi-postgres/app/service/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _inject_user_filter(
3232
return Depends(_inject_user_filter)
3333

3434
@classmethod
35-
async def authenticate_user(cls, _id: int, phone: int) -> Token:
35+
async def authenticate_user(cls, _id: int, phone: str) -> Token:
3636
filter_query = FilterQuery(
3737
query=[
3838
DBQuery(key="id", opt=DBOperator.eq, value=_id),

fastapi-postgres/app/service/create_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ class FactoryModel:
1616
def teacher(cls) -> dict:
1717
return {
1818
"id": randint(10, 900),
19-
"phone": float(randint(10_000_000, 9_999_999_999)),
19+
"phone": str(randint(10_000_000, 9_999_999_999)),
2020
}
2121

2222
@classmethod
2323
def student(cls, teacher_id: int | None = None) -> dict:
2424
data = {
2525
"name": fake.name(),
2626
"grade": choice(Grade.values()),
27-
"phone": float(randint(10**9, 10**10)),
27+
"phone": str(randint(10**9, 10**10)),
2828
}
2929
if teacher_id is not None:
3030
data["teacher_id"] = teacher_id

fastapi-postgres/common/db_model/__init__.py

Lines changed: 87 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,20 @@
1-
import asyncio
21
from contextlib import asynccontextmanager
32
from typing import Any, Type
43

54
from pydantic import BaseModel
6-
from sqlalchemy.exc import IntegrityError, OperationalError
5+
from sqlmodel import and_, update
6+
from sqlalchemy.exc import IntegrityError
77
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
8-
from sqlalchemy.orm import selectinload
9-
from sqlmodel import select
108
from sqlmodel.ext.asyncio.session import AsyncSession
119

1210
from common.config import conf
11+
from common.db_model.query_db import QueryDB
1312
from common.enums import DBOperator, ModelType
1413
from common.serializers import BaseTable, DBQuery, FilterQuery, RowLike
1514
from common.utils import BaseUtils
1615

17-
opts = {
18-
DBOperator.eq: "__eq__",
19-
DBOperator.ne: "__ne__",
20-
DBOperator.lt: "__lt__",
21-
DBOperator.le: "__le__",
22-
DBOperator.gt: "__gt__",
23-
DBOperator.ge: "__ge__",
24-
DBOperator.like: "like",
25-
DBOperator.ilike: "ilike", # PostgreSQL only
26-
DBOperator.in_: "in_",
27-
DBOperator.not_in: "notin_",
28-
DBOperator.is_null: "is_",
29-
}
30-
31-
32-
class DBModel(BaseUtils):
16+
17+
class DBModel(BaseUtils, QueryDB):
3318
table: Type[BaseTable] | None = None
3419
payload: Type[BaseModel] | None = None
3520
model_type: ModelType | None = None
@@ -66,6 +51,29 @@ async def get_session(cls):
6651
finally:
6752
await session.close()
6853

54+
@classmethod
55+
def _raise_db_error(cls, ex, row=None):
56+
dump = (
57+
row
58+
if isinstance(row, dict)
59+
else (row.model_dump() if hasattr(row, "model_dump") else None)
60+
)
61+
if dump:
62+
cls.logger.error(f"{dump}: {ex}")
63+
else:
64+
cls.logger.error(str(ex))
65+
if isinstance(ex, IntegrityError):
66+
orig = str(getattr(ex, "orig", ""))
67+
if "duplicate key value" in orig:
68+
field = "unique"
69+
if "Key (" in orig:
70+
try:
71+
field = orig.split("Key (")[1].split(")=")[0]
72+
except Exception:
73+
pass
74+
cls.error_400(details=f"unique:{field}")
75+
cls.error_400(details=ex)
76+
6977
@classmethod
7078
async def add_update(cls, row: BaseTable | list[BaseTable]):
7179
async with cls.get_session() as session:
@@ -85,100 +93,78 @@ async def add_update(cls, row: BaseTable | list[BaseTable]):
8593
await session.refresh(row)
8694
return row
8795
except Exception as ex:
88-
dump = "" if isinstance(row, list) else row.model_dump()
89-
cls.logger.error(f"{dump}: {ex}")
90-
if isinstance(ex, IntegrityError):
91-
orig = str(getattr(ex, "orig", ""))
92-
if "duplicate key value" in orig:
93-
field = "unique"
94-
if "Key (" in orig:
95-
try:
96-
field = orig.split("Key (")[1].split(")=")[0]
97-
except Exception:
98-
pass
99-
cls.error_400(details=f"unique:{field}")
100-
cls.error_400(details=ex)
96+
cls._raise_db_error(ex, row)
97+
98+
@classmethod
99+
async def update_exist_bulk(cls, filter_query: FilterQuery, values: dict) -> int:
100+
statement = update(cls.table).values(**values)
101+
102+
conditions = cls.generate_where_query(filter_query.query)
103+
104+
if conditions:
105+
statement = statement.where(and_(*conditions))
106+
107+
try:
108+
async with cls.get_session() as session:
109+
results = await session.exec(statement)
110+
return results.rowcount
111+
except Exception as ex:
112+
cls._raise_db_error(ex, values)
101113

102114
@classmethod
103115
async def add_or_find_update_single(
104-
cls,
105-
add_or_id: str | int,
106-
body: BaseModel | dict[str, Any],
107-
**kwargs: Any,
116+
cls,
117+
add_or_id: str | int,
118+
body: BaseModel,
119+
**kwargs: Any,
108120
) -> BaseTable:
109121
user_auth = kwargs.get("user_auth")
110122

111123
if add_or_id != "add":
112-
db_obj = await cls.get_by_id(_id=add_or_id, **kwargs)
124+
_id = int(add_or_id)
125+
query = [DBQuery(key=cls.table.id.key, opt=DBOperator.eq, value=_id)]
126+
127+
if user_auth:
128+
query.append(
129+
DBQuery(
130+
key=conf.AUTH_PARENT_FIELD,
131+
opt=DBOperator.eq,
132+
value=user_auth.id,
133+
)
134+
)
135+
136+
body_dict = body.model_dump()
137+
values = {
138+
key: value for key, value in body_dict.items() if key not in {"id"}
139+
}
113140

114-
if db_obj and (
115-
not user_auth or user_auth.id == getattr(db_obj, conf.AUTH_PARENT_FIELD)
116-
):
117-
cls.set_elements_by_dict(db_obj, body, exclude_items=["id"])
118-
else:
141+
updated = await cls.update_exist_bulk(
142+
filter_query=FilterQuery(query=query),
143+
values=values,
144+
)
145+
146+
if not updated:
119147
cls.error_400(details="not found")
120148

121-
else:
122-
db_obj = cls.table()
123-
cls.set_elements_by_dict(db_obj, body)
149+
body_dict["id"] = _id
150+
new_obj = cls.table(**body_dict)
124151

152+
else:
153+
db_obj = cls.table(**body.model_dump())
125154
if user_auth and hasattr(db_obj, conf.AUTH_PARENT_FIELD):
126155
setattr(db_obj, conf.AUTH_PARENT_FIELD, user_auth.id)
127156

128-
new_obj = await cls.add_update(row=db_obj)
129-
return new_obj
130-
131-
# Generate SQLAlchemy filter conditions with AND logic from a list of DBQuery objects
132-
@classmethod
133-
def generate_where_query(cls, query: list[DBQuery]):
134-
# Build a tuple of SQLAlchemy expressions by applying the operator (e.g., ==, >=) on each field
135-
ans = tuple(
136-
getattr(getattr(cls.table, q.key), opts.get(q.opt, q.opt))(q.value)
137-
for q in query
138-
if hasattr(cls.table, q.key)
139-
and hasattr(getattr(cls.table, q.key), opts.get(q.opt, q.opt))
140-
)
141-
return ans
142-
143-
@classmethod
144-
def build_query(cls, filter_query: FilterQuery, offset: int = 0, limit: int = 1000):
145-
statement = select(cls.table)
146-
147-
if filter_query.query:
148-
statement = statement.where(*cls.generate_where_query(filter_query.query))
149-
150-
if filter_query.relation_model:
151-
statement = statement.options(selectinload("*"))
157+
new_obj = await cls.add_update(row=db_obj)
152158

153-
if filter_query.sort:
154-
try:
155-
field, direction = filter_query.sort.split(":")
156-
col = getattr(cls.table, field, None)
157-
if col:
158-
direction = direction.lower()
159-
statement = statement.order_by(
160-
col.asc() if direction == "asc" else col.desc()
161-
)
162-
except Exception:
163-
raise ValueError(
164-
f"Invalid sort format: '{filter_query.sort}'. Expected 'field:asc' or 'field:desc'"
165-
)
166-
167-
if offset:
168-
statement = statement.offset(offset)
169-
170-
if limit:
171-
statement = statement.limit(limit)
172-
173-
return statement
159+
return new_obj
174160

175161
@classmethod
176162
async def fetch_rows(
177-
cls,
178-
filter_query: FilterQuery = FilterQuery(),
179-
offset: int = 0,
180-
limit: int = 1000,
181-
as_dict: bool = False,
163+
cls,
164+
filter_query: FilterQuery = FilterQuery(),
165+
offset: int = 0,
166+
limit: int = 1000,
167+
as_dict: bool = False,
182168
) -> RowLike | list[RowLike]:
183169
statement = cls.build_query(filter_query, offset, limit)
184170

@@ -204,19 +190,22 @@ async def fetch_rows(
204190

205191
@classmethod
206192
async def delete_rows(
207-
cls, filter_query: FilterQuery = FilterQuery(), offset: int = 0
193+
cls, filter_query: FilterQuery = FilterQuery(), offset: int = 0
208194
):
209195
try:
210196
async with cls.get_session() as session:
197+
filter_query.columns = [cls.table.id]
211198
select_stmt = cls.build_query(filter_query, offset, limit=0)
212199
results = (await session.exec(select_stmt)).all()
213200
if not results:
214201
cls.error_400(details="not delete nothing")
202+
215203
for obj in results:
216204
await session.delete(obj)
205+
return results
217206

218207
except Exception as ex:
219-
cls.error_400(details=ex)
208+
cls._raise_db_error(ex)
220209

221210
@classmethod
222211
async def get_by_id(cls, _id: str | int, **kwargs):

0 commit comments

Comments
 (0)