Skip to content

Commit 6d09a71

Browse files
mtfishmanclaude
andcommitted
Drop LFN/BFN identity_messages; route auto-init through QFN only
- Make identity_messages a QFN-only method. The QFN version pairs each ket Index with its bra counterpart explicitly via dual_index_map(fn), so the construction stays correct when multiple link indices share an edge (where commoninds-zip ordering across the two layers is not guaranteed). - Drop the function-tag argument from initialize_cache. Auto-init is a property of (algorithm, network type) only, and QFN is the only form network where identity_messages is canonical, so a single initialize_cache(alg::bp, fn::QFN) specialization is enough. - Change norm_sqr_network(psi) to return QuadraticFormNetwork(psi), matching its name. Route LinearAlgebra.norm_sqr and normalize through it so they pick up the QFN auto-init on loopy graphs. - test_apply.jl is structurally LFN-based (apply's local-env expectations don't generalize to QFN), so it now builds its own LFN messages explicitly via a small _lfn_identity_messages helper. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2fc91cd commit 6d09a71

12 files changed

Lines changed: 97 additions & 157 deletions

src/abstractitensornetwork.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ function inner_network(
517517
return BilinearFormNetwork(A, x, y; kwargs...)
518518
end
519519

520-
norm_sqr_network::AbstractITensorNetwork) = inner_network(ψ, ψ)
520+
norm_sqr_network::AbstractITensorNetwork) = QuadraticFormNetwork(ψ)
521521

522522
#
523523
# Printing

src/contract.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function logscalar(
4747
cache_update_kwargs = (;)
4848
)
4949
if isnothing(cache!)
50-
cache! = Ref(initialize_cache(scalar, alg, tn; cache_construction_kwargs...))
50+
cache! = Ref(initialize_cache(alg, tn; cache_construction_kwargs...))
5151
end
5252

5353
if update_cache

src/environment.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ function environment(
2828
cache_update_kwargs = (;)
2929
)
3030
if isnothing(cache!)
31-
cache! = Ref(
32-
initialize_cache(environment, alg, ptn; cache_construction_kwargs...)
33-
)
31+
cache! = Ref(initialize_cache(alg, ptn; cache_construction_kwargs...))
3432
end
3533

3634
if update_cache

src/expect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function expect(
3030
)
3131
ψIψ = QuadraticFormNetwork(ψ)
3232
if isnothing(cache!)
33-
cache! = Ref(initialize_cache(expect, alg, ψIψ; cache_construction_kwargs...))
33+
cache! = Ref(initialize_cache(alg, ψIψ; cache_construction_kwargs...))
3434
end
3535

3636
if update_cache

src/formnetworks/bilinearformnetwork.jl

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
using Adapt: adapt
22
using DataGraphs: DataGraphs, set_vertex_data!
3-
using Dictionaries: Dictionary, set!
43
using ITensors.NDTensors: datatype, denseblocks
5-
using ITensors: ITensor, Index, Op, commoninds, dag, delta, prime, sim
4+
using ITensors: ITensor, Index, Op, dag, delta, prime, sim
65
using NamedGraphs.GraphsExtensions: disjoint_union
7-
using NamedGraphs.PartitionedGraphs:
8-
PartitionedGraph, QuotientEdge, partitioned_vertices, quotientedges
96

