Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Accessors = "0.1"
ChainRulesCore = "1.0"
ChainRulesTestUtils = "1.13"
SafeTestsets = "0.1"
Compat = "3.46, 4.2"
DocStringExtensions = "0.9.3"
FiniteDifferences = "0.12"
Expand All @@ -44,10 +44,12 @@ OptimKit = "0.4"
Printf = "1"
QuadGK = "2.11.1"
Random = "1"
SafeTestsets = "0.1"
Statistics = "1"
TensorKit = "0.16.2"
TensorOperations = "5"
TestExtras = "0.3"
TupleTools = "1.6.0"
VectorInterface = "0.4, 0.5"
Zygote = "0.6, 0.7"
julia = "1.10"
Expand Down
2 changes: 2 additions & 0 deletions src/PEPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using KrylovKit: Lanczos, BlockLanczos
using TensorOperations, OptimKit
using ChainRulesCore, Zygote
using LoggingExtras
import TupleTools

using MPSKit
using MPSKit: MPSTensor, MPOTensor, GenericMPSTensor, MPSBondTensor, ProductTransferMatrix
Expand Down Expand Up @@ -82,6 +83,7 @@ include("algorithms/contractions/ctmrg/renormalize_edge.jl")
include("algorithms/contractions/ctmrg/contract_site.jl")
include("algorithms/contractions/ctmrg/gaugefix.jl")

include("algorithms/contractions/absorb_weight.jl")
include("algorithms/contractions/transfer.jl")
include("algorithms/contractions/localoperator.jl")
include("algorithms/contractions/vumps_contractions.jl")
Expand Down
104 changes: 104 additions & 0 deletions src/algorithms/contractions/absorb_weight.jl
Comment thread
ogauthe marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, rowcol::CartesianIndex{2}, virt_axes::NTuple{N, Int}; inv::Bool = false)
absorb_weight(t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight, row::Int, col::Int, virt_axes::NTuple{N, Int}; inv::Bool = false)

Absorb or remove (in a twist-free way) the square root of environment weight
on an axis of the PEPS/PEPO tensor `t` known to be at position (`row`, `col`)
in the unit cell of an InfinitePEPS/InfinitePEPO. The involved weights are
```
|
[2,r,c]
|
- [1,r,c-1] - T[r,c] - [1,r,c] -
|
[2,r+1,c]
|
```

## Arguments

- `t::Union{PEPSTensor, PEPOTensor}` : PEPSTensor or PEPOTensor to which the weight will be absorbed.
- `weights::SUWeight` : All simple update weights.
- `row::Int` : The row index specifying the position in the tensor network.
- `col::Int` : The column index specifying the position in the tensor network.
- `virt_axes::Int` : The axis into which the weight is absorbed, taking values from 1 to 4, standing for north, east, south, west respectively.

## Keyword arguments

- `inv::Bool=false` : If `true`, the inverse square root of the weight is absorbed.

## Examples

