From a572e07fe91121f3cdabae56b0761c1abde3ae2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Sun, 12 Apr 2026 14:23:06 +0200 Subject: [PATCH 01/24] parametrize SimpleUpdate{Trunc} --- src/algorithms/time_evolution/apply_mpo.jl | 8 ++++---- src/algorithms/time_evolution/simpleupdate.jl | 11 ++++++----- src/algorithms/time_evolution/simpleupdate3site.jl | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/algorithms/time_evolution/apply_mpo.jl b/src/algorithms/time_evolution/apply_mpo.jl index b7ae6f42f..e6809c56a 100644 --- a/src/algorithms/time_evolution/apply_mpo.jl +++ b/src/algorithms/time_evolution/apply_mpo.jl @@ -238,11 +238,11 @@ function _get_allprojs( end return _proj_from_RL(Rs[i], Ls[i]; trunc) end - Pas = map(Base.Fix2(getindex, 1), projs_errs) - wts = map(Base.Fix2(getindex, 2), projs_errs) - Pbs = map(Base.Fix2(getindex, 3), projs_errs) + Pas = getindex.(projs_errs, 1) + wts = getindex.(projs_errs, 2) + Pbs = getindex.(projs_errs, 3) # local truncation error on each bond - ϵs = map(Base.Fix2(getindex, 4), projs_errs) + ϵs = getindex.(projs_errs, 4) return Pas, Pbs, wts, ϵs end diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index a28f86d8b..2e7a3fc09 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -7,19 +7,19 @@ Algorithm struct for simple update (SU) of InfinitePEPS or InfinitePEPO. $(TYPEDFIELDS) """ -@kwdef struct SimpleUpdate <: TimeEvolution +@kwdef struct SimpleUpdate{T <: TruncationStrategy} <: TimeEvolution "Truncation strategy for bonds updated by Trotter gates" - trunc::TruncationStrategy + trunc::T "When true (or false), the Trotter gate is `exp(-H dt)` (or `exp(-iH dt)`)" imaginary_time::Bool = true "When true, force decomposition of nearest neighbor gates to MPOs." force_mpo::Bool = false "When true, assume bipartite unit cell structure" bipartite::Bool = false - "(Only applicable to InfinitePEPO) + "(Only applicable to InfinitePEPO) When true, the PEPO is regarded as a purified PEPS, and updated as `|ρ(t + dt)⟩ = exp(-H dt/2) |ρ(t)⟩`. - When false, the PEPO is updated as + When false, the PEPO is updated as `ρ(t + dt) = exp(-H dt/2) ρ(t) exp(-H dt/2)`." purified::Bool = true end @@ -97,7 +97,8 @@ function _su_iter!( bond, rev = _nn_bondrev(sites..., (Nr, Nc)) A, B = _bond_rotation.(Ms, bond[1], rev; inv = false) # apply gate - ϵ, s = 0.0, nothing + ϵ = 0.0 + s = nothing gate_axs = alg.purified ? (1:1) : (1:2) for gate_ax in gate_axs X, a, b, Y = _qr_bond(A, B; gate_ax, positive = true) diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 4b091c830..6c684ceef 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -176,7 +176,7 @@ function _su_iter!( _flip_virtuals!(Ms, flips) # apply gate MPOs and truncate gate_axs = alg.purified ? (1:1) : (1:2) - wts, ϵs = nothing, nothing + global wts, ϵs for gate_ax in gate_axs _apply_gatempo!(Ms, gate; gate_ax) if isa(state, InfinitePEPO) From 9e0222b4f7c0e2cbedb9fd4a062973e8d28ccd33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Sun, 12 Apr 2026 16:21:34 +0200 Subject: [PATCH 02/24] improve stability --- src/algorithms/time_evolution/apply_gate.jl | 22 +++++----- src/algorithms/time_evolution/simpleupdate.jl | 30 ++++++------- .../time_evolution/simpleupdate3site.jl | 43 ++++++++----------- 3 files changed, 45 insertions(+), 50 deletions(-) diff --git a/src/algorithms/time_evolution/apply_gate.jl b/src/algorithms/time_evolution/apply_gate.jl index 363b23f8d..8aaf81e7a 100644 --- a/src/algorithms/time_evolution/apply_gate.jl +++ b/src/algorithms/time_evolution/apply_gate.jl @@ -18,6 +18,15 @@ function _apply_sitegate( return a′ end +function _get_biperms(::PEPSTensor, ::Integer) + return ((2, 4, 5), (1, 3)), ((2, 3, 4), (1, 5)), (1, 4, 2, 3), ntuple(identity, 4) +end +function _get_biperms(::PEPOTensor, gate_ax::Integer) + if gate_ax == 1 + return ((2, 3, 5, 6), (1, 4)), ((2, 3, 4, 5), (1, 6)), (1, 2, 5, 3, 4), ntuple(identity, 5) + end + return ((1, 3, 5, 6), (2, 4)), ((1, 3, 4, 5), (2, 6)), (1, 2, 5, 3, 4), ntuple(identity, 5) +end """ $(SIGNATURES) @@ -48,17 +57,10 @@ When `A`, `B` are PEPOTensors, 5 1 4 1 4 1 ``` """ -function _qr_bond(A::PT, B::PT; gate_ax::Int = 1, kwargs...) where {PT <: Union{PEPSTensor, PEPOTensor}} +function _qr_bond(A::PT, B::PT; gate_ax::Integer = 1, kwargs...) where {PT <: Union{PEPSTensor, PEPOTensor}} @assert 1 <= gate_ax <= numout(A) - permA, permB, permX, permY = if A isa PEPSTensor - ((2, 4, 5), (1, 3)), ((2, 3, 4), (1, 5)), (1, 4, 2, 3), Tuple(1:4) - else - if gate_ax == 1 - ((2, 3, 5, 6), (1, 4)), ((2, 3, 4, 5), (1, 6)), (1, 2, 5, 3, 4), Tuple(1:5) - else - ((1, 3, 5, 6), (2, 4)), ((1, 3, 4, 5), (2, 6)), (1, 2, 5, 3, 4), Tuple(1:5) - end - end + permA, permB, permX, permY = _get_biperms(A, gate_ax) + X, a = left_orth!(permute(A, permA; copy = true); kwargs...) Y, b = left_orth!(permute(B, permB; copy = true); kwargs...) X, Y = permute(X, permX), permute(Y, permY) diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index 2e7a3fc09..a5a62096e 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -84,7 +84,7 @@ end Simple update optimized for nearest neighbor gates utilizing reduced bond tensors with the physical leg. """ -function _su_iter!( +function _su_iter_gate!( state::InfiniteState, gate::NNGate, env::SUWeight, sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate ) @@ -92,7 +92,6 @@ function _su_iter!( truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc)) @assert length(sites) == 2 && length(truncs) == 1 Ms, open_vaxs, = _get_cluster(state, sites, env; permute = false) - normalize!.(Ms, Inf) # rotate bond, rev = _nn_bondrev(sites..., (Nr, Nc)) A, B = _bond_rotation.(Ms, bond[1], rev; inv = false) @@ -144,24 +143,25 @@ function su_iter( elseif length(sites) == 2 (d, r, c), = _nn_bondrev(sites..., (Nr, Nc)) alg.bipartite && r > 1 && continue - ϵ′ = _su_iter!(state2, gate, env2, sites, alg) + ϵ′ = _su_iter_gate!(state2, gate, env2, sites, alg) ϵ = max(ϵ, ϵ′) - (!alg.bipartite) && continue - if d == 1 - rp1, cp1 = _next(r, Nr), _next(c, Nc) - state2[rp1, cp1] = deepcopy(state2[r, c]) - state2[rp1, c] = deepcopy(state2[r, cp1]) - env2[1, rp1, cp1] = deepcopy(env2[1, r, c]) - else - rm1, cm1 = _prev(r, Nr), _prev(c, Nc) - state2[rm1, cm1] = deepcopy(state2[r, c]) - state2[r, cm1] = deepcopy(state2[rm1, c]) - env2[2, rm1, cm1] = deepcopy(env2[2, r, c]) + if alg.bipartite + if d == 1 + rp1, cp1 = _next(r, Nr), _next(c, Nc) + state2[rp1, cp1] = deepcopy(state2[r, c]) + state2[rp1, c] = deepcopy(state2[r, cp1]) + env2[1, rp1, cp1] = deepcopy(env2[1, r, c]) + else + rm1, cm1 = _prev(r, Nr), _prev(c, Nc) + state2[rm1, cm1] = deepcopy(state2[r, c]) + state2[r, cm1] = deepcopy(state2[rm1, c]) + env2[2, rm1, cm1] = deepcopy(env2[2, r, c]) + end end else # N-site MPO gate (N ≥ 2) alg.bipartite && error("Multi-site MPO gates are not compatible with bipartite states.") - ϵ′ = _su_iter!(state2, gate, env2, sites, alg) + ϵ′ = _su_iter_mpo!(state2, gate, env2, sites, alg) ϵ = max(ϵ, ϵ′) end end diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 6c684ceef..0c8f59a42 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -68,12 +68,10 @@ Find the permutation to permute `out_ax`, `in_ax` legs to the first and the last position of a tensor with `Nax` legs, then assign the last leg to domain, and the others to codomain. """ -function _get_mpo_perm(out_ax::Int, in_ax::Int, Nax::Int) - perm = collect(1:Nax) - filter!(x -> x != out_ax && x != in_ax, perm) - pushfirst!(perm, out_ax) - push!(perm, in_ax) - return (Tuple(perm[1:(end - 1)]), (perm[end],)) +function _get_mpo_perm(out_ax::Integer, in_ax::Integer, ::Val{Nax}) where {Nax} + lo, hi = minmax(out_ax, in_ax) + perm = ntuple(k -> k < lo ? k : k < hi - 1 ? k + 1 : k + 2, Nax - 2) + return (out_ax, perm...), (in_ax,) end """ @@ -101,51 +99,46 @@ Otherwise, axes order of each tensor in `Ms` are preserved. - `open_vaxs`: Open virtual axes (1 to 4) of each cluster tensor before permutation. - `invperms`: Permutations to restore the axes order of each cluster tensor. """ -function _get_cluster(state, sites; permute::Bool = true) - return _get_cluster(state, sites, nothing; permute) -end function _get_cluster( state::InfiniteState, sites::Vector{CartesianIndex{2}}, env::Union{SUWeight, Nothing}; permute::Bool = true ) Nr, Nc = size(state) - # number of sites - Ns = length(sites) - # number of physical axes - Np = isa(state, InfinitePEPS) ? 1 : 2 + n_sites = length(sites) + n_physical_axes = numout(eltype(unitcell(state))) # number of axes of each state tensor - Nax = 4 + Np - out_axs = map(2:Ns) do i + Nax = 4 + n_physical_axes + out_axs = map(2:n_sites) do i return _nn_vec_direction(sites[i - 1] - sites[i]) end - in_axs = map(1:(Ns - 1)) do i + in_axs = map(1:(n_sites - 1)) do i return _nn_vec_direction(sites[i + 1] - sites[i]) end - all_vaxs = Tuple(1:4) - open_vaxs = map(1:Ns) do i + all_vaxs = (1, 2, 3, 4) + open_vaxs = map(1:n_sites) do i return if i == 1 filter(x -> x != in_axs[i], all_vaxs) - elseif i == Ns + elseif i == n_sites filter(x -> x != out_axs[i - 1], all_vaxs) else filter(x -> x != out_axs[i - 1] && x != in_axs[i], all_vaxs) end end - perms = map(1:Ns) do i - out_ax, in_ax = if i == 1 + perms = map(1:n_sites) do i + out_ax, in_ax = if i == 1 # first perm # use direction opposite to `in` as `out` mod1(2 + in_axs[i], 4), in_axs[i] - elseif i == Ns + elseif i == n_sites # last perm # use direction opposite to `out` as `in` out_axs[i - 1], mod1(2 + out_axs[i - 1], 4) else - out_axs[i - 1], in_axs[i] + out_axs[i - 1], in_axs[i] # mid perm end - return _get_mpo_perm(out_ax + Np, in_ax + Np, Nax) + return _get_mpo_perm(out_ax + n_physical_axes, in_ax + n_physical_axes, Val(Nax)) end invperms = map(perms) do (p1, p2) p = invperm((p1..., p2...)) - return (p[1:Np], p[(Np + 1):end]) + return (p[begin:n_physical_axes], p[(n_physical_axes + 1):end]) end Ms = map(zip(sites, open_vaxs, perms)) do (site, vaxs, perm) s = CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc)) From c8798e2f7796a5f1afe91cc9705e180930f7c93c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Sun, 12 Apr 2026 23:07:49 +0200 Subject: [PATCH 03/24] improve tuples stability --- .../time_evolution/simpleupdate3site.jl | 82 ++++++++++--------- 1 file changed, 42 insertions(+), 40 deletions(-) diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 0c8f59a42..a463c578c 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -63,14 +63,23 @@ function _nn_bondrev(site1::CartesianIndex{2}, site2::CartesianIndex{2}, (Nrow, end end +""" +Return a size N-k tuple with values 1 to N but the missing ones. Accept k=1 and k=2. +""" +function _filtered_oneto(i, ::Val{N}) where {N} + return ntuple(k -> k < i ? k : k + 1, N - 1) +end +function _filtered_oneto(i, j, ::Val{N}) where {N} + lo, hi = minmax(i, j) + return ntuple(k -> k < lo ? k : k < hi - 1 ? k + 1 : k + 2, N - 2) +end """ Find the permutation to permute `out_ax`, `in_ax` legs to the first and the last position of a tensor with `Nax` legs, then assign the last leg to domain, and the others to codomain. """ function _get_mpo_perm(out_ax::Integer, in_ax::Integer, ::Val{Nax}) where {Nax} - lo, hi = minmax(out_ax, in_ax) - perm = ntuple(k -> k < lo ? k : k < hi - 1 ? k + 1 : k + 2, Nax - 2) + perm = _filtered_oneto(out_ax, in_ax, Val(Nax)) return (out_ax, perm...), (in_ax,) end @@ -95,68 +104,61 @@ Otherwise, axes order of each tensor in `Ms` are preserved. ## Returns -- `Ms`: Tensors in the cluster. +- `vertices`: Tensors in the cluster. - `open_vaxs`: Open virtual axes (1 to 4) of each cluster tensor before permutation. - `invperms`: Permutations to restore the axes order of each cluster tensor. """ function _get_cluster( state::InfiniteState, sites::Vector{CartesianIndex{2}}, - env::Union{SUWeight, Nothing}; permute::Bool = true + env::SUWeight; permute::Bool = true ) Nr, Nc = size(state) n_sites = length(sites) n_physical_axes = numout(eltype(unitcell(state))) # number of axes of each state tensor - Nax = 4 + n_physical_axes + Nax = Val(4 + n_physical_axes) out_axs = map(2:n_sites) do i return _nn_vec_direction(sites[i - 1] - sites[i]) end in_axs = map(1:(n_sites - 1)) do i return _nn_vec_direction(sites[i + 1] - sites[i]) end - all_vaxs = (1, 2, 3, 4) - open_vaxs = map(1:n_sites) do i - return if i == 1 - filter(x -> x != in_axs[i], all_vaxs) - elseif i == n_sites - filter(x -> x != out_axs[i - 1], all_vaxs) - else - filter(x -> x != out_axs[i - 1] && x != in_axs[i], all_vaxs) - end + first_open_vaxs = _filtered_oneto(in_axs[1], Val(4)) + last_open_vaxs = _filtered_oneto(out_axs[n_sites - 1], Val(4)) + mid_vaxs = map(i -> _filtered_oneto(out_axs[i - 1], in_axs[i], Val(4)), 2:(n_sites - 1)) + # use direction opposite to `in` as `out` + first_perm = _get_mpo_perm(mod1(2 + in_axs[1], 4) + n_physical_axes, in_axs[1] + n_physical_axes, Nax) + # use direction opposite to `out` as `in` + last_perm = _get_mpo_perm(out_axs[n_sites - 1] + n_physical_axes, mod1(2 + out_axs[n_sites - 1], 4) + n_physical_axes, Nax) + mid_perms = map(2:(n_sites - 1)) do i + return _get_mpo_perm(out_axs[i - 1] + n_physical_axes, in_axs[i] + n_physical_axes, Nax) end - perms = map(1:n_sites) do i - out_ax, in_ax = if i == 1 # first perm - # use direction opposite to `in` as `out` - mod1(2 + in_axs[i], 4), in_axs[i] - elseif i == n_sites # last perm - # use direction opposite to `out` as `in` - out_axs[i - 1], mod1(2 + out_axs[i - 1], 4) - else - out_axs[i - 1], in_axs[i] # mid perm - end - return _get_mpo_perm(out_ax + n_physical_axes, in_ax + n_physical_axes, Val(Nax)) - end - invperms = map(perms) do (p1, p2) - p = invperm((p1..., p2...)) - return (p[begin:n_physical_axes], p[(n_physical_axes + 1):end]) - end - Ms = map(zip(sites, open_vaxs, perms)) do (site, vaxs, perm) + + open_vaxs::Vector{Tuple{Vararg{Int}}} = [first_open_vaxs, mid_vaxs..., last_open_vaxs] + perms = [first_perm, mid_perms..., last_perm] + invperms = invbiperm.(perms, Val(n_physical_axes)) + vertices::Vector{TensorMap{scalartype(state), spacetype(env), <:Any, <:Any, storagetype(eltype(unitcell(state)))}} = map( + zip(sites, open_vaxs, perms) + ) do (site, vaxs, perm) s = CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc)) - M = if env === nothing - state[s] - else - absorb_weight(state[s], env, s[1], s[2], vaxs) - end - return permute ? TensorKit.permute(M, perm) : M + t = absorb_weight(state[s], env, s[1], s[2], vaxs) + return permute ? TensorKit.permute(t, perm) : t end - return Ms, open_vaxs, invperms + return vertices, open_vaxs, invperms end +function invbiperm(bituple::Tuple{Tuple, Tuple}, ::Val{N}) where {N} + return invbiperm((first(bituple)..., last(bituple)...), Val(N)) +end +function invbiperm(t::Tuple, ::Val{N}) where {N} + p = invperm(t) + return p[begin:N], p[(N + 1):end] +end """ Simple update with an N-site MPO `gate` (N ≥ 2). """ -function _su_iter!( - state::InfiniteState, gate::Vector{T}, env::SUWeight, +function _su_iter_mpo!( + state::InfiniteState, gates::Vector{T}, env::SUWeight, sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate ) where {T <: AbstractTensorMap} Nr, Nc = size(state) From dd7c58b1c795391ba2ef25230113488c31efde9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 13 Apr 2026 00:31:25 +0200 Subject: [PATCH 04/24] remove debug --- src/algorithms/time_evolution/simpleupdate3site.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index a463c578c..1980f1a56 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -134,10 +134,10 @@ function _get_cluster( return _get_mpo_perm(out_axs[i - 1] + n_physical_axes, in_axs[i] + n_physical_axes, Nax) end - open_vaxs::Vector{Tuple{Vararg{Int}}} = [first_open_vaxs, mid_vaxs..., last_open_vaxs] + open_vaxs = [first_open_vaxs, mid_vaxs..., last_open_vaxs] perms = [first_perm, mid_perms..., last_perm] invperms = invbiperm.(perms, Val(n_physical_axes)) - vertices::Vector{TensorMap{scalartype(state), spacetype(env), <:Any, <:Any, storagetype(eltype(unitcell(state)))}} = map( + vertices = map( zip(sites, open_vaxs, perms) ) do (site, vaxs, perm) s = CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc)) @@ -173,7 +173,7 @@ function _su_iter_mpo!( gate_axs = alg.purified ? (1:1) : (1:2) global wts, ϵs for gate_ax in gate_axs - _apply_gatempo!(Ms, gate; gate_ax) + _apply_gatempo!(Ms, gates; gate_ax) if isa(state, InfinitePEPO) Ms = [first(_fuse_physicalspaces(M)) for M in Ms] end From d2b37af9365a0201687d10fe3282386aee5937ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 14 Apr 2026 16:23:46 +0200 Subject: [PATCH 05/24] remove get_cluster --- src/algorithms/time_evolution/simpleupdate.jl | 53 +++++++++-- .../time_evolution/simpleupdate3site.jl | 89 +++++++------------ 2 files changed, 77 insertions(+), 65 deletions(-) diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index a5a62096e..ccf08e599 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -80,6 +80,42 @@ function _bond_rotation(x, bonddir::Int, rev::Bool; inv::Bool = false) end end +""" +Obtain the left (first) cluster tensor from `state` at `site`, +where `in_ax` is the virtual axis connecting to the next tensor. +The tensor is not permuted; the returned `invperm` is the identity. +""" +function _get_left( + state::InfiniteState, site::CartesianIndex{2}, in_ax::Int, + env::SUWeight + ) + Nr, Nc = size(state) + open_vaxs = _filtered_oneto(in_ax, Val(4)) + s = mod1(site[1], Nr), mod1(site[2], Nc) + t = absorb_weight(state[s...], env, s[1], s[2], open_vaxs) + Nax = 4 + numout(t) + invperm = (ntuple(identity, Nax - 1), (Nax,)) + return t, open_vaxs, invperm +end + +""" +Obtain the right (last) cluster tensor from `state` at `site`, +where `out_ax` is the virtual axis connecting to the previous tensor. +The tensor is not permuted; the returned `invperm` is the identity. +""" +function _get_right( + state::InfiniteState, site::CartesianIndex{2}, out_ax::Int, + env::SUWeight + ) + Nr, Nc = size(state) + open_vaxs = _filtered_oneto(out_ax, Val(4)) + s = mod1(site[1], Nr), mod1(site[2], Nc) + t = absorb_weight(state[s...], env, s[1], s[2], open_vaxs) + Nax = 4 + numout(t) + invperm = (ntuple(identity, Nax - 1), (Nax,)) + return t, open_vaxs, invperm +end + """ Simple update optimized for nearest neighbor gates utilizing reduced bond tensors with the physical leg. @@ -91,10 +127,14 @@ function _su_iter_gate!( Nr, Nc = size(state) truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc)) @assert length(sites) == 2 && length(truncs) == 1 - Ms, open_vaxs, = _get_cluster(state, sites, env; permute = false) + in_ax = _nn_vec_direction(sites[2] - sites[1]) + out_ax = _nn_vec_direction(sites[1] - sites[2]) + A, open_vaxs_A, = _get_left(state, sites[1], in_ax, env) + B, open_vaxs_B, = _get_right(state, sites[2], out_ax, env) # rotate bond, rev = _nn_bondrev(sites..., (Nr, Nc)) - A, B = _bond_rotation.(Ms, bond[1], rev; inv = false) + A = _bond_rotation(A, bond[1], rev; inv = false) + B = _bond_rotation(B, bond[1], rev; inv = false) # apply gate ϵ = 0.0 s = nothing @@ -113,14 +153,11 @@ function _su_iter_gate!( siteA, siteB = map(sites) do site return CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc)) end - A = absorb_weight(A, env, siteA[1], siteA[2], open_vaxs[1]; inv = true) - B = absorb_weight(B, env, siteB[1], siteB[2], open_vaxs[2]; inv = true) + A = absorb_weight(A, env, siteA[1], siteA[2], open_vaxs_A; inv = true) + B = absorb_weight(B, env, siteB[1], siteB[2], open_vaxs_B; inv = true) # update tensor dict and weight on current bond - normalize!(A, Inf) - normalize!(B, Inf) - normalize!(s, Inf) state[siteA], state[siteB] = A, B - env[bond...] = s + env[bond...] = normalize!(s, Inf) return ϵ end diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 1980f1a56..0cbe69ebe 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -84,69 +84,26 @@ function _get_mpo_perm(out_ax::Integer, in_ax::Integer, ::Val{Nax}) where {Nax} end """ -Obtain the cluster `Ms` along the (open) path `sites` in `state`. - -When the `SUWeight` environment `env` is provided, -it will be absorbed into tensors of `Ms`. - -When `permute = true`, permute tensors in `Ms` to MPS axis order -``` - PEPS: PEPO: - 3 3 4 - ╱ | ╱ - o -- M -- i o -- M -- i - ╱ | ╱ | - 4 2 5 2 - M[o 2 3 4; i] M[o 2 3 4 5; i] -``` -where `o` (`i`) connects to the previous (next) tensor. -Otherwise, axes order of each tensor in `Ms` are preserved. - -## Returns - -- `vertices`: Tensors in the cluster. -- `open_vaxs`: Open virtual axes (1 to 4) of each cluster tensor before permutation. -- `invperms`: Permutations to restore the axes order of each cluster tensor. +Obtain a middle cluster tensor from `state` at `site`, +where `out_ax` (`in_ax`) is the virtual axis connecting to the previous (next) tensor. +The tensor is permuted to MPS axis order. """ -function _get_cluster( - state::InfiniteState, sites::Vector{CartesianIndex{2}}, - env::SUWeight; permute::Bool = true +function _get_mid( + state::InfiniteState, site::CartesianIndex{2}, out_ax::Int, in_ax::Int, + env::SUWeight ) Nr, Nc = size(state) - n_sites = length(sites) n_physical_axes = numout(eltype(unitcell(state))) - # number of axes of each state tensor Nax = Val(4 + n_physical_axes) - out_axs = map(2:n_sites) do i - return _nn_vec_direction(sites[i - 1] - sites[i]) - end - in_axs = map(1:(n_sites - 1)) do i - return _nn_vec_direction(sites[i + 1] - sites[i]) - end - first_open_vaxs = _filtered_oneto(in_axs[1], Val(4)) - last_open_vaxs = _filtered_oneto(out_axs[n_sites - 1], Val(4)) - mid_vaxs = map(i -> _filtered_oneto(out_axs[i - 1], in_axs[i], Val(4)), 2:(n_sites - 1)) - # use direction opposite to `in` as `out` - first_perm = _get_mpo_perm(mod1(2 + in_axs[1], 4) + n_physical_axes, in_axs[1] + n_physical_axes, Nax) - # use direction opposite to `out` as `in` - last_perm = _get_mpo_perm(out_axs[n_sites - 1] + n_physical_axes, mod1(2 + out_axs[n_sites - 1], 4) + n_physical_axes, Nax) - mid_perms = map(2:(n_sites - 1)) do i - return _get_mpo_perm(out_axs[i - 1] + n_physical_axes, in_axs[i] + n_physical_axes, Nax) - end - - open_vaxs = [first_open_vaxs, mid_vaxs..., last_open_vaxs] - perms = [first_perm, mid_perms..., last_perm] - invperms = invbiperm.(perms, Val(n_physical_axes)) - vertices = map( - zip(sites, open_vaxs, perms) - ) do (site, vaxs, perm) - s = CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc)) - t = absorb_weight(state[s], env, s[1], s[2], vaxs) - return permute ? TensorKit.permute(t, perm) : t - end - return vertices, open_vaxs, invperms + open_vaxs = _filtered_oneto(out_ax, in_ax, Val(4)) + perm = _get_mpo_perm(out_ax + n_physical_axes, in_ax + n_physical_axes, Nax) + invperm = invbiperm(perm, Val(n_physical_axes)) + s = mod1(site[1], Nr), mod1(site[2], Nc) + t = absorb_weight(state[s...], env, s[1], s[2], open_vaxs) + return permute(t, perm), open_vaxs, invperm end + function invbiperm(bituple::Tuple{Tuple, Tuple}, ::Val{N}) where {N} return invbiperm((first(bituple)..., last(bituple)...), Val(N)) end @@ -162,8 +119,26 @@ function _su_iter_mpo!( sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate ) where {T <: AbstractTensorMap} Nr, Nc = size(state) + n_physical_axes = numout(eltype(unitcell(state))) + Nax = Val(4 + n_physical_axes) + n_sites = length(sites) truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc)) - Ms, open_vaxs, invperms = _get_cluster(state, sites, env) + out_axs = map(i -> _nn_vec_direction(sites[i - 1] - sites[i]), 2:n_sites) + in_axs = map(i -> _nn_vec_direction(sites[i + 1] - sites[i]), 1:(n_sites - 1)) + # left and right: get tensor without permutation, then permute to MPS form + left_M, left_vaxs, = _get_left(state, sites[1], in_axs[1], env) + right_M, right_vaxs, = _get_right(state, sites[end], out_axs[end], env) + left_perm = _get_mpo_perm(mod1(2 + in_axs[1], 4) + n_physical_axes, in_axs[1] + n_physical_axes, Nax) + right_perm = _get_mpo_perm(out_axs[end] + n_physical_axes, mod1(2 + out_axs[end], 4) + n_physical_axes, Nax) + left_M = TensorKit.permute(left_M, left_perm) + right_M = TensorKit.permute(right_M, right_perm) + left_invperm = invbiperm(left_perm, Val(n_physical_axes)) + right_invperm = invbiperm(right_perm, Val(n_physical_axes)) + # middle tensors: permuted to MPS form in _get_mid + mids = map(i -> _get_mid(state, sites[i], out_axs[i - 1], in_axs[i], env), 2:(n_sites - 1)) + Ms = [left_M, getindex.(mids, 1)..., right_M] + open_vaxs = [left_vaxs, getindex.(mids, 2)..., right_vaxs] + invperms = [left_invperm, getindex.(mids, 3)..., right_invperm] flips = [isdual(space(M, 1)) for M in Ms[2:end]] Vphys = [codomain(M, 2) for M in Ms] normalize!.(Ms, Inf) From ada6ba7a1a3948e79937b5d8024a311536f07002 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 14 Apr 2026 17:08:52 +0200 Subject: [PATCH 06/24] factorize _su_iter_gate --- src/algorithms/time_evolution/simpleupdate.jl | 45 +++++++++---------- .../time_evolution/simpleupdate3site.jl | 5 +-- 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index ccf08e599..4105df376 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -93,7 +93,7 @@ function _get_left( open_vaxs = _filtered_oneto(in_ax, Val(4)) s = mod1(site[1], Nr), mod1(site[2], Nc) t = absorb_weight(state[s...], env, s[1], s[2], open_vaxs) - Nax = 4 + numout(t) + Nax = 4 + numout(eltype(state)) invperm = (ntuple(identity, Nax - 1), (Nax,)) return t, open_vaxs, invperm end @@ -111,7 +111,7 @@ function _get_right( open_vaxs = _filtered_oneto(out_ax, Val(4)) s = mod1(site[1], Nr), mod1(site[2], Nc) t = absorb_weight(state[s...], env, s[1], s[2], open_vaxs) - Nax = 4 + numout(t) + Nax = 4 + numout(eltype(state)) invperm = (ntuple(identity, Nax - 1), (Nax,)) return t, open_vaxs, invperm end @@ -122,41 +122,36 @@ utilizing reduced bond tensors with the physical leg. """ function _su_iter_gate!( state::InfiniteState, gate::NNGate, env::SUWeight, - sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate + siteA::CartesianIndex{2}, siteB::CartesianIndex{2}, alg::SimpleUpdate ) Nr, Nc = size(state) - truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc)) - @assert length(sites) == 2 && length(truncs) == 1 - in_ax = _nn_vec_direction(sites[2] - sites[1]) - out_ax = _nn_vec_direction(sites[1] - sites[2]) - A, open_vaxs_A, = _get_left(state, sites[1], in_ax, env) - B, open_vaxs_B, = _get_right(state, sites[2], out_ax, env) + trunc = only(_get_cluster_trunc(alg.trunc, [siteA, siteB], (Nr, Nc))) + in_ax = _nn_vec_direction(siteB - siteA) + out_ax = mod1(in_ax + 2, 4) + A0, open_vaxs_A, = _get_left(state, siteA, in_ax, env) + B0, open_vaxs_B, = _get_right(state, siteB, out_ax, env) # rotate - bond, rev = _nn_bondrev(sites..., (Nr, Nc)) - A = _bond_rotation(A, bond[1], rev; inv = false) - B = _bond_rotation(B, bond[1], rev; inv = false) + bond, rev = _nn_bondrev(siteA, siteB, (Nr, Nc)) + dir = first(bond) + A = _bond_rotation(A0, dir, rev; inv = false) + B = _bond_rotation(B0, dir, rev; inv = false) # apply gate ϵ = 0.0 - s = nothing + local s gate_axs = alg.purified ? (1:1) : (1:2) for gate_ax in gate_axs X, a, b, Y = _qr_bond(A, B; gate_ax, positive = true) - a, s, b, ϵ′ = _apply_gate(a, b, gate, truncs[1]) + a, s, b, ϵ′ = _apply_gate(a, b, gate, trunc) ϵ = max(ϵ, ϵ′) A, B = _qr_bond_undo(X, a, b, Y) end - # rotate back - A = _bond_rotation(A, bond[1], rev; inv = true) - B = _bond_rotation(B, bond[1], rev; inv = true) rev && (s = transpose(s)) - # remove environment weights - siteA, siteB = map(sites) do site - return CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc)) + # rotate back & remove environment weights + for (site, vertex, open_vaxs) in ((siteA, A, open_vaxs_A), (siteB, B, open_vaxs_B)) + s′ = (mod1(site[1], Nr), mod1(site[2], Nc)) + rotated = _bond_rotation(vertex, dir, rev; inv = true) + state[s′...] = absorb_weight(rotated, env, s′..., open_vaxs; inv = true) end - A = absorb_weight(A, env, siteA[1], siteA[2], open_vaxs_A; inv = true) - B = absorb_weight(B, env, siteB[1], siteB[2], open_vaxs_B; inv = true) - # update tensor dict and weight on current bond - state[siteA], state[siteB] = A, B env[bond...] = normalize!(s, Inf) return ϵ end @@ -180,7 +175,7 @@ function su_iter( elseif length(sites) == 2 (d, r, c), = _nn_bondrev(sites..., (Nr, Nc)) alg.bipartite && r > 1 && continue - ϵ′ = _su_iter_gate!(state2, gate, env2, sites, alg) + ϵ′ = _su_iter_gate!(state2, gate, env2, sites[1], sites[2], alg) ϵ = max(ϵ, ϵ′) if alg.bipartite if d == 1 diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 0cbe69ebe..69bccfed0 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -130,8 +130,8 @@ function _su_iter_mpo!( right_M, right_vaxs, = _get_right(state, sites[end], out_axs[end], env) left_perm = _get_mpo_perm(mod1(2 + in_axs[1], 4) + n_physical_axes, in_axs[1] + n_physical_axes, Nax) right_perm = _get_mpo_perm(out_axs[end] + n_physical_axes, mod1(2 + out_axs[end], 4) + n_physical_axes, Nax) - left_M = TensorKit.permute(left_M, left_perm) - right_M = TensorKit.permute(right_M, right_perm) + left_M = permute(left_M, left_perm) + right_M = permute(right_M, right_perm) left_invperm = invbiperm(left_perm, Val(n_physical_axes)) right_invperm = invbiperm(right_perm, Val(n_physical_axes)) # middle tensors: permuted to MPS form in _get_mid @@ -141,7 +141,6 @@ function _su_iter_mpo!( invperms = [left_invperm, getindex.(mids, 3)..., right_invperm] flips = [isdual(space(M, 1)) for M in Ms[2:end]] Vphys = [codomain(M, 2) for M in Ms] - normalize!.(Ms, Inf) # flip virtual arrows in `Ms` to ← _flip_virtuals!(Ms, flips) # apply gate MPOs and truncate From 83faeff29192b6a1947ce5480faba924a732ed8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Wed, 15 Apr 2026 11:27:05 +0200 Subject: [PATCH 07/24] type stable absorb_weight --- src/environments/suweight.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/environments/suweight.jl b/src/environments/suweight.jl index 2966d4155..4219ae665 100644 --- a/src/environments/suweight.jl +++ b/src/environments/suweight.jl @@ -209,10 +209,9 @@ function absorb_weight( t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, ax::Int; inv::Bool = false ) - Nr, Nc = size(weights)[2:end] - nin, nout, ntol = numin(t), numout(t), numind(t) + _, Nr, Nc = size(weights) @assert 1 <= row <= Nr && 1 <= col <= Nc - @assert 1 <= ax <= nin + @assert 1 <= ax <= numin(t) pow = inv ? -1 / 2 : 1 / 2 wt = sdiag_pow( if ax == NORTH @@ -226,18 +225,23 @@ function absorb_weight( end, pow, ) - t_idx = [(n - nout == ax) ? 1 : -n for n in 1:ntol] - ax′ = ax + nout - wt_idx = (ax == NORTH || ax == EAST) ? [1, -ax′] : [-ax′, 1] + ax′ = ax + numout(t) # make absorption/removal twist-free twistdual!(wt, 1) - return permute(ncon((t, wt), (t_idx, wt_idx)), (Tuple(1:nout), Tuple((nout + 1):ntol))) + if ax == SOUTH || ax == WEST + wt = transpose(wt) # not sure this can be factorized due to twistdual + end + biperm = (_filtered_oneto(ax′, Val(numind(t))), (ax′,)) + contracted = permute(t, biperm) * wt + invbp = invbiperm(biperm, Val(numout(t))) + return permute(contracted, invbp) end function absorb_weight( t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, ax::NTuple{N, Int}; inv::Bool = false ) where {N} - t2 = copy(t) + t2 = t + # should not permute back and forth for a in ax t2 = absorb_weight(t2, weights, row, col, a; inv) end From 8a5c1962c03df9d4c7cc5be6567e23efb3243f79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Wed, 15 Apr 2026 12:00:36 +0200 Subject: [PATCH 08/24] WIP typing for _su_iter_mpo --- src/algorithms/time_evolution/simpleupdate.jl | 2 +- .../time_evolution/simpleupdate3site.jl | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index 4105df376..aa64f2111 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -139,7 +139,7 @@ function _su_iter_gate!( ϵ = 0.0 local s gate_axs = alg.purified ? (1:1) : (1:2) - for gate_ax in gate_axs + for gate_ax in gate_axs # TODO try to use type stable helper function X, a, b, Y = _qr_bond(A, B; gate_ax, positive = true) a, s, b, ϵ′ = _apply_gate(a, b, gate, trunc) ϵ = max(ϵ, ϵ′) diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 69bccfed0..3500a4d8b 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -126,21 +126,21 @@ function _su_iter_mpo!( out_axs = map(i -> _nn_vec_direction(sites[i - 1] - sites[i]), 2:n_sites) in_axs = map(i -> _nn_vec_direction(sites[i + 1] - sites[i]), 1:(n_sites - 1)) # left and right: get tensor without permutation, then permute to MPS form - left_M, left_vaxs, = _get_left(state, sites[1], in_axs[1], env) - right_M, right_vaxs, = _get_right(state, sites[end], out_axs[end], env) + left_M0, left_vaxs, = _get_left(state, sites[1], in_axs[1], env) + right_M0, right_vaxs, = _get_right(state, sites[end], out_axs[end], env) left_perm = _get_mpo_perm(mod1(2 + in_axs[1], 4) + n_physical_axes, in_axs[1] + n_physical_axes, Nax) right_perm = _get_mpo_perm(out_axs[end] + n_physical_axes, mod1(2 + out_axs[end], 4) + n_physical_axes, Nax) - left_M = permute(left_M, left_perm) - right_M = permute(right_M, right_perm) + left_M = permute(left_M0, left_perm) + right_M = permute(right_M0, right_perm) left_invperm = invbiperm(left_perm, Val(n_physical_axes)) right_invperm = invbiperm(right_perm, Val(n_physical_axes)) # middle tensors: permuted to MPS form in _get_mid mids = map(i -> _get_mid(state, sites[i], out_axs[i - 1], in_axs[i], env), 2:(n_sites - 1)) - Ms = [left_M, getindex.(mids, 1)..., right_M] - open_vaxs = [left_vaxs, getindex.(mids, 2)..., right_vaxs] + Ms = [left_M, getindex.(mids, 1)..., right_M] # TODO remove + open_vaxs = [left_vaxs, getindex.(mids, 2)..., right_vaxs] # TODO removve invperms = [left_invperm, getindex.(mids, 3)..., right_invperm] - flips = [isdual(space(M, 1)) for M in Ms[2:end]] - Vphys = [codomain(M, 2) for M in Ms] + flips = push!([isdual(space(first(x), 1)) for x in mids], isdual(space(right_M, 1))) + Vphys = [codomain(left_M, 2), map(x -> codomain(first(x), 2), mids)..., codomain(right_M, 2)] # flip virtual arrows in `Ms` to ← _flip_virtuals!(Ms, flips) # apply gate MPOs and truncate From 75cfad4b4264fb7a1c1ab749ef78e4acba2d7faf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 17 Apr 2026 01:59:25 +0200 Subject: [PATCH 09/24] avoid permuting twice in absorb_weight --- src/environments/suweight.jl | 50 +++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/src/environments/suweight.jl b/src/environments/suweight.jl index 4219ae665..36825f959 100644 --- a/src/environments/suweight.jl +++ b/src/environments/suweight.jl @@ -205,13 +205,11 @@ absorb_weight(t, weights, 2, 3, 1) absorb_weight(t, weights, 2, 3, 2; inv=true) ``` """ -function absorb_weight( - t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, - row::Int, col::Int, ax::Int; inv::Bool = false +function weight_to_absorb( + weights::SUWeight, row::Int, col::Int, ax::Int; inv::Bool = false ) _, Nr, Nc = size(weights) @assert 1 <= row <= Nr && 1 <= col <= Nc - @assert 1 <= ax <= numin(t) pow = inv ? -1 / 2 : 1 / 2 wt = sdiag_pow( if ax == NORTH @@ -225,27 +223,43 @@ function absorb_weight( end, pow, ) - ax′ = ax + numout(t) # make absorption/removal twist-free twistdual!(wt, 1) - if ax == SOUTH || ax == WEST - wt = transpose(wt) # not sure this can be factorized due to twistdual - end - biperm = (_filtered_oneto(ax′, Val(numind(t))), (ax′,)) - contracted = permute(t, biperm) * wt - invbp = invbiperm(biperm, Val(numout(t))) - return permute(contracted, invbp) + (ax == SOUTH || ax == WEST) && return transpose(wt) # not sure this can be factorized due to twistdual + return wt +end + +function biperm_absorb_weight(legs::NTuple{N, Int}, vax::Int) where {N} + @assert N == 5 || N == 6 + nin = N - 4 + a = vax + nin + codomain_axes = _filtered_oneto(a, Val(N)) + biperm = (map(i -> findfirst(==(i), legs)::Int, codomain_axes), (findfirst(==(a), legs)::Int,)) + new_legs = (ntuple(i -> legs[biperm[1][i]], N - 1)..., a) + return new_legs, biperm +end + +function absorb_first_weight(t::Union{PEPSTensor, PEPOTensor}, wt, vax) + legs = ntuple(identity, numind(t)) + new_legs, biperm = biperm_absorb_weight(legs, vax) + t2 = permute(t, biperm) * wt + return new_legs, t2 end + function absorb_weight( t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, - row::Int, col::Int, ax::NTuple{N, Int}; inv::Bool = false + row::Int, col::Int, virt_axes::NTuple{N, Int}; inv::Bool = false ) where {N} - t2 = t - # should not permute back and forth - for a in ax - t2 = absorb_weight(t2, weights, row, col, a; inv) + vax = first(virt_axes) + weight_vax = weight_to_absorb(weights, row, col, vax; inv) + legs, t2 = absorb_first_weight(t, weight_vax, vax) + for vax in virt_axes[(begin + 1):end] + legs, biperm = biperm_absorb_weight(legs, vax) + weight_vax = weight_to_absorb(weights, row, col, vax; inv) + t2 = permute(t2, biperm) * weight_vax end - return t2 + perm_back = invperm(legs) + return permute(t2, (perm_back[begin:numout(t)], perm_back[(numout(t) + 1):end])) end #= Rotation of SUWeight. Example: 3 x 3 network From e3dc4ed389cf96b57a695a2461ac154fc57350ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 17 Apr 2026 10:03:48 +0200 Subject: [PATCH 10/24] suggestion --- src/algorithms/time_evolution/simpleupdate.jl | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index aa64f2111..87876652c 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -177,18 +177,17 @@ function su_iter( alg.bipartite && r > 1 && continue ϵ′ = _su_iter_gate!(state2, gate, env2, sites[1], sites[2], alg) ϵ = max(ϵ, ϵ′) - if alg.bipartite - if d == 1 - rp1, cp1 = _next(r, Nr), _next(c, Nc) - state2[rp1, cp1] = deepcopy(state2[r, c]) - state2[rp1, c] = deepcopy(state2[r, cp1]) - env2[1, rp1, cp1] = deepcopy(env2[1, r, c]) - else - rm1, cm1 = _prev(r, Nr), _prev(c, Nc) - state2[rm1, cm1] = deepcopy(state2[r, c]) - state2[r, cm1] = deepcopy(state2[rm1, c]) - env2[2, rm1, cm1] = deepcopy(env2[2, r, c]) - end + (!alg.bipartite) && continue + if d == 1 + rp1, cp1 = _next(r, Nr), _next(c, Nc) + state2[rp1, cp1] = copy(state2[r, c]) + state2[rp1, c] = copy(state2[r, cp1]) + env2[1, rp1, cp1] = copy(env2[1, r, c]) + else + rm1, cm1 = _prev(r, Nr), _prev(c, Nc) + state2[rm1, cm1] = copy(state2[r, c]) + state2[r, cm1] = copy(state2[rm1, c]) + env2[2, rm1, cm1] = copy(env2[2, r, c]) end else # N-site MPO gate (N ≥ 2) From 23ea6c090a2a4dc0751cb5e161af6c82d4548521 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 24 Apr 2026 11:02:14 +0200 Subject: [PATCH 11/24] add verbosity kwarg --- src/algorithms/time_evolution/simpleupdate.jl | 74 ++++++++++--------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index 87876652c..5b1830c2c 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -45,7 +45,7 @@ end symmetrize_gates::Bool = false ) -Initialize a `TimeEvolver` with Hamiltonian `H` and simple update `alg`, +Initialize a `TimeEvolver` with Hamiltonian `H` and simple update `alg`, starting from the initial state `psi0` and `SUWeight` environment `env0`. - The initial time is specified by `t0`. @@ -235,7 +235,7 @@ end """ time_evolve( - it::TimeEvolver{<:SimpleUpdate}; + it::TimeEvolver{<:SimpleUpdate}; tol::Float64 = 0.0, check_interval::Int = 500 ) -> (psi, env, info) @@ -248,44 +248,46 @@ or until convergence of `SUWeight` set by a positive `tol`. """ function MPSKit.time_evolve( it::TimeEvolver{<:SimpleUpdate}; - tol::Float64 = 0.0, check_interval::Int = 500 + tol::Float64 = 0.0, check_interval::Int = 500, verbosity::Int = 0 ) - time_start = time() - check_convergence = (tol > 0) - @info "--- Time evolution (simple update), dt = $(it.dt) ---" - if check_convergence - @assert (it.state.psi isa InfinitePEPS) && it.alg.imaginary_time "Only imaginary time evolution of InfinitePEPS allows convergence checking." - end - env0, time0 = it.state.env, time() - for (psi, env, info) in it - iter = it.state.iter - diff = compare_weights(env0, env) - stop = (iter == it.nstep) || (diff < tol) - showinfo = (check_interval > 0) && - ((iter % check_interval == 0) || (iter == 1) || stop) - time1 = time() - if showinfo - @info "Space of x-weight at [1, 1] = $(space(env[1, 1, 1], 1))" - @info @sprintf("SU iter %-7d: |Δλ| = %.3e. Time = %.3f s/it", iter, diff, time1 - time0) - end + return LoggingExtras.withlevel(; verbosity) do + time_start = time() + check_convergence = (tol > 0) + @infov 2 "--- Time evolution (simple update), dt = $(it.dt) ---" if check_convergence - if (iter == it.nstep) && (diff >= tol) - @warn "SU: bond weights have not converged." + @assert (it.state.psi isa InfinitePEPS) && it.alg.imaginary_time "Only imaginary time evolution of InfinitePEPS allows convergence checking." + end + env0, time0 = it.state.env, time() + for (psi, env, info) in it + iter = it.state.iter + diff = compare_weights(env0, env) + stop = (iter == it.nstep) || (diff < tol) + showinfo = (check_interval > 0) && + ((iter % check_interval == 0) || (iter == 1) || stop) + time1 = time() + if showinfo + @infov 2 "Space of x-weight at [1, 1] = $(space(env[1, 1, 1], 1))" + @infov 2 @sprintf("SU iter %-7d: |Δλ| = %.3e. Time = %.3f s/it", iter, diff, time1 - time0) end - if diff < tol - @info "SU: bond weights have converged." + if check_convergence + if (iter == it.nstep) && (diff >= tol) + @warn "SU: bond weights have not converged." + end + if diff < tol + @infov 2 "SU: bond weights have converged." + end end + if stop + time_end = time() + @infov 2 @sprintf("Time evolution finished in %.2f s", time_end - time_start) + return psi, env, info + else + env0 = env + end + time0 = time() end - if stop - time_end = time() - @info @sprintf("Time evolution finished in %.2f s", time_end - time_start) - return psi, env, info - else - env0 = env - end - time0 = time() + return end - return end """ @@ -297,14 +299,14 @@ end Perform time evolution on the initial iPEPS or iPEPO `psi0` and initial environment `env0` with Hamiltonian `H`, using `SimpleUpdate` -algorithm `alg`, time step `dt` for `nstep` number of steps. +algorithm `alg`, time step `dt` for `nstep` number of steps. - Set `symmetrize_gates = true` for second-order Trotter decomposition. - Set `tol > 0` to enable convergence check (for imaginary time evolution of iPEPS only). For other usages it should not be changed. - Use `t0` to specify the initial time of the evolution. - `check_interval` sets the interval to output information. Output during the evolution can be turned off by setting `check_interval <= 0`. -- `info` is a NamedTuple containing information of the evolution, +- `info` is a NamedTuple containing information of the evolution, including the time `info.t` evolved since `psi0`. """ function MPSKit.time_evolve( From 63db27c41ce51f7436e09dadde46a3ca464547d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 24 Apr 2026 12:17:28 +0200 Subject: [PATCH 12/24] clean su_iter_mpo --- .../time_evolution/simpleupdate3site.jl | 56 +++++++++++-------- test/timeevol/j1j2_finiteT.jl | 4 +- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 3500a4d8b..0f5213907 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -111,6 +111,18 @@ function invbiperm(t::Tuple, ::Val{N}) where {N} p = invperm(t) return p[begin:N], p[(N + 1):end] end +function cluster_truncate!(vertices, truncs, ::InfinitePEPO) + Vphys = codomain.(vertices, 2) + fused_vertices = [first(_fuse_physicalspaces(v)) for v in vertices] + wts, ϵs, = _cluster_truncate!(fused_vertices, truncs) + new_vertices = [first(_unfuse_physicalspace(v, Vphy)) for (v, Vphy) in zip(fused_vertices, Vphys)] + return new_vertices, wts, ϵs +end + +function cluster_truncate!(Ms, truncs, ::InfinitePEPS) + wts, ϵs, = _cluster_truncate!(Ms2, truncs) + return Ms, wts, ϵs +end """ Simple update with an N-site MPO `gate` (N ≥ 2). """ @@ -136,28 +148,26 @@ function _su_iter_mpo!( right_invperm = invbiperm(right_perm, Val(n_physical_axes)) # middle tensors: permuted to MPS form in _get_mid mids = map(i -> _get_mid(state, sites[i], out_axs[i - 1], in_axs[i], env), 2:(n_sites - 1)) - Ms = [left_M, getindex.(mids, 1)..., right_M] # TODO remove + vertices = [left_M, getindex.(mids, 1)..., right_M] # TODO remove + # Ms has well defined eltype Here + # issue it is redefined later with Any eltype open_vaxs = [left_vaxs, getindex.(mids, 2)..., right_vaxs] # TODO removve + # open_vaxs however cannot be stable invperms = [left_invperm, getindex.(mids, 3)..., right_invperm] flips = push!([isdual(space(first(x), 1)) for x in mids], isdual(space(right_M, 1))) - Vphys = [codomain(left_M, 2), map(x -> codomain(first(x), 2), mids)..., codomain(right_M, 2)] - # flip virtual arrows in `Ms` to ← - _flip_virtuals!(Ms, flips) + # flip virtual arrows in `vertices` to ← + _flip_virtuals!(vertices, flips) + # apply gate MPOs and truncate - gate_axs = alg.purified ? (1:1) : (1:2) - global wts, ϵs - for gate_ax in gate_axs - _apply_gatempo!(Ms, gates; gate_ax) - if isa(state, InfinitePEPO) - Ms = [first(_fuse_physicalspaces(M)) for M in Ms] - end - wts, ϵs, = _cluster_truncate!(Ms, truncs) - if isa(state, InfinitePEPO) - Ms = [first(_unfuse_physicalspace(M, Vphy)) for (M, Vphy) in zip(Ms, Vphys)] - end + _apply_gatempo!(vertices, gates; gate_ax = 1) + new_vertices, wts, ϵs = cluster_truncate!(vertices, truncs, state) + if !alg.purified + _apply_gatempo!(new_vertices, gates; gate_ax = 2) + new_vertices, wts, ϵs = cluster_truncate!(new_vertices, truncs, state) end - # restore virtual arrows in `Ms` - _flip_virtuals!(Ms, flips) + + # restore virtual arrows in `new_vertices` + _flip_virtuals!(new_vertices, flips) # update env weights bond_revs = map(zip(sites, Iterators.drop(sites, 1))) do (site1, site2) _nn_bondrev(site1, site2, (Nr, Nc)) @@ -166,16 +176,14 @@ function _su_iter_mpo!( wt_new = flip ? _fliptwist_s(wt) : wt wt_new = rev ? transpose(wt_new) : wt_new @assert all(wt_new.data .>= 0) - env[CartesianIndex(bond)] = normalize(wt_new, Inf) + env[CartesianIndex(bond)] = normalize!(wt_new, Inf) end - for (M, s, invperm, vaxs) in zip(Ms, sites, invperms, open_vaxs) + for (vertex, s, invperm, vaxs) in zip(new_vertices, sites, invperms, open_vaxs) s′ = CartesianIndex(mod1(s[1], Nr), mod1(s[2], Nc)) # restore original axes order - M = permute(M, invperm) - # remove weights on open axes of the cluster - M = absorb_weight(M, env, s′[1], s′[2], vaxs; inv = true) - # update state tensors - state[s′] = normalize(M, Inf) + permuted = permute(vertex, invperm) + # remove weights on open axes of the cluster and update state + state[s′] = absorb_weight(permuted, env, s′[1], s′[2], vaxs; inv = true) end return maximum(ϵs) end diff --git a/test/timeevol/j1j2_finiteT.jl b/test/timeevol/j1j2_finiteT.jl index 63576e6a1..72fa6b00e 100644 --- a/test/timeevol/j1j2_finiteT.jl +++ b/test/timeevol/j1j2_finiteT.jl @@ -12,7 +12,7 @@ bm = [-0.1235, -0.213] function converge_env(state, χ::Int) trunc1 = truncrank(χ) & truncerror(; atol = 1.0e-12) env0 = CTMRGEnv(ones, Float64, state, Vect[SU2Irrep](0 => 1)) - env, = leading_boundary(env0, state; alg = :sequential, trunc = trunc1, tol = 1.0e-10) + env, = leading_boundary(env0, state; alg = :sequential, trunc = trunc1, tol = 1.0e-10, verbosity = 0) return env end @@ -25,7 +25,7 @@ pepo0 = PEPSKit.infinite_temperature_density_matrix(ham) wts0 = SUWeight(pepo0) # 7 = 1 (spin-0) + 2 x 3 (spin-1) trunc_pepo = truncrank(7) & truncerror(; atol = 1.0e-12) -check_interval = 100 +check_interval = 2^32 dt, nstep = 1.0e-3, 600 # PEPO approach From b166ef20f00129f146a66e1772b0d286358032cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 24 Apr 2026 16:02:38 +0200 Subject: [PATCH 13/24] _su_iter_mpo type stable --- src/algorithms/time_evolution/apply_mpo.jl | 135 +++++++++--------- .../time_evolution/simpleupdate3site.jl | 29 ++-- src/environments/suweight.jl | 13 +- 3 files changed, 94 insertions(+), 83 deletions(-) diff --git a/src/algorithms/time_evolution/apply_mpo.jl b/src/algorithms/time_evolution/apply_mpo.jl index e6809c56a..a97575016 100644 --- a/src/algorithms/time_evolution/apply_mpo.jl +++ b/src/algorithms/time_evolution/apply_mpo.jl @@ -1,4 +1,4 @@ -#= +#= # Mixed canonical form of an open boundary MPS ``` |ψ⟩ = M[1]-←-...-←-M[N] @@ -54,7 +54,7 @@ Note that Then `M̃[n]` (n = 1, ..., N - 1) satisfies the (generalized) left-orthogonal condition ``` ┌---←--M̃[n]--←- ┌-←- 2 - | | | + | | | s[n-1] ↓ = s[n] (s[0] = 1) | | | └---→--M̃†[n]-→- └-→- 1 @@ -71,16 +71,16 @@ Similarly, we can express M̃ using Qb Then `M̃[n]` (n = 2, ..., N) satisfies the (generalized) right-orthogonal condition ``` -←-M̃[n]-←┐ 1 -←-┐ - ↓ | | + ↓ | | * s[n] = s[n-1] (s[N] = 1) ↓ | | -→M̃†[n]-→┘ 2 -→-┘ ``` -Here `-*-` is the twist on the physical axis. +Here `-*-` is the twist on the physical axis. # Truncation of a bond on OBC-MPS -Suppose we want to truncate the bond between +Suppose we want to truncate the bond between the n-th and the (n+1)-th sites such that the truncated state ``` |ψ̃⟩ = M[1]-←-...-←-M̃[n]-←-M̃[n+1]-←-...-←-M[N] @@ -120,24 +120,24 @@ Perform QR decomposition through a `GenericMPSTensor` ``` """ function qr_through( - R0::MPSBondTensor, M::GenericMPSTensor{S, N}; normalize::Bool = true - ) where {S, N} + R0::MPSBondTensor, M::GenericMPSTensor{S,N}; normalize::Bool=true +) where {S,N} @assert !isdual(codomain(R0, 1)) @assert !isdual(domain(M, 1)) && !isdual(codomain(M, 1)) pR = (codomainind(R0), domainind(R0)) - pM = ((1,), Tuple(2:(N + 1))) + pM = ((1,), Tuple(2:(N+1))) pRM = (codomainind(M), domainind(M)) A = tensorcontract(R0, pR, false, M, pM, false, pRM) - _, r = left_orth!(A; positive = true) + _, r = left_orth!(A; positive=true) normalize && normalize!(r, Inf) return r end # for `M` at the left end of the MPS function qr_through( - ::Nothing, M::GenericMPSTensor{S, N}; normalize::Bool = true - ) where {S, N} + ::Nothing, M::GenericMPSTensor{S,N}; normalize::Bool=true +) where {S,N} @assert !isdual(domain(M, 1)) - _, r = left_orth(M; positive = true) + _, r = left_orth(M; positive=true) normalize && normalize!(r, Inf) return r end @@ -151,25 +151,25 @@ Perform LQ decomposition through a `GenericMPSTensor` ``` """ function lq_through( - M::GenericMPSTensor{S, N}, L1::MPSBondTensor; normalize::Bool = true - ) where {S, N} + M::GenericMPSTensor{S,N}, L1::MPSBondTensor; normalize::Bool=true +) where {S,N} @assert !isdual(domain(L1, 1)) @assert !isdual(codomain(M, 1)) && !isdual(domain(M, 1)) pM = (codomainind(M), domainind(M)) pL = (codomainind(L1), domainind(L1)) - pML = ((1,), Tuple(2:(N + 1))) + pML = ((1,), ntuple(i -> i + 1, N)) A = tensorcontract(M, pM, false, L1, pL, false, pML) - l, _ = right_orth!(A; positive = true) + l, _ = right_orth!(A; positive=true) normalize && normalize!(l, Inf) return l end # for `M` at the right end of the MPS function lq_through( - M::GenericMPSTensor{S, N}, ::Nothing; normalize::Bool = true - ) where {S, N} + M::GenericMPSTensor{S,N}, ::Nothing; normalize::Bool=true +) where {S,N} @assert !isdual(codomain(M, 1)) - A = permute(M, ((1,), Tuple(2:(N + 1))); copy = true) - l, _ = right_orth!(A; positive = true) + A = permute(M, ((1,), ntuple(i -> i + 1, N)); copy=true) + l, _ = right_orth!(A; positive=true) normalize && normalize!(l, Inf) return l end @@ -177,26 +177,23 @@ end """ Given a cluster `Ms`, find all `R`, `L` matrices on each internal bond """ -function _get_allRLs(Ms::Vector{T}) where {T <: GenericMPSTensor} +function _get_allRLs(vertices::Vector{T}) where {T<:GenericMPSTensor} # M1 -- (R1,L1) -- M2 -- (R2,L2) -- M3 - N = length(Ms) + N = length(vertices) # get the first R and the last L - R_first = qr_through(nothing, Ms[1]; normalize = true) - L_last = lq_through(Ms[N], nothing; normalize = true) - Rs = Vector{typeof(R_first)}(undef, N - 1) - Ls = Vector{typeof(L_last)}(undef, N - 1) - Rs[1], Ls[end] = R_first, L_last + Rs = [qr_through(nothing, first(vertices); normalize=true)] + Ls = [lq_through(last(vertices), nothing; normalize=true)] + # get remaining R, L matrices - for n in 2:(N - 1) - m = N - n + 1 - Rs[n] = qr_through(Rs[n - 1], Ms[n]; normalize = true) - Ls[m - 1] = lq_through(Ms[m], Ls[m]; normalize = true) + for n in 2:(N-1) + push!(Rs, qr_through(last(Rs), vertices[n]; normalize=true)) + pushfirst!(Ls, lq_through(vertices[N - n + 1], first(Ls); normalize=true)) end return Rs, Ls end """ -Given the tensors `R`, `L` on a bond, construct +Given the tensors `R`, `L` on a bond, construct the projectors `Pa`, `Pb` and the new bond weight `s` such that the contraction of `Pa`, `s`, `Pb` is identity when `trunc = notrunc`, @@ -207,9 +204,9 @@ The arrows between `Pa`, `s`, `Pb` are ``` """ function _proj_from_RL( - r::MPSBondTensor, l::MPSBondTensor; - trunc::TruncationStrategy = notrunc() - ) + r::MPSBondTensor, l::MPSBondTensor; + trunc::TruncationStrategy=notrunc() +) @assert isdual(domain(r, 1)) == isdual(codomain(r, 1)) == false @assert isdual(domain(l, 1)) == isdual(codomain(l, 1)) == false rl = r * l @@ -219,30 +216,30 @@ function _proj_from_RL( return Pa, s, Pb, ϵ end + +get_proj_trunc(t::TruncationStrategy, ::ElementarySpace) = t +function get_proj_trunc(::FixedSpaceTruncation, v::ElementarySpace) + isdual(tspace) ? truncspace(flip(tspace)) : truncspace(tspace) +end """ Given a cluster `Ms`, find all projectors `Pa`, `Pb` and Schmidt weights `wts` on internal bonds. """ function _get_allprojs( - Ms::Vector{T}, truncs::Vector{E} - ) where {T <: GenericMPSTensor, E <: TruncationStrategy} - N = length(Ms) - Rs, Ls = _get_allRLs(Ms) + vertices::Vector{T}, truncs::Vector{E} +) where {T<:GenericMPSTensor,E<:TruncationStrategy} + N = length(vertices) + Rs, Ls = _get_allRLs(vertices) @assert length(truncs) == N - 1 - projs_errs = map(1:(N - 1)) do i - trunc = if isa(truncs[i], FixedSpaceTruncation) - tspace = space(Ms[i + 1], 1) - isdual(tspace) ? truncspace(flip(tspace)) : truncspace(tspace) - else - truncs[i] - end + projs_errs = map(1:(N-1)) do i + trunc = get_proj_trunc(truncs[i], space(vertices[i+1], 1)) return _proj_from_RL(Rs[i], Ls[i]; trunc) end - Pas = getindex.(projs_errs, 1) - wts = getindex.(projs_errs, 2) - Pbs = getindex.(projs_errs, 3) + Pas = first.(projs_errs) + wts = map(t -> t[2], projs_errs) + Pbs = map(t -> t[3], projs_errs) # local truncation error on each bond - ϵs = getindex.(projs_errs, 4) + ϵs = last.(projs_errs) return Pas, Pbs, wts, ϵs end @@ -250,14 +247,14 @@ end Flip the virtual arrows in the MPS `Ms` """ function _flip_virtuals!( - Ms::Vector{T}, flips::Vector{Bool}; inv::Bool = false - ) where {T <: GenericMPSTensor} + Ms::Vector{T}, flips::Vector{Bool}; inv::Bool=false +) where {T<:GenericMPSTensor} @assert length(flips) == length(Ms) - 1 for (n, flip) in enumerate(flips) !flip && continue - M1, M2 = Ms[n], Ms[n + 1] + M1, M2 = Ms[n], Ms[n+1] Ms[n] = TensorKit.flip(M1, numind(M1); inv) - Ms[n + 1] = TensorKit.flip(M2, 1; inv) + Ms[n+1] = TensorKit.flip(M2, 1; inv) end return Ms end @@ -266,17 +263,17 @@ end Find projectors to truncate internal bonds of the cluster `Ms`. """ function _cluster_truncate!( - Ms::Vector{T}, truncs::Vector{E} - ) where {T <: GenericMPSTensor, E <: TruncationStrategy} - Pas, Pbs, wts, ϵs = _get_allprojs(Ms, truncs) + vertices::Vector{T}, truncs::Vector{E} +) where {T<:GenericMPSTensor,E<:TruncationStrategy} + Pas, Pbs, wts, ϵs = _get_allprojs(vertices, truncs) # apply projectors # M1 -- (Pa1,wt1,Pb1) -- M2 -- (Pa2,wt2,Pb2) -- M3 for (i, (Pa, Pb)) in enumerate(zip(Pas, Pbs)) - Ms[i] = Ms[i] * twistdual(Pa, 1) + vertices[i] = vertices[i] * twistdual(Pa, 1) pP = ((1,), (2,)) - pM = ((1,), Tuple(2:numind(Ms[i + 1]))) - pPM = (codomainind(Ms[i + 1]), domainind(Ms[i + 1])) - Ms[i + 1] = tensorcontract(Pb, pP, false, Ms[i + 1], pM, false, pPM) + pM = ((1,), ntuple(i -> i + 1, numind(eltype(vertices)) - 1)) + pPM = (codomainind(vertices[i+1]), domainind(vertices[i+1])) + vertices[i+1] = tensorcontract(Pb, pP, false, vertices[i+1], pM, false, pPM) end return wts, ϵs, Pas, Pbs end @@ -296,8 +293,8 @@ e.g. Cluster in PEPS with `gate_ax = 1`: ``` """ function _apply_gatempo!( - Ms::Vector{T1}, gs::Vector{T2}; gate_ax::Int = 1 - ) where {T1 <: GenericMPSTensor{<:ElementarySpace, 4}, T2 <: AbstractTensorMap} + Ms::Vector{T1}, gs::Vector{T2}; gate_ax::Int=1 +) where {T1<:GenericMPSTensor{<:ElementarySpace,4},T2<:AbstractTensorMap} @assert length(Ms) == length(gs) @assert gate_ax == 1 @assert all(!isdual(space(g, 1)) for g in gs[2:end]) @@ -323,10 +320,10 @@ function _apply_gatempo!( fr = fusers[i] @tensor (Ms[i])[-1 -2 -3 -4; -5] := M[-1 1 -3 -4; 2] * g[-2 1 3] * fr'[2 3; -5] elseif i == length(Ms) - fl = fusers[i - 1] + fl = fusers[i-1] @tensor (Ms[i])[-1 -2 -3 -4; -5] := fl[-1; 2 3] * M[2 1 -3 -4; -5] * g[3 -2 1] else - fl, fr = fusers[i - 1], fusers[i] + fl, fr = fusers[i-1], fusers[i] @tensor (Ms[i])[-1 -2 -3 -4; -5] := fl[-1; 2 3] * M[2 1 -3 -4; 4] * g[3 -2 1 5] * fr'[4 5; -5] end end @@ -334,8 +331,8 @@ function _apply_gatempo!( end function _apply_gatempo!( - Ms::Vector{T1}, gs::Vector{T2}; gate_ax::Int = 1 - ) where {T1 <: GenericMPSTensor{<:ElementarySpace, 5}, T2 <: AbstractTensorMap} + Ms::Vector{T1}, gs::Vector{T2}; gate_ax::Int=1 +) where {T1<:GenericMPSTensor{<:ElementarySpace,5},T2<:AbstractTensorMap} @assert length(Ms) == length(gs) @assert gate_ax == 1 || gate_ax == 2 @assert all(!isdual(space(g, 1)) for g in gs[2:end]) @@ -376,14 +373,14 @@ function _apply_gatempo!( @tensor (Ms[i])[-1 -2 -3 -4 -5; -6] := M[-1 -2 1 -4 -5; 2] * g[1 -3 3] * fr'[2 3; -6] end elseif i == length(Ms) - fl = fusers[i - 1] + fl = fusers[i-1] if gate_ax == 1 @tensor (Ms[i])[-1 -2 -3 -4 -5; -6] := fl[-1; 2 3] * M[2 1 -3 -4 -5; -6] * g[3 -2 1] else @tensor (Ms[i])[-1 -2 -3 -4 -5; -6] := fl[-1; 2 3] * M[2 -2 1 -4 -5; -6] * g[3 1 -3] end else - fl, fr = fusers[i - 1], fusers[i] + fl, fr = fusers[i-1], fusers[i] if gate_ax == 1 @tensor (Ms[i])[-1 -2 -3 -4 -5; -6] := fl[-1; 2 3] * M[2 1 -3 -4 -5; 4] * g[3 -2 1 5] * fr'[4 5; -6] else diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 0f5213907..6e795f948 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -119,9 +119,9 @@ function cluster_truncate!(vertices, truncs, ::InfinitePEPO) return new_vertices, wts, ϵs end -function cluster_truncate!(Ms, truncs, ::InfinitePEPS) - wts, ϵs, = _cluster_truncate!(Ms2, truncs) - return Ms, wts, ϵs +function cluster_truncate!(vertices, truncs, ::InfinitePEPS) + wts, ϵs, = _cluster_truncate!(vertices, truncs) + return vertices, wts, ϵs end """ Simple update with an N-site MPO `gate` (N ≥ 2). @@ -148,12 +148,9 @@ function _su_iter_mpo!( right_invperm = invbiperm(right_perm, Val(n_physical_axes)) # middle tensors: permuted to MPS form in _get_mid mids = map(i -> _get_mid(state, sites[i], out_axs[i - 1], in_axs[i], env), 2:(n_sites - 1)) - vertices = [left_M, getindex.(mids, 1)..., right_M] # TODO remove - # Ms has well defined eltype Here + vertices = [left_M, first.(mids)..., right_M] # TODO remove + #vertices has well defined eltype here # issue it is redefined later with Any eltype - open_vaxs = [left_vaxs, getindex.(mids, 2)..., right_vaxs] # TODO removve - # open_vaxs however cannot be stable - invperms = [left_invperm, getindex.(mids, 3)..., right_invperm] flips = push!([isdual(space(first(x), 1)) for x in mids], isdual(space(right_M, 1))) # flip virtual arrows in `vertices` to ← _flip_virtuals!(vertices, flips) @@ -175,15 +172,25 @@ function _su_iter_mpo!( for (wt, (bond, rev), flip) in zip(wts, bond_revs, flips) wt_new = flip ? _fliptwist_s(wt) : wt wt_new = rev ? transpose(wt_new) : wt_new - @assert all(wt_new.data .>= 0) env[CartesianIndex(bond)] = normalize!(wt_new, Inf) end - for (vertex, s, invperm, vaxs) in zip(new_vertices, sites, invperms, open_vaxs) + + # left + s′ = CartesianIndex(mod1(first(sites)[1], Nr), mod1(first(sites)[2], Nc)) + leftpermuted = permute(first(new_vertices), left_invperm) + state[s′] = absorb_weight(leftpermuted, env, s′, left_vaxs; inv = true) + + # right + s′ = CartesianIndex(mod1(last(sites)[1], Nr), mod1(last(sites)[2], Nc)) + rightpermuted = permute(last(new_vertices), right_invperm) + state[s′] = absorb_weight(rightpermuted, env, s′, right_vaxs; inv = true) + + for (vertex, s, invperm, vaxs) in zip(new_vertices[(begin + 1):(end - 1)], sites[(begin + 1):(end - 1)], map(t -> t[3], mids), map(t -> t[2], mids)) s′ = CartesianIndex(mod1(s[1], Nr), mod1(s[2], Nc)) # restore original axes order permuted = permute(vertex, invperm) # remove weights on open axes of the cluster and update state - state[s′] = absorb_weight(permuted, env, s′[1], s′[2], vaxs; inv = true) + state[s′] = absorb_weight(permuted, env, s′, vaxs; inv = true) end return maximum(ϵs) end diff --git a/src/environments/suweight.jl b/src/environments/suweight.jl index 36825f959..5afd764fa 100644 --- a/src/environments/suweight.jl +++ b/src/environments/suweight.jl @@ -77,7 +77,7 @@ end """ SUWeight(Nspace::S, Espace::S=Nspace; unitcell::Tuple{Int,Int}=(1, 1)) where {S<:ElementarySpace} -Create a trivial `SUWeight` by specifying its vertical (north) and horizontal (east) +Create a trivial `SUWeight` by specifying its vertical (north) and horizontal (east) as `ElementarySpace`s) and unit cell size. """ function SUWeight( @@ -170,7 +170,7 @@ end absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, ax::Int; inv::Bool = false) absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, ax::NTuple{N, Int}; inv::Bool = false) -Absorb or remove (in a twist-free way) the square root of environment weight +Absorb or remove (in a twist-free way) the square root of environment weight on an axis of the PEPS/PEPO tensor `t` known to be at position (`row`, `col`) in the unit cell of an InfinitePEPS/InfinitePEPO. The involved weights are ``` @@ -246,6 +246,13 @@ function absorb_first_weight(t::Union{PEPSTensor, PEPOTensor}, wt, vax) return new_legs, t2 end +function absorb_weight( + t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, + rowcol::CartesianIndex{2}, virt_axes::NTuple{N, Int}; inv::Bool = false + ) where {N} + return absorb_weight(t, weights, rowcol[1], rowcol[2], virt_axes; inv) +end + function absorb_weight( t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, virt_axes::NTuple{N, Int}; inv::Bool = false @@ -407,7 +414,7 @@ end """ CTMRGEnv(wts::SUWeight) -Construct a CTMRG environment with a trivial environment space +Construct a CTMRG environment with a trivial environment space (bond dimension χ = 1) from SUWeight `wts`, which has the same real scalartype as ``wts`. """ From 2c30b629e9061cd4a0f386d42a14f4ba46437f11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 24 Apr 2026 17:32:06 +0200 Subject: [PATCH 14/24] runic --- src/algorithms/time_evolution/apply_mpo.jl | 88 +++++++++---------- .../time_evolution/simpleupdate3site.jl | 4 +- 2 files changed, 45 insertions(+), 47 deletions(-) diff --git a/src/algorithms/time_evolution/apply_mpo.jl b/src/algorithms/time_evolution/apply_mpo.jl index a97575016..3b4cd37ca 100644 --- a/src/algorithms/time_evolution/apply_mpo.jl +++ b/src/algorithms/time_evolution/apply_mpo.jl @@ -120,24 +120,24 @@ Perform QR decomposition through a `GenericMPSTensor` ``` """ function qr_through( - R0::MPSBondTensor, M::GenericMPSTensor{S,N}; normalize::Bool=true -) where {S,N} + R0::MPSBondTensor, M::GenericMPSTensor{S, N}; normalize::Bool = true + ) where {S, N} @assert !isdual(codomain(R0, 1)) @assert !isdual(domain(M, 1)) && !isdual(codomain(M, 1)) pR = (codomainind(R0), domainind(R0)) - pM = ((1,), Tuple(2:(N+1))) + pM = ((1,), Tuple(2:(N + 1))) pRM = (codomainind(M), domainind(M)) A = tensorcontract(R0, pR, false, M, pM, false, pRM) - _, r = left_orth!(A; positive=true) + _, r = left_orth!(A; positive = true) normalize && normalize!(r, Inf) return r end # for `M` at the left end of the MPS function qr_through( - ::Nothing, M::GenericMPSTensor{S,N}; normalize::Bool=true -) where {S,N} + ::Nothing, M::GenericMPSTensor{S, N}; normalize::Bool = true + ) where {S, N} @assert !isdual(domain(M, 1)) - _, r = left_orth(M; positive=true) + _, r = left_orth(M; positive = true) normalize && normalize!(r, Inf) return r end @@ -151,25 +151,25 @@ Perform LQ decomposition through a `GenericMPSTensor` ``` """ function lq_through( - M::GenericMPSTensor{S,N}, L1::MPSBondTensor; normalize::Bool=true -) where {S,N} + M::GenericMPSTensor{S, N}, L1::MPSBondTensor; normalize::Bool = true + ) where {S, N} @assert !isdual(domain(L1, 1)) @assert !isdual(codomain(M, 1)) && !isdual(domain(M, 1)) pM = (codomainind(M), domainind(M)) pL = (codomainind(L1), domainind(L1)) pML = ((1,), ntuple(i -> i + 1, N)) A = tensorcontract(M, pM, false, L1, pL, false, pML) - l, _ = right_orth!(A; positive=true) + l, _ = right_orth!(A; positive = true) normalize && normalize!(l, Inf) return l end # for `M` at the right end of the MPS function lq_through( - M::GenericMPSTensor{S,N}, ::Nothing; normalize::Bool=true -) where {S,N} + M::GenericMPSTensor{S, N}, ::Nothing; normalize::Bool = true + ) where {S, N} @assert !isdual(codomain(M, 1)) - A = permute(M, ((1,), ntuple(i -> i + 1, N)); copy=true) - l, _ = right_orth!(A; positive=true) + A = permute(M, ((1,), ntuple(i -> i + 1, N)); copy = true) + l, _ = right_orth!(A; positive = true) normalize && normalize!(l, Inf) return l end @@ -177,17 +177,17 @@ end """ Given a cluster `Ms`, find all `R`, `L` matrices on each internal bond """ -function _get_allRLs(vertices::Vector{T}) where {T<:GenericMPSTensor} +function _get_allRLs(vertices::Vector{T}) where {T <: GenericMPSTensor} # M1 -- (R1,L1) -- M2 -- (R2,L2) -- M3 N = length(vertices) # get the first R and the last L - Rs = [qr_through(nothing, first(vertices); normalize=true)] - Ls = [lq_through(last(vertices), nothing; normalize=true)] + Rs = [qr_through(nothing, first(vertices); normalize = true)] + Ls = [lq_through(last(vertices), nothing; normalize = true)] # get remaining R, L matrices - for n in 2:(N-1) - push!(Rs, qr_through(last(Rs), vertices[n]; normalize=true)) - pushfirst!(Ls, lq_through(vertices[N - n + 1], first(Ls); normalize=true)) + for n in 2:(N - 1) + push!(Rs, qr_through(last(Rs), vertices[n]; normalize = true)) + pushfirst!(Ls, lq_through(vertices[N - n + 1], first(Ls); normalize = true)) end return Rs, Ls end @@ -204,9 +204,9 @@ The arrows between `Pa`, `s`, `Pb` are ``` """ function _proj_from_RL( - r::MPSBondTensor, l::MPSBondTensor; - trunc::TruncationStrategy=notrunc() -) + r::MPSBondTensor, l::MPSBondTensor; + trunc::TruncationStrategy = notrunc() + ) @assert isdual(domain(r, 1)) == isdual(codomain(r, 1)) == false @assert isdual(domain(l, 1)) == isdual(codomain(l, 1)) == false rl = r * l @@ -219,20 +219,20 @@ end get_proj_trunc(t::TruncationStrategy, ::ElementarySpace) = t function get_proj_trunc(::FixedSpaceTruncation, v::ElementarySpace) - isdual(tspace) ? truncspace(flip(tspace)) : truncspace(tspace) + return isdual(tspace) ? truncspace(flip(tspace)) : truncspace(tspace) end """ Given a cluster `Ms`, find all projectors `Pa`, `Pb` and Schmidt weights `wts` on internal bonds. """ function _get_allprojs( - vertices::Vector{T}, truncs::Vector{E} -) where {T<:GenericMPSTensor,E<:TruncationStrategy} + vertices::Vector{T}, truncs::Vector{E} + ) where {T <: GenericMPSTensor, E <: TruncationStrategy} N = length(vertices) Rs, Ls = _get_allRLs(vertices) @assert length(truncs) == N - 1 - projs_errs = map(1:(N-1)) do i - trunc = get_proj_trunc(truncs[i], space(vertices[i+1], 1)) + projs_errs = map(1:(N - 1)) do i + trunc = get_proj_trunc(truncs[i], space(vertices[i + 1], 1)) return _proj_from_RL(Rs[i], Ls[i]; trunc) end Pas = first.(projs_errs) @@ -247,14 +247,14 @@ end Flip the virtual arrows in the MPS `Ms` """ function _flip_virtuals!( - Ms::Vector{T}, flips::Vector{Bool}; inv::Bool=false -) where {T<:GenericMPSTensor} + Ms::Vector{T}, flips::Vector{Bool}; inv::Bool = false + ) where {T <: GenericMPSTensor} @assert length(flips) == length(Ms) - 1 for (n, flip) in enumerate(flips) !flip && continue - M1, M2 = Ms[n], Ms[n+1] + M1, M2 = Ms[n], Ms[n + 1] Ms[n] = TensorKit.flip(M1, numind(M1); inv) - Ms[n+1] = TensorKit.flip(M2, 1; inv) + Ms[n + 1] = TensorKit.flip(M2, 1; inv) end return Ms end @@ -263,8 +263,8 @@ end Find projectors to truncate internal bonds of the cluster `Ms`. """ function _cluster_truncate!( - vertices::Vector{T}, truncs::Vector{E} -) where {T<:GenericMPSTensor,E<:TruncationStrategy} + vertices::Vector{T}, truncs::Vector{E} + ) where {T <: GenericMPSTensor, E <: TruncationStrategy} Pas, Pbs, wts, ϵs = _get_allprojs(vertices, truncs) # apply projectors # M1 -- (Pa1,wt1,Pb1) -- M2 -- (Pa2,wt2,Pb2) -- M3 @@ -272,8 +272,8 @@ function _cluster_truncate!( vertices[i] = vertices[i] * twistdual(Pa, 1) pP = ((1,), (2,)) pM = ((1,), ntuple(i -> i + 1, numind(eltype(vertices)) - 1)) - pPM = (codomainind(vertices[i+1]), domainind(vertices[i+1])) - vertices[i+1] = tensorcontract(Pb, pP, false, vertices[i+1], pM, false, pPM) + pPM = (codomainind(vertices[i + 1]), domainind(vertices[i + 1])) + vertices[i + 1] = tensorcontract(Pb, pP, false, vertices[i + 1], pM, false, pPM) end return wts, ϵs, Pas, Pbs end @@ -293,8 +293,8 @@ e.g. Cluster in PEPS with `gate_ax = 1`: ``` """ function _apply_gatempo!( - Ms::Vector{T1}, gs::Vector{T2}; gate_ax::Int=1 -) where {T1<:GenericMPSTensor{<:ElementarySpace,4},T2<:AbstractTensorMap} + Ms::Vector{T1}, gs::Vector{T2}; gate_ax::Int = 1 + ) where {T1 <: GenericMPSTensor{<:ElementarySpace, 4}, T2 <: AbstractTensorMap} @assert length(Ms) == length(gs) @assert gate_ax == 1 @assert all(!isdual(space(g, 1)) for g in gs[2:end]) @@ -320,10 +320,10 @@ function _apply_gatempo!( fr = fusers[i] @tensor (Ms[i])[-1 -2 -3 -4; -5] := M[-1 1 -3 -4; 2] * g[-2 1 3] * fr'[2 3; -5] elseif i == length(Ms) - fl = fusers[i-1] + fl = fusers[i - 1] @tensor (Ms[i])[-1 -2 -3 -4; -5] := fl[-1; 2 3] * M[2 1 -3 -4; -5] * g[3 -2 1] else - fl, fr = fusers[i-1], fusers[i] + fl, fr = fusers[i - 1], fusers[i] @tensor (Ms[i])[-1 -2 -3 -4; -5] := fl[-1; 2 3] * M[2 1 -3 -4; 4] * g[3 -2 1 5] * fr'[4 5; -5] end end @@ -331,8 +331,8 @@ function _apply_gatempo!( end function _apply_gatempo!( - Ms::Vector{T1}, gs::Vector{T2}; gate_ax::Int=1 -) where {T1<:GenericMPSTensor{<:ElementarySpace,5},T2<:AbstractTensorMap} + Ms::Vector{T1}, gs::Vector{T2}; gate_ax::Int = 1 + ) where {T1 <: GenericMPSTensor{<:ElementarySpace, 5}, T2 <: AbstractTensorMap} @assert length(Ms) == length(gs) @assert gate_ax == 1 || gate_ax == 2 @assert all(!isdual(space(g, 1)) for g in gs[2:end]) @@ -373,14 +373,14 @@ function _apply_gatempo!( @tensor (Ms[i])[-1 -2 -3 -4 -5; -6] := M[-1 -2 1 -4 -5; 2] * g[1 -3 3] * fr'[2 3; -6] end elseif i == length(Ms) - fl = fusers[i-1] + fl = fusers[i - 1] if gate_ax == 1 @tensor (Ms[i])[-1 -2 -3 -4 -5; -6] := fl[-1; 2 3] * M[2 1 -3 -4 -5; -6] * g[3 -2 1] else @tensor (Ms[i])[-1 -2 -3 -4 -5; -6] := fl[-1; 2 3] * M[2 -2 1 -4 -5; -6] * g[3 1 -3] end else - fl, fr = fusers[i-1], fusers[i] + fl, fr = fusers[i - 1], fusers[i] if gate_ax == 1 @tensor (Ms[i])[-1 -2 -3 -4 -5; -6] := fl[-1; 2 3] * M[2 1 -3 -4 -5; 4] * g[3 -2 1 5] * fr'[4 5; -6] else diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 6e795f948..94a1c1c48 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -149,8 +149,6 @@ function _su_iter_mpo!( # middle tensors: permuted to MPS form in _get_mid mids = map(i -> _get_mid(state, sites[i], out_axs[i - 1], in_axs[i], env), 2:(n_sites - 1)) vertices = [left_M, first.(mids)..., right_M] # TODO remove - #vertices has well defined eltype here - # issue it is redefined later with Any eltype flips = push!([isdual(space(first(x), 1)) for x in mids], isdual(space(right_M, 1))) # flip virtual arrows in `vertices` to ← _flip_virtuals!(vertices, flips) @@ -185,7 +183,7 @@ function _su_iter_mpo!( rightpermuted = permute(last(new_vertices), right_invperm) state[s′] = absorb_weight(rightpermuted, env, s′, right_vaxs; inv = true) - for (vertex, s, invperm, vaxs) in zip(new_vertices[(begin + 1):(end - 1)], sites[(begin + 1):(end - 1)], map(t -> t[3], mids), map(t -> t[2], mids)) + for (vertex, s, invperm, vaxs) in zip(new_vertices[(begin + 1):(end - 1)], sites[(begin + 1):(end - 1)], map(t -> t[3], mids), map(t -> t[2], mids)) s′ = CartesianIndex(mod1(s[1], Nr), mod1(s[2], Nc)) # restore original axes order permuted = permute(vertex, invperm) From aee67d12d9a3953588d25ede72bba3f5c0cc7244 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 28 Apr 2026 18:51:11 +0200 Subject: [PATCH 15/24] use TupleTools --- Project.toml | 4 +++- src/PEPSKit.jl | 1 + src/algorithms/time_evolution/simpleupdate.jl | 4 ++-- src/algorithms/time_evolution/simpleupdate3site.jl | 14 ++------------ src/environments/suweight.jl | 2 +- 5 files changed, 9 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index 2dec89104..4074ca592 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" +TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -29,7 +30,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Accessors = "0.1" ChainRulesCore = "1.0" ChainRulesTestUtils = "1.13" -SafeTestsets = "0.1" Compat = "3.46, 4.2" DocStringExtensions = "0.9.3" FiniteDifferences = "0.12" @@ -44,10 +44,12 @@ OptimKit = "0.4" Printf = "1" QuadGK = "2.11.1" Random = "1" +SafeTestsets = "0.1" Statistics = "1" TensorKit = "0.16.2" TensorOperations = "5" TestExtras = "0.3" +TupleTools = "1.6.0" VectorInterface = "0.4, 0.5" Zygote = "0.6, 0.7" julia = "1.10" diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index f1c537ee6..4606a24c3 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -25,6 +25,7 @@ using KrylovKit: Lanczos, BlockLanczos using TensorOperations, OptimKit using ChainRulesCore, Zygote using LoggingExtras +import TupleTools using MPSKit using MPSKit: MPSTensor, MPOTensor, GenericMPSTensor, MPSBondTensor, ProductTransferMatrix diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index 5b1830c2c..03a6de719 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -90,7 +90,7 @@ function _get_left( env::SUWeight ) Nr, Nc = size(state) - open_vaxs = _filtered_oneto(in_ax, Val(4)) + open_vaxs = TupleTools.deleteat((1, 2, 3, 4), in_ax) s = mod1(site[1], Nr), mod1(site[2], Nc) t = absorb_weight(state[s...], env, s[1], s[2], open_vaxs) Nax = 4 + numout(eltype(state)) @@ -108,7 +108,7 @@ function _get_right( env::SUWeight ) Nr, Nc = size(state) - open_vaxs = _filtered_oneto(out_ax, Val(4)) + open_vaxs = TupleTools.deleteat((1, 2, 3, 4), out_ax) s = mod1(site[1], Nr), mod1(site[2], Nc) t = absorb_weight(state[s...], env, s[1], s[2], open_vaxs) Nax = 4 + numout(eltype(state)) diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 94a1c1c48..ce502f717 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -63,23 +63,13 @@ function _nn_bondrev(site1::CartesianIndex{2}, site2::CartesianIndex{2}, (Nrow, end end -""" -Return a size N-k tuple with values 1 to N but the missing ones. Accept k=1 and k=2. -""" -function _filtered_oneto(i, ::Val{N}) where {N} - return ntuple(k -> k < i ? k : k + 1, N - 1) -end -function _filtered_oneto(i, j, ::Val{N}) where {N} - lo, hi = minmax(i, j) - return ntuple(k -> k < lo ? k : k < hi - 1 ? k + 1 : k + 2, N - 2) -end """ Find the permutation to permute `out_ax`, `in_ax` legs to the first and the last position of a tensor with `Nax` legs, then assign the last leg to domain, and the others to codomain. """ function _get_mpo_perm(out_ax::Integer, in_ax::Integer, ::Val{Nax}) where {Nax} - perm = _filtered_oneto(out_ax, in_ax, Val(Nax)) + perm = TupleTools.deleteat(ntuple(identity, Nax), (out_ax, in_ax)) return (out_ax, perm...), (in_ax,) end @@ -95,7 +85,7 @@ function _get_mid( Nr, Nc = size(state) n_physical_axes = numout(eltype(unitcell(state))) Nax = Val(4 + n_physical_axes) - open_vaxs = _filtered_oneto(out_ax, in_ax, Val(4)) + open_vaxs = TupleTools.deleteat((1, 2, 3, 4), (out_ax, in_ax)) perm = _get_mpo_perm(out_ax + n_physical_axes, in_ax + n_physical_axes, Nax) invperm = invbiperm(perm, Val(n_physical_axes)) s = mod1(site[1], Nr), mod1(site[2], Nc) diff --git a/src/environments/suweight.jl b/src/environments/suweight.jl index 5afd764fa..a242de783 100644 --- a/src/environments/suweight.jl +++ b/src/environments/suweight.jl @@ -233,7 +233,7 @@ function biperm_absorb_weight(legs::NTuple{N, Int}, vax::Int) where {N} @assert N == 5 || N == 6 nin = N - 4 a = vax + nin - codomain_axes = _filtered_oneto(a, Val(N)) + codomain_axes = TupleTools.deleteat(ntuple(identity, N), a) biperm = (map(i -> findfirst(==(i), legs)::Int, codomain_axes), (findfirst(==(a), legs)::Int,)) new_legs = (ntuple(i -> legs[biperm[1][i]], N - 1)..., a) return new_legs, biperm From c0f7070e047db2d118af91425db4d1eb8cadb583 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 28 Apr 2026 21:23:30 +0200 Subject: [PATCH 16/24] explicit circuit as varname --- src/algorithms/time_evolution/simpleupdate.jl | 6 +++--- src/algorithms/time_evolution/time_evolve.jl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index 03a6de719..9eff366da 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -60,10 +60,10 @@ function TimeEvolver( _timeevol_sanity_check(psi0, physicalspace(H), alg) dt′ = _get_dt(psi0, dt, alg.imaginary_time) # create Trotter gates - gate = trotterize(H, dt′; symmetrize_gates, force_mpo = alg.force_mpo) + circ = trotterize(H, dt′; symmetrize_gates, force_mpo = alg.force_mpo) state = SUState(0, t0, psi0, env0) # TODO: check gates for bipartite case - return TimeEvolver(alg, dt, nstep, gate, state) + return TimeEvolver(alg, dt, nstep, circ, state) end function _bond_rotation(x, bonddir::Int, rev::Bool; inv::Bool = false) @@ -202,7 +202,7 @@ end function Base.iterate(it::TimeEvolver{<:SimpleUpdate}, state = it.state) iter, t = state.iter, state.t (iter == it.nstep) && return nothing - psi, env, ϵ = su_iter(state.psi, it.gate, it.alg, state.env) + psi, env, ϵ = su_iter(state.psi, it.circuit, it.alg, state.env) # update internal state iter += 1 t += it.dt diff --git a/src/algorithms/time_evolution/time_evolve.jl b/src/algorithms/time_evolution/time_evolve.jl index ee5f236c1..9679945a8 100644 --- a/src/algorithms/time_evolution/time_evolve.jl +++ b/src/algorithms/time_evolution/time_evolve.jl @@ -14,15 +14,15 @@ Iterator for Trotter-based time evolution of InfinitePEPS or InfinitePEPO. $(TYPEDFIELDS) """ -mutable struct TimeEvolver{TE <: TimeEvolution, G, S, N <: Number} +mutable struct TimeEvolver{TE <: TimeEvolution, C, S, N <: Number} "Time evolution algorithm (currently supported: `SimpleUpdate`)" alg::TE "Trotter time step" dt::N "The number of iteration steps" nstep::Int - "Trotter gates" - gate::G + "LocalCircuit representing trotterized gates" + circuit::C "Internal state of the iterator, including the number of already performed iterations, evolved time, PEPS/PEPO and its environment" state::S From e96c5150577270d0d34786ead1ff97d70f8d3817 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 28 Apr 2026 21:39:30 +0200 Subject: [PATCH 17/24] revert unneeded changes --- src/algorithms/time_evolution/apply_mpo.jl | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/algorithms/time_evolution/apply_mpo.jl b/src/algorithms/time_evolution/apply_mpo.jl index 3b4cd37ca..e5cb46760 100644 --- a/src/algorithms/time_evolution/apply_mpo.jl +++ b/src/algorithms/time_evolution/apply_mpo.jl @@ -181,13 +181,16 @@ function _get_allRLs(vertices::Vector{T}) where {T <: GenericMPSTensor} # M1 -- (R1,L1) -- M2 -- (R2,L2) -- M3 N = length(vertices) # get the first R and the last L - Rs = [qr_through(nothing, first(vertices); normalize = true)] - Ls = [lq_through(last(vertices), nothing; normalize = true)] - + R_first = qr_through(nothing, first(vertices); normalize = true) + L_last = lq_through(last(vertices), nothing; normalize = true) + Rs = Vector{typeof(R_first)}(undef, N - 1) + Ls = Vector{typeof(L_last)}(undef, N - 1) + Rs[1], Ls[end] = R_first, L_last # get remaining R, L matrices for n in 2:(N - 1) - push!(Rs, qr_through(last(Rs), vertices[n]; normalize = true)) - pushfirst!(Ls, lq_through(vertices[N - n + 1], first(Ls); normalize = true)) + m = N - n + 1 + Rs[n] = qr_through(Rs[n - 1], vertices[n]; normalize = true) + Ls[m - 1] = lq_through(vertices[m], Ls[m]; normalize = true) end return Rs, Ls end @@ -222,7 +225,7 @@ function get_proj_trunc(::FixedSpaceTruncation, v::ElementarySpace) return isdual(tspace) ? truncspace(flip(tspace)) : truncspace(tspace) end """ -Given a cluster `Ms`, find all projectors `Pa`, `Pb` +Given a cluster `vertices`, find all projectors `Pa`, `Pb` and Schmidt weights `wts` on internal bonds. """ function _get_allprojs( From 3b53deb33bbcd995a7b39c29ce5b224c1c587749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 28 Apr 2026 21:52:48 +0200 Subject: [PATCH 18/24] uniformize map --- src/algorithms/time_evolution/apply_mpo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/algorithms/time_evolution/apply_mpo.jl b/src/algorithms/time_evolution/apply_mpo.jl index e5cb46760..620a76b0a 100644 --- a/src/algorithms/time_evolution/apply_mpo.jl +++ b/src/algorithms/time_evolution/apply_mpo.jl @@ -238,11 +238,11 @@ function _get_allprojs( trunc = get_proj_trunc(truncs[i], space(vertices[i + 1], 1)) return _proj_from_RL(Rs[i], Ls[i]; trunc) end - Pas = first.(projs_errs) + Pas = map(t -> t[1], projs_errs) wts = map(t -> t[2], projs_errs) Pbs = map(t -> t[3], projs_errs) # local truncation error on each bond - ϵs = last.(projs_errs) + ϵs = map(t -> t[4], projs_errs) return Pas, Pbs, wts, ϵs end From 843028611a451f1a1df11086f4a6ca93f8acf2c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 28 Apr 2026 22:10:33 +0200 Subject: [PATCH 19/24] avoid explit type annotation --- src/environments/suweight.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/environments/suweight.jl b/src/environments/suweight.jl index a242de783..81f4ea60d 100644 --- a/src/environments/suweight.jl +++ b/src/environments/suweight.jl @@ -234,7 +234,8 @@ function biperm_absorb_weight(legs::NTuple{N, Int}, vax::Int) where {N} nin = N - 4 a = vax + nin codomain_axes = TupleTools.deleteat(ntuple(identity, N), a) - biperm = (map(i -> findfirst(==(i), legs)::Int, codomain_axes), (findfirst(==(a), legs)::Int,)) + q = invperm(legs) + biperm = (map(i -> q[i], codomain_axes), (q[a],)) new_legs = (ntuple(i -> legs[biperm[1][i]], N - 1)..., a) return new_legs, biperm end From 6b29aa9b296206beff12008775cd1f9be99dea1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 28 Apr 2026 22:11:52 +0200 Subject: [PATCH 20/24] keep logs in test --- test/timeevol/j1j2_finiteT.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/timeevol/j1j2_finiteT.jl b/test/timeevol/j1j2_finiteT.jl index 72fa6b00e..f0a3c0f93 100644 --- a/test/timeevol/j1j2_finiteT.jl +++ b/test/timeevol/j1j2_finiteT.jl @@ -25,7 +25,7 @@ pepo0 = PEPSKit.infinite_temperature_density_matrix(ham) wts0 = SUWeight(pepo0) # 7 = 1 (spin-0) + 2 x 3 (spin-1) trunc_pepo = truncrank(7) & truncerror(; atol = 1.0e-12) -check_interval = 2^32 +check_interval = 100 dt, nstep = 1.0e-3, 600 # PEPO approach From 650ac76e330cad37061c35e464373ea3dc2dc08e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 28 Apr 2026 22:14:50 +0200 Subject: [PATCH 21/24] remove redundant spactype/sectortype --- src/environments/suweight.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/environments/suweight.jl b/src/environments/suweight.jl index 81f4ea60d..8e8d00ca9 100644 --- a/src/environments/suweight.jl +++ b/src/environments/suweight.jl @@ -133,10 +133,7 @@ Base.axes(W::SUWeight, args...) = axes(W.data, args...) Base.iterate(W::SUWeight, args...) = iterate(W.data, args...) ## spaces -TensorKit.spacetype(w::SUWeight) = spacetype(typeof(w)) TensorKit.spacetype(::Type{T}) where {E, T <: SUWeight{E}} = spacetype(E) -TensorKit.sectortype(w::SUWeight) = sectortype(typeof(w)) -TensorKit.sectortype(::Type{<:SUWeight{T}}) where {T} = sectortype(spacetype(T)) ## (Approximate) equality function Base.:(==)(wts1::SUWeight, wts2::SUWeight) From c431a00b0af68b57007cbfbd1a8a135da13e0619 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 28 Apr 2026 22:17:56 +0200 Subject: [PATCH 22/24] move absorb_weights to its own file --- src/PEPSKit.jl | 1 + src/algorithms/contractions/absorb_weight.jl | 103 ++++++++++++++++++ src/environments/suweight.jl | 104 ------------------- 3 files changed, 104 insertions(+), 104 deletions(-) create mode 100644 src/algorithms/contractions/absorb_weight.jl diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index 4606a24c3..3560064e0 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -83,6 +83,7 @@ include("algorithms/contractions/ctmrg/renormalize_edge.jl") include("algorithms/contractions/ctmrg/contract_site.jl") include("algorithms/contractions/ctmrg/gaugefix.jl") +include("algorithms/contractions/absorb_weight.jl") include("algorithms/contractions/transfer.jl") include("algorithms/contractions/localoperator.jl") include("algorithms/contractions/vumps_contractions.jl") diff --git a/src/algorithms/contractions/absorb_weight.jl b/src/algorithms/contractions/absorb_weight.jl new file mode 100644 index 000000000..bec5fe4d9 --- /dev/null +++ b/src/algorithms/contractions/absorb_weight.jl @@ -0,0 +1,103 @@ +""" + absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, ax::Int; inv::Bool = false) + absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, ax::NTuple{N, Int}; inv::Bool = false) + +Absorb or remove (in a twist-free way) the square root of environment weight +on an axis of the PEPS/PEPO tensor `t` known to be at position (`row`, `col`) +in the unit cell of an InfinitePEPS/InfinitePEPO. The involved weights are +``` + | + [2,r,c] + | + - [1,r,c-1] - T[r,c] - [1,r,c] - + | + [1,r+1,c] + | +``` + +## Arguments + +- `t::Union{PEPSTensor, PEPOTensor}` : PEPSTensor or PEPOTensor to which the weight will be absorbed. +- `weights::SUWeight` : All simple update weights. +- `row::Int` : The row index specifying the position in the tensor network. +- `col::Int` : The column index specifying the position in the tensor network. +- `ax::Int` : The axis into which the weight is absorbed, taking values from 1 to 4, standing for north, east, south, west respectively. + +## Keyword arguments + +- `inv::Bool=false` : If `true`, the inverse square root of the weight is absorbed. + +## Examples + +```julia +# Absorb the weight into the north axis of tensor at position (2, 3) +absorb_weight(t, weights, 2, 3, 1) + +# Absorb the inverse of (i.e. remove) the weight into the east axis +absorb_weight(t, weights, 2, 3, 2; inv=true) +``` +""" +function weight_to_absorb( + weights::SUWeight, row::Int, col::Int, ax::Int; inv::Bool = false + ) + _, Nr, Nc = size(weights) + @assert 1 <= row <= Nr && 1 <= col <= Nc + pow = inv ? -1 / 2 : 1 / 2 + wt = sdiag_pow( + if ax == NORTH + weights[2, row, col] + elseif ax == EAST + weights[1, row, col] + elseif ax == SOUTH + weights[2, _next(row, Nr), col] + else # WEST + weights[1, row, _prev(col, Nc)] + end, + pow, + ) + # make absorption/removal twist-free + twistdual!(wt, 1) + (ax == SOUTH || ax == WEST) && return transpose(wt) # not sure this can be factorized due to twistdual + return wt +end + +function biperm_absorb_weight(legs::NTuple{N, Int}, vax::Int) where {N} + @assert N == 5 || N == 6 + nin = N - 4 + a = vax + nin + codomain_axes = TupleTools.deleteat(ntuple(identity, N), a) + q = invperm(legs) + biperm = (map(i -> q[i], codomain_axes), (q[a],)) + new_legs = (ntuple(i -> legs[biperm[1][i]], N - 1)..., a) + return new_legs, biperm +end + +function absorb_first_weight(t::Union{PEPSTensor, PEPOTensor}, wt, vax) + legs = ntuple(identity, numind(t)) + new_legs, biperm = biperm_absorb_weight(legs, vax) + t2 = permute(t, biperm) * wt + return new_legs, t2 +end + +function absorb_weight( + t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, + rowcol::CartesianIndex{2}, virt_axes::NTuple{N, Int}; inv::Bool = false + ) where {N} + return absorb_weight(t, weights, rowcol[1], rowcol[2], virt_axes; inv) +end + +function absorb_weight( + t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, + row::Int, col::Int, virt_axes::NTuple{N, Int}; inv::Bool = false + ) where {N} + vax = first(virt_axes) + weight_vax = weight_to_absorb(weights, row, col, vax; inv) + legs, t2 = absorb_first_weight(t, weight_vax, vax) + for vax in virt_axes[(begin + 1):end] + legs, biperm = biperm_absorb_weight(legs, vax) + weight_vax = weight_to_absorb(weights, row, col, vax; inv) + t2 = permute(t2, biperm) * weight_vax + end + perm_back = invperm(legs) + return permute(t2, (perm_back[begin:numout(t)], perm_back[(numout(t) + 1):end])) +end diff --git a/src/environments/suweight.jl b/src/environments/suweight.jl index 8e8d00ca9..d69b58000 100644 --- a/src/environments/suweight.jl +++ b/src/environments/suweight.jl @@ -163,110 +163,6 @@ function Base.show(io::IO, ::MIME"text/plain", wts::SUWeight) return nothing end -""" - absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, ax::Int; inv::Bool = false) - absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, ax::NTuple{N, Int}; inv::Bool = false) - -Absorb or remove (in a twist-free way) the square root of environment weight -on an axis of the PEPS/PEPO tensor `t` known to be at position (`row`, `col`) -in the unit cell of an InfinitePEPS/InfinitePEPO. The involved weights are -``` - | - [2,r,c] - | - - [1,r,c-1] - T[r,c] - [1,r,c] - - | - [1,r+1,c] - | -``` - -## Arguments - -- `t::Union{PEPSTensor, PEPOTensor}` : PEPSTensor or PEPOTensor to which the weight will be absorbed. -- `weights::SUWeight` : All simple update weights. -- `row::Int` : The row index specifying the position in the tensor network. -- `col::Int` : The column index specifying the position in the tensor network. -- `ax::Int` : The axis into which the weight is absorbed, taking values from 1 to 4, standing for north, east, south, west respectively. - -## Keyword arguments - -- `inv::Bool=false` : If `true`, the inverse square root of the weight is absorbed. - -## Examples - -```julia -# Absorb the weight into the north axis of tensor at position (2, 3) -absorb_weight(t, weights, 2, 3, 1) - -# Absorb the inverse of (i.e. remove) the weight into the east axis -absorb_weight(t, weights, 2, 3, 2; inv=true) -``` -""" -function weight_to_absorb( - weights::SUWeight, row::Int, col::Int, ax::Int; inv::Bool = false - ) - _, Nr, Nc = size(weights) - @assert 1 <= row <= Nr && 1 <= col <= Nc - pow = inv ? -1 / 2 : 1 / 2 - wt = sdiag_pow( - if ax == NORTH - weights[2, row, col] - elseif ax == EAST - weights[1, row, col] - elseif ax == SOUTH - weights[2, _next(row, Nr), col] - else # WEST - weights[1, row, _prev(col, Nc)] - end, - pow, - ) - # make absorption/removal twist-free - twistdual!(wt, 1) - (ax == SOUTH || ax == WEST) && return transpose(wt) # not sure this can be factorized due to twistdual - return wt -end - -function biperm_absorb_weight(legs::NTuple{N, Int}, vax::Int) where {N} - @assert N == 5 || N == 6 - nin = N - 4 - a = vax + nin - codomain_axes = TupleTools.deleteat(ntuple(identity, N), a) - q = invperm(legs) - biperm = (map(i -> q[i], codomain_axes), (q[a],)) - new_legs = (ntuple(i -> legs[biperm[1][i]], N - 1)..., a) - return new_legs, biperm -end - -function absorb_first_weight(t::Union{PEPSTensor, PEPOTensor}, wt, vax) - legs = ntuple(identity, numind(t)) - new_legs, biperm = biperm_absorb_weight(legs, vax) - t2 = permute(t, biperm) * wt - return new_legs, t2 -end - -function absorb_weight( - t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, - rowcol::CartesianIndex{2}, virt_axes::NTuple{N, Int}; inv::Bool = false - ) where {N} - return absorb_weight(t, weights, rowcol[1], rowcol[2], virt_axes; inv) -end - -function absorb_weight( - t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, - row::Int, col::Int, virt_axes::NTuple{N, Int}; inv::Bool = false - ) where {N} - vax = first(virt_axes) - weight_vax = weight_to_absorb(weights, row, col, vax; inv) - legs, t2 = absorb_first_weight(t, weight_vax, vax) - for vax in virt_axes[(begin + 1):end] - legs, biperm = biperm_absorb_weight(legs, vax) - weight_vax = weight_to_absorb(weights, row, col, vax; inv) - t2 = permute(t2, biperm) * weight_vax - end - perm_back = invperm(legs) - return permute(t2, (perm_back[begin:numout(t)], perm_back[(numout(t) + 1):end])) -end - #= Rotation of SUWeight. Example: 3 x 3 network - Original From b35a54aa73c4dce32ce26c004ac8571d1b8e3be3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Wed, 29 Apr 2026 08:56:05 +0200 Subject: [PATCH 23/24] absorb_weight docstring --- src/algorithms/contractions/absorb_weight.jl | 57 ++++++++++---------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/src/algorithms/contractions/absorb_weight.jl b/src/algorithms/contractions/absorb_weight.jl index bec5fe4d9..63908d30c 100644 --- a/src/algorithms/contractions/absorb_weight.jl +++ b/src/algorithms/contractions/absorb_weight.jl @@ -1,13 +1,13 @@ """ - absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, ax::Int; inv::Bool = false) - absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, ax::NTuple{N, Int}; inv::Bool = false) + absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, rowcol::CartesianIndex{2}, virt_axes::NTuple{N, Int}; inv::Bool = false) + absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, virt_axes::NTuple{N, Int}; inv::Bool = false) Absorb or remove (in a twist-free way) the square root of environment weight on an axis of the PEPS/PEPO tensor `t` known to be at position (`row`, `col`) in the unit cell of an InfinitePEPS/InfinitePEPO. The involved weights are ``` | - [2,r,c] + [2,r,c] | - [1,r,c-1] - T[r,c] - [1,r,c] - | @@ -21,7 +21,7 @@ in the unit cell of an InfinitePEPS/InfinitePEPO. The involved weights are - `weights::SUWeight` : All simple update weights. - `row::Int` : The row index specifying the position in the tensor network. - `col::Int` : The column index specifying the position in the tensor network. -- `ax::Int` : The axis into which the weight is absorbed, taking values from 1 to 4, standing for north, east, south, west respectively. +- `virt_axes::Int` : The axis into which the weight is absorbed, taking values from 1 to 4, standing for north, east, south, west respectively. ## Keyword arguments @@ -31,12 +31,35 @@ in the unit cell of an InfinitePEPS/InfinitePEPO. The involved weights are ```julia # Absorb the weight into the north axis of tensor at position (2, 3) -absorb_weight(t, weights, 2, 3, 1) +absorb_weight(t, weights, 2, 3, (1,)) # Absorb the inverse of (i.e. remove) the weight into the east axis -absorb_weight(t, weights, 2, 3, 2; inv=true) +absorb_weight(t, weights, 2, 3, (2,); inv=true) ``` """ +function absorb_weight( + t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, + rowcol::CartesianIndex{2}, virt_axes::NTuple{N, Int}; inv::Bool = false + ) where {N} + return absorb_weight(t, weights, rowcol[1], rowcol[2], virt_axes; inv) +end + +function absorb_weight( + t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, + row::Int, col::Int, virt_axes::NTuple{N, Int}; inv::Bool = false + ) where {N} + vax = first(virt_axes) + weight_vax = weight_to_absorb(weights, row, col, vax; inv) + legs, t2 = absorb_first_weight(t, weight_vax, vax) + for vax in Base.tail(virt_axes) + legs, biperm = biperm_absorb_weight(legs, vax) + weight_vax = weight_to_absorb(weights, row, col, vax; inv) + t2 = permute(t2, biperm) * weight_vax + end + perm_back = invperm(legs) + return permute(t2, (perm_back[begin:numout(t)], perm_back[(numout(t) + 1):end])) +end + function weight_to_absorb( weights::SUWeight, row::Int, col::Int, ax::Int; inv::Bool = false ) @@ -79,25 +102,3 @@ function absorb_first_weight(t::Union{PEPSTensor, PEPOTensor}, wt, vax) return new_legs, t2 end -function absorb_weight( - t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, - rowcol::CartesianIndex{2}, virt_axes::NTuple{N, Int}; inv::Bool = false - ) where {N} - return absorb_weight(t, weights, rowcol[1], rowcol[2], virt_axes; inv) -end - -function absorb_weight( - t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, - row::Int, col::Int, virt_axes::NTuple{N, Int}; inv::Bool = false - ) where {N} - vax = first(virt_axes) - weight_vax = weight_to_absorb(weights, row, col, vax; inv) - legs, t2 = absorb_first_weight(t, weight_vax, vax) - for vax in virt_axes[(begin + 1):end] - legs, biperm = biperm_absorb_weight(legs, vax) - weight_vax = weight_to_absorb(weights, row, col, vax; inv) - t2 = permute(t2, biperm) * weight_vax - end - perm_back = invperm(legs) - return permute(t2, (perm_back[begin:numout(t)], perm_back[(numout(t) + 1):end])) -end From 78946df4c2e178e8f30dcf93d7e2ef81fa74dc4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Wed, 29 Apr 2026 14:18:32 +0200 Subject: [PATCH 24/24] fix docstring --- src/algorithms/contractions/absorb_weight.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/contractions/absorb_weight.jl b/src/algorithms/contractions/absorb_weight.jl index 63908d30c..00ee8ae91 100644 --- a/src/algorithms/contractions/absorb_weight.jl +++ b/src/algorithms/contractions/absorb_weight.jl @@ -11,7 +11,7 @@ in the unit cell of an InfinitePEPS/InfinitePEPO. The involved weights are | - [1,r,c-1] - T[r,c] - [1,r,c] - | - [1,r+1,c] + [2,r+1,c] | ```