11from __future__ import annotations
22
3- from collections .abc import Callable
43from dataclasses import dataclass , field
54from datetime import timedelta
65
76from . import bindings
8- from .declarations import CommandDecl , Declarations
7+ from .declarations import BiRewriteDecl , Declarations , RewriteDecl , RuleDecl
98from .egraph_state import EGraphState
109from .pretty import pretty_decl
1110
11+ RewriteOrRuleDecl = RuleDecl | BiRewriteDecl | RewriteDecl
1212
13- def _format_rule_key (decls : Declarations , key : CommandDecl | str ) -> str :
14- if isinstance (key , str ):
15- return key
13+
14+ def _format_rule_key (decls : Declarations , key : RewriteOrRuleDecl ) -> str :
1615 return pretty_decl (decls , key )
1716
1817
@@ -35,20 +34,26 @@ def _from_bindings(cls, report: bindings.RuleReport) -> RuleReport:
3534class RuleSetReport :
3635 _decls : Declarations = field (repr = False )
3736 changed : bool = False
38- rule_reports : dict [CommandDecl | str , list [RuleReport ]] = field (default_factory = dict )
37+ rule_reports : dict [RewriteOrRuleDecl , list [RuleReport ]] = field (default_factory = dict )
3938 search_and_apply_time : timedelta = field (default_factory = timedelta )
4039 merge_time : timedelta = field (default_factory = timedelta )
4140
4241 @classmethod
4342 def _from_bindings (
44- cls , report : bindings .RuleSetReport , translate_key : Callable [[ str ], CommandDecl | str ], decls : Declarations
43+ cls , report : bindings .RuleSetReport , rule_map : dict [ str , RewriteOrRuleDecl ], decls : Declarations
4544 ) -> RuleSetReport :
45+ rule_reports : dict [RewriteOrRuleDecl , list [RuleReport ]] = {}
46+ for k , v in report .rule_reports .items ():
47+ translated = rule_map [k ]
48+ reports = [RuleReport ._from_bindings (rr ) for rr in v ]
49+ if translated in rule_reports :
50+ rule_reports [translated ].extend (reports )
51+ else :
52+ rule_reports [translated ] = reports
4653 return cls (
4754 _decls = decls ,
4855 changed = report .changed ,
49- rule_reports = {
50- translate_key (k ): [RuleReport ._from_bindings (rr ) for rr in v ] for k , v in report .rule_reports .items ()
51- },
56+ rule_reports = rule_reports ,
5257 search_and_apply_time = report .search_and_apply_time ,
5358 merge_time = report .merge_time ,
5459 )
@@ -70,10 +75,10 @@ class IterationReport:
7075
7176 @classmethod
7277 def _from_bindings (
73- cls , report : bindings .IterationReport , translate_key : Callable [[ str ], CommandDecl | str ], decls : Declarations
78+ cls , report : bindings .IterationReport , rule_map : dict [ str , RewriteOrRuleDecl ], decls : Declarations
7479 ) -> IterationReport :
7580 return cls (
76- rule_set_report = RuleSetReport ._from_bindings (report .rule_set_report , translate_key , decls ),
81+ rule_set_report = RuleSetReport ._from_bindings (report .rule_set_report , rule_map , decls ),
7782 rebuild_time = report .rebuild_time ,
7883 )
7984
@@ -85,8 +90,8 @@ class RunReport:
8590 _decls : Declarations = field (repr = False )
8691 iterations : list [IterationReport ] = field (default_factory = list )
8792 updated : bool = False
88- search_and_apply_time_per_rule : dict [CommandDecl | str , timedelta ] = field (default_factory = dict )
89- num_matches_per_rule : dict [CommandDecl | str , int ] = field (default_factory = dict )
93+ search_and_apply_time_per_rule : dict [RewriteOrRuleDecl , timedelta ] = field (default_factory = dict )
94+ num_matches_per_rule : dict [RewriteOrRuleDecl , int ] = field (default_factory = dict )
9095 search_and_apply_time_per_ruleset : dict [str , timedelta ] = field (default_factory = dict )
9196 merge_time_per_ruleset : dict [str , timedelta ] = field (default_factory = dict )
9297 rebuild_time_per_ruleset : dict [str , timedelta ] = field (default_factory = dict )
@@ -106,17 +111,31 @@ def __repr__(self) -> str:
106111
107112 @classmethod
108113 def _from_bindings (cls , report : bindings .RunReport , state : EGraphState ) -> RunReport :
114+ rule_map = state .rule_name_to_command_decl
115+ decls = state .__egg_decls__
116+
117+ search_and_apply_time_per_rule : dict [RewriteOrRuleDecl , timedelta ] = {}
118+ for k , v in report .search_and_apply_time_per_rule .items ():
119+ translated = rule_map [k ]
120+ if translated in search_and_apply_time_per_rule :
121+ search_and_apply_time_per_rule [translated ] += v
122+ else :
123+ search_and_apply_time_per_rule [translated ] = v
124+
125+ num_matches_per_rule : dict [RewriteOrRuleDecl , int ] = {}
126+ for k , v in report .num_matches_per_rule .items ():
127+ translated = rule_map [k ]
128+ if translated in num_matches_per_rule :
129+ num_matches_per_rule [translated ] += v
130+ else :
131+ num_matches_per_rule [translated ] = v
132+
109133 return cls (
110- _decls = state .__egg_decls__ ,
111- iterations = [
112- IterationReport ._from_bindings (it , state .translate_rule_key , state .__egg_decls__ )
113- for it in report .iterations
114- ],
134+ _decls = decls ,
135+ iterations = [IterationReport ._from_bindings (it , rule_map , decls ) for it in report .iterations ],
115136 updated = report .updated ,
116- search_and_apply_time_per_rule = {
117- state .translate_rule_key (k ): v for k , v in report .search_and_apply_time_per_rule .items ()
118- },
119- num_matches_per_rule = {state .translate_rule_key (k ): v for k , v in report .num_matches_per_rule .items ()},
137+ search_and_apply_time_per_rule = search_and_apply_time_per_rule ,
138+ num_matches_per_rule = num_matches_per_rule ,
120139 search_and_apply_time_per_ruleset = report .search_and_apply_time_per_ruleset ,
121140 merge_time_per_ruleset = report .merge_time_per_ruleset ,
122141 rebuild_time_per_ruleset = report .rebuild_time_per_ruleset ,
0 commit comments