forked from egraphs-good/egglog-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_run_report.py
More file actions
200 lines (157 loc) · 7.27 KB
/
Copy pathtest_run_report.py
File metadata and controls
200 lines (157 loc) · 7.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# mypy: disable-error-code="empty-body"
from __future__ import annotations
from datetime import timedelta
from egglog import *
class TestPrettyRunReport:
def _setup_simple_egraph(self):
egraph = EGraph()
class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...
x, y = vars_("x y", Num)
egraph.register(rewrite(x + y).to(y + x))
egraph.register(Num(1) + Num(2))
return egraph
def test_run_returns_pretty_report(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
assert type(report).__name__ == "PrettyRunReport"
def test_stats_returns_pretty_report(self):
egraph = self._setup_simple_egraph()
egraph.run(10)
report = egraph.stats()
assert type(report).__name__ == "PrettyRunReport"
def test_rule_names_translated_in_top_level_dicts(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
for key in report.search_and_apply_time_per_rule:
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"
assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}"
for key in report.num_matches_per_rule:
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"
assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}"
def test_rule_names_translated_in_iterations(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
assert len(report.iterations) > 0
for iteration in report.iterations:
for key in iteration.rule_set_report.rule_reports:
assert "__main__" not in key, f"Iteration rule key not translated: {key}"
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"
def test_updated_field(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
assert isinstance(report.updated, bool)
assert report.updated is True
def test_num_matches(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
total_matches = sum(report.num_matches_per_rule.values())
assert total_matches > 0
def test_timedelta_types(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
for v in report.search_and_apply_time_per_rule.values():
assert isinstance(v, timedelta)
for v in report.search_and_apply_time_per_ruleset.values():
assert isinstance(v, timedelta)
for v in report.merge_time_per_ruleset.values():
assert isinstance(v, timedelta)
for v in report.rebuild_time_per_ruleset.values():
assert isinstance(v, timedelta)
def test_iteration_reports_are_pretty(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
for it in report.iterations:
assert type(it).__name__ == "PrettyIterationReport"
assert type(it.rule_set_report).__name__ == "PrettyRuleSetReport"
for rule_reports in it.rule_set_report.rule_reports.values():
for rr in rule_reports:
assert type(rr).__name__ == "PrettyRuleReport"
def test_str_no_egglog_sexprs(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
output = str(report)
assert "(rewrite" not in output, f"str() still contains egglog s-expressions:\n{output}"
assert "__main__" not in output, f"str() still contains mangled names:\n{output}"
def test_multiple_rules(self):
egraph = EGraph()
class Math(Expr):
def __init__(self, value: i64Like) -> None: ...
def __add__(self, other: Math) -> Math: ...
def __mul__(self, other: Math) -> Math: ...
a, b = vars_("a b", Math)
egraph.register(
rewrite(a + b).to(b + a),
rewrite(a * b).to(b * a),
)
egraph.register(Math(1) + Math(2), Math(3) * Math(4))
report = egraph.run(10)
# should have two distinct translated rule keys
rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) == 2
for key in rule_keys:
assert "__main__" not in key, f"Key not translated: {key}"
def test_empty_run(self):
egraph = EGraph()
report = egraph.run(1)
assert type(report).__name__ == "PrettyRunReport"
assert isinstance(report.updated, bool)
def test_named_rule(self):
egraph = EGraph()
class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...
x, y = vars_("x y", Num)
egraph.register(rule(x + y, name="comm").then(union(x + y).with_(y + x)))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)
output = str(report)
assert "__main__" not in output, f"str() still contains mangled names:\n{output}"
def test_unnamed_rule_decl(self):
egraph = EGraph()
class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...
x, y = vars_("x y", Num)
egraph.register(rule(x + y).then(union(x + y).with_(y + x)))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)
output = str(report)
assert "__main__" not in output, f"Unnamed RuleDecl key not translated:\n{output}"
# Should contain Python rule() syntax somewhere in the keys
rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) > 0
for key in rule_keys:
assert "__main__" not in key, f"RuleDecl key not translated: {key}"
def test_birewrite_decl(self):
egraph = EGraph()
class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...
def __mul__(self, other: Num) -> Num: ...
x, y = vars_("x y", Num)
egraph.register(birewrite(x + y).to(y + x))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)
output = str(report)
assert "__main__" not in output, f"BiRewriteDecl key not translated:\n{output}"
rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) > 0
for key in rule_keys:
assert "__main__" not in key, f"BiRewriteDecl key not translated: {key}"
assert "birewrite" in key, f"Expected birewrite() syntax, got: {key}"
def test_rewrite_decl(self):
egraph = EGraph()
class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...
x, y = vars_("x y", Num)
egraph.register(rewrite(x + y).to(y + x))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)
rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) == 1
key = rule_keys[0]
assert "rewrite" in key, f"Expected rewrite() syntax, got: {key}"
assert "__main__" not in key, f"RewriteDecl key not translated: {key}"