Skip to content

Commit e58670e

Browse files
mtfishmanclaude
andcommitted
Refactor messagecache.jl: drop AbstractMessageCache supertype
`MessageCache` and `SqrtMessageCache` now subtype `AbstractDataGraph` directly rather than going through a shared `AbstractMessageCache` abstract type. Shared methods are emitted per-type via the existing `for Cache in (:MessageCache, :SqrtMessageCache)` `@eval` loop, which already wrapped the constructors and now covers the rest of the interface: key/val types, `NamedGraphs.add_edge!` / `rem_edge!` / `induced_subgraph_from_vertices`, `DataGraphs` accessors, `==`, the four `copyto!` variants, and `Base.show`. The `copyto!_messagecache` helper drops its first-arg type constraint (was `::AbstractMessageCache`, now untyped — internal helper). Once `AbstractEdgeDataGraph` lands in DataGraphs.jl (PR #121), both types can subtype that and most of the `@eval` loop can collapse into shared methods on the new abstract type. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent b6f824a commit e58670e

1 file changed

Lines changed: 100 additions & 114 deletions

File tree

src/beliefpropagation/messagecache.jl

Lines changed: 100 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph
1010
using NamedGraphs: NamedDiGraph, Vertices, convert_vertextype, ordered_vertices,
1111
parent_graph_indices, position_graph, to_graph_index, vertex_positions
1212

13-
abstract type AbstractMessageCache{T, V} <: AbstractDataGraph{V, Nothing, T} end
14-
15-
struct MessageCache{T, V} <: AbstractMessageCache{T, V}
13+
struct MessageCache{T, V} <: AbstractDataGraph{V, Nothing, T}
1614
messages::Dictionary{NamedEdge{V}, T}
1715
underlying_graph::NamedDiGraph{V}
1816
function MessageCache{T, V}(::UndefInitializer, vertices) where {T, V}
@@ -28,7 +26,7 @@ end
2826
# "full" message `M`. Structurally identical to `MessageCache`; the apply-
2927
# operator BP path dispatches on the type to use the messages as gauge
3028
# factors directly and skip the sqrt-via-eigh step.
31-
struct SqrtMessageCache{T, V} <: AbstractMessageCache{T, V}
29+
struct SqrtMessageCache{T, V} <: AbstractDataGraph{V, Nothing, T}
3230
messages::Dictionary{NamedEdge{V}, T}
3331
underlying_graph::NamedDiGraph{V}
3432
function SqrtMessageCache{T, V}(::UndefInitializer, vertices) where {T, V}
@@ -38,11 +36,16 @@ struct SqrtMessageCache{T, V} <: AbstractMessageCache{T, V}
3836
end
3937
end
4038

41-
# Constructors and convenience factories shared between `MessageCache` and
42-
# `SqrtMessageCache`: the storage and graph structure are identical, only the
43-
# semantic interpretation of the message values differs.
39+
# `MessageCache` and `SqrtMessageCache` are sibling concrete types: the storage
40+
# and graph structure are identical, only the semantic interpretation of the
41+
# message values differs. Shared methods are emitted per-type via this loop
42+
# rather than via a shared abstract supertype. Once
43+
# `DataGraphs.AbstractEdgeDataGraph` (DataGraphs.jl#121) lands, both can
44+
# subtype that and most of this loop can fall away.
4445
for Cache in (:MessageCache, :SqrtMessageCache)
4546
@eval begin
47+
# ============================ constructors ===================================== #
48+
4649
function $Cache{T}(::UndefInitializer, vertices) where {T}
4750
return $Cache{T, eltype(vertices)}(undef, vertices)
4851
end
@@ -66,117 +69,120 @@ for Cache in (:MessageCache, :SqrtMessageCache)
6669
end
6770

6871
Base.copy(cache::$Cache) = $Cache(copy(cache.messages))
69-
end
70-
end
71-
72-
messagecache(pairs) = MessageCache(Dict(pairs))
73-
messagecache(f, edges) = messagecache(edge => f(edge) for edge in edges)
7472

75-
sqrtmessagecache(pairs) = SqrtMessageCache(Dict(pairs))
76-
sqrtmessagecache(f, edges) = sqrtmessagecache(edge => f(edge) for edge in edges)
73+
# ============================ key/val types ==================================== #
7774

78-
# compatibility with generic key-val iterables
79-
Base.keytype(c::AbstractMessageCache) = keytype(typeof(c))
80-
Base.keytype(::Type{<:AbstractMessageCache{T, V}}) where {T, V} = NamedEdge{V}
75+
Base.keytype(c::$Cache) = keytype(typeof(c))
76+
Base.keytype(::Type{<:$Cache{T, V}}) where {T, V} = NamedEdge{V}
77+
Base.valtype(c::$Cache) = valtype(typeof(c))
78+
Base.valtype(::Type{<:$Cache{T}}) where {T} = T
79+
Base.keys(cache::$Cache) = edges(cache)
8180

82-
Base.valtype(c::AbstractMessageCache) = valtype(typeof(c))
83-
Base.valtype(::Type{<:AbstractMessageCache{T}}) where {T} = T
81+
# ============================ NamedGraphs interface ============================ #
8482

85-
Base.keys(cache::AbstractMessageCache) = edges(cache)
83+
function NamedGraphs.add_edge!(c::$Cache, edge)
84+
add_edge!(c.underlying_graph, edge)
85+
return c
86+
end
8687

87-
# ================================ NamedGraphs interface ================================= #
88-
function NamedGraphs.add_edge!(c::AbstractMessageCache, edge)
89-
add_edge!(c.underlying_graph, edge)
90-
return c
91-
end
88+
function NamedGraphs.rem_edge!(c::$Cache, edge)
89+
delete!(c.messages, to_graph_index(c, edge))
90+
rem_edge!(c.underlying_graph, edge)
91+
return c
92+
end
9293

93-
function NamedGraphs.rem_edge!(c::AbstractMessageCache, edge)
94-
delete!(c.messages, to_graph_index(c, edge))
95-
rem_edge!(c.underlying_graph, edge)
96-
return c
97-
end
94+
function NamedGraphs.induced_subgraph_from_vertices(cache::$Cache, subvertices)
95+
# TODO: once we have `subgraph_edges` in `NamedGraphs`, simplify this.
96+
underlying_subgraph, vlist =
97+
Graphs.induced_subgraph(cache.underlying_graph, subvertices)
98+
assigned = v -> isassigned(cache, v)
99+
assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph))
100+
messages = getindices(cache.messages, Indices(assigned_subedges))
101+
return $Cache(messages), vlist
102+
end
98103

