Skip to content

Commit f4295df

Browse files
add typed get_substate variant, improve modify_states typing
1 parent c93441b commit f4295df

File tree

4 files changed

+45
-21
lines changed

4 files changed

+45
-21
lines changed

reflex/app.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
Optional,
2727
Set,
2828
Type,
29+
TypeVar,
2930
Union,
3031
get_args,
3132
get_type_hints,
33+
overload,
3234
)
3335

3436
from fastapi import FastAPI, HTTPException, Request, UploadFile
@@ -1101,11 +1103,23 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
11011103
sid=state.router.session.session_id,
11021104
)
11031105

1106+
S = TypeVar("S", bound=BaseState)
1107+
1108+
@overload
1109+
async def modify_states(
1110+
self, substate_cls: Type[S], from_state: None
1111+
) -> AsyncIterator[S]: ...
1112+
1113+
@overload
1114+
async def modify_states(
1115+
self, substate_cls: None, from_state: BaseState
1116+
) -> AsyncIterator[BaseState]: ...
1117+
11041118
async def modify_states(
11051119
self,
1106-
substate_cls: Type[BaseState] | None = None,
1120+
substate_cls: Type[S] | Type[BaseState] | None = None,
11071121
from_state: BaseState | None = None,
1108-
) -> AsyncIterator[BaseState]:
1122+
) -> AsyncIterator[S] | AsyncIterator[BaseState]:
11091123
"""Iterate over the states.
11101124
11111125
Args:
@@ -1128,11 +1142,11 @@ async def modify_states(
11281142
from_state is not None
11291143
and from_state.router.session.client_token == token
11301144
):
1131-
yield from_state
1145+
state = from_state
11321146
continue
11331147
async with self.modify_state(token) as state:
11341148
if substate_cls is not None:
1135-
state = state.get_substate(substate_cls.get_name())
1149+
state = state.get_substate(substate_cls)
11361150
yield state
11371151

11381152
def _process_background(

reflex/state.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
Sequence,
2525
Set,
2626
Type,
27+
TypeVar,
2728
Union,
2829
cast,
30+
overload,
2931
)
3032

3133
import dill
@@ -291,6 +293,9 @@ def __call__(self, *args: Any) -> EventSpec:
291293
return super().__call__(*args)
292294

293295

296+
S = TypeVar("S", bound="BaseState")
297+
298+
294299
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
295300
"""The state of the app."""
296301

@@ -1151,18 +1156,28 @@ def _reset_client_storage(self):
11511156
for substate in self.substates.values():
11521157
substate._reset_client_storage()
11531158

1154-
def get_substate(self, path: Sequence[str]) -> BaseState:
1159+
@overload
1160+
def get_substate(self, path: Sequence[str]) -> BaseState: ...
1161+
1162+
@overload
1163+
def get_substate(self, path: Type[S]) -> S: ...
1164+
1165+
def get_substate(self, path: Sequence[str] | Type[S]) -> BaseState | S:
11551166
"""Get the substate.
11561167
11571168
Args:
1158-
path: The path to the substate.
1169+
path: The path to the substate or the class of the substate.
11591170
11601171
Returns:
11611172
The substate.
11621173
11631174
Raises:
11641175
ValueError: If the substate is not found.
11651176
"""
1177+
if isinstance(path, type):
1178+
path = (
1179+
path.get_full_name().removeprefix(f"{self.get_full_name()}.").split(".")
1180+
)
11661181
if len(path) == 0:
11671182
return self
11681183
if path[0] == self.get_name():
@@ -1295,7 +1310,7 @@ def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
12951310
root_state = self
12961311
else:
12971312
root_state = self._get_parent_states()[-1][1]
1298-
return root_state.get_substate(state_cls.get_full_name().split("."))
1313+
return root_state.get_substate(state_cls)
12991314

13001315
async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
13011316
"""Get a state instance from redis.

