@@ -35,7 +35,7 @@ struct ENode
3535 istree:: Bool
3636 head:: Any
3737 operation:: Any
38- args:: Vector{EClassId }
38+ args:: Vector{Id }
3939 hash:: Ref{UInt}
4040 ENode (head, operation, args) = new (true , head, operation, args, Ref {UInt} (0 ))
4141 ENode (literal) = new (false , nothing , literal, UNDEF_ID_VEC, Ref {UInt} (0 ))
7979
8080# parametrize metadata by M
8181mutable struct EClass{D}
82- id:: EClassId
82+ id:: Id
8383 nodes:: Vector{ENode}
84- parents:: Vector{Pair{ENode,EClassId }}
84+ parents:: Vector{Pair{ENode,Id }}
8585 data:: Union{D,Nothing}
8686end
8787
@@ -101,7 +101,7 @@ function Base.show(io::IO, a::EClass)
101101 print (io, " )" )
102102end
103103
104- function addparent! (@nospecialize (a:: EClass ), n:: ENode , id:: EClassId )
104+ function addparent! (@nospecialize (a:: EClass ), n:: ENode , id:: Id )
105105 push! (a. parents, (n => id))
106106end
107107
@@ -134,22 +134,22 @@ mutable struct EGraph{Head,Analysis}
134134 " stores the equality relations over e-class ids"
135135 uf:: UnionFind
136136 " map from eclass id to eclasses"
137- classes:: Dict{EClassId ,EClass{Analysis}}
137+ classes:: Dict{Id ,EClass{Analysis}}
138138 " hashcons"
139- memo:: Dict{ENode,EClassId }
139+ memo:: Dict{ENode,Id }
140140 " Nodes which need to be processed for rebuilding. The id is the id of the enode, not the canonical id of the eclass."
141- pending:: Vector{Pair{ENode,EClassId }}
142- analysis_pending:: UniqueQueue{Pair{ENode,EClassId }}
143- root:: EClassId
141+ pending:: Vector{Pair{ENode,Id }}
142+ analysis_pending:: UniqueQueue{Pair{ENode,Id }}
143+ root:: Id
144144 " a cache mapping function symbols and their arity to e-classes that contain e-nodes with that function symbol."
145- classes_by_op:: Dict{Pair{Any,Int},Vector{EClassId }}
145+ classes_by_op:: Dict{Pair{Any,Int},Vector{Id }}
146146 clean:: Bool
147147 " If we use global buffers we may need to lock. Defaults to false."
148148 needslock:: Bool
149149 " Buffer for e-matching which defaults to a global. Use a local buffer for generated functions."
150150 buffer:: Vector{EMatchBindings}
151151 " Buffer for rule application which defaults to a global. Use a local buffer for generated functions."
152- merges_buffer:: Vector{EClassId }
152+ merges_buffer:: Vector{Id }
153153 lock:: ReentrantLock
154154end
155155
@@ -161,16 +161,16 @@ Construct an EGraph from a starting symbolic expression `expr`.
161161function EGraph {Head,Analysis} (; needslock:: Bool = false ) where {Head,Analysis}
162162 EGraph {Head,Analysis} (
163163 UnionFind (),
164- Dict {EClassId ,EClass{Analysis}} (),
165- Dict {ENode,EClassId } (),
166- Pair{ENode,EClassId }[],
167- UniqueQueue {Pair{ENode,EClassId }} (),
164+ Dict {Id ,EClass{Analysis}} (),
165+ Dict {ENode,Id } (),
166+ Pair{ENode,Id }[],
167+ UniqueQueue {Pair{ENode,Id }} (),
168168 0 ,
169- Dict {Pair{Any,Int},Vector{EClassId }} (),
169+ Dict {Pair{Any,Int},Vector{Id }} (),
170170 false ,
171171 needslock,
172172 EMatchBindings[],
173- EClassId [],
173+ Id [],
174174 ReentrantLock (),
175175 )
176176end
@@ -199,16 +199,16 @@ end
199199"""
200200Returns the canonical e-class id for a given e-class.
201201"""
202- @inline find (g:: EGraph , a:: EClassId ):: EClassId = find (g. uf, a)
203- @inline find (@nospecialize (g:: EGraph ), @nospecialize (a:: EClass )):: EClassId = find (g, a. id)
202+ @inline find (g:: EGraph , a:: Id ):: Id = find (g. uf, a)
203+ @inline find (@nospecialize (g:: EGraph ), @nospecialize (a:: EClass )):: Id = find (g, a. id)
204204
205- @inline Base. getindex (@nospecialize (g:: EGraph ), i:: EClassId ):: EClass = g. classes[find (g, i)]
205+ @inline Base. getindex (@nospecialize (g:: EGraph ), i:: Id ):: EClass = g. classes[find (g, i)]
206206
207207function canonicalize (g:: EGraph , n:: ENode ):: ENode
208208 n. istree || return n
209209 ar = length (n. args)
210210 ar == 0 && return n
211- canonicalized_args = Vector {EClassId } (undef, ar)
211+ canonicalized_args = Vector {Id } (undef, ar)
212212 for i in 1 : ar
213213 @inbounds canonicalized_args[i] = find (g, n. args[i])
214214 end
@@ -224,7 +224,7 @@ function canonicalize!(g::EGraph, n::ENode)
224224 return n
225225end
226226
227- function lookup (g:: EGraph , n:: ENode ):: EClassId
227+ function lookup (g:: EGraph , n:: ENode ):: Id
228228 cc = canonicalize (g, n)
229229 haskey (g. memo, cc) ? find (g, g. memo[cc]) : 0
230230end
242242"""
243243Inserts an e-node in an [`EGraph`](@ref)
244244"""
245- function add! (g:: EGraph{Head,Analysis} , n:: ENode ):: EClassId where {Head,Analysis}
245+ function add! (g:: EGraph{Head,Analysis} , n:: ENode ):: Id where {Head,Analysis}
246246 n = canonicalize (g, n)
247247 haskey (g. memo, n) && return g. memo[n]
248248
@@ -257,7 +257,7 @@ function add!(g::EGraph{Head,Analysis}, n::ENode)::EClassId where {Head,Analysis
257257 g. memo[n] = id
258258
259259 add_class_by_op (g, n, id)
260- eclass = EClass {Analysis} (id, ENode[n], Pair{ENode,EClassId }[], make (g, n))
260+ eclass = EClass {Analysis} (id, ENode[n], Pair{ENode,Id }[], make (g, n))
261261 g. classes[id] = eclass
262262 modify! (g, eclass)
263263 push! (g. pending, n => id)
@@ -282,14 +282,14 @@ Recursively traverse an type satisfying the `TermInterface` and insert terms int
282282[`EGraph`](@ref). If `e` has no children (has an arity of 0) then directly
283283insert the literal into the [`EGraph`](@ref).
284284"""
285- function addexpr! (g:: EGraph , se):: EClassId
285+ function addexpr! (g:: EGraph , se):: Id
286286 se isa EClass && return se. id
287287 e = preprocess (se)
288288
289289 n = if istree (se)
290290 args = arguments (e)
291291 ar = arity (e)
292- class_ids = Vector {EClassId } (undef, ar)
292+ class_ids = Vector {Id } (undef, ar)
293293 for i in 1 : ar
294294 @inbounds class_ids[i] = addexpr! (g, args[i])
295295 end
305305Given an [`EGraph`](@ref) and two e-class ids, set
306306the two e-classes as equal.
307307"""
308- function Base. union! (g:: EGraph , enode_id1:: EClassId , enode_id2:: EClassId ):: Bool
308+ function Base. union! (g:: EGraph , enode_id1:: Id , enode_id2:: Id ):: Bool
309309 g. clean = false
310310
311311 id_1 = find (g, enode_id1)
@@ -335,7 +335,7 @@ function Base.union!(g::EGraph, enode_id1::EClassId, enode_id2::EClassId)::Bool
335335 return true
336336end
337337
338- function in_same_class (g:: EGraph , ids:: EClassId ... ):: Bool
338+ function in_same_class (g:: EGraph , ids:: Id ... ):: Bool
339339 nids = length (ids)
340340 nids == 1 && return true
341341
@@ -377,7 +377,7 @@ function process_unions!(@nospecialize(g::EGraph))::Int
377377
378378 while ! isempty (g. pending) || ! isempty (g. analysis_pending)
379379 while ! isempty (g. pending)
380- (node:: ENode , eclass_id:: EClassId ) = pop! (g. pending)
380+ (node:: ENode , eclass_id:: Id ) = pop! (g. pending)
381381 canonicalize! (g, node)
382382 if haskey (g. memo, node)
383383 old_class_id = g. memo[node]
@@ -390,7 +390,7 @@ function process_unions!(@nospecialize(g::EGraph))::Int
390390 end
391391
392392 while ! isempty (g. analysis_pending)
393- (node:: ENode , eclass_id:: EClassId ) = pop! (g. analysis_pending)
393+ (node:: ENode , eclass_id:: Id ) = pop! (g. analysis_pending)
394394 eclass_id = find (g, eclass_id)
395395 eclass = g[eclass_id]
396396
@@ -414,7 +414,7 @@ function process_unions!(@nospecialize(g::EGraph))::Int
414414end
415415
416416function check_memo (g:: EGraph ):: Bool
417- test_memo = Dict {ENode,EClassId } ()
417+ test_memo = Dict {ENode,Id } ()
418418 for (id, class) in g. classes
419419 @assert id == class. id
420420 for node in class. nodes
@@ -463,10 +463,10 @@ end
463463Recursive function that traverses an [`EGraph`](@ref) and
464464returns a vector of all reachable e-classes from a given e-class id.
465465"""
466- function reachable (g:: EGraph , id:: EClassId )
466+ function reachable (g:: EGraph , id:: Id )
467467 id = find (g, id)
468- hist = EClassId [id]
469- todo = EClassId [id]
468+ hist = Id [id]
469+ todo = Id [id]
470470
471471
472472 function reachable_node (xn:: ENode )
493493
494494import Metatheory: lookup_pat
495495
496- function lookup_pat (g:: EGraph{Head} , p:: PatTerm ):: EClassId where {Head}
496+ function lookup_pat (g:: EGraph{Head} , p:: PatTerm ):: Id where {Head}
497497 @assert isground (p)
498498
499499 op = operation (p)
@@ -502,7 +502,7 @@ function lookup_pat(g::EGraph{Head}, p::PatTerm)::EClassId where {Head}
502502
503503 eh = Head (head_symbol (head (p)))
504504
505- ids = Vector {EClassId } (undef, ar)
505+ ids = Vector {Id } (undef, ar)
506506 for i in 1 : ar
507507 @inbounds ids[i] = lookup_pat (g, args[i])
508508 ids[i] <= 0 && return 0
@@ -516,20 +516,20 @@ function lookup_pat(g::EGraph{Head}, p::PatTerm)::EClassId where {Head}
516516 end
517517end
518518
519- lookup_pat (g:: EGraph , p:: Any ):: EClassId = lookup (g, ENode (p))
519+ lookup_pat (g:: EGraph , p:: Any ):: Id = lookup (g, ENode (p))
520520
521521struct Extractor{CostFun,Cost}
522522 g:: EGraph
523523 cost_function:: CostFun
524- costs:: Dict{EClassId ,Tuple{Cost,Int64}} # Cost and index in eclass
524+ costs:: Dict{Id ,Tuple{Cost,Int64}} # Cost and index in eclass
525525end
526526
527527"""
528528Given a cost function, extract the expression
529529with the smallest computed cost from an [`EGraph`](@ref)
530530"""
531531function Extractor (g:: EGraph , cost_function:: Function , cost_type = Float64)
532- extractor = Extractor {typeof(cost_function),cost_type} (g, cost_function, Dict {EClassId ,Tuple{cost_type,Int}} ())
532+ extractor = Extractor {typeof(cost_function),cost_type} (g, cost_function, Dict {Id ,Tuple{cost_type,Int}} ())
533533 find_costs! (extractor)
534534 extractor
535535end
@@ -546,13 +546,13 @@ end
546546
547547
548548function (extractor:: Extractor )(root = extractor. g. root)
549- get_node (eclass_id:: EClassId ) = find_best_node (extractor, eclass_id)
549+ get_node (eclass_id:: Id ) = find_best_node (extractor, eclass_id)
550550 # TODO check if infinite cost?
551551 extract_expr_recursive (find_best_node (extractor, root), get_node)
552552end
553553
554554# costs dict stores index of enode. get this enode from the eclass
555- function find_best_node (extractor:: Extractor , eclass_id:: EClassId )
555+ function find_best_node (extractor:: Extractor , eclass_id:: Id )
556556 eclass = extractor. g[eclass_id]
557557 (_, node_index) = extractor. costs[eclass. id]
558558 eclass. nodes[node_index]
0 commit comments