Skip to content

Commit 7857664

Browse files
add typed get_substate variant, improve modify_states typing
1 parent c93441b commit 7857664

3 files changed

Lines changed: 46 additions & 18 deletions

File tree

reflex/app.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import traceback
1717
from datetime import datetime
1818
from pathlib import Path
19+
from types import NoneType
1920
from typing import (
2021
Any,
2122
AsyncIterator,
@@ -26,9 +27,11 @@
2627
Optional,
2728
Set,
2829
Type,
30+
TypeVar,
2931
Union,
3032
get_args,
3133
get_type_hints,
34+
overload,
3235
)
3336

3437
from fastapi import FastAPI, HTTPException, Request, UploadFile
@@ -1101,11 +1104,23 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
11011104
sid=state.router.session.session_id,
11021105
)
11031106

1107+
S = TypeVar("S", bound=BaseState)
1108+
1109+
@overload
1110+
async def modify_states(
1111+
self, substate_cls: Type[S], from_state: NoneType
1112+
) -> AsyncIterator[S]: ...
1113+
1114+
@overload
1115+
async def modify_states(
1116+
self, substate_cls: NoneType, from_state: BaseState
1117+
) -> AsyncIterator[BaseState]: ...
1118+
11041119
async def modify_states(
11051120
self,
1106-
substate_cls: Type[BaseState] | None = None,
1121+
substate_cls: Type[S] | Type[BaseState] | None = None,
11071122
from_state: BaseState | None = None,
1108-
) -> AsyncIterator[BaseState]:
1123+
) -> AsyncIterator[S] | AsyncIterator[BaseState]:
11091124
"""Iterate over the states.
11101125
11111126
Args:
@@ -1128,11 +1143,11 @@ async def modify_states(
11281143
from_state is not None
11291144
and from_state.router.session.client_token == token
11301145
):
1131-
yield from_state
1146+
state = from_state
11321147
continue
11331148
async with self.modify_state(token) as state:
11341149
if substate_cls is not None:
1135-
state = state.get_substate(substate_cls.get_name())
1150+
state = state.get_substate(substate_cls)
11361151
yield state
11371152

11381153
def _process_background(

reflex/state.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
Sequence,
2525
Set,
2626
Type,
27+
TypeVar,
2728
Union,
2829
cast,
30+
overload,
2931
)
3032

3133
import dill
@@ -291,6 +293,9 @@ def __call__(self, *args: Any) -> EventSpec:
291293
return super().__call__(*args)
292294

293295

296+
S = TypeVar("S", bound="BaseState")
297+
298+
294299
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
295300
"""The state of the app."""
296301

@@ -1151,18 +1156,28 @@ def _reset_client_storage(self):
11511156
for substate in self.substates.values():
11521157
substate._reset_client_storage()
11531158

1154-
def get_substate(self, path: Sequence[str]) -> BaseState:
1159+
@overload
1160+
def get_substate(self, path: Sequence[str]) -> BaseState: ...
1161+
1162+
@overload
1163+
def get_substate(self, path: Type[S]) -> S: ...
1164+
1165+
def get_substate(self, path: Sequence[str] | Type[S]) -> BaseState | S:
11551166
"""Get the substate.
11561167
11571168
Args:
1158-
path: The path to the substate.
1169+
path: The path to the substate or the class of the substate.
11591170
11601171
Returns:
11611172
The substate.
11621173
11631174
Raises:
11641175
ValueError: If the substate is not found.
11651176
"""
1177+
if isinstance(path, type):
1178+
path = (
1179+
path.get_full_name().removeprefix(f"{self.get_full_name()}.").split(".")
1180+
)
11661181
if len(path) == 0:
11671182
return self
11681183
if path[0] == self.get_name():
@@ -1295,7 +1310,7 @@ def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
12951310
root_state = self
12961311
else:
12971312
root_state = self._get_parent_states()[-1][1]
1298-
return root_state.get_substate(state_cls.get_full_name().split("."))
1313+
return root_state.get_substate(state_cls)
12991314

13001315
async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
13011316
"""Get a state instance from redis.