99-
# ================================= DataGraphs interface ================================= #
104+
# ============================ DataGraphs interface ============================= #
100105

101-
DataGraphs.underlying_graph(cache::AbstractMessageCache) = cache.underlying_graph
106+
DataGraphs.underlying_graph(cache::$Cache) = cache.underlying_graph
107+
DataGraphs.is_vertex_assigned(::$Cache, _) = false
108+
DataGraphs.is_edge_assigned(c::$Cache, edge) = haskey(c.messages, edge)
102109

103-
DataGraphs.is_vertex_assigned(::AbstractMessageCache, _) = false
104-
DataGraphs.is_edge_assigned(c::AbstractMessageCache, edge) = haskey(c.messages, edge)
110+
function DataGraphs.get_edge_data(c::$Cache, edge::AbstractEdge)
111+
return c.messages[edge]
112+
end
113+
function DataGraphs.set_edge_data!(c::$Cache, val, edge)
114+
return set!(c.messages, edge, val)
115+
end
105116

106-
function DataGraphs.get_edge_data(c::AbstractMessageCache, edge::AbstractEdge)
107-
return c.messages[edge]
108-
end
109-
function DataGraphs.set_edge_data!(c::AbstractMessageCache, val, edge)
110-
return set!(c.messages, edge, val)
111-
end
117+
# ============================ equality ========================================= #
112118

113-
function Base.:(==)(cache1::C, cache2::C) where {C <: AbstractMessageCache}
114-
ug1 = cache1.underlying_graph
115-
ug2 = cache2.underlying_graph
119+
function Base.:(==)(c1::$Cache, c2::$Cache)
120+
return c1.underlying_graph == c2.underlying_graph && c1.messages == c2.messages
121+
end
116122

117-
ms1 = cache1.messages
118-
ms2 = cache2.messages
123+
# ============================ copyto! ========================================== #
124+
125+
# see: copyto!(dest, src) for analogous behaviour to 2 argument method
126+
# see: copyto!(dest, Rdest::CartesianIndices, src, Rsrc::CartesianIndices)
127+
# for analogous behaviour to 3 argument method.
128+
# TODO: these can be made generic for `AbstractDataGraph` in `DataGraphs.jl`.
129+
function Base.copyto!(
130+
cache_dst::$Cache, cache_src::AbstractDataGraph, inds = nothing
131+
)
132+
copyto!_messagecache(cache_dst, edge_data(cache_src), inds)
133+
return cache_dst
134+
end
119135

120-
return (ug1 == ug2 && ms1 == ms2)
121-
end
136+
function Base.copyto!(
137+
cache_dst::$Cache, dictionary_src::Dictionary, inds = nothing
138+
)
139+
copyto!_messagecache(cache_dst, dictionary_src, inds)
140+
return cache_dst
141+
end
122142

