Skip to content

Commit ea8fb04

Browse files
committed
Cleanup AbstractBeliefPropagationCache interface.
1 parent d0f7da6 commit ea8fb04

1 file changed

Lines changed: 36 additions & 43 deletions

File tree

src/beliefpropagation/abstractbeliefpropagationcache.jl

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,32 @@ using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data
22
using Graphs: AbstractEdge, AbstractGraph
33
using NamedGraphs.GraphsExtensions: boundary_edges
44
using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent
5+
using NamedGraphs: AbstractEdges, AbstractVertices, to_graph_index
56

6-
messages(bp_cache::AbstractGraph) = edge_data(bp_cache)
7-
messages(bp_cache::AbstractGraph, edges) = map(e -> message(bp_cache, e), edges)
7+
messages(bpc::AbstractDataGraph) = edge_data(bpc)
8+
messages(bpc::AbstractGraph, edges) = map(e -> message(bpc, e), edges)
89

9-
message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge]
10+
message(bpc::AbstractGraph, edge) = messages(bpc)[edge]
1011

11-
deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented()
12-
function deletemessage!(bp_cache::AbstractDataGraph, edge)
13-
ms = messages(bp_cache)
14-
delete!(ms, edge)
15-
return bp_cache
16-
end
12+
deletemessage!(bpc::AbstractGraph, edge) = not_implemented()
1713

18-
function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache))
14+
function deletemessages!(bpc::AbstractGraph, edges = edges(bpc))
1915
for e in edges
20-
deletemessage!(bp_cache, e)
16+
deletemessage!(bpc, e)
2117
end
22-
return bp_cache
18+
return bpc
2319
end
2420

25-
setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented()
26-
function setmessage!(bp_cache::AbstractDataGraph, edge, message)
27-
setindex!(bp_cache, message, edge)
28-
return bp_cache
29-
end
30-
function setmessage!(bp_cache::QuotientView, edge, message)
31-
setmessages!(parent(bp_cache), QuotientEdge(edge), message)
32-
return bp_cache
21+
# Fallback; assume `setindex!` is implemented.
22+
function setmessage!(bpc::AbstractGraph, edge, message)
23+
bpc[edge] = message
24+
return bpc
3325
end
34-
35-
function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message)
36-
for e in edges(bp_cache, edge)
37-
setmessage!(parent(bp_cache), e, message[e])
26+
function setmessages!(bpc::AbstractGraph, messages)
27+
for (key, val) in messages
28+
setmessage!(bpc, key, val)
3829
end
39-
return bp_cache
30+
return bpc
4031
end
4132
function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges)
4233
for e in edges
@@ -45,50 +36,51 @@ function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges)
4536
return bpc_dst
4637
end
4738

48-
factors(bpc::AbstractGraph) = vertex_data(bpc)
49-
factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices]
50-
factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex])
39+
factors(bpc::AbstractDataGraph) = vertex_data(bpc)
40+
factors(bpc::AbstractGraph, vertices) = map(v -> factor(bpc, v), vertices)
5141

5242
factor(bpc::AbstractGraph, vertex) = bpc[vertex]
5343

54-
setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented()
55-
function setfactor!(bpc::AbstractDataGraph, vertex, factor)
56-
fs = factors(bpc)
57-
setindex!(fs, vertex, factor)
44+
function setfactor!(bpc::AbstractGraph, vertex, factor)
45+
bpc[vertex] = factor
5846
return bpc
5947
end
6048

61-
function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge; alg = "exact")
49+
# Internal convenience only
50+
_graph_index_scalar(bpc::AbstractGraph, vertex) = vertex_scalar(bpc, vertex)
51+
_graph_index_scalar(bpc::AbstractGraph, edge::AbstractEdge) = edge_scalar(bpc, edge)
52+
53+
function edge_scalar(bp_cache::AbstractGraph, edge; kwargs...)
6254
# Make generic to deal with the possibilty of multiple messages.
6355
m1s = messages(bp_cache, [edge])
6456
m2s = messages(bp_cache, [reverse(edge)])
65-
return contract_network(vcat(m1s, m2s); alg)[]
57+
return contract_network(vcat(m1s, m2s); kwargs...)[]
6658
end
6759

68-
function region_scalar(bp_cache::AbstractGraph, vertex; alg = "exact")
60+
function vertex_scalar(bp_cache::AbstractGraph, vertex; kwargs...)
6961
messages = incoming_messages(bp_cache, vertex)
70-
state = factors(bp_cache, vertex)
62+
state = factors(bp_cache, [vertex])
7163

72-
return contract_network(vcat(messages, state); alg)[]
64+
return contract_network(vcat(messages, state); kwargs...)[]
7365
end
7466

7567
message_type(bpc::AbstractGraph) = message_type(typeof(bpc))
7668
message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G))
7769
message_type(type::Type{<:AbstractDataGraph}) = edge_data_type(type)
7870

7971
function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache))
80-
return map(v -> region_scalar(bp_cache, v), vertices)
72+
return map(v -> vertex_scalar(bp_cache, v), vertices)
8173
end
8274

8375
function edge_scalars(
8476
bp_cache::AbstractGraph,
8577
edges = edges(undirected_graph(underlying_graph(bp_cache)))
8678
)
87-
return map(e -> region_scalar(bp_cache, e), edges)
79+
return map(e -> edge_scalar(bp_cache, e), edges)
8880
end
8981

90-
function scalar_factors_quotient(bp_cache::AbstractGraph)
91-
return vertex_scalars(bp_cache), edge_scalars(bp_cache)
82+
function region_scalar(bpc::AbstractGraph, region)
83+
return mapreduce(ind -> _graph_index_scalar(bpc, ind), *, region)
9284
end
9385

9486
function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = [])
@@ -127,8 +119,9 @@ factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD
127119
message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc))
128120
message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED
129121

130-
function logscalar(bp_cache::AbstractBeliefPropagationCache)
131-
numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache)
122+
function logscalar(bpc::AbstractBeliefPropagationCache)
123+
numerator_terms = vertex_scalars(bpc)
124+
denominator_terms = edge_scalars(bpc)
132125

133126
if any(t -> real(t) < 0, numerator_terms)
134127
numerator_terms = complex.(numerator_terms)

0 commit comments

Comments
 (0)