Skip to content

Commit bf58975

Browse files
committed
add tests
1 parent 34073f2 commit bf58975

2 files changed

Lines changed: 583 additions & 0 deletions

File tree

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
"""Tests for dependency tracking functionality."""
2+
3+
from __future__ import annotations
4+
5+
import pytest
6+
7+
import reflex as rx
8+
from reflex.state import State
9+
from reflex.utils.exceptions import VarValueError
10+
from reflex.vars.dep_tracking import DependencyTracker, get_cell_value
11+
12+
13+
class DependencyTestState(State):
14+
"""Test state for dependency tracking tests."""
15+
16+
count: rx.Field[int] = rx.field(default=0)
17+
name: rx.Field[str] = rx.field(default="test")
18+
items: rx.Field[list[str]] = rx.field(default_factory=list)
19+
20+
21+
class AnotherTestState(State):
22+
"""Another test state for cross-state dependencies."""
23+
24+
value: rx.Field[int] = rx.field(default=42)
25+
text: rx.Field[str] = rx.field(default="hello")
26+
27+
28+
def test_simple_attribute_access():
29+
"""Test tracking simple attribute access on self."""
30+
31+
def simple_func(self: DependencyTestState):
32+
return self.count
33+
34+
tracker = DependencyTracker(simple_func, DependencyTestState)
35+
36+
expected_deps = {DependencyTestState.get_full_name(): {"count"}}
37+
assert tracker.dependencies == expected_deps
38+
39+
40+
def test_multiple_attribute_access():
41+
"""Test tracking multiple attribute access on self."""
42+
43+
def multi_attr_func(self: DependencyTestState):
44+
return self.count + len(self.name) + len(self.items)
45+
46+
tracker = DependencyTracker(multi_attr_func, DependencyTestState)
47+
48+
expected_deps = {DependencyTestState.get_full_name(): {"count", "name", "items"}}
49+
assert tracker.dependencies == expected_deps
50+
51+
52+
def test_method_call_dependencies():
53+
"""Test tracking dependencies from method calls."""
54+
55+
class StateWithMethod(State):
56+
value: int = 0
57+
58+
def helper_method(self):
59+
return self.value * 2
60+
61+
def func_with_method_call(self):
62+
return self.helper_method()
63+
64+
tracker = DependencyTracker(StateWithMethod.func_with_method_call, StateWithMethod)
65+
66+
# Should track dependencies from both the method call and the method itself
67+
expected_deps = {StateWithMethod.get_full_name(): {"value"}}
68+
assert tracker.dependencies == expected_deps
69+
70+
71+
def test_nested_function_dependencies():
72+
"""Test tracking dependencies in nested functions."""
73+
74+
def func_with_nested(self: DependencyTestState):
75+
def inner():
76+
return self.count
77+
78+
return inner()
79+
80+
tracker = DependencyTracker(func_with_nested, DependencyTestState)
81+
82+
expected_deps = {DependencyTestState.get_full_name(): {"count"}}
83+
assert tracker.dependencies == expected_deps
84+
85+
86+
def test_list_comprehension_dependencies():
87+
"""Test tracking dependencies in list comprehensions."""
88+
89+
def func_with_comprehension(self: DependencyTestState):
90+
return [x for x in self.items if len(x) > self.count]
91+
92+
tracker = DependencyTracker(func_with_comprehension, DependencyTestState)
93+
94+
expected_deps = {DependencyTestState.get_full_name(): {"items", "count"}}
95+
assert tracker.dependencies == expected_deps
96+
97+
98+
def test_invalid_attribute_access():
99+
"""Test that accessing invalid attributes raises VarValueError."""
100+
101+
def invalid_func(self: DependencyTestState):
102+
return self.parent_state
103+
104+
with pytest.raises(
105+
VarValueError, match="cannot access arbitrary state via `parent_state`"
106+
):
107+
DependencyTracker(invalid_func, DependencyTestState)
108+
109+
110+
def test_get_state_functionality():
111+
"""Test tracking dependencies when using get_state."""
112+
113+
async def func_with_get_state(self: DependencyTestState):
114+
other_state = await self.get_state(AnotherTestState)
115+
return other_state.value
116+
117+
tracker = DependencyTracker(func_with_get_state, DependencyTestState)
118+
119+
expected_deps = {AnotherTestState.get_full_name(): {"value"}}
120+
assert tracker.dependencies == expected_deps
121+
122+
123+
def test_get_state_with_local_var_error():
124+
"""Test that get_state with local variables raises appropriate error."""
125+
126+
async def invalid_get_state_func(self: DependencyTestState):
127+
state_cls = AnotherTestState
128+
return (await self.get_state(state_cls)).value
129+
130+
with pytest.raises(
131+
VarValueError, match="cannot identify get_state class from local var"
132+
):
133+
DependencyTracker(invalid_get_state_func, DependencyTestState)
134+
135+
136+
def test_get_var_value_functionality():
137+
"""Test tracking dependencies when using get_var_value."""
138+
139+
async def func_with_get_var_value(self: DependencyTestState):
140+
return await self.get_var_value(DependencyTestState.count)
141+
142+
tracker = DependencyTracker(func_with_get_var_value, DependencyTestState)
143+
expected_deps = {DependencyTestState.get_full_name(): {"count"}}
144+
assert tracker.dependencies == expected_deps
145+
146+
147+
def test_merge_deps():
148+
"""Test merging dependencies from multiple trackers."""
149+
150+
def func1(self: DependencyTestState):
151+
return self.count
152+
153+
def func2(self: DependencyTestState):
154+
return self.name
155+
156+
tracker1 = DependencyTracker(func1, DependencyTestState)
157+
tracker2 = DependencyTracker(func2, DependencyTestState)
158+
159+
tracker1._merge_deps(tracker2)
160+
161+
expected_deps = {DependencyTestState.get_full_name(): {"count", "name"}}
162+
assert tracker1.dependencies == expected_deps
163+
164+
165+
def test_get_globals_with_function():
166+
"""Test _get_globals method with a function."""
167+
168+
def test_func(self: DependencyTestState):
169+
return self.count
170+
171+
tracker = DependencyTracker(test_func, DependencyTestState)
172+
globals_dict = tracker._get_globals()
173+
174+
assert isinstance(globals_dict, dict)
175+
assert "DependencyTestState" in globals_dict
176+
assert "State" in globals_dict
177+
178+
179+
def test_get_globals_with_code_object():
180+
"""Test _get_globals method with a code object."""
181+
182+
def test_func(self: DependencyTestState):
183+
return self.count
184+
185+
code_obj = test_func.__code__
186+
tracker = DependencyTracker(code_obj, DependencyTestState)
187+
globals_dict = tracker._get_globals()
188+
189+
assert not globals_dict
190+
191+
192+
def test_get_closure_with_function():
193+
"""Test _get_closure method with a function that has closure."""
194+
outer_var = "test"
195+
196+
def func_with_closure(self: DependencyTestState):
197+
return self.count + len(outer_var)
198+
199+
tracker = DependencyTracker(func_with_closure, DependencyTestState)
200+
closure_dict = tracker._get_closure()
201+
202+
assert isinstance(closure_dict, dict)
203+
assert "outer_var" in closure_dict
204+
assert closure_dict["outer_var"] == "test"
205+
206+
207+
def test_get_closure_with_code_object():
208+
"""Test _get_closure method with a code object."""
209+
210+
def test_func(self: DependencyTestState):
211+
return self.count
212+
213+
code_obj = test_func.__code__
214+
tracker = DependencyTracker(code_obj, DependencyTestState)
215+
closure_dict = tracker._get_closure()
216+
217+
assert not closure_dict
218+
219+
220+
def test_property_dependencies():
221+
"""Test tracking dependencies through property access."""
222+
223+
class StateWithProperty(State):
224+
_value: int = 0
225+
226+
def computed_value(self) -> int:
227+
return self._value * 2
228+
229+
def func_with_property(self):
230+
return self.computed_value
231+
232+
tracker = DependencyTracker(StateWithProperty.func_with_property, StateWithProperty)
233+
234+
# Should track dependencies from the property getter
235+
expected_deps = {StateWithProperty.get_full_name(): {"_value"}}
236+
assert tracker.dependencies == expected_deps
237+
238+
239+
def test_no_dependencies():
240+
"""Test functions with no state dependencies."""
241+
242+
def func_no_deps(self: DependencyTestState):
243+
return 42
244+
245+
tracker = DependencyTracker(func_no_deps, DependencyTestState)
246+
247+
assert not tracker.dependencies
248+
249+
250+
def test_complex_expression_dependencies():
251+
"""Test tracking dependencies in complex expressions."""
252+
253+
def complex_func(self: DependencyTestState):
254+
return (self.count * 2 + len(self.name)) if self.items else 0
255+
256+
tracker = DependencyTracker(complex_func, DependencyTestState)
257+
258+
expected_deps = {DependencyTestState.get_full_name(): {"count", "name", "items"}}
259+
assert tracker.dependencies == expected_deps
260+
261+
262+
def test_get_cell_value_with_valid_cell():
263+
"""Test get_cell_value with a valid cell containing a value."""
264+
# Create a closure to get a cell object
265+
value = "test_value"
266+
267+
def outer():
268+
def inner():
269+
return value
270+
271+
return inner
272+
273+
inner_func = outer()
274+
275+
assert inner_func.__closure__ is not None
276+
277+
cell = inner_func.__closure__[0]
278+
result = get_cell_value(cell)
279+
assert result == "test_value"
280+
281+
282+
def test_cross_state_dependencies_complex():
283+
"""Test complex cross-state dependency scenarios."""
284+
285+
class StateA(State):
286+
value_a: int = 1
287+
288+
class StateB(State):
289+
value_b: int = 2
290+
291+
async def get_a_value(self):
292+
return (await self.get_state(StateA)).value_a
293+
294+
async def complex_cross_state_func(self: DependencyTestState):
295+
state_a = await self.get_state(StateA)
296+
state_b = await self.get_state(StateB)
297+
return state_a.value_a + state_b.value_b
298+
299+
tracker = DependencyTracker(complex_cross_state_func, DependencyTestState)
300+
301+
expected_deps = {
302+
StateA.get_full_name(): {"value_a"},
303+
StateB.get_full_name(): {"value_b"},
304+
}
305+
assert tracker.dependencies == expected_deps
306+
307+
308+
def test_lambda_function_dependencies():
309+
"""Test tracking dependencies in lambda functions."""
310+
311+
def lambda_func(self: DependencyTestState):
312+
return self.count * 2
313+
314+
tracker = DependencyTracker(lambda_func, DependencyTestState)
315+
316+
expected_deps = {DependencyTestState.get_full_name(): {"count"}}
317+
assert tracker.dependencies == expected_deps
318+
319+
320+
def test_dependencies_with_computed_var():
321+
"""Test that computed vars are handled correctly in dependency tracking."""
322+
323+
class StateWithComputedVar(State):
324+
base_value: int = 0
325+
326+
@rx.var
327+
def computed_value(self) -> int:
328+
return self.base_value * 2
329+
330+
def func_using_computed_var(self: StateWithComputedVar):
331+
return self.computed_value
332+
333+
tracker = DependencyTracker(func_using_computed_var, StateWithComputedVar)
334+
335+
# Should track the computed var, not its dependencies
336+
expected_deps = {StateWithComputedVar.get_full_name(): {"computed_value"}}
337+
assert tracker.dependencies == expected_deps

0 commit comments

Comments
 (0)