Skip to content

Commit 810bd95

Browse files
committed
fix: store CommandDecl
1 parent c554d1a commit 810bd95

3 files changed

Lines changed: 75 additions & 50 deletions

File tree

python/egglog/egraph_state.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def _normalize_global_let_name(name: str) -> str:
4848
return name if name.startswith("$") else f"${name}"
4949

5050

51+
def _normalize_rule_key(key: str) -> str:
52+
"""Normalize an egglog rule string for consistent matching."""
53+
key = key.replace("'", '"')
54+
return re.sub(r"\s+", " ", key).strip()
55+
56+
5157
@dataclass
5258
class EGraphState:
5359
"""
@@ -249,13 +255,12 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
249255
case _:
250256
assert_never(schedule)
251257

252-
def translate_rule_key(self, egglog_key: str) -> str:
258+
def translate_rule_key(self, egglog_key: str) -> CommandDecl:
253259
"""
254-
Translate an egglog rule name to its Python representation.
260+
Look up the original Python CommandDecl for an egglog rule key.
255261
"""
256-
if egglog_key in self.egg_rule_to_command_decl:
257-
return pretty_decl(self.__egg_decls__, self.egg_rule_to_command_decl[egglog_key])
258-
return egglog_key
262+
normalized = _normalize_rule_key(egglog_key)
263+
return self.egg_rule_to_command_decl[normalized]
259264

260265
def ruleset_to_egg(self, ident: Ident) -> None:
261266
"""
@@ -304,7 +309,11 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
304309
else:
305310
egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite)
306311

307-
self.egg_rule_to_command_decl[str(egg_cmd)] = cmd
312+
normalized = _normalize_rule_key(str(egg_cmd))
313+
self.egg_rule_to_command_decl[normalized] = cmd
314+
if isinstance(cmd, BiRewriteDecl):
315+
self.egg_rule_to_command_decl[normalized + "=>"] = cmd
316+
self.egg_rule_to_command_decl[normalized + "<="] = cmd
308317
return egg_cmd
309318
case RuleDecl(head, body, name):
310319
egg_cmd = bindings.RuleCommand(
@@ -316,7 +325,9 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
316325
str(ruleset),
317326
)
318327
)
319-
self.egg_rule_to_command_decl[str(egg_cmd)] = cmd
328+
self.egg_rule_to_command_decl[_normalize_rule_key(str(egg_cmd))] = cmd
329+
if name:
330+
self.egg_rule_to_command_decl[name] = cmd
320331
return egg_cmd
321332
# TODO: Replace with just constants value and looking at REF of function
322333
case DefaultRewriteDecl(ref, expr, subsume):

python/egglog/run_report.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from __future__ import annotations
22

3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field
44
from datetime import timedelta
55

66
from . import bindings
7+
from .declarations import CommandDecl, Declarations
78
from .egraph_state import EGraphState
9+
from .pretty import pretty_decl
10+
11+
12+
def _format_rule_key(decls: Declarations, key: CommandDecl) -> str:
13+
return pretty_decl(decls, key)
814

915

1016
@dataclass
@@ -25,19 +31,32 @@ def from_bindings(cls, report: bindings.RuleReport) -> RuleReport:
2531
@dataclass
2632
class RuleSetReport:
2733
changed: bool
28-
rule_reports: dict[str, list[RuleReport]]
34+
rule_reports: dict[CommandDecl, list[RuleReport]]
2935
search_and_apply_time: timedelta
3036
merge_time: timedelta
37+
_decls: Declarations = field(repr=False, default=None)
3138

3239
@classmethod
33-
def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) -> RuleSetReport:
40+
def from_bindings(
41+
cls, report: bindings.RuleSetReport, translate_key: callable, decls: Declarations
42+
) -> RuleSetReport:
3443
return cls(
3544
changed=report.changed,
3645
rule_reports={
3746
translate_key(k): [RuleReport.from_bindings(rr) for rr in v] for k, v in report.rule_reports.items()
3847
},
3948
search_and_apply_time=report.search_and_apply_time,
4049
merge_time=report.merge_time,
50+
_decls=decls,
51+
)
52+
53+
def __repr__(self) -> str:
54+
rule_reports_str = {_format_rule_key(self._decls, k): v for k, v in self.rule_reports.items()}
55+
return (
56+
f"RuleSetReport(changed={self.changed}, "
57+
f"rule_reports={rule_reports_str}, "
58+
f"search_and_apply_time={self.search_and_apply_time}, "
59+
f"merge_time={self.merge_time})"
4160
)
4261

4362

@@ -47,9 +66,11 @@ class IterationReport:
4766
rebuild_time: timedelta
4867

4968
@classmethod
50-
def from_bindings(cls, report: bindings.IterationReport, translate_key: callable) -> IterationReport:
69+
def from_bindings(
70+
cls, report: bindings.IterationReport, translate_key: callable, decls: Declarations
71+
) -> IterationReport:
5172
return cls(
52-
rule_set_report=RuleSetReport.from_bindings(report.rule_set_report, translate_key),
73+
rule_set_report=RuleSetReport.from_bindings(report.rule_set_report, translate_key, decls),
5374
rebuild_time=report.rebuild_time,
5475
)
5576

@@ -60,26 +81,40 @@ class RunReport:
6081

