Skip to content

Commit 2db4cb1

Browse files
authored
Adapt extension (#384)
* add adapt extension * add tests * centralize storagetype definition * accept storagetype for `jordanmpotensortype` * small fixes
1 parent bb559eb commit 2db4cb1

9 files changed

Lines changed: 161 additions & 14 deletions

File tree

Project.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,15 @@ TensorKitManifolds = "11fa318c-39cb-4a83-b1ed-cdc7ba1e3684"
2323
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
2424
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2525

26+
[weakdeps]
27+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
28+
29+
[extensions]
30+
MPSKitAdaptExt = "Adapt"
31+
2632
[compat]
2733
Accessors = "0.1"
34+
Adapt = "4"
2835
Aqua = "0.8.9"
2936
BlockTensorKit = "0.3.4"
3037
Combinatorics = "1"
@@ -53,6 +60,7 @@ VectorInterface = "0.2, 0.3, 0.4, 0.5"
5360
julia = "1.10"
5461

5562
[extras]
63+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
5664
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5765
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5866
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
@@ -63,4 +71,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6371
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
6472

6573
[targets]
66-
test = ["Aqua", "Pkg", "Test", "TestExtras", "Plots", "Combinatorics", "ParallelTestRunner", "TensorKitTensors"]
74+
test = ["Aqua", "Adapt", "Pkg", "Test", "TestExtras", "Plots", "Combinatorics", "ParallelTestRunner", "TensorKitTensors"]

ext/MPSKitAdaptExt.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
module MPSKitAdaptExt
2+
3+
using TensorKit: space, spacetype
4+
using MPSKit
5+
using BlockTensorKit: nonzero_pairs
6+
using Adapt
7+
8+
function Adapt.adapt_structure(to, mps::FiniteMPS)
9+
ad = adapt(to)
10+
adapt_not_missing(x) = ismissing(x) ? x : ad(x)
11+
12+
TA = Base.promote_op(ad, MPSKit.site_type(mps))
13+
TB = Base.promote_op(ad, MPSKit.bond_type(mps))
14+
15+
ALs = map!(adapt_not_missing, similar(mps.ALs, Union{Missing, TA}), mps.ALs)
16+
ARs = map!(adapt_not_missing, similar(mps.ARs, Union{Missing, TA}), mps.ARs)
17+
ACs = map!(adapt_not_missing, similar(mps.ACs, Union{Missing, TA}), mps.ACs)
18+
Cs = map!(adapt_not_missing, similar(mps.Cs, Union{Missing, TB}), mps.Cs)
19+
20+
return FiniteMPS{TA, TB}(ALs, ARs, ACs, Cs)
21+
end
22+
23+
function Adapt.adapt_structure(to, mps::InfiniteMPS)
24+
ad = adapt(to)
25+
AL = map(ad, mps.AL)
26+
AR = map(ad, mps.AR)
27+
C = map(ad, mps.C)
28+
AC = map(ad, mps.AC)
29+
30+
return InfiniteMPS{eltype(AL), eltype(C)}(AL, AR, C, AC)
31+
end
32+
33+
Adapt.adapt_structure(to, mpo::MPO) = MPO(map(adapt(to), mpo.O))
34+
35+
function Adapt.adapt_structure(::Type{TorA}, W::MPSKit.JordanMPOTensor) where {TorA <: Union{Number, DenseVector{<:Number}}}
36+
TT = MPSKit.jordanmpotensortype(spacetype(W), TorA)
37+
W′ = TT(undef, space(W))
38+
ad = adapt(TorA)
39+
40+
for (k, v) in nonzero_pairs(W.A)
41+
W′.A[k] = ad(v)
42+
end
43+
for (k, v) in nonzero_pairs(W.B)
44+
W′.B[k] = ad(v)
45+
end
46+
for (k, v) in nonzero_pairs(W.C)
47+
W′.C[k] = ad(v)
48+
end
49+
for (k, v) in nonzero_pairs(W.D)
50+
W′.D[k] = ad(v)
51+
end
52+
53+
return W′
54+
end
55+
Adapt.adapt_structure(to, mpo::MPOHamiltonian) = MPOHamiltonian(map(adapt(to), mpo.W))
56+
57+
end

src/operators/jordanmpotensor.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ function JordanMPOTensor(W::SparseBlockTensorMap{TT, E, S, 2, 2}) where {TT, E,
121121
)
122122
end
123123

124-
function jordanmpotensortype(::Type{S}, ::Type{E}) where {S <: VectorSpace, E <: Number}
125-
TA = Union{tensormaptype(S, 2, 2, E), BraidingTensor{E, S}}
126-
TB = tensormaptype(S, 2, 1, E)
127-
TC = tensormaptype(S, 1, 2, E)
128-
TD = tensormaptype(S, 1, 1, E)
129-
return JordanMPOTensor{E, S, TA, TB, TC, TD}
124+
function jordanmpotensortype(::Type{S}, ::Type{TorA}) where {S <: VectorSpace, TorA}
125+
TA = Union{tensormaptype(S, 2, 2, TorA), BraidingTensor{scalartype(TorA), S}}
126+
TB = tensormaptype(S, 2, 1, TorA)
127+
TC = tensormaptype(S, 1, 2, TorA)
128+
TD = tensormaptype(S, 1, 1, TorA)
129+
return JordanMPOTensor{scalartype(TorA), S, TA, TB, TC, TD}
130130
end
131131
function jordanmpotensortype(::Type{O}) where {O <: MPOTensor}
132132
return jordanmpotensortype(spacetype(O), scalartype(O))

src/states/abstractmps.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ TensorKit.spacetype(ψ::AbstractMPS) = spacetype(typeof(ψ))
199199
TensorKit.spacetype(ψtype::Type{<:AbstractMPS}) = spacetype(site_type(ψtype))
200200
TensorKit.sectortype::AbstractMPS) = sectortype(typeof(ψ))
201201
TensorKit.sectortype(ψtype::Type{<:AbstractMPS}) = sectortype(site_type(ψtype))
202+
TensorKit.storagetype(ψtype::Type{<:AbstractMPS}) = storagetype(site_type(ψtype))
202203

