Skip to content

Commit 7ecc467

Browse files
committed
chore: test_run_report is functional style
1 parent 3e45d9c commit 7ecc467

1 file changed

Lines changed: 182 additions & 171 deletions

File tree

python/tests/test_run_report.py

Lines changed: 182 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -7,178 +7,189 @@
77
from egglog.declarations import BiRewriteDecl, RewriteDecl, RuleDecl
88

99

10-
class TestRunReport:
11-
def _setup_simple_egraph(self):
12-
egraph = EGraph()
13-
14-
class Num(Expr):
15-
def __init__(self, n: i64Like) -> None: ...
16-
def __add__(self, other: Num) -> Num: ...
17-
18-
x, y = vars_("x y", Num)
19-
egraph.register(rewrite(x + y).to(y + x))
20-
egraph.register(Num(1) + Num(2))
21-
return egraph
22-
23-
def test_run_returns_report(self):
24-
egraph = self._setup_simple_egraph()
25-
report = egraph.run(10)
26-
assert type(report).__name__ == "RunReport"
27-
28-
def test_stats_returns_report(self):
29-
egraph = self._setup_simple_egraph()
30-
egraph.run(10)
31-
report = egraph.stats()
32-
assert type(report).__name__ == "RunReport"
33-
34-
def test_rule_names_translated_in_top_level_dicts(self):
35-
egraph = self._setup_simple_egraph()
36-
report = egraph.run(10)
37-
38-
for key in report.search_and_apply_time_per_rule:
39-
assert isinstance(key, RewriteDecl)
10+
def _setup_simple_egraph():
11+
egraph = EGraph()
4012

41-
for key in report.num_matches_per_rule:
42-
assert isinstance(key, RewriteDecl)
13+
class Num(Expr):
14+
def __init__(self, n: i64Like) -> None: ...
15+
def __add__(self, other: Num) -> Num: ...
16+
17+
x, y = vars_("x y", Num)
18+
egraph.register(rewrite(x + y).to(y + x))
19+
egraph.register(Num(1) + Num(2))
20+
return egraph
21+
22+
23+
def test_run_returns_report():
24+
egraph = _setup_simple_egraph()
25+
report = egraph.run(10)
26+
assert type(report).__name__ == "RunReport"
27+
28+
29+
def test_stats_returns_report():
30+
egraph = _setup_simple_egraph()
31+
egraph.run(10)
32+
report = egraph.stats()
33+
assert type(report).__name__ == "RunReport"
34+
35+
36+
def test_rule_names_translated_in_top_level_dicts():
37+
egraph = _setup_simple_egraph()
38+
report = egraph.run(10)
39+
40+
for key in report.search_and_apply_time_per_rule:
41+
assert isinstance(key, RewriteDecl)
42+
43+
for key in report.num_matches_per_rule:
44+
assert isinstance(key, RewriteDecl)
4345

44-
def test_rule_names_translated_in_iterations(self):
45-
egraph = self._setup_simple_egraph()
46-
report = egraph.run(10)
47-
48-
assert len(report.iterations) > 0
49-
for iteration in report.iterations:
50-
for key in iteration.rule_set_report.rule_reports:
51-
assert isinstance(key, RewriteDecl)
52-
53-
def test_updated_field(self):
54-
egraph = self._setup_simple_egraph()
55-
report = egraph.run(10)
56-
assert isinstance(report.updated, bool)
57-
assert report.updated is True
58-
59-
def test_num_matches(self):
60-
egraph = self._setup_simple_egraph()
61-
report = egraph.run(10)
62-
63-
total_matches = sum(report.num_matches_per_rule.values())
64-
assert total_matches > 0
65-
66-
def test_timedelta_types(self):
67-
egraph = self._setup_simple_egraph()
68-
report = egraph.run(10)
69-
70-
for v in report.search_and_apply_time_per_rule.values():
71-
assert isinstance(v, timedelta)
72-
for v in report.search_and_apply_time_per_ruleset.values():
73-
assert isinstance(v, timedelta)
74-
for v in report.merge_time_per_ruleset.values():
75-
assert isinstance(v, timedelta)
76-
for v in report.rebuild_time_per_ruleset.values():
77-
assert isinstance(v, timedelta)
78-
79-
def test_iteration_reports(self):
80-
egraph = self._setup_simple_egraph()
81-
report = egraph.run(10)
82-
83-
for it in report.iterations:
84-
assert type(it).__name__ == "IterationReport"
85-
assert type(it.rule_set_report).__name__ == "RuleSetReport"
86-
for rule_reports in it.rule_set_report.rule_reports.values():
87-
for rr in rule_reports:
88-
assert type(rr).__name__ == "RuleReport"
89-
90-
def test_str_no_egglog_sexprs(self):
91-
egraph = self._setup_simple_egraph()
92-
report = egraph.run(10)
93-
output = str(report)
94-
95-
assert "(rewrite" not in output, f"str() still contains egglog s-expressions:\n{output}"
96-
assert "__main__" not in output, f"str() still contains mangled names:\n{output}"
97-
assert "rewrite(" in output, f"Expected 'rewrite(' in:\n{output}"
98-
99-
def test_multiple_rules(self):
100-
egraph = EGraph()
101-
102-
class Math(Expr):
103-
def __init__(self, value: i64Like) -> None: ...
104-
def __add__(self, other: Math) -> Math: ...
105-
def __mul__(self, other: Math) -> Math: ...
106-
107-
a, b = vars_("a b", Math)
108-
egraph.register(
109-
rewrite(a + b).to(b + a),
110-
rewrite(a * b).to(b * a),
111-
)
112-
egraph.register(Math(1) + Math(2), Math(3) * Math(4))
113-
report = egraph.run(10)
114-
115-
# should have two distinct translated rule keys
116-
rule_keys = list(report.search_and_apply_time_per_rule.keys())
117-
assert len(rule_keys) == 2
118-
for key in rule_keys:
46+
47+
def test_rule_names_translated_in_iterations():
48+
egraph = _setup_simple_egraph()
49+
report = egraph.run(10)
50+
51+
assert len(report.iterations) > 0
52+
for iteration in report.iterations:
53+
for key in iteration.rule_set_report.rule_reports:
11954
assert isinstance(key, RewriteDecl)
12055