107
default_dual_site_index_map = prime
118
default_dual_link_index_map = sim
@@ -108,50 +105,3 @@ function update(
108105
tensornetwork(blf)[ket_vertex(blf, original_ket_state_vertex)] = ket_state
109106
return blf
110107
end
111-
112-
# Initial BP messages from bra↔ket pairings on each quotient edge.
113-
# Errors when the operator subnet has its own inter-vertex links (the
114-
# multi-site-operator case): those legs are operator-internal and have
115-
# no bra/ket pair, so no canonical identity initialization exists — the
116-
# caller must supply `messages` explicitly.
117-
function identity_messages(fn::BilinearFormNetwork, ptn::PartitionedGraph)
118-
pairings = Dictionary{QuotientEdge, Pair{Vector{Index}, Vector{Index}}}()
119-
tn = tensornetwork(fn)
120-
pv = partitioned_vertices(ptn)
121-
ket_s = ket_vertex_suffix(fn)
122-
for pe in quotientedges(ptn)
123-
src_orig = unique(first.(filter(v -> last(v) == ket_s, pv[parent(src(pe))])))
124-
dst_orig = unique(first.(filter(v -> last(v) == ket_s, pv[parent(dst(pe))])))
125-
for v_from in src_orig, v_to in dst_orig
126-
op_inds = commoninds(
127-
tn[operator_vertex(fn, v_from)], tn[operator_vertex(fn, v_to)]
128-
)
129-
if !isempty(op_inds)
130-
error(
131-
"BilinearFormNetwork: operator-internal cross-Index between " *
132-
"$v_from and $v_to has no bra/ket pair; supply `messages` " *
133-
"explicitly to BP."
134-
)
135-
end
136-
end
137-
for (from_orig, to_orig, e) in (
138-
(src_orig, dst_orig, pe),
139-
(dst_orig, src_orig, reverse(pe)),
140-
)
141-
bras = Index[]
142-
kets = Index[]
143-
for v_from in from_orig, v_to in to_orig
144-
append!(
145-
bras,
146-
commoninds(tn[bra_vertex(fn, v_from)], tn[bra_vertex(fn, v_to)])
147-
)
148-
append!(
149-
kets,
150-
commoninds(tn[ket_vertex(fn, v_from)], tn[ket_vertex(fn, v_to)])
151-
)
152-
end
153-
set!(pairings, e, bras => kets)
154-
end
155-
end
156-
return identity_messages(scalartype(tn), pairings)
157-
end

src/formnetworks/linearformnetwork.jl

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
using DataGraphs: DataGraphs, set_vertex_data!
2-
using Dictionaries: Dictionary, set!
3-
using ITensors: ITensor, Index, commoninds, dag, prime
2+
using ITensors: ITensor, dag, prime
43
using NamedGraphs.GraphsExtensions: disjoint_union
5-
using NamedGraphs.PartitionedGraphs:
6-
PartitionedGraph, QuotientEdge, partitioned_vertices, quotientedges
74

85
default_dual_link_index_map = prime
96

@@ -60,42 +57,3 @@ function update(lf::LinearFormNetwork, original_ket_state_vertex, ket_state::ITe
6057
tensornetwork(lf)[ket_vertex(blf, original_ket_state_vertex)] = ket_state
6158
return lf
6259
end
63-
64-
# Initial BP messages on each quotient edge built from the bra↔ket leg
65-
# pairs crossing that edge: legs are taken from the `from`-partition side
66-
# of each layer (so the bra leg and the ket leg carry opposite directions
67-
# because the bra layer is `dag(dual_link_index_map(ket))`), and the
68-
# forward/reverse messages use opposite-end views so each one's open
69-
# legs face the correct receiving partition when read during BP updates.
70-
# Iteration is over Cartesian products of the original ket-graph vertices
71-
# in each partition, so this works for arbitrary partitionings (per-vertex
72-
# or coarser groupings such as whole columns).
73-
function identity_messages(fn::LinearFormNetwork, ptn::PartitionedGraph)
74-
pairings = Dictionary{QuotientEdge, Pair{Vector{Index}, Vector{Index}}}()
75-
tn = tensornetwork(fn)
76-
pv = partitioned_vertices(ptn)
77-
ket_s = ket_vertex_suffix(fn)
78-
for pe in quotientedges(ptn)
79-
src_orig = unique(first.(filter(v -> last(v) == ket_s, pv[parent(src(pe))])))
80-
dst_orig = unique(first.(filter(v -> last(v) == ket_s, pv[parent(dst(pe))])))
81-
for (from_orig, to_orig, e) in (
82-
(src_orig, dst_orig, pe),
83-
(dst_orig, src_orig, reverse(pe)),
84-
)
85-
bras = Index[]
86-
kets = Index[]
87-
for v_from in from_orig, v_to in to_orig
88-
append!(
89-
bras,
90-
commoninds(tn[bra_vertex(fn, v_from)], tn[bra_vertex(fn, v_to)])
91-
)
92-
append!(
93-
kets,
94-
commoninds(tn[ket_vertex(fn, v_from)], tn[ket_vertex(fn, v_to)])
95-
)
96-
end
97-
set!(pairings, e, bras => kets)
98-
end
99-
end
100-
return identity_messages(scalartype(tn), pairings)
101-
end

src/formnetworks/quadraticformnetwork.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using DataGraphs: DataGraphs, set_vertex_data!, underlying_graph, vertex_data
2-
using NamedGraphs.PartitionedGraphs: PartitionedGraph
2+
using Dictionaries: Dictionary, set!
3+
using ITensors: Index, commoninds, dag
4+
using NamedGraphs.PartitionedGraphs:
5+
PartitionedGraph, QuotientEdge, partitioned_vertices, quotientedges
36

47
default_index_map = prime
58
default_inv_index_map = noprime
@@ -86,8 +89,37 @@ function QuadraticFormNetwork(
8689
return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map)
8790
end
8891

92+
# Build initial BP messages on each quotient edge as `delta(bra, ket)`
93+
# pairs, one per ket link Index crossing the cut. The bra-side counterpart
94+
# of each ket Index is computed explicitly via `dual_index_map(fn)`, so
95+
# the pairing is correct even when multiple link indices share an edge
96+
# (where `commoninds`-zip ordering between layers is not guaranteed).
8997
function identity_messages(fn::QuadraticFormNetwork, ptn::PartitionedGraph)
90-
return identity_messages(bilinear_formnetwork(fn), ptn)
98+
pairings = Dictionary{QuotientEdge, Pair{Vector{Index}, Vector{Index}}}()
99+
tn = tensornetwork(fn)
100+
map_idx = dual_index_map(fn)
101+
pv = partitioned_vertices(ptn)
102+
ket_s = ket_vertex_suffix(fn)
103+
for pe in quotientedges(ptn)
104+
src_orig = unique(first.(filter(v -> last(v) == ket_s, pv[parent(src(pe))])))
105+
dst_orig = unique(first.(filter(v -> last(v) == ket_s, pv[parent(dst(pe))])))
106+
for (from_orig, to_orig, e) in (
107+
(src_orig, dst_orig, pe),
108+
(dst_orig, src_orig, reverse(pe)),
109+
)
110+
kets = Index[]
111+
bras = Index[]
112+
for v_from in from_orig, v_to in to_orig
113+
cur_kets = collect(
114+
commoninds(tn[ket_vertex(fn, v_from)], tn[ket_vertex(fn, v_to)])
115+
)
116+
append!(kets, cur_kets)
117+
append!(bras, dag.(map_idx.(cur_kets)))
118+
end
119+
set!(pairings, e, bras => kets)
120+
end
121+
end
122+
return identity_messages(scalartype(tn), pairings)
91123
end
92124

93125
function update(qf::QuadraticFormNetwork, original_state_vertex, ket_state::ITensor)

src/initialize_cache.jl

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,24 @@
11
using Dictionaries: Dictionary
22
using Graphs: is_tree
33
using ITensors.NDTensors: @Algorithm_str, Algorithm
4-
using ITensors: scalar
5-
using LinearAlgebra: normalize
64
using NamedGraphs.PartitionedGraphs:
75
AbstractPartitionedGraph, PartitionedGraph, quotient_graph
86

9-
# Build a cache appropriate for `f` on `tn` using algorithm `alg`. The
10-
# `f` tag carries the calling context (e.g. `scalar`, `normalize`,
11-
# `expect`, `rescale`, `environment`) so per-purpose methods can inject
12-
# context-specific initialization. The fallback constructs a plain
13-
# `BeliefPropagationCache` with no message defaults.
14-
function initialize_cache(f, alg::Algorithm"bp", tn::AbstractITensorNetwork; kwargs...)
7+
# Build a cache for algorithm `alg` on `tn`. The fallback constructs a
8+
# plain `BeliefPropagationCache` with no message defaults; the
9+
# `QuadraticFormNetwork` specialization injects `identity_messages` on
10+
# loopy quotient graphs (canonical for the structurally ψ-vs-ψ case).
11+
function initialize_cache(alg::Algorithm"bp", tn::AbstractITensorNetwork; kwargs...)
1512
return BeliefPropagationCache(tn; kwargs...)
1613
end
1714

18-
function initialize_cache(
19-
f, alg::Algorithm"bp", ptn::AbstractPartitionedGraph; kwargs...
20-
)
15+
function initialize_cache(alg::Algorithm"bp", ptn::AbstractPartitionedGraph; kwargs...)
2116
return BeliefPropagationCache(ptn; kwargs...)
2217
end
2318

24-
# Core helper: build a BPC on a form network with `identity_messages`
25-
# on loopy quotient graphs (empty messages on trees). Used by the
26-
# per-purpose specializations below where the form network is
27-
# structurally ψ-vs-ψ, so `identity_messages(fn, ptn)` is canonical.
28-
function _bp_cache_identity_messages(
29-
fn::AbstractFormNetwork;
19+
function initialize_cache(
20+
alg::Algorithm"bp",
21+
fn::QuadraticFormNetwork;
3022
partitioned_vertices = default_partitioned_vertices(fn),
3123
messages = nothing
3224
)
@@ -36,27 +28,3 @@ function _bp_cache_identity_messages(
3628
end
3729
return BeliefPropagationCache(ptn; messages)
3830
end
39-
40-
function initialize_cache(
41-
::typeof(scalar), alg::Algorithm"bp", fn::QuadraticFormNetwork; kwargs...
42-
)
43-
return _bp_cache_identity_messages(fn; kwargs...)
44-
end
45-
46-
function initialize_cache(
47-
::typeof(normalize), alg::Algorithm"bp", fn::AbstractFormNetwork; kwargs...
48-
)
49-
return _bp_cache_identity_messages(fn; kwargs...)
50-
end
51-
52-
function initialize_cache(
53-
::typeof(rescale), alg::Algorithm"bp", fn::AbstractFormNetwork; kwargs...
54-
)
55-
return _bp_cache_identity_messages(fn; kwargs...)
56-
end
57-
58-
function initialize_cache(
59-
::typeof(expect), alg::Algorithm"bp", fn::QuadraticFormNetwork; kwargs...
60-
)
61-
return _bp_cache_identity_messages(fn; kwargs...)
62-
end

src/inner.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ end
173173

174174
# TODO: rename `sqnorm` to match https://github.com/JuliaStats/Distances.jl,
175175
# or `norm_sqr` to match `LinearAlgebra.norm_sqr`
176-
LinearAlgebra.norm_sqr::AbstractITensorNetwork; kwargs...) = inner(ψ, ψ; kwargs...)
176+
function LinearAlgebra.norm_sqr::AbstractITensorNetwork; kwargs...)
177+
return scalar(norm_sqr_network(ψ); kwargs...)
178+
end
177179

178180
function LinearAlgebra.norm::AbstractITensorNetwork; kwargs...)
179181
return sqrt(abs(real(norm_sqr(ψ; kwargs...))))

src/normalize.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function rescale(
2121
kwargs...
2222
)
2323
if isnothing(cache!)
24-
cache! = Ref(initialize_cache(rescale, alg, tn; cache_construction_kwargs...))
24+
cache! = Ref(initialize_cache(alg, tn; cache_construction_kwargs...))
2525
end
2626

2727
if update_cache
@@ -55,7 +55,7 @@ end
5555
function LinearAlgebra.normalize(
5656
alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...
5757
)
58-
logn = logscalar(alg, inner_network(tn, tn); kwargs...)
58+
logn = logscalar(alg, norm_sqr_network(tn); kwargs...)
5959
c = inv(exp(logn / (2 * length(vertices(tn)))))
6060
return map(t -> c * t, tn)
6161
end
@@ -68,10 +68,9 @@ function LinearAlgebra.normalize(
6868
cache_update_kwargs = (;),
6969
cache_construction_kwargs = (;)
7070
)
71-
norm_tn = inner_network(tn, tn)
71+
norm_tn = norm_sqr_network(tn)
7272
if isnothing(cache!)
73-
cache! =
74-
Ref(initialize_cache(normalize, alg, norm_tn; cache_construction_kwargs...))
73+
cache! = Ref(initialize_cache(alg, norm_tn; cache_construction_kwargs...))
7574
end
7675
vs = collect(vertices(tn))
7776
verts = vcat([ket_vertex(norm_tn, v) for v in vs], [bra_vertex(norm_tn, v) for v in vs])

0 commit comments

Comments
 (0)