Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 163 additions & 22 deletions reflex/vars/dep_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import dataclasses
import dis
import enum
import importlib
import inspect
import sys
from types import CellType, CodeType, FunctionType
from types import CellType, CodeType, FunctionType, ModuleType
from typing import TYPE_CHECKING, Any, ClassVar, cast

from reflex.utils.exceptions import VarValueError
Expand Down Expand Up @@ -43,9 +44,38 @@ class ScanStatus(enum.Enum):
SCANNING = enum.auto()
GETTING_ATTR = enum.auto()
GETTING_STATE = enum.auto()
GETTING_STATE_POST_AWAIT = enum.auto()
GETTING_VAR = enum.auto()


class UntrackedLocalVarError(VarValueError):
"""Raised when a local variable is referenced, but it is not tracked in the current scope."""


def assert_base_state(
local_value: Any,
local_name: str | None = None,
) -> type[BaseState]:
"""Assert that a local variable is a BaseState subclass.

Args:
local_value: The value of the local variable to check.
local_name: The name of the local variable to check.

Returns:
The local variable value if it is a BaseState subclass.

Raises:
VarValueError: If the object is not a BaseState subclass.
"""
from reflex.state import BaseState

if not isinstance(local_value, type) or not issubclass(local_value, BaseState):
msg = f"Cannot determine dependencies in fetched state {local_name!r}: {local_value!r} is not a BaseState."
raise VarValueError(msg)
return local_value


@dataclasses.dataclass
class DependencyTracker:
"""State machine for identifying state attributes that are accessed by a function."""
Expand All @@ -58,10 +88,15 @@ class DependencyTracker:
scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
top_of_stack: str | None = dataclasses.field(default=None)

tracked_locals: dict[str, type[BaseState]] = dataclasses.field(default_factory=dict)
tracked_locals: dict[str, type[BaseState] | ModuleType] = dataclasses.field(
default_factory=dict
)

_getting_state_class: type[BaseState] | None = dataclasses.field(default=None)
_getting_state_class: type[BaseState] | ModuleType | None = dataclasses.field(
default=None
)
_get_var_value_positions: dis.Positions | None = dataclasses.field(default=None)
_last_import_name: str | None = dataclasses.field(default=None)

INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]

Expand Down Expand Up @@ -90,6 +125,26 @@ def _merge_deps(self, tracker: DependencyTracker) -> None:
for state_name, dep_name in tracker.dependencies.items():
self.dependencies.setdefault(state_name, set()).update(dep_name)

def get_tracked_local(self, local_name: str) -> type[BaseState] | ModuleType:
"""Get the value of a local name tracked in the current function scope.

Args:
local_name: The name of the local variable to fetch.

Returns:
The value of local name tracked in the current scope (a referenced
BaseState subclass or imported module).

Raises:
UntrackedLocalVarError: If the local variable is not being tracked.
"""
try:
local_value = self.tracked_locals[local_name]
except KeyError as ke:
msg = f"{local_name!r} is not tracked in the current scope."
raise UntrackedLocalVarError(msg) from ke
return local_value

def load_attr_or_method(self, instruction: dis.Instruction) -> None:
"""Handle loading an attribute or method from the object on top of the stack.

Expand All @@ -100,7 +155,8 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None:
instruction: The dis instruction to process.

Raises:
VarValueError: if the attribute is an disallowed name.
VarValueError: if the attribute is an disallowed name or attribute
does not reference a BaseState.
"""
from .base import ComputedVar

Expand All @@ -122,7 +178,8 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None:
self.scan_status = ScanStatus.SCANNING
if not self.top_of_stack:
return
target_state = self.tracked_locals[self.top_of_stack]
target_obj = self.get_tracked_local(self.top_of_stack)
target_state = assert_base_state(target_obj, local_name=self.top_of_stack)
try:
ref_obj = getattr(target_state, instruction.argval)
except AttributeError:
Expand Down Expand Up @@ -190,15 +247,14 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None:
Raises:
VarValueError: if the state class cannot be determined from the instruction.
"""
from reflex.state import BaseState

if instruction.opname in ("LOAD_FAST", "LOAD_FAST_BORROW"):
msg = f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
raise VarValueError(msg)
if isinstance(self.func, CodeType):
msg = "Dependency detection cannot identify get_state class from a code object."
raise VarValueError(msg)
if instruction.opname == "LOAD_GLOBAL":
if instruction.opname in ("LOAD_FAST", "LOAD_FAST_BORROW"):
self._getting_state_class = self.get_tracked_local(
local_name=instruction.argval,
)
elif instruction.opname == "LOAD_GLOBAL":
# Special case: referencing state class from global scope.
try:
self._getting_state_class = self._get_globals()[instruction.argval]
Expand All @@ -212,16 +268,43 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None:
except (ValueError, KeyError) as ve:
msg = f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
raise VarValueError(msg) from ve
elif instruction.opname == "STORE_FAST":
elif instruction.opname in ("LOAD_ATTR", "LOAD_METHOD"):
self._getting_state_class = getattr(
self._getting_state_class,
instruction.argval,
)
elif instruction.opname == "GET_AWAITABLE":
# Now inside the `await` machinery, subsequent instructions
# operate on the result of the `get_state` call.
self.scan_status = ScanStatus.GETTING_STATE_POST_AWAIT
if self._getting_state_class is not None:
self.top_of_stack = "_"
self.tracked_locals[self.top_of_stack] = self._getting_state_class
self._getting_state_class = None

def handle_getting_state_post_await(self, instruction: dis.Instruction) -> None:
"""Handle bytecode analysis after `get_state` was called in the function.