tests/test_state.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def child_state(test_state) -> ChildState:
218218
Returns:
219219
A test child state.
220220
"""
221-
child_state = test_state.get_substate([ChildState.get_name()])
221+
child_state = test_state.get_substate(ChildState)
222222
assert child_state is not None
223223
return child_state
224224

@@ -233,7 +233,7 @@ def child_state2(test_state) -> ChildState2:
233233
Returns:
234234
A second test child state.
235235
"""
236-
child_state2 = test_state.get_substate([ChildState2.get_name()])
236+
child_state2 = test_state.get_substate(ChildState2)
237237
assert child_state2 is not None
238238
return child_state2
239239

@@ -248,7 +248,7 @@ def grandchild_state(child_state) -> GrandchildState:
248248
Returns:
249249
A test state.
250250
"""
251-
grandchild_state = child_state.get_substate([GrandchildState.get_name()])
251+
grandchild_state = child_state.get_substate(GrandchildState)
252252
assert grandchild_state is not None
253253
return grandchild_state
254254

@@ -1183,7 +1183,7 @@ def set_v4(self, v: int):
11831183
assert ms.v == 2
11841184

11851185
# ensure handler can be called from substate (referencing grandparent handler)
1186-
ms.get_substate(tuple(SubSubState.get_full_name().split("."))).set_v4(3)
1186+
ms.get_substate(SubSubState).set_v4(3)
11871187
assert ms.v == 3
11881188

11891189

@@ -2854,7 +2854,7 @@ async def test_get_state(mock_app: rx.App, token: str):
28542854
)
28552855

28562856
# Get the child_state2 directly.
2857-
child_state2_direct = test_state.get_substate([ChildState2.get_name()])
2857+
child_state2_direct = test_state.get_substate(ChildState2)
28582858
child_state2_get_state = await test_state.get_state(ChildState2)
28592859
# These should be the same object.
28602860
assert child_state2_direct is child_state2_get_state
@@ -2871,15 +2871,13 @@ async def test_get_state(mock_app: rx.App, token: str):
28712871
)
28722872

28732873
# ChildState should be retrievable
2874-
child_state_direct = test_state.get_substate([ChildState.get_name()])
2874+
child_state_direct = test_state.get_substate(ChildState)
28752875
child_state_get_state = await test_state.get_state(ChildState)
28762876
# These should be the same object.
28772877
assert child_state_direct is child_state_get_state
28782878

28792879
# GrandchildState instance should be the same as the one retrieved from the child_state2.
2880-
assert grandchild_state is child_state_direct.get_substate(
2881-
[GrandchildState.get_name()]
2882-
)
2880+
assert grandchild_state is child_state_direct.get_substate(GrandchildState)
28832881
grandchild_state.value2 = "set_value"
28842882

28852883
assert test_state.get_delta() == {
@@ -2920,7 +2918,7 @@ async def test_get_state(mock_app: rx.App, token: str):
29202918
)
29212919

29222920
# Set a value on child_state2, should update cached var in grandchild_state2
2923-
child_state2 = new_test_state.get_substate((ChildState2.get_name(),))
2921+
child_state2 = new_test_state.get_substate(ChildState2)
29242922
child_state2.value = "set_c2_value"
29252923

29262924
assert new_test_state.get_delta() == {
@@ -3015,7 +3013,7 @@ class GreatGrandchild3(Grandchild3):
30153013
assert Child3.get_name() in root.substates # (due to @rx.var)
30163014

30173015
# Get the unconnected sibling state, which will be used to `get_state` other instances.
3018-
child = root.get_substate(Child.get_full_name().split("."))
3016+
child = root.get_substate(Child)
30193017

30203018
# Get an uncached child state.
30213019
child2 = await child.get_state(Child2)

0 commit comments

Comments
 (0)