121-
def test_empty_run(self):
122-
egraph = EGraph()
123-
report = egraph.run(1)
124-
assert type(report).__name__ == "RunReport"
125-
assert isinstance(report.updated, bool)
126-
127-
def test_named_rule(self):
128-
egraph = EGraph()
129-
130-
class Num(Expr):
131-
def __init__(self, n: i64Like) -> None: ...
132-
def __add__(self, other: Num) -> Num: ...
133-
134-
x, y = vars_("x y", Num)
135-
egraph.register(rule(x + y, name="comm").then(union(x + y).with_(y + x)))
136-
egraph.register(Num(1) + Num(2))
137-
report = egraph.run(10)
138-
139-
output = str(report)
140-
assert "__main__" not in output, f"str() still contains mangled names:\n{output}"
141-
assert "rule(" in output, f"Expected 'rule(' in:\n{output}"
142-
assert "comm" in output, f"Expected rule name 'comm' in:\n{output}"
143-
144-
def test_unnamed_rule_decl(self):
145-
egraph = EGraph()
146-
147-
class Num(Expr):
148-
def __init__(self, n: i64Like) -> None: ...
149-
def __add__(self, other: Num) -> Num: ...
150-
151-
x, y = vars_("x y", Num)
152-
egraph.register(rule(x + y).then(union(x + y).with_(y + x)))
153-
egraph.register(Num(1) + Num(2))
154-
report = egraph.run(10)
155-
156-
output = str(report)
157-
assert "__main__" not in output, f"Unnamed RuleDecl key not translated:\n{output}"
158-
assert "rule(" in output, f"Expected 'rule(' in:\n{output}"
159-
# Should contain Python rule() syntax somewhere in the keys
160-
rule_keys = list(report.search_and_apply_time_per_rule.keys())
161-
assert len(rule_keys) > 0
162-
for key in rule_keys:
163-
assert isinstance(key, RuleDecl)
164-
165-
def test_birewrite_decl(self):
166-
egraph = EGraph()
167-
168-
class Num(Expr):
169-
def __init__(self, n: i64Like) -> None: ...
170-
def __add__(self, other: Num) -> Num: ...
171-
def __mul__(self, other: Num) -> Num: ...
172-
173-
x, y = vars_("x y", Num)
174-
egraph.register(birewrite(x + y).to(y + x))
175-
egraph.register(Num(1) + Num(2))
176-
report = egraph.run(10)
177-
178-
output = str(report)
179-
assert "__main__" not in output, f"BiRewriteDecl key not translated:\n{output}"
180-
assert "birewrite(" in output, f"Expected 'birewrite(' in:\n{output}"
181-
rule_keys = list(report.search_and_apply_time_per_rule.keys())
182-
assert len(rule_keys) > 0
183-
for key in rule_keys:
184-
assert isinstance(key, BiRewriteDecl)
56+
57+
def test_updated_field():
58+
egraph = _setup_simple_egraph()
59+
report = egraph.run(10)
60+
assert isinstance(report.updated, bool)
61+
assert report.updated is True
62+
63+
64+
def test_num_matches():
65+
egraph = _setup_simple_egraph()
66+
report = egraph.run(10)
67+
68+
total_matches = sum(report.num_matches_per_rule.values())
69+
assert total_matches > 0
70+
71+
72+
def test_timedelta_types():
73+
egraph = _setup_simple_egraph()
74+
report = egraph.run(10)
75+
76+
for v in report.search_and_apply_time_per_rule.values():
77+
assert isinstance(v, timedelta)
78+
for v in report.search_and_apply_time_per_ruleset.values():
79+
assert isinstance(v, timedelta)
80+
for v in report.merge_time_per_ruleset.values():
81+
assert isinstance(v, timedelta)
82+
for v in report.rebuild_time_per_ruleset.values():
83+
assert isinstance(v, timedelta)
84+
85+
86+
def test_iteration_reports():
87+
egraph = _setup_simple_egraph()
88+
report = egraph.run(10)
89+
90+
for it in report.iterations:
91+
assert type(it).__name__ == "IterationReport"
92+
assert type(it.rule_set_report).__name__ == "RuleSetReport"
93+
for rule_reports in it.rule_set_report.rule_reports.values():
94+
for rr in rule_reports:
95+
assert type(rr).__name__ == "RuleReport"
96+
97+
98+
def test_str_no_egglog_sexprs():
99+
egraph = _setup_simple_egraph()
100+
report = egraph.run(10)
101+
output = str(report)
102+
103+
assert "(rewrite" not in output, f"str() still contains egglog s-expressions:\n{output}"
104+
assert "__main__" not in output, f"str() still contains mangled names:\n{output}"
105+
assert "rewrite(" in output, f"Expected 'rewrite(' in:\n{output}"
106+
107+
108+
def test_multiple_rules():
109+
egraph = EGraph()
110+
111+
class Math(Expr):
112+
def __init__(self, value: i64Like) -> None: ...
113+
def __add__(self, other: Math) -> Math: ...
114+
def __mul__(self, other: Math) -> Math: ...
115+
116+
a, b = vars_("a b", Math)
117+
egraph.register(
118+
rewrite(a + b).to(b + a),
119+
rewrite(a * b).to(b * a),
120+
)
121+
egraph.register(Math(1) + Math(2), Math(3) * Math(4))
122+
report = egraph.run(10)
123+
124+
rule_keys = list(report.search_and_apply_time_per_rule.keys())
125+
assert len(rule_keys) == 2
126+
for key in rule_keys:
127+
assert isinstance(key, RewriteDecl)
128+
129+
130+
def test_empty_run():
131+
egraph = EGraph()
132+
report = egraph.run(1)
133+
assert type(report).__name__ == "RunReport"
134+
assert isinstance(report.updated, bool)
135+
136+
137+
def test_named_rule():
138+
egraph = EGraph()
139+
140+
class Num(Expr):
141+
def __init__(self, n: i64Like) -> None: ...
142+
def __add__(self, other: Num) -> Num: ...
143+
144+
x, y = vars_("x y", Num)
145+
egraph.register(rule(x + y, name="comm").then(union(x + y).with_(y + x)))
146+
egraph.register(Num(1) + Num(2))
147+
report = egraph.run(10)
148+
149+
output = str(report)
150+
assert "__main__" not in output, f"str() still contains mangled names:\n{output}"
151+
assert "rule(" in output, f"Expected 'rule(' in:\n{output}"
152+
assert "comm" in output, f"Expected rule name 'comm' in:\n{output}"
153+
154+
155+
def test_unnamed_rule_decl():
156+
egraph = EGraph()
157+
158+
class Num(Expr):
159+
def __init__(self, n: i64Like) -> None: ...
160+
def __add__(self, other: Num) -> Num: ...
161+
162+
x, y = vars_("x y", Num)
163+
egraph.register(rule(x + y).then(union(x + y).with_(y + x)))
164+
egraph.register(Num(1) + Num(2))
165+
report = egraph.run(10)
166+
167+
output = str(report)
168+
assert "__main__" not in output, f"Unnamed RuleDecl key not translated:\n{output}"
169+
assert "rule(" in output, f"Expected 'rule(' in:\n{output}"
170+
rule_keys = list(report.search_and_apply_time_per_rule.keys())
171+
assert len(rule_keys) > 0
172+
for key in rule_keys:
173+
assert isinstance(key, RuleDecl)
174+
175+
176+
def test_birewrite_decl():
177+
egraph = EGraph()
178+
179+
class Num(Expr):
180+
def __init__(self, n: i64Like) -> None: ...
181+
def __add__(self, other: Num) -> Num: ...
182+
def __mul__(self, other: Num) -> Num: ...
183+
184+
x, y = vars_("x y", Num)
185+
egraph.register(birewrite(x + y).to(y + x))
186+
egraph.register(Num(1) + Num(2))
187+
report = egraph.run(10)
188+
189+
output = str(report)
190+
assert "__main__" not in output, f"BiRewriteDecl key not translated:\n{output}"
191+
assert "birewrite(" in output, f"Expected 'birewrite(' in:\n{output}"
192+
rule_keys = list(report.search_and_apply_time_per_rule.keys())
193+
assert len(rule_keys) > 0
194+
for key in rule_keys:
195+
assert isinstance(key, BiRewriteDecl)

0 commit comments

Comments
 (0)