This function is called _after_ awaiting self.get_state to capture the
local variable holding the state instance or directly record access to
attributes accessed on the result of get_state.

Args:
instruction: The dis instruction to process.

Raises:
VarValueError: if the state class cannot be determined from the instruction.
"""
if instruction.opname == "STORE_FAST" and self.top_of_stack:
# Storing the result of get_state in a local variable.
if not isinstance(self._getting_state_class, type) or not issubclass(
self._getting_state_class, BaseState
):
msg = f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
raise VarValueError(msg)
self.tracked_locals[instruction.argval] = self._getting_state_class
self.tracked_locals[instruction.argval] = self.tracked_locals.pop(
self.top_of_stack
)
self.top_of_stack = None
self.scan_status = ScanStatus.SCANNING
self._getting_state_class = None
elif instruction.opname in ("LOAD_ATTR", "LOAD_METHOD"):
# Attribute access on an inline `get_state`, not assigned to a variable.
self.load_attr_or_method(instruction)

def _eval_var(self, positions: dis.Positions) -> Var:
"""Evaluate instructions from the wrapped function to get the Var object.
Expand Down Expand Up @@ -262,8 +345,12 @@ def _eval_var(self, positions: dis.Positions) -> Var:
])
else:
snipped_source = source[0][start_column:end_column]
# Evaluate the string in the context of the function's globals and closure.
return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
# Evaluate the string in the context of the function's globals, closure and tracked local scope.
return eval(
f"({snipped_source})",
self._get_globals(),
{**self._get_closure(), **self.tracked_locals},
)

def handle_getting_var(self, instruction: dis.Instruction) -> None:
"""Handle bytecode analysis when `get_var_value` was called in the function.
Expand Down Expand Up @@ -304,16 +391,38 @@ def _populate_dependencies(self) -> None:
for instruction in dis.get_instructions(self.func):
if self.scan_status == ScanStatus.GETTING_STATE:
self.handle_getting_state(instruction)
elif self.scan_status == ScanStatus.GETTING_STATE_POST_AWAIT:
self.handle_getting_state_post_await(instruction)
elif self.scan_status == ScanStatus.GETTING_VAR:
self.handle_getting_var(instruction)
elif (
instruction.opname in ("LOAD_FAST", "LOAD_DEREF", "LOAD_FAST_BORROW")
instruction.opname
in (
"LOAD_FAST",
"LOAD_DEREF",
"LOAD_FAST_BORROW",
"LOAD_FAST_CHECK",
"LOAD_FAST_AND_CLEAR",
)
and instruction.argval in self.tracked_locals
):
# bytecode loaded the class instance to the top of stack, next load instruction
# is referencing an attribute on self
self.top_of_stack = instruction.argval
self.scan_status = ScanStatus.GETTING_ATTR
elif (
instruction.opname
in (
"LOAD_FAST_LOAD_FAST",
"LOAD_FAST_BORROW_LOAD_FAST_BORROW",
"STORE_FAST_LOAD_FAST",
)
and instruction.argval[-1] in self.tracked_locals
):
# Double LOAD_FAST family instructions load multiple values onto the stack,
# the last value in the argval list is the top of the stack.
self.top_of_stack = instruction.argval[-1]
self.scan_status = ScanStatus.GETTING_ATTR
elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
"LOAD_ATTR",
"LOAD_METHOD",
Expand All @@ -332,3 +441,35 @@ def _populate_dependencies(self) -> None:
tracked_locals=self.tracked_locals,
)
)
elif instruction.opname == "IMPORT_NAME" and instruction.argval is not None:
self._last_import_name = instruction.argval
importlib.import_module(instruction.argval)
top_module_name = instruction.argval.split(".")[0]
self.tracked_locals[instruction.argval] = sys.modules[top_module_name]
self.top_of_stack = instruction.argval
elif instruction.opname == "IMPORT_FROM":
if not self._last_import_name:
msg = f"Cannot find package associated with import {instruction.argval} in {self.func!r}."
raise VarValueError(msg)
if instruction.argval in self._last_import_name.split("."):
# `import ... as ...` case:
# import from interim package, update tracked_locals for the last imported name.
self.tracked_locals[self._last_import_name] = getattr(
self.tracked_locals[self._last_import_name], instruction.argval
)
continue
# Importing a name from a package/module.
if self._last_import_name is not None and self.top_of_stack:
# The full import name does NOT end up in scope for a `from ... import`.
self.tracked_locals.pop(self._last_import_name)
self.tracked_locals[instruction.argval] = getattr(
importlib.import_module(self._last_import_name),
instruction.argval,
)
# If we see a STORE_FAST, we can assign the top of stack to an aliased name.
self.top_of_stack = instruction.argval
elif instruction.opname == "STORE_FAST" and self.top_of_stack is not None:
self.tracked_locals[instruction.argval] = self.tracked_locals.pop(
self.top_of_stack
)
self.top_of_stack = None
3 changes: 3 additions & 0 deletions tests/units/states/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ def reassign_mutables(self):
"mod_third_key": {"key": "value"},
}
self.test_set = {1, 2, 3, 4, "five"}

def _get_array(self) -> list[str | int | list | dict[str, str]]:
return self.array
Loading
Loading