```julia
# Absorb the weight into the north axis of tensor at position (2, 3)
absorb_weight(t, weights, 2, 3, (1,))

# Absorb the inverse of (i.e. remove) the weight into the east axis
absorb_weight(t, weights, 2, 3, (2,); inv=true)
```
"""
function absorb_weight(
t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight,
rowcol::CartesianIndex{2}, virt_axes::NTuple{N, Int}; inv::Bool = false
) where {N}
return absorb_weight(t, weights, rowcol[1], rowcol[2], virt_axes; inv)
end

function absorb_weight(
t::Union{PEPSTensor, PEPOTensor}, weights::SUWeight,
row::Int, col::Int, virt_axes::NTuple{N, Int}; inv::Bool = false
) where {N}
vax = first(virt_axes)
weight_vax = weight_to_absorb(weights, row, col, vax; inv)
legs, t2 = absorb_first_weight(t, weight_vax, vax)
for vax in Base.tail(virt_axes)
legs, biperm = biperm_absorb_weight(legs, vax)
weight_vax = weight_to_absorb(weights, row, col, vax; inv)
t2 = permute(t2, biperm) * weight_vax
end
perm_back = invperm(legs)
return permute(t2, (perm_back[begin:numout(t)], perm_back[(numout(t) + 1):end]))
end

function weight_to_absorb(
weights::SUWeight, row::Int, col::Int, ax::Int; inv::Bool = false
)
_, Nr, Nc = size(weights)
@assert 1 <= row <= Nr && 1 <= col <= Nc
pow = inv ? -1 / 2 : 1 / 2
wt = sdiag_pow(
if ax == NORTH
weights[2, row, col]
elseif ax == EAST
weights[1, row, col]
elseif ax == SOUTH
weights[2, _next(row, Nr), col]
else # WEST
weights[1, row, _prev(col, Nc)]
end,
pow,
)
# make absorption/removal twist-free
twistdual!(wt, 1)
(ax == SOUTH || ax == WEST) && return transpose(wt) # not sure this can be factorized due to twistdual
return wt
end

function biperm_absorb_weight(legs::NTuple{N, Int}, vax::Int) where {N}
@assert N == 5 || N == 6
nin = N - 4
a = vax + nin
codomain_axes = TupleTools.deleteat(ntuple(identity, N), a)
q = invperm(legs)
biperm = (map(i -> q[i], codomain_axes), (q[a],))
new_legs = (ntuple(i -> legs[biperm[1][i]], N - 1)..., a)
return new_legs, biperm
end

function absorb_first_weight(t::Union{PEPSTensor, PEPOTensor}, wt, vax)
legs = ntuple(identity, numind(t))
new_legs, biperm = biperm_absorb_weight(legs, vax)
t2 = permute(t, biperm) * wt
return new_legs, t2
end

22 changes: 12 additions & 10 deletions src/algorithms/time_evolution/apply_gate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ function _apply_sitegate(
return a′
end

function _get_biperms(::PEPSTensor, ::Integer)
return ((2, 4, 5), (1, 3)), ((2, 3, 4), (1, 5)), (1, 4, 2, 3), ntuple(identity, 4)
end
function _get_biperms(::PEPOTensor, gate_ax::Integer)
if gate_ax == 1
return ((2, 3, 5, 6), (1, 4)), ((2, 3, 4, 5), (1, 6)), (1, 2, 5, 3, 4), ntuple(identity, 5)
end
return ((1, 3, 5, 6), (2, 4)), ((1, 3, 4, 5), (2, 6)), (1, 2, 5, 3, 4), ntuple(identity, 5)
end
"""
$(SIGNATURES)

