1010import json
1111from collections.abc import Callable, Sequence
1212from types import MethodType
13- from typing import TYPE_CHECKING, Any, SupportsIndex
13+ from typing import TYPE_CHECKING, Any, SupportsIndex, TypeVar
1414
1515import pydantic
1616import wrapt
2727if TYPE_CHECKING:
2828 from reflex.state import BaseState, StateUpdate
2929
30+ T_STATE = TypeVar("T_STATE", bound="BaseState")
31+
3032
3133class StateProxy(wrapt.ObjectProxy):
3234 """Proxy of a state instance to control mutability of vars for a background task.
@@ -269,7 +271,7 @@ def get_substate(self, path: Sequence[str]) -> BaseState:
269271 raise ImmutableStateError(msg)
270272 return self.__wrapped__.get_substate(path)
271273
272- async def get_state(self, state_cls: type[BaseState ]) -> BaseState :
274+ async def get_state(self, state_cls: type[T_STATE ]) -> T_STATE :
273275 """Get an instance of the state associated with this token.
274276
275277 Args:
@@ -289,7 +291,7 @@ async def get_state(self, state_cls: type[BaseState]) -> BaseState:
289291 raise ImmutableStateError(msg)
290292 return type(self)(
291293 await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
292- )
294+ ) # pyright: ignore [reportReturnType]
293295
294296 async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
295297 """Temporarily allow mutability to access parent_state.
0 commit comments