Skip to content

Commit 3e45d9c

Browse files
committed
chore: clean up by removing egg_rule_to_command_decl fallback and clear type
1 parent e217a3f commit 3e45d9c

3 files changed

Lines changed: 47 additions & 38 deletions

File tree

python/egglog/egraph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
"GreedyDagCost",
7272
"RewriteOrRule",
7373
"Ruleset",
74+
"RunReport",
7475
"Schedule",
7576
"_BirewriteBuilder",
7677
"_EqBuilder",

python/egglog/egraph_state.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ class EGraphState:
7676
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
7777
egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
7878

79-
egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict)
80-
8179
# Cache of egg expressions for converting to egg
8280
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
8381

@@ -91,7 +89,7 @@ class EGraphState:
9189
# Counter for numeric rule names
9290
rule_name_counter: int = 0
9391
# Mapping from numeric name (str) to command decl
94-
rule_name_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict)
92+
rule_name_to_command_decl: dict[str, RuleDecl | BiRewriteDecl | RewriteDecl] = field(default_factory=dict)
9593

9694
def copy(self) -> EGraphState:
9795
"""
@@ -256,17 +254,6 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
256254
case _:
257255
assert_never(schedule)
258256

259-
def translate_rule_key(self, egglog_key: str) -> CommandDecl | str:
260-
"""
261-
Look up the original Python CommandDecl for an egglog rule key.
262-
"""
263-
clean_key = egglog_key.removesuffix("=>").removesuffix("<=")
264-
if clean_key in self.rule_name_to_command_decl:
265-
return self.rule_name_to_command_decl[clean_key]
266-
if egglog_key in self.egg_rule_to_command_decl:
267-
return self.egg_rule_to_command_decl[egglog_key]
268-
return egglog_key
269-
270257
def ruleset_to_egg(self, ident: Ident) -> None:
271258
"""
272259
Registers a ruleset if it's not already registered.
@@ -305,7 +292,6 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
305292
self.type_ref_to_egg(tp)
306293
name = str(self.rule_name_counter)
307294
self.rule_name_counter += 1
308-
self.rule_name_to_command_decl[name] = cmd
309295
rewrite = bindings.Rewrite(
310296
span(),
311297
self._expr_to_egg(lhs),
@@ -315,8 +301,11 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
315301
)
316302
egg_cmd: bindings._Command
317303
if isinstance(cmd, RewriteDecl):
304+
self.rule_name_to_command_decl[name] = cmd
318305
egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
319306
else:
307+
self.rule_name_to_command_decl[f"{name}=>"] = cmd
308+
self.rule_name_to_command_decl[f"{name}<="] = cmd
320309
egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite)
321310
return egg_cmd
322311
case RuleDecl(head, body, name):

python/egglog/run_report.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable
43
from dataclasses import dataclass, field
54
from datetime import timedelta
65

76
from . import bindings
8-
from .declarations import CommandDecl, Declarations
7+
from .declarations import BiRewriteDecl, Declarations, RewriteDecl, RuleDecl
98
from .egraph_state import EGraphState
109
from .pretty import pretty_decl
1110

11+
RewriteOrRuleDecl = RuleDecl | BiRewriteDecl | RewriteDecl
1212

13-
def _format_rule_key(decls: Declarations, key: CommandDecl | str) -> str:
14-
if isinstance(key, str):
15-
return key
13+
14+
def _format_rule_key(decls: Declarations, key: RewriteOrRuleDecl) -> str:
1615
return pretty_decl(decls, key)
1716

1817

@@ -35,20 +34,26 @@ def _from_bindings(cls, report: bindings.RuleReport) -> RuleReport:
3534
class RuleSetReport:
3635
_decls: Declarations = field(repr=False)
3736
changed: bool = False
38-
rule_reports: dict[CommandDecl | str, list[RuleReport]] = field(default_factory=dict)
37+
rule_reports: dict[RewriteOrRuleDecl, list[RuleReport]] = field(default_factory=dict)
3938
search_and_apply_time: timedelta = field(default_factory=timedelta)
4039
merge_time: timedelta = field(default_factory=timedelta)
4140