reflex/utils/prerequisites.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,6 @@ def get_app(reload: bool = False) -> App:
287287
288288
Returns:
289289
The app based on the default config.
290-
291-
Raises:
292-
RuntimeError: If the app name is not set in the config.
293290
"""
294291
return getattr(get_app_module(reload=reload), constants.CompileVars.APP)
295292

tests/test_state.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def child_state(test_state) -> ChildState:
218218
Returns:
219219
A test child state.
220220
"""
221-
child_state = test_state.get_substate([ChildState.get_name()])
221+
child_state = test_state.get_substate(ChildState)
222222
assert child_state is not None
223223
return child_state
224224

@@ -233,7 +233,7 @@ def child_state2(test_state) -> ChildState2:
233233
Returns:
234234
A second test child state.
235235
"""
236-
child_state2 = test_state.get_substate([ChildState2.get_name()])
236+
child_state2 = test_state.get_substate(ChildState2)
237237
assert child_state2 is not None
238238
return child_state2
239239

@@ -248,7 +248,7 @@ def grandchild_state(child_state) -> GrandchildState:
248248
Returns:
249249
A test state.
250250
"""
251-
grandchild_state = child_state.get_substate([GrandchildState.get_name()])
251+
grandchild_state = child_state.get_substate(GrandchildState)
252252
assert grandchild_state is not None
253253
return grandchild_state
254254

@@ -1183,7 +1183,7 @@ def set_v4(self, v: int):
11831183
assert ms.v == 2
11841184

11851185
# ensure handler can be called from substate (referencing grandparent handler)
1186-
ms.get_substate(tuple(SubSubState.get_full_name().split("."))).set_v4(3)
1186+
ms.get_substate(SubSubState).set_v4(3)
11871187
assert ms.v == 3
11881188

11891189

@@ -2854,7 +2854,7 @@ async def test_get_state(mock_app: rx.App, token: str):
28542854
)
28552855

28562856
# Get the child_state2 directly.
2857-
child_state2_direct = test_state.get_substate([ChildState2.get_name()])
2857+
child_state2_direct = test_state.get_substate(ChildState2)
28582858
child_state2_get_state = await test_state.get_state(ChildState2)
28592859
# These should be the same object.
28602860
assert child_state2_direct is child_state2_get_state
@@ -2871,15 +2871,13 @@ async def test_get_state(mock_app: rx.App, token: str):
28712871
)
28722872

28732873
# ChildState should be retrievable
2874-
child_state_direct = test_state.get_substate([ChildState.get_name()])
2874+
child_state_direct = test_state.get_substate(ChildState)
28752875
child_state_get_state = await test_state.get_state(ChildState)
28762876
# These should be the same object.
28772877
assert child_state_direct is child_state_get_state
28782878

28792879
# GrandchildState instance should be the same as the one retrieved from the child_state2.
2880-
assert grandchild_state is child_state_direct.get_substate(
2881-
[GrandchildState.get_name()]
2882-
)
2880+
assert grandchild_state is child_state_direct.get_substate(GrandchildState)
28832881
grandchild_state.value2 = "set_value"
28842882

28852883
assert test_state.get_delta() == {
@@ -2920,7 +2918,7 @@ async def test_get_state(mock_app: rx.App, token: str):
29202918
)
29212919

29222920
# Set a value on child_state2, should update cached var in grandchild_state2
2923-
child_state2 = new_test_state.get_substate((ChildState2.get_name(),))
2921+
child_state2 = new_test_state.get_substate(ChildState2)
29242922
child_state2.value = "set_c2_value"
29252923

29262924
assert new_test_state.get_delta() == {
@@ -3015,7 +3013,7 @@ class GreatGrandchild3(Grandchild3):
30153013
assert Child3.get_name() in root.substates # (due to @rx.var)
30163014

30173015
# Get the unconnected sibling state, which will be used to `get_state` other instances.
3018-
child = root.get_substate(Child.get_full_name().split("."))
3016+
child = root.get_substate(Child)
30193017

30203018
# Get an uncached child state.
30213019
child2 = await child.get_state(Child2)

0 commit comments

Comments
 (0)