Skip to content

Commit 4913563

Browse files
committed
Changes for non-CPU array support
1 parent b3427bc commit 4913563

8 files changed

Lines changed: 37 additions & 22 deletions

File tree

src/algorithms/expval.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ function expectation_value(
192192
ψ::FiniteMPS, O::ProjectionOperator,
193193
envs::FiniteEnvironments = environments(ψ, O)
194194
)
195-
ens = zeros(scalartype(ψ), length(ψ))
195+
ens = zeros(storagetype(ψ), length(ψ))
196196
for i in 1:length(ψ)
197197
operator = AC_hamiltonian(i, ψ, O, ψ, envs)
198198
ens[i] = dot.AC[i], operator * ψ.AC[i])

src/environments/abstract_envs.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,48 +15,52 @@ Base.unlock(envs::AbstractMPSEnvironments) = unlock(envs.lock);
1515
# ------------------
1616
function allocate_GL(bra::AbstractMPS, mpo::AbstractMPO, ket::AbstractMPS, i::Int)
1717
T = Base.promote_type(scalartype(bra), scalartype(mpo), scalartype(ket))
18+
TA = similarstoragetype(storagetype(mpo), T)
1819
V = left_virtualspace(bra, i) left_virtualspace(mpo, i)'
1920
left_virtualspace(ket, i)
2021
if V isa BlockTensorKit.TensorMapSumSpace
21-
TT = blocktensormaptype(spacetype(bra), numout(V), numin(V), T)
22+
TT = blocktensormaptype(spacetype(bra), numout(V), numin(V), TA)
2223
else
23-
TT = TensorMap{T}
24+
TT = TensorKit.TensorMapWithStorage{T, TA}
2425
end
2526
return TT(undef, V)
2627
end
2728

2829
function allocate_GR(bra::AbstractMPS, mpo::AbstractMPO, ket::AbstractMPS, i::Int)
2930
T = Base.promote_type(scalartype(bra), scalartype(mpo), scalartype(ket))
31+
TA = similarstoragetype(storagetype(mpo), T)
3032
V = right_virtualspace(ket, i) right_virtualspace(mpo, i)
3133
right_virtualspace(bra, i)
3234
if V isa BlockTensorKit.TensorMapSumSpace
33-
TT = blocktensormaptype(spacetype(bra), numout(V), numin(V), T)
35+
TT = blocktensormaptype(spacetype(bra), numout(V), numin(V), TA)
3436
else
35-
TT = TensorMap{T}
37+
TT = TensorKit.TensorMapWithStorage{T, TA}
3638
end
3739
return TT(undef, V)
3840
end
3941

4042
function allocate_GBL(bra::QP, mpo::AbstractMPO, ket::QP, i::Int)
4143
T = Base.promote_type(scalartype(bra), scalartype(mpo), scalartype(ket))
44+
TA = similarstoragetype(storagetype(mpo), T)
4245
V = left_virtualspace(bra.left_gs, i) left_virtualspace(mpo, i)'
4346
auxiliaryspace(ket)' left_virtualspace(ket.right_gs, i)
4447
if V isa BlockTensorKit.TensorMapSumSpace
45-
TT = blocktensormaptype(spacetype(bra), numout(V), numin(V), T)
48+
TT = blocktensormaptype(spacetype(bra), numout(V), numin(V), TA)
4649
else
47-
TT = TensorMap{T}
50+
TT = TensorKit.TensorMapWithStorage{T, TA}
4851
end
4952
return TT(undef, V)
5053
end
5154

5255
function allocate_GBR(bra::QP, mpo::AbstractMPO, ket::QP, i::Int)
5356
T = Base.promote_type(scalartype(bra), scalartype(mpo), scalartype(ket))
57+
TA = similarstoragetype(storagetype(mpo), T)
5458
V = right_virtualspace(ket.left_gs, i) right_virtualspace(mpo, i)
5559
auxiliaryspace(ket)' right_virtualspace(bra.right_gs, i)
5660
if V isa BlockTensorKit.TensorMapSumSpace
57-
TT = blocktensormaptype(spacetype(bra), numout(V), numin(V), T)
61+
TT = blocktensormaptype(spacetype(bra), numout(V), numin(V), TA)
5862
else
59-
TT = TensorMap{T}
63+
TT = TensorKit.TensorMapWithStorage{T, TA}
6064
end
6165
return TT(undef, V)
6266
end

