Skip to content

Commit 2654b1c

Browse files
lkdvos20akshay00AFeuerpfeil
authored
Fix Adapt extension (#389)
* Revert "Adapt extension (#384)" This reverts commit 2db4cb1. * Deep copying MPS with copy (#387) * broadcast copy over arrays * add basic tests for copying mps * fix variable naming mismatch * small fixes * format --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Refactor entropy function to also use spectrum directly (#377) * Refactor entropy function to also use spectrum directly * specialize to SectorVector and add test * add infinite test * update docstring --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Deep copying MPS with copy (#387) * broadcast copy over arrays * add basic tests for copying mps * fix variable naming mismatch * small fixes * format --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com> * fix wrong testcase * alter JordanMPOTensor Adapt implementation * type stability chenanigans * bump TensorKit minimal version * more bypassing of type stability issues... * fix some tests * lost of git struggles * more git misery! * handle storagetype better * disable lts type stability tests --------- Co-authored-by: Akshay Shankar <sakshays.2000@gmail.com> Co-authored-by: Andreas Feuerpfeil <andreas.feuerpfeil@gmail.com>
1 parent 0b58b84 commit 2654b1c

8 files changed

Lines changed: 157 additions & 6 deletions

File tree

Project.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,17 @@ TensorKitManifolds = "11fa318c-39cb-4a83-b1ed-cdc7ba1e3684"
2323
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
2424
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2525

26+
[weakdeps]
27+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
28+
29+
[extensions]
30+
MPSKitAdaptExt = "Adapt"
31+
2632
[compat]
2733
Accessors = "0.1"
34+
Adapt = "4"
2835
Aqua = "0.8.9"
29-
BlockTensorKit = "0.3.4"
36+
BlockTensorKit = "0.3.8"
3037
Combinatorics = "1"
3138
Compat = "3.47, 4.10"
3239
DocStringExtensions = "0.9.3"
@@ -43,7 +50,7 @@ Plots = "1.40"
4350
Printf = "1"
4451
Random = "1"
4552
RecipesBase = "1.1"
46-
TensorKit = "0.16"
53+
TensorKit = "0.16.3"
4754
TensorKitManifolds = "0.7"
4855
TensorKitTensors = "0.2"
4956
TensorOperations = "5"
@@ -53,6 +60,7 @@ VectorInterface = "0.2, 0.3, 0.4, 0.5"
5360
julia = "1.10"
5461

5562
[extras]
63+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
5664
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5765
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5866
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
@@ -63,4 +71,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6371
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
6472

6573
[targets]
66-
test = ["Aqua", "Pkg", "Test", "TestExtras", "Plots", "Combinatorics", "ParallelTestRunner", "TensorKitTensors"]
74+
test = ["Aqua", "Adapt", "Pkg", "Test", "TestExtras", "Plots", "Combinatorics", "ParallelTestRunner", "TensorKitTensors"]

ext/MPSKitAdaptExt.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
module MPSKitAdaptExt
2+
3+
using TensorKit: space, spacetype
4+
using MPSKit
5+
using BlockTensorKit: nonzero_pairs
6+
using Adapt
7+
8+
function Adapt.adapt_structure(to, mps::FiniteMPS)
9+
ad = adapt(to)
10+
adapt_not_missing(x) = ismissing(x) ? x : ad(x)
11+
12+
TA = Base.promote_op(ad, MPSKit.site_type(mps))
13+
TB = Base.promote_op(ad, MPSKit.bond_type(mps))
14+
15+
ALs = map!(adapt_not_missing, similar(mps.ALs, Union{Missing, TA}), mps.ALs)
16+
ARs = map!(adapt_not_missing, similar(mps.ARs, Union{Missing, TA}), mps.ARs)
17+
ACs = map!(adapt_not_missing, similar(mps.ACs, Union{Missing, TA}), mps.ACs)
18+
Cs = map!(adapt_not_missing, similar(mps.Cs, Union{Missing, TB}), mps.Cs)
19+
20+
return FiniteMPS{TA, TB}(ALs, ARs, ACs, Cs)
21+
end
22+
23+
function Adapt.adapt_structure(to, mps::InfiniteMPS)
24+
ad = adapt(to)
25+
AL = map(ad, mps.AL)
26+
AR = map(ad, mps.AR)
27+
C = map(ad, mps.C)
28+
AC = map(ad, mps.AC)
29+
return InfiniteMPS{eltype(AL), eltype(C)}(AL, AR, C, AC)
30+
end
31+
32+
# inline to improve type stability with closures
33+
@inline Adapt.adapt_structure(to, mpo::MPO) = MPO(map(adapt(to), mpo.O))
34+
@inline Adapt.adapt_structure(to, W::MPSKit.JordanMPOTensor) =
35+
MPSKit.JordanMPOTensor(space(W), adapt(to, W.A), adapt(to, W.B), adapt(to, W.C), adapt(to, W.D))
36+
@inline Adapt.adapt_structure(to, mpo::MPOHamiltonian) =
37+
MPOHamiltonian(map(x -> adapt(to, x), mpo.W))
38+
39+
end

src/states/abstractmps.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ TensorKit.spacetype(ψtype::Type{<:AbstractMPS}) = spacetype(site_type(ψtype))
200200
TensorKit.sectortype::AbstractMPS) = sectortype(typeof(ψ))
201201
TensorKit.sectortype(ψtype::Type{<:AbstractMPS}) = sectortype(site_type(ψtype))
202202

203+
TensorKit.storagetype(ψtype::Type{<:AbstractMPS}) = storagetype(site_type(ψtype))
204+
203205
"""
204206
left_virtualspace(ψ::AbstractMPS, [pos=1:length(ψ)])
205207

src/states/finitemps.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,6 @@ end
387387

388388
site_type(::Type{<:FiniteMPS{A}}) where {A} = A
389389
bond_type(::Type{<:FiniteMPS{<:Any, B}}) where {B} = B
390-
function TensorKit.storagetype(::Union{MPS, Type{MPS}}) where {A, MPS <: FiniteMPS{A}}
391-
return storagetype(A)
392-
end
393390

394391
function left_virtualspace::FiniteMPS, n::Integer)
395392
checkbounds(ψ, n)

test/operators/mpo.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using MPSKit
1010
using MPSKit: GeometryStyle, FiniteChainStyle, InfiniteChainStyle, OperatorStyle, MPOStyle
1111
using TensorKit
1212
using TensorKit:
13+
using Adapt
1314

1415
@testset "FiniteMPO" begin
1516
# start from random operators
@@ -98,3 +99,25 @@ end
9899
@test OperatorStyle(typeof(H)) == MPOStyle()
99100
@test OperatorStyle(H) == MPOStyle()
100101
end
102+
103+
@testset "Adapt" for V in (ℂ^2, U1Space(-1 => 1, 0 => 1, 1 => 1))
104+
L = 3
105+
o = rand(Float32, V^L V^L)
106+
mpo1 = FiniteMPO(o)
107+
for T in (Float64, ComplexF64)
108+
mpo2 = @testinferred adapt(Vector{T}, mpo1)
109+
@test mpo2 isa FiniteMPO
110+
@test scalartype(mpo2) == T
111+
@test storagetype(mpo2) == Vector{T}
112+
@test convert(TensorMap, mpo2) o
113+
end
114+
115+
mpo3 = InfiniteMPO(mpo1[2:2])
116+
for T in (Float64, ComplexF64)
117+
mpo4 = @testinferred adapt(Vector{T}, mpo3)
118+
@test mpo4 isa InfiniteMPO
119+
@test scalartype(mpo4) == T
120+
@test storagetype(mpo4) == Vector{T}
121+
@test dot(mpo3, mpo4) norm(mpo3)^2 atol = 1.0e-4
122+
end
123+
end

test/operators/mpohamiltonian.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using MPSKit
1010
using MPSKit: GeometryStyle, FiniteChainStyle, InfiniteChainStyle, OperatorStyle, HamiltonianStyle
1111
using TensorKit
1212
using TensorKit:
13+
using Adapt
1314

1415
pspaces = (ℙ^4, Rep[U₁](0 => 2), Rep[SU₂](1 => 1))
1516
vspaces = (ℙ^10, Rep[U₁]((0 => 20)), Rep[SU₂](1 // 2 => 10, 3 // 2 => 5, 5 // 2 => 1))
@@ -239,3 +240,44 @@ end
239240
h4 = H4 * H4
240241
@test real(expectation_value(ψ2, H4)) >= 0
241242
end
243+
244+
@testset "Adapt" for V in (ℂ^2, U1Space(-1 => 1, 0 => 1, 1 => 1))
245+
h = rand(Float32, V^2 V^2)
246+
h += h'
247+
248+
L = 4
249+
H1 = FiniteMPOHamiltonian(
250+
fill(V, L),
251+
((i, i + 1) => h for i in 1:(L - 1))...,
252+
((i, i + 2) => h for i in 1:(L - 2))...,
253+
((i, i + 3) => h for i in 1:(L - 3))...,
254+
)
255+
mps1 = FiniteMPS(physicalspace(H1), oneunit(V))
256+
257+
for T in (Float64, ComplexF64)
258+
H2 = if VERSION <= v"1.12"
259+
adapt(Vector{T}, H1)
260+
else
261+
@testinferred adapt(Vector{T}, H1)
262+
end
263+
@test H2 isa FiniteMPOHamiltonian
264+
@test scalartype(H2) == T
265+
@test storagetype(H2) == Vector{T}
266+
@test expectation_value(mps1, H1) expectation_value(mps1, H2)
267+
end
268+
269+
H3 = InfiniteMPOHamiltonian(fill(V, L), (1, 2) => h, (1, 3) => h, (1, 4) => h)
270+
mps2 = InfiniteMPS(physicalspace(H3), [oneunit(V)])
271+
for T in (Float64, ComplexF64)
272+
H4 = if VERSION <= v"1.12"
273+
# this is type unstable for LTS for some reason
274+
adapt(Vector{T}, H3)
275+
else
276+
@testinferred adapt(Vector{T}, H3)
277+
end
278+
@test H4 isa InfiniteMPOHamiltonian
279+
@test scalartype(H4) == T
280+
@test storagetype(H4) == Vector{T}
281+
@test expectation_value(mps2, H3) expectation_value(mps2, H4)
282+
end
283+
end

test/states/finitemps.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using MPSKit: _transpose_front, _transpose_tail
1111
using MPSKit: GeometryStyle, FiniteChainStyle
1212
using TensorKit
1313
using TensorKit:
14+
using Adapt
1415

1516
@testset "FiniteMPS ($(sectortype(D)), $elt)" for (D, d, elt) in [
1617
(ℙ^10, ℙ^2, ComplexF64),

test/states/infinitemps.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using MPSKit
1010
using MPSKit: GeometryStyle, InfiniteChainStyle, TransferMatrix
1111
using TensorKit
1212
using TensorKit:
13+
using Adapt
1314

1415
@testset "InfiniteMPS ($(sectortype(D)), $elt)" for (D, d, elt) in
1516
[(ℙ^10, ℙ^2, ComplexF64), (Rep[U₁](1 => 3), Rep[U₁](0 => 1), ComplexF64)]
@@ -71,6 +72,19 @@ end
7172
@test mps1.C[1] !== mps2.C[1]
7273
end
7374

75+
@testset "Adapt" begin
76+
for (d, D) in [(ℂ^2, ℂ^4), (ℙ^2, ℙ^4)]
77+
mps1 = InfiniteMPS(rand, Float32, d, D)
78+
for T in (Float64, ComplexF64)
79+
mps2 = @testinferred adapt(Vector{T}, mps1)
80+
@test mps2 isa InfiniteMPS
81+
@test scalartype(mps2) == T
82+
@test storagetype(mps2) == Vector{T}
83+
@test dot(mps1, mps2) 1 atol = 1.0e-4
84+
end
85+
end
86+
end
87+
7488
@testset "InfiniteMPS entropy ($(sectortype(D)), $elt)" for (D, d, elt) in
7589
[(ℙ^10, ℙ^2, ComplexF64), (Rep[U₁](1 => 3), Rep[U₁](0 => 1), ComplexF64)]
7690
ψ = InfiniteMPS([d, d], [D, D])
@@ -92,3 +106,28 @@ end
92106
Ss_product = entropy(ψ_product)
93107
@test all(S -> isapprox(S, 0; atol = 1.0e-10), Ss_product)
94108
end
109+
110+
@testset "InfiniteMPS copying" begin
111+
mps1 = InfiniteMPS(rand, ComplexF64, ℂ^2, ℂ^5)
112+
mps2 = copy(mps1)
113+
114+
@test mps1 !== mps2
115+
116+
# elements are equal
117+
@test mps1.AL[1] == mps2.AL[1]
118+
@test mps1.AR[1] == mps2.AR[1]
119+
@test mps1.AC[1] == mps2.AC[1]
120+
@test mps1.C[1] == mps2.C[1]
121+
122+
# arrays are distinct
123+
@test mps1.AL !== mps2.AL
124+
@test mps1.AR !== mps2.AR
125+
@test mps1.AC !== mps2.AC
126+
@test mps1.C !== mps2.C
127+
128+
# tensors are distinct
129+
@test mps1.AL[1] !== mps2.AL[1]
130+
@test mps1.AR[1] !== mps2.AR[1]
131+
@test mps1.AC[1] !== mps2.AC[1]
132+
@test mps1.C[1] !== mps2.C[1]
133+
end

0 commit comments

Comments
 (0)