From c8459471e67c4163539808b8d432032bdde93838 Mon Sep 17 00:00:00 2001 From: Joey Date: Fri, 13 Sep 2024 10:19:49 -0400 Subject: [PATCH 1/6] Blah --- examples/test.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 examples/test.jl diff --git a/examples/test.jl b/examples/test.jl new file mode 100644 index 00000000..6bf1bfcf --- /dev/null +++ b/examples/test.jl @@ -0,0 +1,21 @@ +using ITensorNetworks: IndsNetwork, siteinds, ttn +using ITensorNetworks.ModelHamiltonians: ising +using ITensors: Index, OpSum, terms, sites +using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.GraphsExtensions: rem_vertex + +function filter_terms(H, verts) + H_new = OpSum() + for term in terms(H) + if isempty(filter(v -> v ∈ verts, sites(term))) + H_new += term + end + end + return H_new +end + +g = named_grid((8,1)) +s = siteinds("S=1/2", g) +H = ising(s) +H_mod = filter_terms(H, [(4,1)]) +ttno = ttn(H_mod, s) \ No newline at end of file From 6ff0cd572c947e9b1ed3642e690b43233277beb0 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 17 Oct 2024 14:56:22 +0100 Subject: [PATCH 2/6] Bug fix in current ortho. Change test --- .../alternating_update/region_update.jl | 45 ++++++++----------- .../test_solvers/test_dmrg.jl | 12 ++--- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/src/solvers/alternating_update/region_update.jl b/src/solvers/alternating_update/region_update.jl index b92adc8c..97241c20 100644 --- a/src/solvers/alternating_update/region_update.jl +++ b/src/solvers/alternating_update/region_update.jl @@ -7,36 +7,27 @@ function current_ortho(sweep_plan, which_region_update) if !isa(region, AbstractEdge) && length(region) == 1 return only(current_verts) end - if which_region_update == length(regions) - # look back by one should be sufficient, but may be brittle? - overlapping_vertex = only( - intersect(current_verts, support(regions[which_region_update - 1])) - ) - return overlapping_vertex - else - # look forward - other_regions = filter( - x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end]) + # look forward + other_regions = filter( + x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end]) + ) + # find the first region that has overlapping support with current region + ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) + if isnothing(ind) + # look backward + other_regions = reverse( + filter( + x -> !(issetequal(x, current_verts)), support.(regions[1:(which_region_update - 1)]) + ), ) - # find the first region that has overlapping support with current region ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) - if isnothing(ind) - # look backward - other_regions = reverse( - filter( - x -> !(issetequal(x, current_verts)), - support.(regions[1:(which_region_update - 1)]), - ), - ) - ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) - end - @assert !isnothing(ind) - future_verts = union(support(other_regions[ind])) - # return ortho_ceter as the vertex in current region that does not overlap with following one - overlapping_vertex = intersect(current_verts, future_verts) - nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex)) - return nonoverlapping_vertex end + @assert !isnothing(ind) + future_verts = union(support(other_regions[ind])) + # return ortho_ceter as the vertex in current region that does not overlap with following one + overlapping_vertex = intersect(current_verts, future_verts) + nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex)) + return nonoverlapping_vertex end function region_update( diff --git a/test/test_treetensornetworks/test_solvers/test_dmrg.jl b/test/test_treetensornetworks/test_solvers/test_dmrg.jl index cf8a1caf..004ec561 100644 --- a/test/test_treetensornetworks/test_solvers/test_dmrg.jl +++ b/test/test_treetensornetworks/test_solvers/test_dmrg.jl @@ -1,7 +1,7 @@ @eval module $(gensym()) using DataGraphs: edge_data, vertex_data using Dictionaries: Dictionary -using Graphs: nv, vertices +using Graphs: nv, vertices, uniform_tree using ITensorMPS: ITensorMPS using ITensorNetworks: ITensorNetworks, @@ -19,6 +19,7 @@ using ITensorNetworks.ITensorsExtensions: replace_vertices using ITensorNetworks.ModelHamiltonians: ModelHamiltonians using ITensors: ITensors using KrylovKit: eigsolve +using NamedGraphs: NamedGraph, rename_vertices using NamedGraphs.NamedGraphGenerators: named_comb_tree using Observers: observer using StableRNGs: StableRNG @@ -313,11 +314,12 @@ end nsites = 2 nsweeps = 10 - c = named_comb_tree((3, 2)) - s = siteinds("S=1/2", c) - os = ModelHamiltonians.heisenberg(c) - H = ttn(os, s) rng = StableRNG(1234) + g = NamedGraph(uniform_tree(10)) + g = rename_vertices(v -> (v, 1), g) + s = siteinds("S=1/2", g) + os = ModelHamiltonians.heisenberg(g) + H = ttn(os, s) psi = random_ttn(rng, s; link_space=5) e, psi = dmrg(H, psi; nsweeps, maxdim, nsites) From d0967229e2c9c0d645ad110bb3944566b52d3385 Mon Sep 17 00:00:00 2001 From: Joey Date: Tue, 26 Nov 2024 13:50:41 -0500 Subject: [PATCH 3/6] Fix bug --- src/abstractitensornetwork.jl | 40 ++++++++++++------- .../abstracttreetensornetwork.jl | 12 ++++-- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/abstractitensornetwork.jl b/src/abstractitensornetwork.jl index fc0edce4..afdbbb41 100644 --- a/src/abstractitensornetwork.jl +++ b/src/abstractitensornetwork.jl @@ -19,6 +19,7 @@ using Graphs: using ITensors: ITensors, ITensor, + @Algorithm_str, addtags, combiner, commoninds, @@ -44,7 +45,7 @@ using MacroTools: @capture using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rename_vertices, vertextype -using NDTensors: NDTensors, dim +using NDTensors: NDTensors, dim, Algorithm using SplitApplyCombine: flatten abstract type AbstractITensorNetwork{V} <: AbstractDataGraph{V,ITensor,ITensor} end @@ -585,17 +586,22 @@ function LinearAlgebra.factorize(tn::AbstractITensorNetwork, edge::Pair; kwargs. end # For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged? -function orthogonalize_walk(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...) - return orthogonalize_walk(tn, [edge]; kwargs...) +function gauge_walk( + alg::Algorithm, tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs... +) + return gauge_walk(tn, [edge]; kwargs...) end -function orthogonalize_walk(tn::AbstractITensorNetwork, edge::Pair; kwargs...) - return orthogonalize_walk(tn, edgetype(tn)(edge); kwargs...) +function gauge_walk(alg::Algorithm, tn::AbstractITensorNetwork, edge::Pair; kwargs...) + return gauge_walk(alg::Algorithm, tn, edgetype(tn)(edge); kwargs...) end # For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged? -function orthogonalize_walk( - tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs... +function gauge_walk( + alg::Algorithm"orthogonalize", + tn::AbstractITensorNetwork, + edges::Vector{<:AbstractEdge}; + kwargs..., ) # tn = factorize(tn, edge; kwargs...) # # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))` @@ -612,22 +618,28 @@ function orthogonalize_walk( return tn end -function orthogonalize_walk(tn::AbstractITensorNetwork, edges::Vector{<:Pair}; kwargs...) - return orthogonalize_walk(tn, edgetype(tn).(edges); kwargs...) +function gauge_walk( + alg::Algorithm, tn::AbstractITensorNetwork, edges::Vector{<:Pair}; kwargs... +) + return gauge_walk(alg, tn, edgetype(tn).(edges); kwargs...) end -# Orthogonalize an ITensorNetwork towards a region, treating +# Gauge a ITensorNetwork towards a region, treating # the network as a tree spanned by a spanning tree. -function tree_orthogonalize(ψ::AbstractITensorNetwork, region::Vector) +function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region::Vector) region_center = length(region) != 1 ? first(center(steiner_tree(ψ, region))) : only(region) path = post_order_dfs_edges(bfs_tree(ψ, region_center), region_center) path = filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), path) - return orthogonalize_walk(ψ, path) + return gauge_walk(alg, ψ, path) +end + +function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region) + return tree_gauge(alg, ψ, [region]) end -function tree_orthogonalize(ψ::AbstractITensorNetwork, region) - return tree_orthogonalize(ψ, [region]) +function tree_orthogonalize(ψ::AbstractITensorNetwork, region; kwargs...) + return tree_gauge(Algorithm("orthogonalize"), ψ, region; kwargs...) end # TODO: decide whether to use graph mutating methods when resulting graph is unchanged? diff --git a/src/treetensornetworks/abstracttreetensornetwork.jl b/src/treetensornetworks/abstracttreetensornetwork.jl index 8815b33f..f6c8f49f 100644 --- a/src/treetensornetworks/abstracttreetensornetwork.jl +++ b/src/treetensornetworks/abstracttreetensornetwork.jl @@ -8,7 +8,7 @@ using NamedGraphs.GraphsExtensions: a_star using NamedGraphs: namedgraph_a_star, steiner_tree using IsApprox: IsApprox, Approx -using ITensors: ITensors, @Algorithm_str, directsum, hasinds, permute, plev +using ITensors: ITensors, Algorithm, @Algorithm_str, directsum, hasinds, permute, plev using ITensorMPS: ITensorMPS, linkind, loginner, lognorm, orthogonalize using TupleTools: TupleTools @@ -35,19 +35,23 @@ function set_ortho_region(tn::AbstractTTN, new_region) return error("Not implemented") end -function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...) +function gauge(alg::Algorithm, ttn::AbstractTTN, region::Vector; kwargs...) issetequal(region, ortho_region(ttn)) && return ttn st = steiner_tree(ttn, union(region, ortho_region(ttn))) path = post_order_dfs_edges(st, first(region)) path = filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), path) if !isempty(path) - ttn = typeof(ttn)(orthogonalize_walk(ITensorNetwork(ttn), path; kwargs...)) + ttn = typeof(ttn)(gauge_walk(alg, ITensorNetwork(ttn), path; kwargs...)) end return set_ortho_region(ttn, region) end +function gauge(alg::Algorithm, ttn::AbstractTTN, region; kwargs...) + return gauge(alg, ttn, [region]; kwargs...) +end + function ITensorMPS.orthogonalize(ttn::AbstractTTN, region; kwargs...) - return orthogonalize(ttn, [region]; kwargs...) + return gauge(Algorithm("orthogonalize"), ttn, region; kwargs...) end function tree_orthogonalize(ttn::AbstractTTN, args...; kwargs...) From 9d6c1bcb7aa8b36163df43e95b063dd9cd03761f Mon Sep 17 00:00:00 2001 From: Joey Date: Wed, 19 Mar 2025 12:07:10 -0700 Subject: [PATCH 4/6] File removed --- examples/test.jl | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 examples/test.jl diff --git a/examples/test.jl b/examples/test.jl deleted file mode 100644 index 6bf1bfcf..00000000 --- a/examples/test.jl +++ /dev/null @@ -1,21 +0,0 @@ -using ITensorNetworks: IndsNetwork, siteinds, ttn -using ITensorNetworks.ModelHamiltonians: ising -using ITensors: Index, OpSum, terms, sites -using NamedGraphs.NamedGraphGenerators: named_grid -using NamedGraphs.GraphsExtensions: rem_vertex - -function filter_terms(H, verts) - H_new = OpSum() - for term in terms(H) - if isempty(filter(v -> v ∈ verts, sites(term))) - H_new += term - end - end - return H_new -end - -g = named_grid((8,1)) -s = siteinds("S=1/2", g) -H = ising(s) -H_mod = filter_terms(H, [(4,1)]) -ttno = ttn(H_mod, s) \ No newline at end of file From 03cb8fe8d283274cbd295d1c8fe970468163ae69 Mon Sep 17 00:00:00 2001 From: Joey Date: Fri, 11 Apr 2025 12:44:31 -0400 Subject: [PATCH 5/6] Add Truncation Error Report to apply --- Project.toml | 2 +- src/apply.jl | 27 +++++++++++++++++++++------ test/test_apply.jl | 6 +++++- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index e969820b..069691c6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworks" uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7" authors = ["Matthew Fishman , Joseph Tindall and contributors"] -version = "0.13.4" +version = "0.13.5" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/apply.jl b/src/apply.jl index 5250c806..f45d2b3c 100644 --- a/src/apply.jl +++ b/src/apply.jl @@ -79,7 +79,9 @@ function full_update_bp( return ψᵥ₁, ψᵥ₂ end -function simple_update_bp_full(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_kwargs...) +function simple_update_bp_full( + o, ψ, v⃗; envs, (singular_values!)=nothing, (truncation_error!)=nothing, apply_kwargs... +) cutoff = 10 * eps(real(scalartype(ψ))) envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs) envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs) @@ -116,9 +118,12 @@ function simple_update_bp_full(o, ψ, v⃗; envs, (singular_values!)=nothing, ap v1_inds = [v1_inds; siteinds(ψ, v⃗[1])] v2_inds = [v2_inds; siteinds(ψ, v⃗[2])] e = v⃗[1] => v⃗[2] - ψᵥ₁, ψᵥ₂ = factorize_svd( + ψᵥ₁, ψᵥ₂, spec = factorize_svd( oψ, v1_inds; ortho="none", tags=edge_tag(e), singular_values!, apply_kwargs... ) + if !isnothing(truncation_error!) + truncation_error[] = spec.truncerr + end for inv_sqrt_env_v1 in inv_sqrt_envs_v1 ψᵥ₁ *= dag(inv_sqrt_env_v1) end @@ -129,7 +134,9 @@ function simple_update_bp_full(o, ψ, v⃗; envs, (singular_values!)=nothing, ap end # Reduced version -function simple_update_bp(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_kwargs...) +function simple_update_bp( + o, ψ, v⃗; envs, (singular_values!)=nothing, (truncation_error!)=nothing, apply_kwargs... +) cutoff = 10 * eps(real(scalartype(ψ))) envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs) envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs) @@ -164,7 +171,7 @@ function simple_update_bp(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_k rᵥ₂ = commoninds(Qᵥ₂, Rᵥ₂) oR = apply(o, Rᵥ₁ * Rᵥ₂) e = v⃗[1] => v⃗[2] - Rᵥ₁, Rᵥ₂ = factorize_svd( + Rᵥ₁, Rᵥ₂, spec = factorize_svd( oR, unioninds(rᵥ₁, sᵥ₁); ortho="none", @@ -172,6 +179,9 @@ function simple_update_bp(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_k singular_values!, apply_kwargs..., ) + if !isnothing(truncation_error!) + truncation_error![] = spec.truncerr + end Qᵥ₁ = contract([Qᵥ₁; dag.(inv_sqrt_envs_v1)]) Qᵥ₂ = contract([Qᵥ₂; dag.(inv_sqrt_envs_v2)]) ψᵥ₁ = Qᵥ₁ * Rᵥ₁ @@ -189,6 +199,7 @@ function ITensors.apply( print_fidelity_loss=false, envisposdef=false, (singular_values!)=nothing, + (truncation_error!)=nothing, variational_optimization_only=false, symmetrize=false, reduced=true, @@ -230,9 +241,13 @@ function ITensors.apply( ) else if reduced - ψᵥ₁, ψᵥ₂ = simple_update_bp(o, ψ, v⃗; envs, singular_values!, apply_kwargs...) + ψᵥ₁, ψᵥ₂ = simple_update_bp( + o, ψ, v⃗; envs, singular_values!, truncation_error!, apply_kwargs... + ) else - ψᵥ₁, ψᵥ₂ = simple_update_bp_full(o, ψ, v⃗; envs, singular_values!, apply_kwargs...) + ψᵥ₁, ψᵥ₂ = simple_update_bp_full( + o, ψ, v⃗; envs, singular_values!, truncation_error!, apply_kwargs... + ) end end if normalize diff --git a/test/test_apply.jl b/test/test_apply.jl index 315e0e2e..99f3c9cd 100644 --- a/test/test_apply.jl +++ b/test/test_apply.jl @@ -39,9 +39,10 @@ using Test: @test, @testset envsGBP = environment(bp_cache, [(v1, "bra"), (v1, "ket"), (v2, "bra"), (v2, "ket")]) inner_alg = "exact" ngates = 5 + truncerr_exact, truncerr_bp = Ref(Float64(0)), Ref(Float64(0)) for i in 1:ngates o = op("RandomUnitary", s[v1]..., s[v2]...) - ψOexact = apply(o, ψ; cutoff=1e-16) + ψOexact = apply(o, ψ; (truncation_error!)=truncerr_exact, cutoff=nothing) ψOSBP = apply( o, ψ; @@ -50,6 +51,7 @@ using Test: @test, @testset normalize=true, print_fidelity_loss=true, envisposdef=true, + (truncation_error!)=truncerr_bp, ) ψOv = apply(o, ψv; maxdim=χ, normalize=true) ψOVidal_symm = ITensorNetwork(ψOv) @@ -73,6 +75,8 @@ using Test: @test, @testset fGBP = inner(ψOGBP, ψOexact; alg=inner_alg) / sqrt(inner(ψOexact, ψOexact; alg=inner_alg) * inner(ψOGBP, ψOGBP; alg=inner_alg)) + @test iszero(truncerr_exact[]) + @test !iszero(truncerr_bp[]) @test real(fGBP * conj(fGBP)) >= real(fSBP * conj(fSBP)) @test isapprox(real(fSBP * conj(fSBP)), real(fVidal * conj(fVidal)); atol=1e-3) end From 59eb43940b1911b0294c2c30255faadd0df877f9 Mon Sep 17 00:00:00 2001 From: Joey Date: Fri, 11 Apr 2025 13:44:08 -0400 Subject: [PATCH 6/6] Use Callback --- src/apply.jl | 37 ++++++++++++++----------------------- test/test_apply.jl | 16 ++++++++++------ 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/src/apply.jl b/src/apply.jl index f45d2b3c..6be06756 100644 --- a/src/apply.jl +++ b/src/apply.jl @@ -34,7 +34,7 @@ function full_update_bp( nfullupdatesweeps=10, print_fidelity_loss=false, envisposdef=false, - (singular_values!)=nothing, + callback=Returns(nothing), symmetrize=false, apply_kwargs..., ) @@ -65,7 +65,8 @@ function full_update_bp( apply_kwargs..., ) if symmetrize - Rᵥ₁, Rᵥ₂ = factorize_svd( + singular_values! = Ref(ITensor()) + Rᵥ₁, Rᵥ₂, spec = factorize_svd( Rᵥ₁ * Rᵥ₂, inds(Rᵥ₁); ortho="none", @@ -73,15 +74,14 @@ function full_update_bp( singular_values!, apply_kwargs..., ) + callback(; singular_values=singular_values![], truncation_error=spec.truncerr) end ψᵥ₁ = Qᵥ₁ * Rᵥ₁ ψᵥ₂ = Qᵥ₂ * Rᵥ₂ return ψᵥ₁, ψᵥ₂ end -function simple_update_bp_full( - o, ψ, v⃗; envs, (singular_values!)=nothing, (truncation_error!)=nothing, apply_kwargs... -) +function simple_update_bp_full(o, ψ, v⃗; envs, callback=Returns(nothing), apply_kwargs...) cutoff = 10 * eps(real(scalartype(ψ))) envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs) envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs) @@ -118,12 +118,11 @@ function simple_update_bp_full( v1_inds = [v1_inds; siteinds(ψ, v⃗[1])] v2_inds = [v2_inds; siteinds(ψ, v⃗[2])] e = v⃗[1] => v⃗[2] + singular_values! = Ref(ITensor()) ψᵥ₁, ψᵥ₂, spec = factorize_svd( oψ, v1_inds; ortho="none", tags=edge_tag(e), singular_values!, apply_kwargs... ) - if !isnothing(truncation_error!) - truncation_error[] = spec.truncerr - end + callback(; singular_values=singular_values![], truncation_error=spec.truncerr) for inv_sqrt_env_v1 in inv_sqrt_envs_v1 ψᵥ₁ *= dag(inv_sqrt_env_v1) end @@ -134,9 +133,7 @@ function simple_update_bp_full( end # Reduced version -function simple_update_bp( - o, ψ, v⃗; envs, (singular_values!)=nothing, (truncation_error!)=nothing, apply_kwargs... -) +function simple_update_bp(o, ψ, v⃗; envs, callback=Returns(nothing), apply_kwargs...) cutoff = 10 * eps(real(scalartype(ψ))) envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs) envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs) @@ -171,6 +168,7 @@ function simple_update_bp( rᵥ₂ = commoninds(Qᵥ₂, Rᵥ₂) oR = apply(o, Rᵥ₁ * Rᵥ₂) e = v⃗[1] => v⃗[2] + singular_values! = Ref(ITensor()) Rᵥ₁, Rᵥ₂, spec = factorize_svd( oR, unioninds(rᵥ₁, sᵥ₁); @@ -179,9 +177,7 @@ function simple_update_bp( singular_values!, apply_kwargs..., ) - if !isnothing(truncation_error!) - truncation_error![] = spec.truncerr - end + callback(; singular_values=singular_values![], truncation_error=spec.truncerr) Qᵥ₁ = contract([Qᵥ₁; dag.(inv_sqrt_envs_v1)]) Qᵥ₂ = contract([Qᵥ₂; dag.(inv_sqrt_envs_v2)]) ψᵥ₁ = Qᵥ₁ * Rᵥ₁ @@ -198,8 +194,7 @@ function ITensors.apply( nfullupdatesweeps=10, print_fidelity_loss=false, envisposdef=false, - (singular_values!)=nothing, - (truncation_error!)=nothing, + callback=Returns(nothing), variational_optimization_only=false, symmetrize=false, reduced=true, @@ -235,19 +230,15 @@ function ITensors.apply( nfullupdatesweeps, print_fidelity_loss, envisposdef, - singular_values!, + callback, symmetrize, apply_kwargs..., ) else if reduced - ψᵥ₁, ψᵥ₂ = simple_update_bp( - o, ψ, v⃗; envs, singular_values!, truncation_error!, apply_kwargs... - ) + ψᵥ₁, ψᵥ₂ = simple_update_bp(o, ψ, v⃗; envs, callback, apply_kwargs...) else - ψᵥ₁, ψᵥ₂ = simple_update_bp_full( - o, ψ, v⃗; envs, singular_values!, truncation_error!, apply_kwargs... - ) + ψᵥ₁, ψᵥ₂ = simple_update_bp_full(o, ψ, v⃗; envs, callback, apply_kwargs...) end end if normalize diff --git a/test/test_apply.jl b/test/test_apply.jl index 99f3c9cd..92793a28 100644 --- a/test/test_apply.jl +++ b/test/test_apply.jl @@ -11,7 +11,7 @@ using ITensorNetworks: random_tensornetwork, siteinds, update -using ITensors: ITensors, inner, op +using ITensors: ITensors, ITensor, inner, op using NamedGraphs.NamedGraphGenerators: named_grid using NamedGraphs.PartitionedGraphs: PartitionVertex using SplitApplyCombine: group @@ -39,10 +39,15 @@ using Test: @test, @testset envsGBP = environment(bp_cache, [(v1, "bra"), (v1, "ket"), (v2, "bra"), (v2, "ket")]) inner_alg = "exact" ngates = 5 - truncerr_exact, truncerr_bp = Ref(Float64(0)), Ref(Float64(0)) + truncerr = 0.0 + singular_values = ITensor() + function callback(; singular_values, truncation_error) + truncerr = truncation_error + singular_values = singular_values + end for i in 1:ngates o = op("RandomUnitary", s[v1]..., s[v2]...) - ψOexact = apply(o, ψ; (truncation_error!)=truncerr_exact, cutoff=nothing) + ψOexact = apply(o, ψ; cutoff=nothing) ψOSBP = apply( o, ψ; @@ -51,7 +56,7 @@ using Test: @test, @testset normalize=true, print_fidelity_loss=true, envisposdef=true, - (truncation_error!)=truncerr_bp, + callback, ) ψOv = apply(o, ψv; maxdim=χ, normalize=true) ψOVidal_symm = ITensorNetwork(ψOv) @@ -75,8 +80,7 @@ using Test: @test, @testset fGBP = inner(ψOGBP, ψOexact; alg=inner_alg) / sqrt(inner(ψOexact, ψOexact; alg=inner_alg) * inner(ψOGBP, ψOGBP; alg=inner_alg)) - @test iszero(truncerr_exact[]) - @test !iszero(truncerr_bp[]) + @test !iszero(truncerr) @test real(fGBP * conj(fGBP)) >= real(fSBP * conj(fSBP)) @test isapprox(real(fSBP * conj(fSBP)), real(fVidal * conj(fVidal)); atol=1e-3) end