123-
function NamedGraphs.induced_subgraph_from_vertices(cache::MessageCache, subvertices)
124-
# TODO: once we have `subgraph_edges` in `NamedGraphs`, simplify this.
125-
underlying_subgraph, vlist =
126-
Graphs.induced_subgraph(cache.underlying_graph, subvertices)
143+
function Base.copyto!(
144+
cache_dst::$Cache, dict_src::Dict, inds = keys(dict_src)
145+
)
146+
for key in inds
147+
cache_dst[key] = dict_src[key]
148+
end
149+
return cache_dst
150+
end
127151

128-
assigned = v -> isassigned(cache, v)
152+
# ============================ printing ========================================= #
153+
154+
# TODO: This is the definition for the proposed `DataGraphs.AbstractEdgeDataGraph`.
155+
function Base.show(io::IO, mime::MIME"text/plain", graph::$Cache)
156+
println(io, "$(typeof(graph)) with $(nv(graph)) vertices:")
157+
show(io, mime, vertices(graph))
158+
println(io, "\n")
159+
println(io, "and $(ne(graph)) edge(s):")
160+
for e in edges(graph)
161+
show(io, mime, e)
162+
println(io)
163+
end
164+
println(io)
165+
println(io, "with edge data:")
166+
show(io, mime, edge_data(graph))
167+
return nothing
168+
end
129169

130-
assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph))
170+
Base.show(io::IO, graph::$Cache) = show(io, MIME"text/plain"(), graph)
171+
end
172+
end
131173

132-
messages = getindices(cache.messages, Indices(assigned_subedges))
174+
messagecache(pairs) = MessageCache(Dict(pairs))
175+
messagecache(f, edges) = messagecache(edge => f(edge) for edge in edges)
133176

134-
return MessageCache(messages), vlist
135-
end
177+
sqrtmessagecache(pairs) = SqrtMessageCache(Dict(pairs))
178+
sqrtmessagecache(f, edges) = sqrtmessagecache(edge => f(edge) for edge in edges)
136179

137-
# see: copyto!(dest, src) for analogous behaviour to 2 argument method
138-
# see: copyto!(dest, Rdest::CartesianIndices, src, Rsrc::CartesianIndices)
139-
# for analogous behaviour to 3 argument method.
140-
# TODO: these can be made generic for `AbtractDataGraph` in `DataGraphs.jl`
141-
function copyto!_messagecache(
142-
cache_dst::AbstractMessageCache,
143-
cache_src,
144-
inds = nothing
145-
)
180+
function copyto!_messagecache(cache_dst, cache_src, inds = nothing)
146181
inds = isnothing(inds) ? Indices(keys(cache_src)) : Indices(inds)
147182
view(edge_data(cache_dst), inds) .= view(cache_src, inds)
148183
return cache_dst
149184
end
150185

151-
function Base.copyto!(
152-
cache_dst::AbstractMessageCache,
153-
cache_src::AbstractDataGraph,
154-
inds = nothing
155-
)
156-
copyto!_messagecache(cache_dst, edge_data(cache_src), inds)
157-
return cache_dst
158-
end
159-
160-
function Base.copyto!(
161-
cache_dst::AbstractMessageCache,
162-
dictionary_src::Dictionary,
163-
inds = nothing
164-
)
165-
copyto!_messagecache(cache_dst, dictionary_src, inds)
166-
return cache_dst
167-
end
168-
169-
function Base.copyto!(
170-
cache_dst::AbstractMessageCache,
171-
dict_src::Dict,
172-
inds = keys(dict_src)
173-
)
174-
for key in inds
175-
cache_dst[key] = dict_src[key]
176-
end
177-
return cache_dst
178-
end
179-
180186
# ===================================== contraction ====================================== #
181187

182188
function incoming_messages(cache::AbstractGraph, pair::Pair)
@@ -274,23 +280,3 @@ function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_roo
274280
end
275281
return rv
276282
end
277-
278-
# ======================================= printing ======================================= #
279-
280-
# TODO: This is the definition for the proposed `DataGraphs.AbstractEdgeDataGraph`.
281-
function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractMessageCache)
282-
println(io, "$(typeof(graph)) with $(nv(graph)) vertices:")
283-
show(io, mime, vertices(graph))
284-
println(io, "\n")
285-
println(io, "and $(ne(graph)) edge(s):")
286-
for e in edges(graph)
287-
show(io, mime, e)
288-
println(io)
289-
end
290-
println(io)
291-
println(io, "with edge data:")
292-
show(io, mime, edge_data(graph))
293-
return nothing
294-
end
295-
296-
Base.show(io::IO, graph::AbstractMessageCache) = show(io, MIME"text/plain"(), graph)

0 commit comments

Comments
 (0)