Skip to content

Commit cf8a145

Browse files
kshyattlkdvos
andauthored
Some more small changes for GPU support (#48)
* Updates for GPU compatibility for MPSKit * Fixes for BraidingTensor and svd * Formatter * Cleanup removeunit * Move over storagetype logic from TensorKit * Missing module qualifier * Undo ugly formatting changes * move auxiliary code * remove unnecessary auxiliary code * use similar_diagonal * revert tensorcontract_type changes * try and revert storagetype --------- Co-authored-by: lkdvos <ldevos98@gmail.com>
1 parent 0bd8ffa commit cf8a145

6 files changed

Lines changed: 36 additions & 27 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BlockTensorKit"
22
uuid = "5f87ffc2-9cf1-4a46-8172-465d160bd8cd"
3-
version = "0.3.9"
3+
version = "0.3.10"
44
authors = ["Lukas Devos <ldevos98@gmail.com> and contributors"]
55

66
[deps]

src/BlockTensorKit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import TensorOperations as TO
3737
import TupleTools as TT
3838
import MatrixAlgebraKit as MAK
3939

40+
include("auxiliary/blockarrays.jl")
41+
4042
# Spaces
4143
include("vectorspaces/sumspace.jl")
4244
include("vectorspaces/sumspaceindices.jl")

src/auxiliary/blockarrays.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
function copy_dense!(Adense, A)
2+
for block_index in Iterators.product(blockaxes(A)...)
3+
a = view(A, block_index...)
4+
indices = getindex.(axes(A), block_index)
5+
Adense[indices...] .= a
6+
end
7+
return Adense
8+
end
9+
10+
const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T}
11+
12+
function MAK.zero!(A::BlockBlasMat)
13+
for bj in blockaxes(A, 2), bi in blockaxes(A, 1)
14+
a = view(A, bi, bj)
15+
MAK.zero!(a)
16+
end
17+
return A
18+
end
19+
20+
function MAK.one!(A::BlockBlasMat)
21+
for bj in blockaxes(A, 2), bi in blockaxes(A, 1)
22+
a = view(A, bi, bj)
23+
bi == bj ? MAK.one!(a) : MAK.zero!(a)
24+
end
25+
return A
26+
end

src/linalg/factorizations.jl

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,6 @@ import MatrixAlgebraKit as MAK
44

55
# Type piracy for defining the MAK rules on BlockArrays!
66
# -----------------------------------------------------
7-
8-
const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T}
9-
10-
function MatrixAlgebraKit.zero!(A::BlockBlasMat)
11-
for bj in blockaxes(A, 2), bi in blockaxes(A, 1)
12-
a = view(A, bi, bj)
13-
MAK.zero!(a)
14-
end
15-
return A
16-
end
17-
18-
function MatrixAlgebraKit.one!(A::BlockBlasMat)
19-
for bj in blockaxes(A, 2), bi in blockaxes(A, 1)
20-
a = view(A, bi, bj)
21-
bi == bj ? MAK.one!(a) : MAK.zero!(a)
22-
end
23-
return A
24-
end
25-
267
for f in
278
[
289
:svd_compact, :svd_full, :svd_vals,
@@ -44,7 +25,7 @@ for f! in (
4425
)
4526
@eval function MAK.$f!(t::AbstractBlockTensorMap, F, alg::AbstractAlgorithm)
4627
TensorKit.foreachblock(t, F...) do _, (tblock, Fblocks...)
47-
Fblocks′ = MAK.$f!(Array(tblock), alg)
28+
Fblocks′ = MAK.$f!(copy_dense!(similar(tblock, size(tblock)), tblock), alg)
4829
# deal with the case where the output is not in-place
4930
for (b′, b) in zip(Fblocks′, Fblocks)
5031
b === b′ || copy!(b, b′)
@@ -63,7 +44,7 @@ for f! in (
6344
)
6445
@eval function MAK.$f!(t::AbstractBlockTensorMap, N, alg::AbstractAlgorithm)
6546
TensorKit.foreachblock(t, N) do _, (tblock, Nblock)
66-
Nblock′ = MAK.$f!(Array(tblock), alg)
47+
Nblock′ = MAK.$f!(copy_dense!(similar(tblock, size(tblock)), tblock), alg)
6748
# deal with the case where the output is not the same as the input
6849
Nblock === Nblock′ || copy!(Nblock, Nblock′)
6950
return nothing
@@ -144,15 +125,15 @@ end
144125
function MAK.initialize_output(::typeof(eigh_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
145126
V_D = (fuse(domain(t)))
146127
T = real(scalartype(t))
147-
D = DiagonalTensorMap{T}(undef, V_D)
128+
D = TK.similar_diagonal(t, T, V_D)
148129
V = dense_similar(t, codomain(t) V_D)
149130
return D, V
150131
end
151132

152133
function MAK.initialize_output(::typeof(eig_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
153134
V_D = (fuse(domain(t)))
154135
Tc = complex(scalartype(t))
155-
D = DiagonalTensorMap{Tc}(undef, V_D)
136+
D = TK.similar_diagonal(t, Tc, V_D)
156137
V = dense_similar(t, Tc, codomain(t) V_D)
157138
return D, V
158139
end
@@ -168,7 +149,7 @@ end
168149
function MAK.initialize_output(::typeof(svd_compact!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
169150
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
170151
U = dense_similar(t, codomain(t) V_cod)
171-
S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod)
152+
S = TK.similar_diagonal(t, real(scalartype(t)), V_cod)
172153
Vᴴ = dense_similar(t, V_dom domain(t))
173154
return U, S, Vᴴ
174155
end

src/tensors/abstractblocktensor/abstractarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ Base.eltypeof(t::AbstractBlockTensorMap) = eltype(t)
328328
) where {T <: AbstractTensorMap}
329329
catdims = Base.dims2cat(dims)
330330
V = space(Base._cat(dims, eachspace.(ts)...))
331-
A = similar(ts[1], T, V)
331+
A = similar(ts[1], TK.storagetype(ts[1]), V)
332332
shape = size(A)
333333
if count(!iszero, catdims)::Int > 1
334334
zerovector!(A)

src/tensors/abstractblocktensor/abstracttensormap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,4 @@ function Base.iterate(iter::TK.BlockIterator{<:AbstractBlockTensorMap}, state...
106106
end
107107
Base.getindex(iter::TK.BlockIterator{<:AbstractBlockTensorMap}, c::Sector) = block(iter.t, c)
108108

109-
TensorKit.storagetype(::Type{TT}) where {TT <: AbstractBlockTensorMap} = storagetype(eltype(TT))
109+
TK.storagetype(::Type{TT}) where {TT <: AbstractBlockTensorMap} = storagetype(eltype(TT))

0 commit comments

Comments
 (0)