@@ -2,41 +2,32 @@ using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data
22using Graphs: AbstractEdge, AbstractGraph
33using NamedGraphs. GraphsExtensions: boundary_edges
44using NamedGraphs. PartitionedGraphs: QuotientEdge, QuotientView, parent
5+ using NamedGraphs: AbstractEdges, AbstractVertices, to_graph_index
56
6- messages (bp_cache :: AbstractGraph ) = edge_data (bp_cache )
7- messages (bp_cache :: AbstractGraph , edges) = map (e -> message (bp_cache , e), edges)
7+ messages (bpc :: AbstractDataGraph ) = edge_data (bpc )
8+ messages (bpc :: AbstractGraph , edges) = map (e -> message (bpc , e), edges)
89
9- message (bp_cache :: AbstractGraph , edge:: AbstractEdge ) = messages (bp_cache )[edge]
10+ message (bpc :: AbstractGraph , edge) = messages (bpc )[edge]
1011
11- deletemessage! (bp_cache:: AbstractGraph , edge) = not_implemented ()
12- function deletemessage! (bp_cache:: AbstractDataGraph , edge)
13- ms = messages (bp_cache)
14- delete! (ms, edge)
15- return bp_cache
16- end
12+ deletemessage! (bpc:: AbstractGraph , edge) = not_implemented ()
1713
18- function deletemessages! (bp_cache :: AbstractGraph , edges = edges (bp_cache ))
14+ function deletemessages! (bpc :: AbstractGraph , edges = edges (bpc ))
1915 for e in edges
20- deletemessage! (bp_cache , e)
16+ deletemessage! (bpc , e)
2117 end
22- return bp_cache
18+ return bpc
2319end
2420
25- setmessage! (bp_cache:: AbstractGraph , edge, message) = not_implemented ()
26- function setmessage! (bp_cache:: AbstractDataGraph , edge, message)
27- setindex! (bp_cache, message, edge)
28- return bp_cache
29- end
30- function setmessage! (bp_cache:: QuotientView , edge, message)
31- setmessages! (parent (bp_cache), QuotientEdge (edge), message)
32- return bp_cache
21+ # Fallback; assume `setindex!` is implemented.
22+ function setmessage! (bpc:: AbstractGraph , edge, message)
23+ bpc[edge] = message
24+ return bpc
3325end
34-
35- function setmessages! (bp_cache:: AbstractGraph , edge:: QuotientEdge , message)
36- for e in edges (bp_cache, edge)
37- setmessage! (parent (bp_cache), e, message[e])
26+ function setmessages! (bpc:: AbstractGraph , messages)
27+ for (key, val) in messages
28+ setmessage! (bpc, key, val)
3829 end
39- return bp_cache
30+ return bpc
4031end
4132function setmessages! (bpc_dst:: AbstractGraph , bpc_src:: AbstractGraph , edges)
4233 for e in edges
@@ -45,50 +36,51 @@ function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges)
4536 return bpc_dst
4637end
4738
48- factors (bpc:: AbstractGraph ) = vertex_data (bpc)
49- factors (bpc:: AbstractGraph , vertices:: Vector ) = [factor (bpc, v) for v in vertices]
50- factors (bpc:: AbstractGraph{V} , vertex:: V ) where {V} = factors (bpc, V[vertex])
39+ factors (bpc:: AbstractDataGraph ) = vertex_data (bpc)
40+ factors (bpc:: AbstractGraph , vertices) = map (v -> factor (bpc, v), vertices)
5141
5242factor (bpc:: AbstractGraph , vertex) = bpc[vertex]
5343
54- setfactor! (bpc:: AbstractGraph , vertex, factor) = not_implemented ()
55- function setfactor! (bpc:: AbstractDataGraph , vertex, factor)
56- fs = factors (bpc)
57- setindex! (fs, vertex, factor)
44+ function setfactor! (bpc:: AbstractGraph , vertex, factor)
45+ bpc[vertex] = factor
5846 return bpc
5947end
6048
61- function region_scalar (bp_cache:: AbstractGraph , edge:: AbstractEdge ; alg = " exact" )
49+ # Internal convenience only
50+ _graph_index_scalar (bpc:: AbstractGraph , vertex) = vertex_scalar (bpc, vertex)
51+ _graph_index_scalar (bpc:: AbstractGraph , edge:: AbstractEdge ) = edge_scalar (bpc, edge)
52+
53+ function edge_scalar (bp_cache:: AbstractGraph , edge; kwargs... )
6254 # Make generic to deal with the possibilty of multiple messages.
6355 m1s = messages (bp_cache, [edge])
6456 m2s = messages (bp_cache, [reverse (edge)])
65- return contract_network (vcat (m1s, m2s); alg )[]
57+ return contract_network (vcat (m1s, m2s); kwargs ... )[]
6658end
6759
68- function region_scalar (bp_cache:: AbstractGraph , vertex; alg = " exact " )
60+ function vertex_scalar (bp_cache:: AbstractGraph , vertex; kwargs ... )
6961 messages = incoming_messages (bp_cache, vertex)
70- state = factors (bp_cache, vertex)
62+ state = factors (bp_cache, [ vertex] )
7163
72- return contract_network (vcat (messages, state); alg )[]
64+ return contract_network (vcat (messages, state); kwargs ... )[]
7365end
7466
7567message_type (bpc:: AbstractGraph ) = message_type (typeof (bpc))
7668message_type (G:: Type{<:AbstractGraph} ) = eltype (Base. promote_op (messages, G))
7769message_type (type:: Type{<:AbstractDataGraph} ) = edge_data_type (type)
7870
7971function vertex_scalars (bp_cache:: AbstractGraph , vertices = vertices (bp_cache))
80- return map (v -> region_scalar (bp_cache, v), vertices)
72+ return map (v -> vertex_scalar (bp_cache, v), vertices)
8173end
8274
8375function edge_scalars (
8476 bp_cache:: AbstractGraph ,
8577 edges = edges (undirected_graph (underlying_graph (bp_cache)))
8678 )
87- return map (e -> region_scalar (bp_cache, e), edges)
79+ return map (e -> edge_scalar (bp_cache, e), edges)
8880end
8981
90- function scalar_factors_quotient (bp_cache :: AbstractGraph )
91- return vertex_scalars (bp_cache ), edge_scalars (bp_cache )
82+ function region_scalar (bpc :: AbstractGraph , region )
83+ return mapreduce (ind -> _graph_index_scalar (bpc, ind ), * , region )
9284end
9385
9486function incoming_messages (bp_cache:: AbstractGraph , vertices; ignore_edges = [])
@@ -127,8 +119,9 @@ factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD
127119message_type (bpc:: AbstractBeliefPropagationCache ) = message_type (typeof (bpc))
128120message_type (:: Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}} ) where {ED} = ED
129121
130- function logscalar (bp_cache:: AbstractBeliefPropagationCache )
131- numerator_terms, denominator_terms = scalar_factors_quotient (bp_cache)
122+ function logscalar (bpc:: AbstractBeliefPropagationCache )
123+ numerator_terms = vertex_scalars (bpc)
124+ denominator_terms = edge_scalars (bpc)
132125
133126 if any (t -> real (t) < 0 , numerator_terms)
134127 numerator_terms = complex .(numerator_terms)
0 commit comments