src/operators/mpohamiltonian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ function Base.:*(H::FiniteMPOHamiltonian, mps::FiniteMPS)
824824
)
825825
)
826826
# left to middle
827-
U = ones(scalartype(H), left_virtualspace(H, 1))
827+
U = ones(storagetype(H), left_virtualspace(H, 1))
828828
@plansor a[-1 -2; -3 -4] := A[1][-1 2; -3] * H[1][1 -2; 2 -4] * conj(U[1])
829829
Q, R = qr_compact!(a)
830830
A′[1] = TensorMap(Q)
@@ -836,7 +836,7 @@ function Base.:*(H::FiniteMPOHamiltonian, mps::FiniteMPS)
836836
end
837837

838838
# right to middle
839-
U = ones(scalartype(H), right_virtualspace(H, N))
839+
U = ones(storagetype(H), right_virtualspace(H, N))
840840
@plansor a[-1 -2; -3 -4] := A[end][-1 2; -3] * H[end][-2 -4; 2 1] * U[1]
841841
L, Q = lq_compact!(a)
842842
A′[end] = transpose(TensorMap(Q), ((1, 3), (2,)))

src/operators/multilinempo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ end
2222
MultilineMPO(mpos::AbstractVector{<:AbstractMPO}) = Multiline(mpos)
2323
MultilineMPO(t::MPOTensor) = MultilineMPO(PeriodicMatrix(fill(t, 1, 1)))
2424

25+
TensorKit.storagetype(M::MultilineMPO) = storagetype(M.data)
26+
2527
# allow indexing with two indices
2628
Base.getindex(t::MultilineMPO, ::Colon, j::Int) = Base.getindex.(t.data, j)
2729
Base.getindex(t::MultilineMPO, i::Int, j) = Base.getindex(t[i], j)

src/states/abstractmps.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ Construct an `MPSTensor` with given physical and virtual spaces.
3333
- `right_D::Int`: right virtual dimension
3434
"""
3535
function MPSTensor(
36-
::UndefInitializer, eltype, P::Union{S, CompositeSpace{S}}, Vₗ::S, Vᵣ::S = Vₗ
37-
) where {S <: ElementarySpace}
38-
return TensorMap{eltype}(undef, Vₗ P Vᵣ)
36+
::UndefInitializer, ::Type{TorA}, P::Union{S, CompositeSpace{S}}, Vₗ::S, Vᵣ::S = Vₗ
37+
) where {S <: ElementarySpace, TorA}
38+
return TensorKit.TensorMapWithStorage{TorA}(undef, Vₗ P Vᵣ)
3939
end
4040
function MPSTensor(
41-
f, eltype, P::Union{S, CompositeSpace{S}}, Vₗ::S, Vᵣ::S = Vₗ
42-
) where {S <: ElementarySpace}
43-
A = MPSTensor(undef, eltype, P, Vₗ, Vᵣ)
41+
f, ::Type{TorA}, P::Union{S, CompositeSpace{S}}, Vₗ::S, Vᵣ::S = Vₗ
42+
) where {S <: ElementarySpace, TorA}
43+
A = MPSTensor(undef, TorA, P, Vₗ, Vᵣ)
4444
if f === rand
4545
return rand!(A)
4646
elseif f === randn
@@ -70,18 +70,18 @@ Construct an `MPSTensor` with given physical and virtual dimensions.
7070
- `Dₗ::Int`: left virtual dimension
7171
- `Dᵣ::Int`: right virtual dimension
7272
"""
73-
MPSTensor(f, eltype, d::Int, Dₗ::Int, Dᵣ::Int = Dₗ) = MPSTensor(f, eltype, ℂ^d, ℂ^Dₗ, ℂ^Dᵣ)
73+
MPSTensor(f, ::Type{TorA}, d::Int, Dₗ::Int, Dᵣ::Int = Dₗ) where {TorA} = MPSTensor(f, TorA, ℂ^d, ℂ^Dₗ, ℂ^Dᵣ)
7474
MPSTensor(d::Int, Dₗ::Int; Dᵣ::Int = Dₗ) = MPSTensor(ℂ^d, ℂ^Dₗ, ℂ^Dᵣ)
7575