203204
"""
204205
left_virtualspace(ψ::AbstractMPS, [pos=1:length(ψ)])

src/states/finitemps.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,6 @@ end
376376

377377
site_type(::Type{<:FiniteMPS{A}}) where {A} = A
378378
bond_type(::Type{<:FiniteMPS{<:Any, B}}) where {B} = B
379-
function TensorKit.storagetype(::Union{MPS, Type{MPS}}) where {A, MPS <: FiniteMPS{A}}
380-
return storagetype(A)
381-
end
382379

383380
function left_virtualspace::FiniteMPS, n::Integer)
384381
checkbounds(ψ, n)

test/operators/mpo.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ println("
44
--------------------
55
")
66

7-
using .TestSetup
87
using Test, TestExtras
8+
using Adapt
99
using MPSKit
1010
using MPSKit: GeometryStyle, FiniteChainStyle, InfiniteChainStyle, OperatorStyle, MPOStyle
1111
using TensorKit
@@ -83,6 +83,7 @@ using TensorKit: ℙ
8383
end
8484
end
8585

86+
8687
@testset "InfiniteMPO" begin
8788
P =^2
8889
T = Float64
@@ -98,3 +99,26 @@ end
9899
@test OperatorStyle(typeof(H)) == MPOStyle()
99100
@test OperatorStyle(H) == MPOStyle()
100101
end
102+
103+
@testset "Adapt" for V in (ℂ^2, U1Space(-1 => 1, 0 => 1, 1 => 1))
104+
L = 3
105+
o = rand(Float32, V^L V^L)
106+
mpo1 = FiniteMPO(o)
107+
for T in (Float64, ComplexF64)
108+
mpo2 = @testinferred adapt(Vector{T}, mpo1)
109+
@test mpo2 isa FiniteMPO
110+
@test scalartype(mpo2) == T
111+
@test storagetype(mpo2) == Vector{T}
112+
@test convert(TensorMap, mpo2) o
113+
end
114+
115+
mpo3 = InfiniteMPO(mpo1[2:2])
116+
for T in (Float64, ComplexF64)
117+
mpo4 = @testinferred adapt(Vector{T}, mpo3)
118+
@test mpo4 isa InfiniteMPO
119+
@test scalartype(mpo4) == T
120+
@test storagetype(mpo4) == Vector{T}
121+
@test dot(mpo3, mpo4) 1 atol = 1.0e-4
122+
end
123+
124+
end

test/operators/mpohamiltonian.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ println("
44
----------------------------
55
")
66

7-
using .TestSetup
87
using Test, TestExtras
8+
using Adapt
99
using MPSKit
1010
using MPSKit: GeometryStyle, FiniteChainStyle, InfiniteChainStyle, OperatorStyle, HamiltonianStyle
1111
using TensorKit
@@ -239,3 +239,36 @@ end
239239
h4 = H4 * H4
240240
@test real(expectation_value(ψ2, H4)) >= 0
241241
end
242+
243+
@testset "Adapt" for V in (ℂ^2, U1Space(-1 => 1, 0 => 1, 1 => 1))
244+
h = rand(Float32, V^2 V^2)
245+
h += h'
246+
247+
L = 4
248+
H1 = FiniteMPOHamiltonian(
249+
fill(V, L),
250+
((i, i + 1) => h for i in 1:(L - 1))...,
251+
((i, i + 2) => h for i in 1:(L - 2))...,
252+
((i, i + 3) => h for i in 1:(L - 3))...,
253+
)
254+
mps1 = FiniteMPS(physicalspace(H1), oneunit(V))
255+
256+
for T in (Float64, ComplexF64)
257+
H2 = @testinferred adapt(Vector{T}, H1)
258+
@test H2 isa FiniteMPOHamiltonian
259+
@test scalartype(H2) == T
260+
@test storagetype(H2) == Vector{T}
261+
@test expectation_value(mps1, H1) expectation_value(mps1, H2)
262+
end
263+
264+
H3 = InfiniteMPOHamiltonian(fill(V, L), (1, 2) => h, (1, 3) => h, (1, 4) => h)
265+
mps2 = InfiniteMPS(physicalspace(H3), [oneunit(V)])
266+
for T in (Float64, ComplexF64)
267+
H4 = @testinferred adapt(Vector{T}, H3)
268+
@test H4 isa InfiniteMPOHamiltonian
269+
@test scalartype(H4) == T
270+
@test storagetype(H4) == Vector{T}
271+
@test expectation_value(mps2, H3) expectation_value(mps2, H4)
272+
end
273+
274+
end

test/states/finitemps.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ println("
44
----------------------
55
")
66

7-
using .TestSetup
87
using Test, TestExtras
8+
using Adapt
99
using MPSKit
1010
using MPSKit: _transpose_front, _transpose_tail
1111
using MPSKit: GeometryStyle, FiniteChainStyle
@@ -76,6 +76,20 @@ end
7676
end
7777
end
7878

79+
@testset "Adapt" begin
80+
for (d, D) in [(ℂ^2, ℂ^4), (ℙ^2, ℙ^4)]
81+
mps1 = FiniteMPS(rand, Float32, 3, d, D)
82+
t1 = convert(TensorMap, mps1)
83+
for T in (Float64, ComplexF64)
84+
mps2 = @testinferred adapt(Vector{T}, mps1)
85+
@test mps2 isa FiniteMPS
86+
@test scalartype(mps2) == T
87+
@test storagetype(mps2) == Vector{T}
88+
@test convert(TensorMap, mps2) t1
89+
end
90+
end
91+
end
92+
7993
@testset "FiniteMPS center + (slice) indexing" begin
8094
L = 11
8195
ψ = FiniteMPS(L, ℂ^2, ℂ^16)

test/states/infinitemps.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ println("
44
------------------------
55
")
66

7-
using .TestSetup
87
using Test, TestExtras
8+
using Adapt
99
using MPSKit
1010
using MPSKit: GeometryStyle, InfiniteChainStyle, TransferMatrix
1111
using TensorKit
@@ -45,3 +45,16 @@ using TensorKit: ℙ
4545
@test TransferMatrix.AR[i], ψ.AR[i]) * r_RR(ψ, i) r_RR(ψ, i + 1)
4646
end
4747
end
48+
49+
@testset "Adapt" begin
50+
for (d, D) in [(ℂ^2, ℂ^4), (ℙ^2, ℙ^4)]
51+
mps1 = InfiniteMPS(rand, Float32, d, D)
52+
for T in (Float64, ComplexF64)
53+
mps2 = @testinferred adapt(Vector{T}, mps1)
54+
@test mps2 isa InfiniteMPS
55+
@test scalartype(mps2) == T
56+
@test storagetype(mps2) == Vector{T}
57+
@test dot(mps1, mps2) 1 atol = 1.0e-4
58+
end
59+
end
60+
end

0 commit comments

Comments
 (0)