Skip to content

Commit 2766d0a

Browse files
committed
Remove error exception..
1 parent c3e18bd commit 2766d0a

4 files changed

Lines changed: 65 additions & 40 deletions

File tree

pymfdata/common/__init__.py

Whitespace-only changes.

pymfdata/common/errors/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

pymfdata/mongodb/repository.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from abc import ABC
22
from bson import ObjectId
3-
from typing import final
3+
from typing import final, Optional
44

55
from pymfdata.mongodb.connection import AsyncMotor
6-
from pymfdata.common.errors import NotFoundException
76

87

98
class AsyncRepository(ABC):
109
def __init__(self, collection_name: str, motor: AsyncMotor) -> None:
1110
self._collection = motor.client[motor.db_name][collection_name]
1211

1312
@final
14-
async def delete_by_id(self, item_id: str):
13+
async def delete_by_id(self, item_id: str) -> bool:
1514
row = await self._collection.delete_one({"_id": ObjectId(item_id)})
1615
if not row:
17-
raise NotFoundException()
16+
return False
17+
18+
return True
1819

1920
@final
2021
async def find_all(self):
@@ -24,10 +25,10 @@ async def find_all(self):
2425
return results
2526

2627
@final
27-
async def find_by_id(self, item_id: str):
28+
async def find_by_id(self, item_id: str) -> Optional[dict]:
2829
row = await self._collection.find_one({"_id": ObjectId(item_id)})
2930
if not row:
30-
raise NotFoundException()
31+
return None
3132

3233
return row
3334

pymfdata/rdb/repository.py

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
from contextlib import AbstractAsyncContextManager, AbstractContextManager
2-
from typing import AsyncIterator, Callable, final, Iterator, Protocol, TypeVar
2+
from typing import Callable, final, Iterator, List, Protocol, TypeVar, Optional
33
from sqlalchemy.ext.asyncio import AsyncSession
44
from sqlalchemy.future import select
55
from sqlalchemy.orm import Session, Query
66
from sqlalchemy.sql.selectable import Select
77

88
from pymfdata.rdb.connection import Base
9-
from pymfdata.common.errors import NotFoundException
109

11-
_MT = TypeVar("_MT", bound=Base)
10+
_MT = TypeVar("_MT", bound=Base) # Model Type
11+
_T = TypeVar("_T") # Primary key Type
1212

1313

14-
class AsyncRepository(Protocol[_MT]):
14+
class AsyncRepository(Protocol[_MT, _T]):
1515
_model: _MT
1616
_session_factory: Callable[..., AbstractAsyncContextManager]
1717
_pk_column: str
1818

19-
async def delete_by_pk(self, pk):
20-
if not await self.is_exists(**{self._pk_column: pk}):
21-
raise NotFoundException()
22-
19+
async def delete_by_pk(self, pk: _T) -> bool:
2320
item = await self.find_by_pk(pk)
21+
if item is not None:
22+
session: AsyncSession
23+
async with self._session_factory() as session:
24+
await session.delete(item)
25+
await session.commit()
2426

25-
session: AsyncSession
26-
async with self._session_factory() as session:
27-
await session.delete(item)
28-
await session.commit()
27+
return True
28+
return False
2929

30-
async def find_by_pk(self, pk) -> _MT:
30+
async def find_by_pk(self, pk: _T) -> Optional[_MT]:
3131
return await self.find_by_col(**{self._pk_column: pk})
3232

3333
@final
34-
async def find_by_col(self, **kwargs) -> _MT:
34+
async def find_by_col(self, **kwargs) -> Optional[_MT]:
3535
if not await self.is_exists(**kwargs):
36-
raise NotFoundException()
36+
return None
3737

3838
session: AsyncSession
3939
async with self._session_factory() as session:
@@ -49,7 +49,7 @@ def _gen_stmt_for_param(self, **kwargs) -> Select:
4949
return stmt
5050

5151
@final
52-
async def find_all(self, **kwargs) -> AsyncIterator[_MT]:
52+
async def find_all(self, **kwargs) -> List[_MT]:
5353
session: AsyncSession
5454
async with self._session_factory() as session:
5555
stmt = self._gen_stmt_for_param(**kwargs)
@@ -71,34 +71,45 @@ async def save(self, item: Base):
7171
await session.commit()
7272
await session.refresh(item)
7373

74-
async def update_by_pk(self, pk, req: dict):
75-
if not await self.is_exists(**{self._pk_column: pk}):
76-
raise NotFoundException()
77-
74+
async def update_by_pk(self, pk: _T, req: dict) -> bool:
7875
item = await self.find_by_pk(pk)
76+
if item is not None:
77+
session: AsyncSession
78+
async with self._session_factory() as session:
79+
for k, v in req.items():
80+
if v is not None:
81+
setattr(item, k, v)
7982

80-
session: AsyncSession
81-
async with self._session_factory() as session:
82-
for k, v in req.items():
83-
if v is not None:
84-
setattr(item, k, v)
83+
await session.commit()
84+
await session.refresh(item)
8585

86-
await session.commit()
87-
await session.refresh(item)
86+
return True
87+
return False
8888

8989

90-
class SyncRepository(Protocol[_MT]):
90+
class SyncRepository(Protocol[_MT, _T]):
9191
_model: _MT
9292
_session_factory: Callable[..., AbstractContextManager]
9393
_pk_column: str
9494

95-
def find_by_pk(self, pk) -> _MT:
95+
def delete_by_pk(self, pk: _T) -> bool:
96+
item = self.find_by_pk(pk)
97+
if item is not None:
98+
session: Session
99+
with self._session_factory() as session:
100+
session.delete(item)
101+
session.commit()
102+
103+
return True
104+
return False
105+
106+
def find_by_pk(self, pk: _T) -> Optional[_MT]:
96107
return self.find_by_col(**{self._pk_column: pk})
97108

98109
@final
99-
def find_by_col(self, **kwargs) -> _MT:
110+
def find_by_col(self, **kwargs) -> Optional[_MT]:
100111
if not self.is_exists(**kwargs):
101-
raise NotFoundException()
112+
return None
102113

103114
with self._session_factory() as session:
104115
query = self._gen_query_for_param(session, **kwargs)
@@ -121,8 +132,8 @@ def find_all(self, **kwargs) -> Iterator[_MT]:
121132

122133
@final
123134
def is_exists(self, **kwargs) -> bool:
135+
session: Session
124136
with self._session_factory() as session:
125-
session: Session
126137
return session.query(self._gen_query_for_param(session, **kwargs).exists()).scalar()
127138

128139
@final
@@ -131,3 +142,18 @@ def save(self, item: Base):
131142
with self._session_factory() as session:
132143
session.add(item)
133144
session.commit()
145+
146+
def update_by_pk(self, pk: _T, req: dict) -> bool:
147+
item = self.find_by_pk(pk)
148+
if item is not None:
149+
session: Session
150+
with self._session_factory() as session:
151+
for k, v in req.items():
152+
if v is not None:
153+
setattr(item, k, v)
154+
155+
await session.commit()
156+
await session.refresh(item)
157+
158+
return True
159+
return False

0 commit comments

Comments
 (0)