From 689f02cac9ec13e77eec404f98ca74b267ca4bc8 Mon Sep 17 00:00:00 2001 From: 2jun0 Date: Mon, 1 Apr 2024 23:53:37 +0900 Subject: [PATCH 1/6] Add AsyncSQLModel and AwaitableField --- pyproject.toml | 2 + sqlmodel/ext/asyncio/async_model.py | 69 +++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 sqlmodel/ext/asyncio/async_model.py diff --git a/pyproject.toml b/pyproject.toml index 9da631b985..5f07203236 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ httpx = "0.24.1" dirty-equals = "^0.6.0" typer-cli = "^0.0.13" mkdocs-markdownextradata-plugin = ">=0.1.7,<0.3.0" +pytest-asyncio = "0.21.1" +aiosqlite = "0.19.0" [build-system] requires = ["poetry-core"] diff --git a/sqlmodel/ext/asyncio/async_model.py b/sqlmodel/ext/asyncio/async_model.py new file mode 100644 index 0000000000..bd8ca1ed8c --- /dev/null +++ b/sqlmodel/ext/asyncio/async_model.py @@ -0,0 +1,69 @@ +from typing import Any, ClassVar, Coroutine, Dict, Tuple, Type + +from pydantic._internal._repr import Representation +from sqlalchemy.util.concurrency import greenlet_spawn + +from ... import SQLModel +from ..._compat import get_annotations +from ...main import SQLModelMetaclass + + +class AwaitableFieldInfo(Representation): + def __init__(self, *, field: str): + self.field = field + + +def AwaitableField(*, field: str) -> Any: + return AwaitableFieldInfo(field=field) + + +class AsyncSQLModelMetaclass(SQLModelMetaclass): + __async_sqlmodel_awaitable_fields__: Dict[str, AwaitableFieldInfo] + + def __new__( + cls, + name: str, + bases: Tuple[Type[Any], ...], + class_dict: Dict[str, Any], + **kwargs: Any + ) -> Any: + awaitable_fields: Dict[str, AwaitableFieldInfo] = {} + dict_for_sqlmodel = {} + original_annotations = get_annotations(class_dict) + sqlmodel_annotations = {} + awaitable_fields_annotations = {} + for k, v in class_dict.items(): + if isinstance(v, AwaitableFieldInfo): + awaitable_fields[k] = v + else: + dict_for_sqlmodel[k] = v + for k, v in original_annotations.items(): + if k in awaitable_fields: + awaitable_fields_annotations[k] = v + else: + sqlmodel_annotations[k] = v + + dict_used = { + **dict_for_sqlmodel, + "__async_sqlmodel_awaitable_fields__": awaitable_fields, + "__annotations__": sqlmodel_annotations, + } + return super().__new__(cls, name, bases, dict_used, **kwargs) + + def __init__( + cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any + ) -> None: + for field_name, field_info in cls.__async_sqlmodel_awaitable_fields__.items(): + + def get_awaitable_field( + self, field: str = field_info.field + ) -> Coroutine[Any, Any, Any]: + return greenlet_spawn(getattr, self, field) + + setattr(cls, field_name, property(get_awaitable_field)) # type: ignore + + SQLModelMetaclass.__init__(cls, classname, bases, dict_, **kw) + + +class AsyncSQLModel(SQLModel, metaclass=AsyncSQLModelMetaclass): + __async_sqlmodel_awaitable_fields__: ClassVar[Dict[str, AwaitableFieldInfo]] From dee0d2770a6d7ee65da98f72178aa5db79e9283e Mon Sep 17 00:00:00 2001 From: 2jun0 Date: Mon, 1 Apr 2024 23:54:39 +0900 Subject: [PATCH 2/6] Add Test that use AwaitableFields --- tests/test_awaitable_field.py | 76 +++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/test_awaitable_field.py diff --git a/tests/test_awaitable_field.py b/tests/test_awaitable_field.py new file mode 100644 index 0000000000..ec7bcfd657 --- /dev/null +++ b/tests/test_awaitable_field.py @@ -0,0 +1,76 @@ +from typing import Awaitable, List, Optional + +import pytest +from sqlalchemy.exc import MissingGreenlet +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.util.concurrency import greenlet_spawn + +from sqlmodel import Field, Relationship, SQLModel, select +from sqlmodel.ext.asyncio.async_model import AsyncSQLModel, AwaitableField +from sqlmodel.ext.asyncio.session import AsyncSession + + +@pytest.mark.asyncio +async def test_awaitable_nomral_field(clear_sqlmodel): + class Hero(AsyncSQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + awt_name: Awaitable[str] = AwaitableField(field="name") + awt_age: Awaitable[str] = AwaitableField(field="age") + + hero_deadpond = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_async_engine("sqlite+aiosqlite://") + await greenlet_spawn(SQLModel.metadata.create_all, engine.sync_engine) + + async with AsyncSession(engine) as session: + session.add(hero_deadpond) + await session.commit() + + # loading expired attribute will raise MissingGreenlet error + with pytest.raises(MissingGreenlet): + hero_deadpond.name + + name = await hero_deadpond.awt_name + assert name == "Deadpond" + + +@pytest.mark.asyncio +async def test_awaitable_relation_field(clear_sqlmodel): + class Team(AsyncSQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + + heroes: List["Hero"] = Relationship() + awt_heroes: Awaitable[List["Hero"]] = AwaitableField(field="heroes") + + class Hero(AsyncSQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional[Team] = Relationship(back_populates="heroes") + awt_team: Awaitable[Optional[Team]] = AwaitableField(field="team") + + team_preventers = Team(name="Preventers") + hero_rusty_man = Hero(name="Rusty-Man", team=team_preventers) + + engine = create_async_engine("sqlite+aiosqlite://") + await greenlet_spawn(SQLModel.metadata.create_all, engine.sync_engine) + + async with AsyncSession(engine) as session: + session.add(hero_rusty_man) + await session.commit() + + async with AsyncSession(engine) as session: + hero = (await session.exec(select(Hero).where(Hero.name == "Rusty-Man"))).one() + + # loading lazy loading attribute will raise MissingGreenlet error + with pytest.raises(MissingGreenlet): + hero.team + + team = await hero.awt_team + assert team + assert team.name == "Preventers" From b446b771a5bf16ee43a04c026ee9eb725b6c3511 Mon Sep 17 00:00:00 2001 From: 2jun0 Date: Tue, 2 Apr 2024 00:00:20 +0900 Subject: [PATCH 3/6] Format codes --- sqlmodel/ext/asyncio/async_model.py | 2 +- tests/test_awaitable_field.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sqlmodel/ext/asyncio/async_model.py b/sqlmodel/ext/asyncio/async_model.py index bd8ca1ed8c..4646c16286 100644 --- a/sqlmodel/ext/asyncio/async_model.py +++ b/sqlmodel/ext/asyncio/async_model.py @@ -25,7 +25,7 @@ def __new__( name: str, bases: Tuple[Type[Any], ...], class_dict: Dict[str, Any], - **kwargs: Any + **kwargs: Any, ) -> Any: awaitable_fields: Dict[str, AwaitableFieldInfo] = {} dict_for_sqlmodel = {} diff --git a/tests/test_awaitable_field.py b/tests/test_awaitable_field.py index ec7bcfd657..9c7fba94ff 100644 --- a/tests/test_awaitable_field.py +++ b/tests/test_awaitable_field.py @@ -4,7 +4,6 @@ from sqlalchemy.exc import MissingGreenlet from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.util.concurrency import greenlet_spawn - from sqlmodel import Field, Relationship, SQLModel, select from sqlmodel.ext.asyncio.async_model import AsyncSQLModel, AwaitableField from sqlmodel.ext.asyncio.session import AsyncSession @@ -31,7 +30,7 @@ class Hero(AsyncSQLModel, table=True): # loading expired attribute will raise MissingGreenlet error with pytest.raises(MissingGreenlet): - hero_deadpond.name + hero_deadpond.name # noqa: B018 name = await hero_deadpond.awt_name assert name == "Deadpond" @@ -69,7 +68,7 @@ class Hero(AsyncSQLModel, table=True): # loading lazy loading attribute will raise MissingGreenlet error with pytest.raises(MissingGreenlet): - hero.team + hero.team # noqa: B018 team = await hero.awt_team assert team From 2d9af2f8e3fc04172120ef062e8c8378e0bca708 Mon Sep 17 00:00:00 2001 From: 2jun0 Date: Tue, 2 Apr 2024 01:11:16 +0900 Subject: [PATCH 4/6] Resolve lint error --- sqlmodel/ext/asyncio/async_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/ext/asyncio/async_model.py b/sqlmodel/ext/asyncio/async_model.py index 4646c16286..118454a11b 100644 --- a/sqlmodel/ext/asyncio/async_model.py +++ b/sqlmodel/ext/asyncio/async_model.py @@ -56,11 +56,11 @@ def __init__( for field_name, field_info in cls.__async_sqlmodel_awaitable_fields__.items(): def get_awaitable_field( - self, field: str = field_info.field + self: "AsyncSQLModel", field: str = field_info.field ) -> Coroutine[Any, Any, Any]: return greenlet_spawn(getattr, self, field) - setattr(cls, field_name, property(get_awaitable_field)) # type: ignore + setattr(cls, field_name, property(get_awaitable_field)) SQLModelMetaclass.__init__(cls, classname, bases, dict_, **kw) From 548f0050b3df31e0206d9157b2715912ff7ca643 Mon Sep 17 00:00:00 2001 From: 2jun0 Date: Tue, 2 Apr 2024 01:34:24 +0900 Subject: [PATCH 5/6] Move test to the sub folder for test order --- tests/test_advanced/test_async/__init__.py | 0 tests/{ => test_advanced/test_async}/test_awaitable_field.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/test_advanced/test_async/__init__.py rename tests/{ => test_advanced/test_async}/test_awaitable_field.py (100%) diff --git a/tests/test_advanced/test_async/__init__.py b/tests/test_advanced/test_async/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_awaitable_field.py b/tests/test_advanced/test_async/test_awaitable_field.py similarity index 100% rename from tests/test_awaitable_field.py rename to tests/test_advanced/test_async/test_awaitable_field.py From 1d7f73e03ab0440d81719b7ce0a1398296acc1ab Mon Sep 17 00:00:00 2001 From: 2jun0 Date: Tue, 2 Apr 2024 01:37:44 +0900 Subject: [PATCH 6/6] Fix pydantic v1 import error --- sqlmodel/ext/asyncio/async_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sqlmodel/ext/asyncio/async_model.py b/sqlmodel/ext/asyncio/async_model.py index 118454a11b..88c8ee8db2 100644 --- a/sqlmodel/ext/asyncio/async_model.py +++ b/sqlmodel/ext/asyncio/async_model.py @@ -1,10 +1,9 @@ from typing import Any, ClassVar, Coroutine, Dict, Tuple, Type -from pydantic._internal._repr import Representation from sqlalchemy.util.concurrency import greenlet_spawn from ... import SQLModel -from ..._compat import get_annotations +from ..._compat import Representation, get_annotations from ...main import SQLModelMetaclass