Expand Down Expand Up @@ -48,17 +57,10 @@ When `A`, `B` are PEPOTensors,
5 1 4 1 4 1
```
"""
function _qr_bond(A::PT, B::PT; gate_ax::Int = 1, kwargs...) where {PT <: Union{PEPSTensor, PEPOTensor}}
function _qr_bond(A::PT, B::PT; gate_ax::Integer = 1, kwargs...) where {PT <: Union{PEPSTensor, PEPOTensor}}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these functions, both for type stability and human readability, I would really prefer to just use dispatch and write out the two cases manually. _get_biperms is an indirection that doesn't hurt because the compiler decides to inline this, but depending on compiler-internals like this tends to not be great, and honestly I think this function would overall be better suited to simply be two functions with using dispatch and hardcoding the permutations (which could then use @tensor calls for the permutations, which is just more readable in general.

Copy link
Copy Markdown
Member

@Yue-Zhengyuan Yue-Zhengyuan Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See latest changes in #144 where I replaced _qr_bond with multiple bond_tensor_xxx functions (link).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ogauthe You can have a look at related parts of #144 and port them here if you wish.

@assert 1 <= gate_ax <= numout(A)
permA, permB, permX, permY = if A isa PEPSTensor
((2, 4, 5), (1, 3)), ((2, 3, 4), (1, 5)), (1, 4, 2, 3), Tuple(1:4)
else
if gate_ax == 1
((2, 3, 5, 6), (1, 4)), ((2, 3, 4, 5), (1, 6)), (1, 2, 5, 3, 4), Tuple(1:5)
else
((1, 3, 5, 6), (2, 4)), ((1, 3, 4, 5), (2, 6)), (1, 2, 5, 3, 4), Tuple(1:5)
end
end
permA, permB, permX, permY = _get_biperms(A, gate_ax)

X, a = left_orth!(permute(A, permA; copy = true); kwargs...)
Y, b = left_orth!(permute(B, permB; copy = true); kwargs...)
X, Y = permute(X, permX), permute(Y, permY)
Expand Down
68 changes: 34 additions & 34 deletions src/algorithms/time_evolution/apply_mpo.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#=
#=
# Mixed canonical form of an open boundary MPS
```
|ψ⟩ = M[1]-←-...-←-M[N]
Expand Down Expand Up @@ -54,7 +54,7 @@ Note that
Then `M̃[n]` (n = 1, ..., N - 1) satisfies the (generalized) left-orthogonal condition
```
┌---←--M̃[n]--←- ┌-←- 2
| | |
| | |
s[n-1] ↓ = s[n] (s[0] = 1)
| | |
└---→--M̃†[n]-→- └-→- 1
Expand All @@ -71,16 +71,16 @@ Similarly, we can express M̃ using Qb
Then `M̃[n]` (n = 2, ..., N) satisfies the (generalized) right-orthogonal condition
```
-←-M̃[n]-←┐ 1 -←-┐
↓ | |
↓ | |
* s[n] = s[n-1] (s[N] = 1)
↓ | |
-→M̃†[n]-→┘ 2 -→-┘
```
Here `-*-` is the twist on the physical axis.
Here `-*-` is the twist on the physical axis.

# Truncation of a bond on OBC-MPS

Suppose we want to truncate the bond between
Suppose we want to truncate the bond between
the n-th and the (n+1)-th sites such that the truncated state
```
|ψ̃⟩ = M[1]-←-...-←-M̃[n]-←-M̃[n+1]-←-...-←-M[N]
Expand Down Expand Up @@ -157,7 +157,7 @@ function lq_through(
@assert !isdual(codomain(M, 1)) && !isdual(domain(M, 1))
pM = (codomainind(M), domainind(M))
pL = (codomainind(L1), domainind(L1))
pML = ((1,), Tuple(2:(N + 1)))
pML = ((1,), ntuple(i -> i + 1, N))
A = tensorcontract(M, pM, false, L1, pL, false, pML)
l, _ = right_orth!(A; positive = true)
normalize && normalize!(l, Inf)
Expand All @@ -168,7 +168,7 @@ function lq_through(
M::GenericMPSTensor{S, N}, ::Nothing; normalize::Bool = true
) where {S, N}
@assert !isdual(codomain(M, 1))
A = permute(M, ((1,), Tuple(2:(N + 1))); copy = true)
A = permute(M, ((1,), ntuple(i -> i + 1, N)); copy = true)
l, _ = right_orth!(A; positive = true)
normalize && normalize!(l, Inf)
return l
Expand All @@ -177,26 +177,26 @@ end
"""
Given a cluster `Ms`, find all `R`, `L` matrices on each internal bond
"""
function _get_allRLs(Ms::Vector{T}) where {T <: GenericMPSTensor}
function _get_allRLs(vertices::Vector{T}) where {T <: GenericMPSTensor}
# M1 -- (R1,L1) -- M2 -- (R2,L2) -- M3
N = length(Ms)
N = length(vertices)
# get the first R and the last L
R_first = qr_through(nothing, Ms[1]; normalize = true)
L_last = lq_through(Ms[N], nothing; normalize = true)
R_first = qr_through(nothing, first(vertices); normalize = true)
L_last = lq_through(last(vertices), nothing; normalize = true)
Rs = Vector{typeof(R_first)}(undef, N - 1)
Ls = Vector{typeof(L_last)}(undef, N - 1)
Rs[1], Ls[end] = R_first, L_last
# get remaining R, L matrices
for n in 2:(N - 1)
m = N - n + 1
Rs[n] = qr_through(Rs[n - 1], Ms[n]; normalize = true)
Ls[m - 1] = lq_through(Ms[m], Ls[m]; normalize = true)
Rs[n] = qr_through(Rs[n - 1], vertices[n]; normalize = true)
Ls[m - 1] = lq_through(vertices[m], Ls[m]; normalize = true)
end
return Rs, Ls
end

"""
Given the tensors `R`, `L` on a bond, construct
Given the tensors `R`, `L` on a bond, construct
the projectors `Pa`, `Pb` and the new bond weight `s`
such that the contraction of `Pa`, `s`, `Pb` is identity when `trunc = notrunc`,

Expand All @@ -219,30 +219,30 @@ function _proj_from_RL(
return Pa, s, Pb, ϵ
end


get_proj_trunc(t::TruncationStrategy, ::ElementarySpace) = t
function get_proj_trunc(::FixedSpaceTruncation, v::ElementarySpace)
return isdual(tspace) ? truncspace(flip(tspace)) : truncspace(tspace)
end
"""
Given a cluster `Ms`, find all projectors `Pa`, `Pb`
Given a cluster `vertices`, find all projectors `Pa`, `Pb`
and Schmidt weights `wts` on internal bonds.
"""
function _get_allprojs(
Ms::Vector{T}, truncs::Vector{E}
vertices::Vector{T}, truncs::Vector{E}
) where {T <: GenericMPSTensor, E <: TruncationStrategy}
N = length(Ms)
Rs, Ls = _get_allRLs(Ms)
N = length(vertices)
Rs, Ls = _get_allRLs(vertices)
@assert length(truncs) == N - 1
projs_errs = map(1:(N - 1)) do i
trunc = if isa(truncs[i], FixedSpaceTruncation)
tspace = space(Ms[i + 1], 1)
isdual(tspace) ? truncspace(flip(tspace)) : truncspace(tspace)
else
truncs[i]
end
trunc = get_proj_trunc(truncs[i], space(vertices[i + 1], 1))
return _proj_from_RL(Rs[i], Ls[i]; trunc)
end
Pas = map(Base.Fix2(getindex, 1), projs_errs)
wts = map(Base.Fix2(getindex, 2), projs_errs)
Pbs = map(Base.Fix2(getindex, 3), projs_errs)
Pas = map(t -> t[1], projs_errs)
wts = map(t -> t[2], projs_errs)
Pbs = map(t -> t[3], projs_errs)
# local truncation error on each bond
ϵs = map(Base.Fix2(getindex, 4), projs_errs)
ϵs = map(t -> t[4], projs_errs)
return Pas, Pbs, wts, ϵs
end

Expand All @@ -266,17 +266,17 @@ end
Find projectors to truncate internal bonds of the cluster `Ms`.
"""
function _cluster_truncate!(
Ms::Vector{T}, truncs::Vector{E}
vertices::Vector{T}, truncs::Vector{E}
) where {T <: GenericMPSTensor, E <: TruncationStrategy}
Pas, Pbs, wts, ϵs = _get_allprojs(Ms, truncs)
Pas, Pbs, wts, ϵs = _get_allprojs(vertices, truncs)
# apply projectors
# M1 -- (Pa1,wt1,Pb1) -- M2 -- (Pa2,wt2,Pb2) -- M3
for (i, (Pa, Pb)) in enumerate(zip(Pas, Pbs))
Ms[i] = Ms[i] * twistdual(Pa, 1)
vertices[i] = vertices[i] * twistdual(Pa, 1)
pP = ((1,), (2,))
pM = ((1,), Tuple(2:numind(Ms[i + 1])))
pPM = (codomainind(Ms[i + 1]), domainind(Ms[i + 1]))
Ms[i + 1] = tensorcontract(Pb, pP, false, Ms[i + 1], pM, false, pPM)
pM = ((1,), ntuple(i -> i + 1, numind(eltype(vertices)) - 1))
pPM = (codomainind(vertices[i + 1]), domainind(vertices[i + 1]))
vertices[i + 1] = tensorcontract(Pb, pP, false, vertices[i + 1], pM, false, pPM)
end
return wts, ϵs, Pas, Pbs
end
Expand Down
Loading
Loading