Skip to content

Commit 888de7e

Browse files
authored
Changes for non-CPU array support (#375)
1 parent c5ea656 commit 888de7e

16 files changed

Lines changed: 207 additions & 41 deletions

src/algorithms/toolbox.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@ function entropy(spectrum::TensorKit.SectorVector{T}) where {T}
1919
S = zero(T)
2020
tol = eps(T)
2121
for (c, b) in pairs(spectrum)
22-
s = zero(S)
23-
for x in b
24-
x < tol && break
22+
s = sum(b; init = zero(S)) do x
2523
= x^2
26-
s +=* log(x²)
24+
return x < tol ? zero(x) :* log(x²)
2725
end
2826
S += oftype(S, dim(c) * s)
2927
end

src/operators/abstractmpo.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ end
151151
Compute the mpo tensor that arises from multiplying MPOs.
152152
"""
153153
function fuse_mul_mpo(O1, O2)
154-
T = promote_type(scalartype(O1), scalartype(O2))
154+
TT = promote_type(scalartype(O1), scalartype(O2))
155+
T = TensorKit.similarstoragetype(storagetype(O1), TT)
155156
F_left = fuser(T, left_virtualspace(O2), left_virtualspace(O1))
156157
F_right = fuser(T, right_virtualspace(O2), right_virtualspace(O1))
157158
return _fuse_mpo_mpo(O1, O2, F_left, F_right)

src/operators/mpohamiltonian.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct MPOHamiltonian{TO <: JordanMPOTensor, V <: AbstractVector{TO}} <: Abstrac
3232
W::V
3333
end
3434
OperatorStyle(::Type{<:MPOHamiltonian}) = HamiltonianStyle()
35+
TensorKit.storagetype(::Type{MPOHamiltonian{O, V}}) where {O, V} = storagetype(O)
3536

3637
const FiniteMPOHamiltonian{O <: MPOTensor} = MPOHamiltonian{O, Vector{O}}
3738
Base.isfinite(::Type{<:FiniteMPOHamiltonian}) = true

src/states/abstractmps.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ Construct an `MPSTensor` with given physical and virtual spaces.
3333
- `right_D::Int`: right virtual dimension
3434
"""
3535
function MPSTensor(
36-
::UndefInitializer, eltype, P::Union{S, CompositeSpace{S}}, Vₗ::S, Vᵣ::S = Vₗ
36+
::UndefInitializer, T, P::Union{S, CompositeSpace{S}}, Vₗ::S, Vᵣ::S = Vₗ
3737
) where {S <: ElementarySpace}
38-
TT = tensormaptype(S, 1 + (P isa S ? 1 : length(P)), 1, eltype)
38+
TT = tensormaptype(S, 1 + (P isa S ? 1 : length(P)), 1, T)
3939
return TT(undef, Vₗ P Vᵣ)
4040
end
4141
function MPSTensor(
42-
f, eltype, P::Union{S, CompositeSpace{S}}, Vₗ::S, Vᵣ::S = Vₗ
42+
f, T, P::Union{S, CompositeSpace{S}}, Vₗ::S, Vᵣ::S = Vₗ
4343
) where {S <: ElementarySpace}
44-
A = MPSTensor(undef, eltype, P, Vₗ, Vᵣ)
44+
A = MPSTensor(undef, T, P, Vₗ, Vᵣ)
4545
if f === rand
4646
return rand!(A)
4747
elseif f === randn
@@ -71,15 +71,15 @@ Construct an `MPSTensor` with given physical and virtual dimensions.
7171
- `Dₗ::Int`: left virtual dimension
7272
- `Dᵣ::Int`: right virtual dimension
7373
"""
74-
MPSTensor(f, eltype, d::Int, Dₗ::Int, Dᵣ::Int = Dₗ) = MPSTensor(f, eltype, ℂ^d, ℂ^Dₗ, ℂ^Dᵣ)
74+
MPSTensor(f, T, d::Int, Dₗ::Int, Dᵣ::Int = Dₗ) = MPSTensor(f, T, ℂ^d, ℂ^Dₗ, ℂ^Dᵣ)
7575
MPSTensor(d::Int, Dₗ::Int; Dᵣ::Int = Dₗ) = MPSTensor(ℂ^d, ℂ^Dₗ, ℂ^Dᵣ)
7676

7777
"""
7878
MPSTensor(A::AbstractArray)
7979
8080
Convert an array to an `MPSTensor`.
8181
"""
82-
function MPSTensor(A::AbstractArray{T}) where {T <: Number}
82+
function MPSTensor(A::AbstractArray{<:Number})
8383
@assert ndims(A) > 2 "MPSTensor should have at least 3 dims, but has $ndims(A)"
8484
sz = size(A)
8585
V = foldl(, ComplexSpace.(sz[1:(end - 1)])) ^sz[end]

src/states/infinitemps.jl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,35 +116,42 @@ function InfiniteMPS(
116116
convert(PeriodicVector{B}, C), convert(PeriodicVector{A}, AC)
117117
)
118118
end
119-
119+
function InfiniteMPS(
120+
T::Type,
121+
pspaces::AbstractVector{S}, Dspaces::AbstractVector{S};
122+
kwargs...
123+
) where {S <: IndexSpace}
124+
return InfiniteMPS(MPSTensor.(rand, T, pspaces, circshift(Dspaces, 1), Dspaces); kwargs...)
125+
end
120126
function InfiniteMPS(
121127
pspaces::AbstractVector{S}, Dspaces::AbstractVector{S};
122128
kwargs...
123129
) where {S <: IndexSpace}
124-
return InfiniteMPS(MPSTensor.(pspaces, circshift(Dspaces, 1), Dspaces); kwargs...)
130+
return InfiniteMPS(MPSTensor.(rand, ComplexF64, pspaces, circshift(Dspaces, 1), Dspaces); kwargs...)
125131
end
126132
function InfiniteMPS(
127-
f, elt::Type{<:Number}, pspaces::AbstractVector{S}, Dspaces::AbstractVector{S};
133+
f, T::Type, pspaces::AbstractVector{S}, Dspaces::AbstractVector{S};
128134
kwargs...
129135
) where {S <: IndexSpace}
130136
return InfiniteMPS(
131-
MPSTensor.(f, elt, pspaces, circshift(Dspaces, 1), Dspaces);
137+
MPSTensor.(f, T, pspaces, circshift(Dspaces, 1), Dspaces);
132138
kwargs...
133139
)
134140
end
135-
InfiniteMPS(d::S, D::S) where {S <: Union{Int, <:IndexSpace}} = InfiniteMPS([d], [D])
141+
InfiniteMPS(T::Type, d::S, D::S; kwargs...) where {S <: Union{Int, <:IndexSpace}} = InfiniteMPS(T, [d], [D]; kwargs...)
142+
InfiniteMPS(d::S, D::S; kwargs...) where {S <: Union{Int, <:IndexSpace}} = InfiniteMPS([d], [D]; kwargs...)
136143
function InfiniteMPS(
137-
f, elt::Type{<:Number}, d::S, D::S
144+
f, T::Type, d::S, D::S; kwargs...
138145
) where {S <: Union{Int, <:IndexSpace}}
139-
return InfiniteMPS(f, elt, [d], [D])
146+
return InfiniteMPS(f, T, [d], [D]; kwargs...)
140147
end
141-
function InfiniteMPS(ds::AbstractVector{Int}, Ds::AbstractVector{Int})
142-
return InfiniteMPS(ComplexSpace.(ds), ComplexSpace.(Ds))
148+
function InfiniteMPS(ds::AbstractVector{Int}, Ds::AbstractVector{Int}; kwargs...)
149+
return InfiniteMPS(ComplexSpace.(ds), ComplexSpace.(Ds); kwargs...)
143150
end
144151
function InfiniteMPS(
145-
f, elt::Type{<:Number}, ds::AbstractVector{Int}, Ds::AbstractVector{Int}, kwargs...
152+
f, T::Type, ds::AbstractVector{Int}, Ds::AbstractVector{Int}, kwargs...
146153
)
147-
return InfiniteMPS(f, elt, ComplexSpace.(ds), ComplexSpace.(Ds); kwargs...)
154+
return InfiniteMPS(f, T, ComplexSpace.(ds), ComplexSpace.(Ds); kwargs...)
148155
end
149156

150157
function InfiniteMPS(A::AbstractVector{<:GenericMPSTensor}; kwargs...)

src/states/multilinemps.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,6 @@ for f in (:r_RR, :r_RL, :r_LR, :r_LL)
8181
@eval $f(t::MultilineMPS, i, j = size(t, 2)) = $f(t[i], j)
8282
end
8383

84-
site_type(::Type{Multiline{S}}) where {S} = site_type(S)
85-
bond_type(::Type{Multiline{S}}) where {S} = bond_type(S)
86-
site_type(st::Multiline) = site_type(typeof(st))
87-
bond_type(st::Multiline) = bond_type(typeof(st))
88-
VectorInterface.scalartype(::Multiline{T}) where {T} = scalartype(T)
89-
TensorKit.sectortype(t::Multiline) = sectortype(typeof(t))
90-
TensorKit.sectortype(::Type{Multiline{T}}) where {T} = sectortype(T)
91-
TensorKit.spacetype(t::Multiline) = spacetype(typeof(t))
92-
TensorKit.spacetype(::Type{Multiline{T}}) where {T} = spacetype(T)
93-
9484
function TensorKit.dot(a::MultilineMPS, b::MultilineMPS; kwargs...)
9585
return sum(dot.(parent(a), parent(b); kwargs...))
9686
end

src/states/ortho.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ $(TYPEDFIELDS)
2121
verbosity::Int = VERBOSE_WARN
2222

2323
"algorithm used for orthogonalization of the tensors"
24-
alg_orth = LAPACK_HouseholderQR(; positive = true)
24+
alg_orth = Defaults.alg_qr()
2525
"algorithm used for the eigensolver"
2626
alg_eigsolve = _GAUGE_ALG_EIGSOLVE
2727
"minimal amount of iterations before using the eigensolver steps"
@@ -46,7 +46,7 @@ $(TYPEDFIELDS)
4646
verbosity::Int = VERBOSE_WARN
4747

4848
"algorithm used for orthogonalization of the tensors"
49-
alg_orth = LAPACK_HouseholderLQ(; positive = true)
49+
alg_orth = Defaults.alg_lq()
5050
"algorithm used for the eigensolver"
5151
alg_eigsolve = _GAUGE_ALG_EIGSOLVE
5252
"minimal amount of iterations before using the eigensolver steps"
@@ -73,16 +73,22 @@ end
7373

7474
function MixedCanonical(;
7575
tol::Real = Defaults.tolgauge, maxiter::Int = Defaults.maxiter,
76-
verbosity::Int = VERBOSE_WARN, alg_orth = LAPACK_HouseholderQR(; positive = true),
76+
verbosity::Int = VERBOSE_WARN, alg_orth = Defaults.alg_qr(),
7777
alg_eigsolve = _GAUGE_ALG_EIGSOLVE,
7878
eig_miniter::Int = 10, order::Symbol = :LR
7979
)
8080
if alg_orth isa LAPACK_HouseholderQR
8181
alg_leftorth = alg_orth
8282
alg_rightorth = LAPACK_HouseholderLQ(; alg_orth.kwargs...)
83+
elseif alg_orth isa CUSOLVER_HouseholderQR
84+
alg_leftorth = alg_orth
85+
alg_rightorth = LQViaTransposedQR(CUSOLVER_HouseholderQR(; alg_orth.kwargs...))
8386
elseif alg_orth isa LAPACK_HouseholderLQ
8487
alg_leftorth = LAPACK_HouseholderQR(; alg_orth.kwargs...)
8588
alg_rightorth = alg_orth
89+
elseif alg_orth isa LQViaTransposedQR
90+
alg_leftorth = alg_orth
91+
alg_rightorth = alg_orth.qr_alg
8692
else
8793
alg_leftorth = alg_rightorth = alg_orth
8894
end

src/states/quasiparticle_state.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#=
22
Should not be constructed by the user - acts like a vector (used in eigsolve)
33
I think it makes sense to see these things as an actual state instead of return an array of B tensors (what we used to do)
4-
This will allow us to plot energy density (finite qp) and measure observeables.
4+
This will allow us to plot energy density (finite qp) and measure observables.
55
=#
66

77
struct LeftGaugedQP{S, T1, T2, E <: Number}

src/utility/multiline.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,14 @@ function VectorInterface.inner(x::Multiline, y::Multiline)
106106
end
107107

108108
LinearAlgebra.norm(x::Multiline) = sqrt(real(inner(x, x)))
109+
110+
# TensorKit
111+
#----------
112+
113+
site_type(::Type{Multiline{S}}) where {S} = site_type(S)
114+
bond_type(::Type{Multiline{S}}) where {S} = bond_type(S)
115+
site_type(st::Multiline) = site_type(typeof(st))
116+
bond_type(st::Multiline) = bond_type(typeof(st))
117+
TensorKit.sectortype(::Type{Multiline{T}}) where {T} = sectortype(T)
118+
TensorKit.spacetype(::Type{Multiline{T}}) where {T} = spacetype(T)
119+
TensorKit.storagetype(::Type{Multiline{T}}) where {T} = storagetype(T)

src/utility/utility.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ function check_length(a, b...)
123123
return L
124124
end
125125

126-
function fuser(::Type{T}, V1::S, V2::S) where {T, S <: IndexSpace}
127-
return isomorphism(T, fuse(V1 V2), V1 V2)
126+
function fuser(::Type{TorA}, V1::S, V2::S) where {TorA, S <: IndexSpace}
127+
return isomorphism(TorA, fuse(V1 V2), V1 V2)
128128
end
129129

130130
"""

0 commit comments

Comments
 (0)