Skip to content

Commit c0cb74d

Browse files
mtfishmanclaude
andauthored
Drop BP default_message; explicit identity_messages discipline (#371)
## Summary - Loopy BP on QN-graded networks was silently NaN-ing because the single-leg `delta(i)` initial messages collapsed half the QN sectors to empty blocks. Replaces the implicit `default_message` family with an explicit two-leg `delta(b, k)` `identity_messages` that pairs bra and ket link inds, keeping the QN sectors aligned through the contractions. - `BeliefPropagationCache(ptn)` now starts with an empty messages dict. Form-network `identity_messages(fn, ptn)` builds the bra/ket pairings from cross-partition vertex pairs and works for both per-vertex and coarser partitionings (e.g. column-grouped 2D grids). - The only auto-init lives on `QuadraticFormNetwork` (structurally ψ-vs-ψ, so identity messages are canonical): `scalar`/`logscalar`/`normalize`/`rescale`/`expect` on a QFN-backed network thread identity messages on a loopy quotient graph. Asymmetric form networks (general `LFN`/`BFN` built from ϕ ≠ ψ) fall through to the generic path, and the caller must supply messages — `identity_messages` isn't well-defined when bra and ket link dims can differ. - Adds a 4-cycle QN-conserving regression test routed through `scalar(QuadraticFormNetwork(ψ); alg = "bp")` that pins the old NaN failure mode. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent eb5958f commit c0cb74d

18 files changed

Lines changed: 200 additions & 145 deletions

src/ITensorNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ include("inner.jl")
5555
include("normalize.jl")
5656
include("expect.jl")
5757
include("environment.jl")
58+
include("initialize_cache.jl")
5859
include("exports.jl")
5960

6061
end

src/abstractitensornetwork.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ function inner_network(
517517
return BilinearFormNetwork(A, x, y; kwargs...)
518518
end
519519

520-
norm_sqr_network::AbstractITensorNetwork) = inner_network(ψ, ψ)
520+
norm_sqr_network::AbstractITensorNetwork) = QuadraticFormNetwork(ψ)
521521

522522
#
523523
# Printing

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using Adapt: Adapt, adapt, adapt_structure
22
using DataGraphs: DataGraphs, underlying_graph, vertex_data
3+
using Dictionaries: Dictionary
34
using Graphs: Graphs, IsDirected, dst, src
4-
using ITensors: commoninds, delta, dir
5+
using ITensors: dir
56
using LinearAlgebra: diag, dot
67
using NDTensors: NDTensors
78
using NamedGraphs.GraphsExtensions: subgraph
@@ -34,14 +35,6 @@ function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
3435
return 1 - f
3536
end
3637

37-
function default_message(datatype::Type{<:AbstractArray}, inds_e)
38-
return [adapt(datatype, denseblocks(delta(i))) for i in inds_e]
39-
end
40-
41-
function default_message(elt::Type{<:Number}, inds_e)
42-
return default_message(Vector{elt}, inds_e)
43-
end
44-
default_messages(ptn::PartitionedGraph) = Dictionary()
4538
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
4639
@traitfn function default_bp_maxiter(g::::IsDirected)
4740
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
@@ -50,11 +43,6 @@ default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertice
5043

5144
partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented()
5245
messages(bpc::AbstractBeliefPropagationCache) = not_implemented()
53-
function default_message(
54-
bpc::AbstractBeliefPropagationCache, edge::QuotientEdge; kwargs...
55-
)
56-
return not_implemented()
57-
end
5846
default_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
5947
default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
6048
Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented()
@@ -162,11 +150,6 @@ function PartitionedGraphs.quotientedge(
162150
return PartitionedGraphs.quotientedge(partitioned_tensornetwork(bpc), edge)
163151
end
164152

165-
function linkinds(bpc::AbstractBeliefPropagationCache, pe::QuotientEdge)
166-
pitn = partitioned_tensornetwork(bpc)
167-
return commoninds(subgraph(pitn, src(pe)), subgraph(pitn, dst(pe)))
168-
end
169-
170153
NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc))
171154

