Skip to content

Commit 2f7a17a

Browse files
leburgellkdvos
andauthored
Split off norm-preserving vector retractions (#160)
* Split off vector retractions, remove diffset * Add dedicated retraction test * Apply suggestions from code review Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Address remaining review comments * Fix typos * Fix typo * Distinguishing fancy ticks is hard --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 0fdcd15 commit 2f7a17a

6 files changed

Lines changed: 134 additions & 94 deletions

File tree

src/PEPSKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ include("utility/diffable_threads.jl")
1919
include("utility/svd.jl")
2020
include("utility/rotations.jl")
2121
include("utility/mirror.jl")
22-
include("utility/diffset.jl")
2322
include("utility/hook_pullback.jl")
2423
include("utility/autoopt.jl")
24+
include("utility/retractions.jl")
2525

2626
include("networks/tensors.jl")
2727
include("networks/local_sandwich.jl")

src/algorithms/optimization/peps_optimization.jl

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ function fixedpoint(
197197
if isnothing(alg.symmetrization)
198198
retract = peps_retract
199199
else
200-
retract, symm_finalize! = symmetrize_retract_and_finalize!(alg.symmetrization)
201-
fin! = finalize! # Previous finalize!
202-
finalize! = (x, f, g, numiter) -> fin!(symm_finalize!(x, f, g, numiter)..., numiter)
200+
retract, finalize! = symmetrize_retract_and_finalize!(
201+
alg.symmetrization, peps_retract, finalize!
202+
)
203203
end
204204

205205
# :fixed mode compatibility
@@ -284,30 +284,21 @@ end
284284
Performs a norm-preserving retraction of an infinite PEPS `A = x[1]` along `η` with step
285285
size `α`, giving a new PEPS `A´`,
286286
```math
287-
A' \\leftarrow \\cos \\left( α \\frac{||η||}{||A||} \\right) A + \\sin \\left( α \\frac{||η||}{||A||} \\right) ||A|| \\frac{η}{||η||},
287+
A' \\cos ( α ‖η‖ / ‖A‖ ) A + \\sin ( α ‖η‖ / ‖A‖ ) ‖A‖ η / ‖η‖,
288288
```
289289
and corresponding directional derivative `ξ`,
290290
```math
291-
ξ = \\cos \\left( α \\frac{||η||}{||A||} \\right) η - \\sin \\left( α \\frac{||η||}{||A||} \\right) ||η|| \\frac{A}{||A||},
291+
ξ = \\cos ( α ‖η‖ / ‖A‖ ) η - \\sin ( α ‖η‖ / ‖A‖ ) ‖η‖ A / ‖A‖,
292292
```
293-
such that ``\\langle A', ξ \\rangle = 0`` and ``||A'|| = ||A||``.
293+
such that `` A', ξ = 0`` and ``‖A'‖ = ‖A‖``.
294294
"""
295295
function peps_retract(x, η, α)
296296
peps = x[1]
297-
norms_peps = norm.(peps.A)
298-
norms_η = norm.(η.A)
299-
300-
peps´ = similar(x[1])
301-
peps´.A .=
302-
cos.(α .* norms_η ./ norms_peps) .* peps.A .+
303-
sin.(α .* norms_η ./ norms_peps) .* norms_peps .* η.A ./ norms_η
304-
305297
env = deepcopy(x[2])
306298

307-
ξ = similar(η)
308-
ξ.A .=
309-
cos.(α .* norms_η ./ norms_peps) .* η.A .-
310-
sin.(α .* norms_η ./ norms_peps) .* norms_η .* peps.A ./ norms_peps
299+
retractions = norm_preserving_retract.(unitcell(peps), unitcell(η), α)
300+
peps´ = InfinitePEPS(map(first, retractions))
301+
ξ = InfinitePEPS(map(last, retractions))
311302

312303
return (peps´, env), ξ
313304
end
@@ -319,31 +310,21 @@ Transports a direction at `A = x[1]` to a valid direction at `A´ = x´[1]` corr
319310
the norm-preserving retraction of `A` along `η` with step size `α`. In particular, starting
320311
from a direction `η` of the form
321312
```math
322-
ξ = \\left\\langle \\frac{η}{||η||}, ξ \\right\\rangle \\frac{η}{||η||} + Δξ
313+
ξ = ⟨ η / ‖η‖, ξ ⟩ η / ‖η‖ + Δξ
323314
```
324-
where ``\\langle Δξ, A \\rangle = \\langle Δξ, η \\rangle = 0``, it returns
315+
where `` Δξ, A = Δξ, η = 0``, it returns
325316
```math
326-
ξ(α) = \\left\\langle \\frac{η}{||η||}, ξ \\right \\rangle \\left( \\cos \\left( α \\frac{||η||}{||A||} \\right) \\frac{η}{||η||} - \\sin( \\left( α \\frac{||η||}{||A||} \\right) \\frac{A}{||A||} \\right) + Δξ
317+
ξ(α) = ⟨ η / ‖η‖, ξ ( \\cos ( α ‖η‖ / ‖A‖ ) η / ‖η‖ - \\sin( α ‖η‖ / ‖A‖ ) A / ‖A‖ ) + Δξ
327318
```
328-
such that ``||ξ(α)|| = ||ξ||, \\langle A', ξ(α) \\rangle = 0``.
319+
such that ``ξ(α) = ‖ξ‖, ⟨ A', ξ(α) = 0``.
329320
"""
330321
function peps_transport!(ξ, x, η, α, x´)
331322
peps = x[1]
332-
norms_peps = norm.(peps.A)
333-
334-
norms_η = norm.(η.A)
335-
normalized_η = η.A ./ norms_η
336-
overlaps_η_ξ = inner.(normalized_η, ξ.A)
323+
peps´ = x´[1]
337324

338-
# isolate the orthogonal component
339-
Δξ = ξ.A .- overlaps_η_ξ .* normalized_η
340-
341-
# keep orthogonal component fixed, modify the rest by the proper directional derivative
342-
ξ.A .=
343-
overlaps_η_ξ .* (
344-
cos.(α .* norms_η ./ norms_peps) .* normalized_η .-
345-
sin.(α .* norms_η ./ norms_peps) .* peps.A ./ norms_peps
346-
) .+ Δξ
325+
norm_preserving_transport!.(
326+
unitcell(ξ), unitcell(peps), unitcell(η), α, unitcell(peps´)
327+
)
347328

348329
return ξ
349330
end
@@ -356,17 +337,22 @@ real_inner(_, η₁, η₂) = real(dot(η₁, η₂))
356337
357338
Return the `retract` and `finalize!` function for symmetrizing the `peps` and `grad` tensors.
358339
"""
359-
function symmetrize_retract_and_finalize!(symm::SymmetrizationStyle)
360-
finf = function symmetrize_finalize!((peps, env), E, grad, _)
340+
function symmetrize_retract_and_finalize!(
341+
symm::SymmetrizationStyle, retract=peps_retract, (finalize!)=OptimKit._finalize!
342+
)
343+
function symmetrize_then_finalize!((peps, env), E, grad, numiter)
344+
# symmetrize the gradient
361345
grad_symm = symmetrize!(grad, symm)
362-
return (peps, env), E, grad_symm
346+
# then finalize
347+
return finalize!((peps, env), E, grad_symm, numiter)
363348
end
364-
retf = function symmetrize_retract((peps, env), η, α)
365-
peps_symm = deepcopy(peps)
366-
peps_symm.A .+= η.A .* α
367-
env′ = deepcopy(env)
368-
symmetrize!(peps_symm, symm)
369-
return (peps_symm, env′), η
349+
function retract_then_symmetrize((peps, env), η, α)
350+
# retract
351+
(peps´, env´), ξ = retract((peps, env), η, α)
352+
# symmetrize retracted point and directional derivative
353+
peps´_symm = symmetrize!(peps´, symm)
354+
ξ_symm = symmetrize!(ξ, symm)
355+
return (peps´_symm, env´), ξ_symm
370356
end
371-
return retf, finf
357+
return retract_then_symmetrize, symmetrize_then_finalize!
372358
end

src/utility/diffset.jl

Lines changed: 0 additions & 47 deletions
This file was deleted.

src/utility/retractions.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#=
2+
Utilities for preserving the norm of (VectorInterface-compliant) vectors during optimization.
3+
=#
4+
5+
"""
6+
norm_preserving_retract(A, η, α)
7+
8+
Performs a norm-preserving retraction of vector `A` along the direction `η` with step size
9+
`α`, giving a new vector `A´`,
10+
```math
11+
A' ← \\cos ( α ‖η‖ / ‖A‖ ) A + \\sin ( α ‖η‖ / ‖A‖ ) ‖A‖ η / ‖η‖,
12+
```
13+
and corresponding directional derivative `ξ`,
14+
```math
15+
ξ = \\cos ( α ‖η‖ / ‖A‖ ) η - \\sin ( α ‖η‖ / ‖A‖ ) ‖η‖ A / ‖A‖,
16+
```
17+
such that ``⟨ A', ξ ⟩ = 0`` and ``‖A'‖ = ‖A‖``.
18+
19+
!!! note
20+
The vectors `A` and `η` should satisfy the interface specified by
21+
[VectorInterface.jl](https://github.com/Jutho/VectorInterface.jl)
22+
23+
"""
24+
function norm_preserving_retract(A, η, α)
25+
n_A = norm(A)
26+
n_η = norm(η)
27+
sn, cs = sincos* n_η / n_A)
28+
29+
= add(A, η, sn * n_A / n_η, cs)
30+
ξ = add(A, η, cs, -sn * n_η / n_A)
31+
32+
return A´, ξ
33+
end
34+
35+
"""
36+
norm_preserving_transport!(ξ, A, η, α, A′)
37+
38+
Transports a direction `ξ` at `A` to a valid direction at `A´` corresponding to
39+
the norm-preserving retraction of `A` along `η` with step size `α`. In particular, starting
40+
from a direction `η` of the form
41+
```math
42+
ξ = ⟨ η / ‖η‖, ξ ⟩ η / ‖η‖ + Δξ
43+
```
44+
where ``⟨ Δξ, A ⟩ = ⟨ Δξ, η ⟩ = 0``, it returns
45+
```math
46+
ξ(α) = ⟨ η / ‖η‖, ξ ⟩ ( \\cos ( α ‖η‖ / ‖A‖ ) η / ‖η‖ - \\sin ( α ‖η‖ / ‖A‖ ) A / ‖A‖ ) + Δξ
47+
```
48+
such that ``‖ξ(α)‖ = ‖ξ‖, ⟨ A', ξ(α) ⟩ = 0``.
49+
50+
!!! note
51+
The vectors `A` and `η` should satisfy the interface specified by
52+
[VectorInterface.jl](https://github.com/Jutho/VectorInterface.jl)
53+
54+
"""
55+
function norm_preserving_transport!(ξ, A, η, α, A´)
56+
n_A = norm(A)
57+
n_η = norm(η)
58+
sn, cs = sincos* n_η / n_A)
59+
60+
overlaps_η_ξ = inner(η, ξ) / n_η
61+
add!(ξ, η, (cs - 1) * overlaps_η_ξ / n_η)
62+
add!(ξ, A, -sn * overlaps_η_ξ / n_A)
63+
64+
return ξ
65+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ end
5858
@time @safetestset "Differentiable tmap" begin
5959
include("utility/diff_maps.jl")
6060
end
61+
@time @safetestset "Norm-preserving retractions" begin
62+
include("utility/retractions.jl")
63+
end
6164
end
6265
if GROUP == "ALL" || GROUP == "EXAMPLES"
6366
@time @safetestset "Transverse Field Ising model" begin

test/utility/retractions.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using Test
2+
using Random
3+
using LinearAlgebra
4+
using TensorKit
5+
using VectorInterface
6+
using PEPSKit
7+
8+
dtype = ComplexF64
9+
Vphyss = [ℂ^2, U1Space(0 => 1, -1 => 1, 1 => 1)]
10+
Vpepss = [ℂ^4, U1Space(0 => 2, -1 => 1, 1 => 1)]
11+
12+
@testset "Norm-preserving tensor retractions for sectortype $(sectortype(Vphyss[i]))" for i in
13+
eachindex(
14+
Vphyss
15+
)
16+
Vphys = Vphyss[i]
17+
Vpeps = Vpepss[i]
18+
peps_space = Vphys Vpeps Vpeps Vpeps' Vpeps'
19+
20+
α = 1e-1 * randn(Float64)
21+
A = randn(dtype, peps_space)
22+
normalized_A = scale(A, inv(norm(A)))
23+
η = randn(dtype, peps_space)
24+
ζ = randn(dtype, peps_space)
25+
add!(η, normalized_A, -inner(normalized_A, η))
26+
add!(ζ, normalized_A, -inner(normalized_A, ζ))
27+
28+
A´, ξ = PEPSKit.norm_preserving_retract(A, η, α)
29+
@test norm(A´) norm(A) rtol = 1e-12
30+
31+
PEPSKit.norm_preserving_transport!(ζ, A, η, α, A´)
32+
@test inner(ζ, A´) 0 atol = 1e-12
33+
end

0 commit comments

Comments
 (0)