6182
iterations: list[IterationReport]
6283
updated: bool
63-
search_and_apply_time_per_rule: dict[str, timedelta]
64-
num_matches_per_rule: dict[str, int]
84+
search_and_apply_time_per_rule: dict[CommandDecl, timedelta]
85+
num_matches_per_rule: dict[CommandDecl, int]
6586
search_and_apply_time_per_ruleset: dict[str, timedelta]
6687
merge_time_per_ruleset: dict[str, timedelta]
6788
rebuild_time_per_ruleset: dict[str, timedelta]
89+
_decls: Declarations = field(repr=False, default=None)
90+
91+
def __repr__(self) -> str:
92+
time_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.search_and_apply_time_per_rule.items()}
93+
matches_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.num_matches_per_rule.items()}
94+
return (
95+
f"RunReport(iterations={self.iterations}, "
96+
f"updated={self.updated}, "
97+
f"search_and_apply_time_per_rule={time_per_rule}, "
98+
f"num_matches_per_rule={matches_per_rule}, "
99+
f"search_and_apply_time_per_ruleset={self.search_and_apply_time_per_ruleset}, "
100+
f"merge_time_per_ruleset={self.merge_time_per_ruleset}, "
101+
f"rebuild_time_per_ruleset={self.rebuild_time_per_ruleset})"
102+
)
68103

69104
@classmethod
70105
def from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport:
71106
return cls(
72-
iterations=[IterationReport.from_bindings(it, state.translate_rule_key) for it in report.iterations],
107+
iterations=[
108+
IterationReport.from_bindings(it, state.translate_rule_key, state.__egg_decls__)
109+
for it in report.iterations
110+
],
73111
updated=report.updated,
74112
search_and_apply_time_per_rule={
75113
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items()
76114
},
77115
num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()},
78-
search_and_apply_time_per_ruleset={
79-
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_ruleset.items()
80-
},
81-
merge_time_per_ruleset={state.translate_rule_key(k): v for k, v in report.merge_time_per_ruleset.items()},
82-
rebuild_time_per_ruleset={
83-
state.translate_rule_key(k): v for k, v in report.rebuild_time_per_ruleset.items()
84-
},
116+
search_and_apply_time_per_ruleset=report.search_and_apply_time_per_ruleset,
117+
merge_time_per_ruleset=report.merge_time_per_ruleset,
118+
rebuild_time_per_ruleset=report.rebuild_time_per_ruleset,
119+
_decls=state.__egg_decls__,
85120
)

python/tests/test_run_report.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import timedelta
55

66
from egglog import *
7+
from egglog.declarations import BiRewriteDecl, RewriteDecl, RuleDecl
78

89

910
class TestRunReport:
@@ -35,12 +36,10 @@ def test_rule_names_translated_in_top_level_dicts(self):
3536
report = egraph.run(10)
3637

3738
for key in report.search_and_apply_time_per_rule:
38-
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"
39-
assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}"
39+
assert isinstance(key, RewriteDecl)
4040

4141
for key in report.num_matches_per_rule:
42-
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"
43-
assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}"
42+
assert isinstance(key, RewriteDecl)
4443

4544
def test_rule_names_translated_in_iterations(self):
4645
egraph = self._setup_simple_egraph()
@@ -49,8 +48,7 @@ def test_rule_names_translated_in_iterations(self):
4948
assert len(report.iterations) > 0
5049
for iteration in report.iterations:
5150
for key in iteration.rule_set_report.rule_reports:
52-
assert "__main__" not in key, f"Iteration rule key not translated: {key}"
53-
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"
51+
assert isinstance(key, RewriteDecl)
5452

5553
def test_updated_field(self):
5654
egraph = self._setup_simple_egraph()
@@ -117,7 +115,7 @@ def __mul__(self, other: Math) -> Math: ...
117115
rule_keys = list(report.search_and_apply_time_per_rule.keys())
118116
assert len(rule_keys) == 2
119117
for key in rule_keys:
120-
assert "__main__" not in key, f"Key not translated: {key}"
118+
assert isinstance(key, RewriteDecl)
121119

122120
def test_empty_run(self):
123121
egraph = EGraph()
@@ -158,7 +156,7 @@ def __add__(self, other: Num) -> Num: ...
158156
rule_keys = list(report.search_and_apply_time_per_rule.keys())
159157
assert len(rule_keys) > 0
160158
for key in rule_keys:
161-
assert "__main__" not in key, f"RuleDecl key not translated: {key}"
159+
assert isinstance(key, RuleDecl)
162160

163161
def test_birewrite_decl(self):
164162
egraph = EGraph()
@@ -178,23 +176,4 @@ def __mul__(self, other: Num) -> Num: ...
178176
rule_keys = list(report.search_and_apply_time_per_rule.keys())
179177
assert len(rule_keys) > 0
180178
for key in rule_keys:
181-
assert "__main__" not in key, f"BiRewriteDecl key not translated: {key}"
182-
assert "birewrite" in key, f"Expected birewrite() syntax, got: {key}"
183-
184-
def test_rewrite_decl(self):
185-
egraph = EGraph()
186-
187-
class Num(Expr):
188-
def __init__(self, n: i64Like) -> None: ...
189-
def __add__(self, other: Num) -> Num: ...
190-
191-
x, y = vars_("x y", Num)
192-
egraph.register(rewrite(x + y).to(y + x))
193-
egraph.register(Num(1) + Num(2))
194-
report = egraph.run(10)
195-
196-
rule_keys = list(report.search_and_apply_time_per_rule.keys())
197-
assert len(rule_keys) == 1
198-
key = rule_keys[0]
199-
assert "rewrite" in key, f"Expected rewrite() syntax, got: {key}"
200-
assert "__main__" not in key, f"RewriteDecl key not translated: {key}"
179+
assert isinstance(key, BiRewriteDecl)

0 commit comments

Comments
 (0)