4241
@classmethod
4342
def _from_bindings(
44-
cls, report: bindings.RuleSetReport, translate_key: Callable[[str], CommandDecl | str], decls: Declarations
43+
cls, report: bindings.RuleSetReport, rule_map: dict[str, RewriteOrRuleDecl], decls: Declarations
4544
) -> RuleSetReport:
45+
rule_reports: dict[RewriteOrRuleDecl, list[RuleReport]] = {}
46+
for k, v in report.rule_reports.items():
47+
translated = rule_map[k]
48+
reports = [RuleReport._from_bindings(rr) for rr in v]
49+
if translated in rule_reports:
50+
rule_reports[translated].extend(reports)
51+
else:
52+
rule_reports[translated] = reports
4653
return cls(
4754
_decls=decls,
4855
changed=report.changed,
49-
rule_reports={
50-
translate_key(k): [RuleReport._from_bindings(rr) for rr in v] for k, v in report.rule_reports.items()
51-
},
56+
rule_reports=rule_reports,
5257
search_and_apply_time=report.search_and_apply_time,
5358
merge_time=report.merge_time,
5459
)
@@ -70,10 +75,10 @@ class IterationReport:
7075

7176
@classmethod
7277
def _from_bindings(
73-
cls, report: bindings.IterationReport, translate_key: Callable[[str], CommandDecl | str], decls: Declarations
78+
cls, report: bindings.IterationReport, rule_map: dict[str, RewriteOrRuleDecl], decls: Declarations
7479
) -> IterationReport:
7580
return cls(
76-
rule_set_report=RuleSetReport._from_bindings(report.rule_set_report, translate_key, decls),
81+
rule_set_report=RuleSetReport._from_bindings(report.rule_set_report, rule_map, decls),
7782
rebuild_time=report.rebuild_time,
7883
)
7984

@@ -85,8 +90,8 @@ class RunReport:
8590
_decls: Declarations = field(repr=False)
8691
iterations: list[IterationReport] = field(default_factory=list)
8792
updated: bool = False
88-
search_and_apply_time_per_rule: dict[CommandDecl | str, timedelta] = field(default_factory=dict)
89-
num_matches_per_rule: dict[CommandDecl | str, int] = field(default_factory=dict)
93+
search_and_apply_time_per_rule: dict[RewriteOrRuleDecl, timedelta] = field(default_factory=dict)
94+
num_matches_per_rule: dict[RewriteOrRuleDecl, int] = field(default_factory=dict)
9095
search_and_apply_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)
9196
merge_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)
9297
rebuild_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)
@@ -106,17 +111,31 @@ def __repr__(self) -> str:
106111

107112
@classmethod
108113
def _from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport:
114+
rule_map = state.rule_name_to_command_decl
115+
decls = state.__egg_decls__
116+
117+
search_and_apply_time_per_rule: dict[RewriteOrRuleDecl, timedelta] = {}
118+
for k, v in report.search_and_apply_time_per_rule.items():
119+
translated = rule_map[k]
120+
if translated in search_and_apply_time_per_rule:
121+
search_and_apply_time_per_rule[translated] += v
122+
else:
123+
search_and_apply_time_per_rule[translated] = v
124+
125+
num_matches_per_rule: dict[RewriteOrRuleDecl, int] = {}
126+
for k, v in report.num_matches_per_rule.items():
127+
translated = rule_map[k]
128+
if translated in num_matches_per_rule:
129+
num_matches_per_rule[translated] += v
130+
else:
131+
num_matches_per_rule[translated] = v
132+
109133
return cls(
110-
_decls=state.__egg_decls__,
111-
iterations=[
112-
IterationReport._from_bindings(it, state.translate_rule_key, state.__egg_decls__)
113-
for it in report.iterations
114-
],
134+
_decls=decls,
135+
iterations=[IterationReport._from_bindings(it, rule_map, decls) for it in report.iterations],
115136
updated=report.updated,
116-
search_and_apply_time_per_rule={
117-
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items()
118-
},
119-
num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()},
137+
search_and_apply_time_per_rule=search_and_apply_time_per_rule,
138+
num_matches_per_rule=num_matches_per_rule,
120139
search_and_apply_time_per_ruleset=report.search_and_apply_time_per_ruleset,
121140
merge_time_per_ruleset=report.merge_time_per_ruleset,
122141
rebuild_time_per_ruleset=report.rebuild_time_per_ruleset,

0 commit comments

Comments
 (0)