172155
"""
@@ -187,12 +170,11 @@ function update_factor(bpc, vertex, factor)
187170
return bpc
188171
end
189172

190-
function message(bpc::AbstractBeliefPropagationCache, edge::QuotientEdge; kwargs...)
191-
mts = messages(bpc)
192-
return get(() -> default_message(bpc, edge; kwargs...), mts, edge)
173+
function message(bpc::AbstractBeliefPropagationCache, edge::QuotientEdge)
174+
return messages(bpc)[edge]
193175
end
194-
function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...)
195-
return map(edge -> message(bpc, edge; kwargs...), edges)
176+
function messages(bpc::AbstractBeliefPropagationCache, edges)
177+
return map(edge -> message(bpc, edge), edges)
196178
end
197179
function set_messages!(bpc::AbstractBeliefPropagationCache, quotientedges_messages)
198180
ms = messages(bpc)

src/caches/beliefpropagationcache.jl

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DataGraphs: DataGraphs, set_vertex_data!
2+
using Dictionaries: Dictionary
23
using Graphs: IsDirected
34
using ITensors: dir
45
using LinearAlgebra: diag, dot
@@ -9,22 +10,14 @@ using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraph,
910
using SimpleTraits: SimpleTraits, @traitfn, Not
1011
using SplitApplyCombine: group
1112

12-
function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork)
13-
return (; partitioned_vertices = default_partitioned_vertices(ψ))
14-
end
15-
16-
function default_cache_construction_kwargs(alg::Algorithm"bp", pg::PartitionedGraph)
17-
return (;)
18-
end
19-
2013
struct BeliefPropagationCache{V, PV, PTN <: AbstractPartitionedGraph{V, PV}, MTS} <:
2114
AbstractBeliefPropagationCache{V, PV}
2215
partitioned_tensornetwork::PTN
2316
messages::MTS
2417
end
2518

2619
#Constructors...
27-
function BeliefPropagationCache(ptn::PartitionedGraph; messages = default_messages(ptn))
20+
function BeliefPropagationCache(ptn::PartitionedGraph; messages = Dictionary())
2821
return BeliefPropagationCache(ptn, messages)
2922
end
3023

@@ -41,20 +34,12 @@ function BeliefPropagationCache(
4134
return BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
4235
end
4336

44-
function cache(alg::Algorithm"bp", tn; kwargs...)
45-
return BeliefPropagationCache(tn; kwargs...)
46-
end
47-
4837
function partitioned_tensornetwork(bp_cache::BeliefPropagationCache)
4938
return bp_cache.partitioned_tensornetwork
5039
end
5140

5241
messages(bp_cache::BeliefPropagationCache) = bp_cache.messages
5342

54-
function default_message(bp_cache::BeliefPropagationCache, edge::QuotientEdge)
55-
return default_message(datatype(bp_cache), linkinds(bp_cache, edge))
56-
end
57-
5843
function Base.copy(bp_cache::BeliefPropagationCache)
5944
return BeliefPropagationCache(
6045
copy(partitioned_tensornetwork(bp_cache)), copy(messages(bp_cache))

src/contract.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ function logscalar(
4242
alg::Algorithm,
4343
tn::AbstractITensorNetwork;
4444
(cache!) = nothing,
45-
cache_construction_kwargs = default_cache_construction_kwargs(alg, tn),
45+
cache_construction_kwargs = (;),
4646
update_cache = isnothing(cache!),
4747
cache_update_kwargs = (;)
4848
)
4949
if isnothing(cache!)
50-
cache! = Ref(cache(alg, tn; cache_construction_kwargs...))
50+
cache! = Ref(initialize_cache(alg, tn; cache_construction_kwargs...))
5151
end
5252

5353
if update_cache

src/environment.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ function environment(
2424
vertices::Vector;
2525
(cache!) = nothing,
2626
update_cache = isnothing(cache!),
27-
cache_construction_kwargs = default_cache_construction_kwargs(alg, ptn),
27+
cache_construction_kwargs = (;),
2828
cache_update_kwargs = (;)
2929
)
3030
if isnothing(cache!)
31-
cache! = Ref(cache(alg, ptn; cache_construction_kwargs...))
31+
cache! = Ref(initialize_cache(alg, ptn; cache_construction_kwargs...))
3232
end
3333

3434
if update_cache

src/expect.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using Dictionaries: Dictionary, set!
21
using ITensors: Op, contract, op, which_op
32

43
default_expect_alg() = "bp"
@@ -31,7 +30,7 @@ function expect(
3130
)
3231
ψIψ = QuadraticFormNetwork(ψ)
3332
if isnothing(cache!)
34-
cache! = Ref(cache(alg, ψIψ; cache_construction_kwargs...))
33+
cache! = Ref(initialize_cache(alg, ψIψ; cache_construction_kwargs...))
3534
end
3635

3736
if update_cache

src/formnetworks/bilinearformnetwork.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Adapt: adapt
22
using DataGraphs: DataGraphs, set_vertex_data!
33
using ITensors.NDTensors: datatype, denseblocks
4-
using ITensors: ITensor, Op, delta, prime, sim
4+
using ITensors: ITensor, Index, Op, dag, delta, prime, sim
55
using NamedGraphs.GraphsExtensions: disjoint_union
66

77
default_dual_site_index_map = prime

src/formnetworks/linearformnetwork.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using DataGraphs: DataGraphs, set_vertex_data!
2-
using ITensors: ITensor, prime
2+
using ITensors: ITensor, dag, prime
33
using NamedGraphs.GraphsExtensions: disjoint_union
44

55
default_dual_link_index_map = prime

src/formnetworks/quadraticformnetwork.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
using DataGraphs: DataGraphs, set_vertex_data!, underlying_graph, vertex_data
2+
using Dictionaries: Dictionary, set!
3+
using ITensors: ITensor, commoninds, dag, delta
4+
using NamedGraphs.PartitionedGraphs: PartitionedGraph, QuotientEdge, quotientedges
25

36
default_index_map = prime
47
default_inv_index_map = noprime
@@ -85,6 +88,41 @@ function QuadraticFormNetwork(
8588
return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map)
8689
end
8790

91+
# Build initial BP messages on each quotient edge as `delta(bra, ket)`
92+
# pairs, one per ket link Index crossing the cut. The bra-side counterpart
93+
# of each ket Index is computed explicitly via `dual_index_map(fn)`, so
94+
# the pairing is correct even when multiple link indices share an edge
95+
# (where `commoninds`-zip ordering between layers is not guaranteed).
96+
function identity_messages(
97+
fn::QuadraticFormNetwork;
98+
partitioned_vertices = default_partitioned_vertices(fn)
99+
)
100+
ptn = PartitionedGraph(fn, partitioned_vertices)
101+
messages = Dictionary{QuotientEdge, Vector{ITensor}}()
102+
tn = tensornetwork(fn)
103+
elt = scalartype(tn)
104+
map_idx = dual_index_map(fn)
105+
pv = partitioned_vertices
106+
ket_s = ket_vertex_suffix(fn)
107+
for pe in quotientedges(ptn)
108+
src_orig = unique(first.(filter(v -> last(v) == ket_s, pv[parent(src(pe))])))
109+
dst_orig = unique(first.(filter(v -> last(v) == ket_s, pv[parent(dst(pe))])))
110+
for (from_orig, to_orig, e) in (
111+
(src_orig, dst_orig, pe),
112+
(dst_orig, src_orig, reverse(pe)),
113+
)
114+
ms = ITensor[]
115+
for v_from in from_orig, v_to in to_orig
116+
for k in commoninds(tn[ket_vertex(fn, v_from)], tn[ket_vertex(fn, v_to)])
117+
push!(ms, delta(elt, dag(map_idx(k)), k))
118+
end
119+
end
120+
set!(messages, e, ms)
121+
end
122+
end
123+
return messages
124+
end
125+
88126
function update(qf::QuadraticFormNetwork, original_state_vertex, ket_state::ITensor)
89127
state_inds = inds(ket_state)
90128
bra_state = replaceinds(dag(ket_state), state_inds, dual_index_map(qf).(state_inds))

0 commit comments

Comments
 (0)