7676
"""
7777
MPSTensor(A::AbstractArray)
7878
7979
Convert an array to an `MPSTensor`.
8080
"""
81-
function MPSTensor(A::AbstractArray{T}) where {T <: Number}
81+
function MPSTensor(A::AA) where {T <: Number, AA <: AbstractArray{T}}
8282
@assert ndims(A) > 2 "MPSTensor should have at least 3 dims, but has $ndims(A)"
8383
sz = size(A)
84-
t = TensorMap(undef, T, foldl(, ComplexSpace.(sz[1:(end - 1)])) ^sz[end])
84+
t = TensorKit.TensorMapWithStorage{T, AA}(undef, foldl(, ComplexSpace.(sz[1:(end - 1)])) ^sz[end])
8585
t[] .= A
8686
return t
8787
end

src/states/ortho.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ $(TYPEDFIELDS)
2121
verbosity::Int = VERBOSE_WARN
2222

2323
"algorithm used for orthogonalization of the tensors"
24-
alg_orth = LAPACK_HouseholderQR(; positive = true)
24+
alg_orth = Defaults.alg_qr()
2525
"algorithm used for the eigensolver"
2626
alg_eigsolve = _GAUGE_ALG_EIGSOLVE
2727
"minimal amount of iterations before using the eigensolver steps"
@@ -46,7 +46,7 @@ $(TYPEDFIELDS)
4646
verbosity::Int = VERBOSE_WARN
4747

4848
"algorithm used for orthogonalization of the tensors"
49-
alg_orth = LAPACK_HouseholderLQ(; positive = true)
49+
alg_orth = Defaults.alg_lq()
5050
"algorithm used for the eigensolver"
5151
alg_eigsolve = _GAUGE_ALG_EIGSOLVE
5252
"minimal amount of iterations before using the eigensolver steps"
@@ -80,9 +80,15 @@ function MixedCanonical(;
8080
if alg_orth isa LAPACK_HouseholderQR
8181
alg_leftorth = alg_orth
8282
alg_rightorth = LAPACK_HouseholderLQ(; alg_orth.kwargs...)
83+
elseif alg_orth isa CUSOLVER_HouseholderQR
84+
alg_leftorth = alg_orth
85+
alg_rightorth = LQViaTransposedQR(CUSOLVER_HouseholderQR(; alg_orth.kwargs...))
8386
elseif alg_orth isa LAPACK_HouseholderLQ
8487
alg_leftorth = LAPACK_HouseholderQR(; alg_orth.kwargs...)
8588
alg_rightorth = alg_orth
89+
elseif alg_orth isa LQViaTransposedQR
90+
alg_leftorth = alg_orth
91+
alg_rightorth = alg_orth.qr_alg
8692
else
8793
alg_leftorth = alg_rightorth = alg_orth
8894
end

src/utility/defaults.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using OhMyThreads
1010
using ..MPSKit: DynamicTol
1111
using TensorKit: TensorKit
1212
using MatrixAlgebraKit: LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_DivideAndConquer
13+
using MatrixAlgebraKit: CUSOLVER_HouseholderQR, LQViaTransposedQR, CUSOLVER_Jacobi
1314

1415
const VERBOSE_NONE = 0
1516
const VERBOSE_WARN = 1

src/utility/periodicarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ function PeriodicArray{T, N}(initializer, args...) where {T, N}
3838
return PeriodicArray(Array{T, N}(initializer, args...))
3939
end
4040

41+
TensorKit.storagetype(PA::PeriodicArray{T, N}) where {T, N} = storagetype(T)
42+
4143
"""
4244
PeriodicVector{T}
4345

0 commit comments

Comments
 (0)