Skip to content

Commit 92633fb

Browse files
authored
Skip MutableProxy rebind self for _sa_instrumented functions (#6168)
* Skip MutableProxy rebind self for _sa_instrumented functions Not a general solution to the issue of cases where rebinding `self` might break descriptors, but at least fix the case for sqlalchemy models. Fix #6167 * add test case for mp/sqlalchemy rebinding
1 parent 3c11451 commit 92633fb

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

reflex/istate/proxy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,8 @@ def __getattr__(self, __name: str) -> Any:
572572
)
573573
and (func := getattr(value, "__func__", None)) is not None
574574
and not inspect.isclass(getattr(value, "__self__", None))
575+
# skip SQLAlchemy instrumented methods
576+
and not getattr(value, "_sa_instrumented", False)
575577
):
576578
# Rebind `self` to the proxy on methods to capture nested mutations.
577579
return functools.partial(func, self)

tests/units/test_model.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import reflex.model
99
from reflex.constants.state import FIELD_MARKER
1010
from reflex.model import Model, ModelRegistry
11-
from reflex.state import BaseState
11+
from reflex.state import BaseState, State
1212
from tests.units.test_state import (
1313
mock_app_simple, # noqa: F401 # for pytest.mark.usefixtures
1414
)
@@ -240,3 +240,43 @@ async def test_upcast_event_handler_arg(handler, payload):
240240
assert update.delta == {
241241
UpcastStateWithSqlAlchemy.get_full_name(): {"passed" + FIELD_MARKER: True}
242242
}
243+
244+
245+
def test_no_rebind_mutable_proxy_for_instrumented_functions():
246+
"""Test that we don't rebind mutable proxies for instrumented functions."""
247+
import sqlalchemy
248+
import sqlalchemy.orm
249+
250+
class SABase(sqlalchemy.orm.MappedAsDataclass, sqlalchemy.orm.DeclarativeBase):
251+
pass
252+
253+
class SAKeyword(SABase):
254+
__tablename__ = "sa_keyword"
255+
256+
id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column(
257+
primary_key=True, init=False, default=None
258+
)
259+
value: sqlalchemy.orm.Mapped[str] = sqlalchemy.orm.mapped_column(default="")
260+
obj_id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column(
261+
sqlalchemy.ForeignKey("sa_obj.id"), default=None
262+
)
263+
264+
class SAObj(SABase):
265+
__tablename__ = "sa_obj"
266+
267+
id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column(
268+
primary_key=True, init=False, default=None
269+
)
270+
keywords: sqlalchemy.orm.Mapped[list[SAKeyword]] = sqlalchemy.orm.relationship(
271+
lazy="selectin", # codespell:ignore
272+
cascade="all, delete",
273+
default_factory=list,
274+
)
275+
276+
class SAState(State):
277+
sa_obj: SAObj = SAObj()
278+
279+
sa_state = SAState()
280+
assert "sa_obj" not in sa_state.dirty_vars
281+
sa_state.sa_obj.keywords.append(SAKeyword(value="test"))
282+
assert "sa_obj" in sa_state.dirty_vars

0 commit comments

Comments
 (0)