Skip to content

Commit c89afa3

Browse files
committed
move auxiliary code
1 parent 3f46738 commit c89afa3

3 files changed

Lines changed: 34 additions & 24 deletions

File tree

src/BlockTensorKit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import TensorOperations as TO
3737
import TupleTools as TT
3838
import MatrixAlgebraKit as MAK
3939

40+
include("auxiliary/blockarrays.jl")
41+
4042
# Spaces
4143
include("vectorspaces/sumspace.jl")
4244
include("vectorspaces/sumspaceindices.jl")

src/auxiliary/blockarrays.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
copy_dense(A) = copy_dense!(similar(first(A.blocks), size(A)), A)
2+
function copy_dense!(Adense, A)
3+
for bj in blockaxes(A, 2)
4+
js = axes(A, 2)[bj]
5+
for bi in blockaxes(A, 1)
6+
a = view(A, bi, bj)
7+
is = axes(A, 1)[bi]
8+
Adense[is, js] = @view A[block_index...]
9+
end
10+
end
11+
return Adense
12+
end
13+
14+
const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T}
15+
16+
function MAK.zero!(A::BlockBlasMat)
17+
for bj in blockaxes(A, 2), bi in blockaxes(A, 1)
18+
a = view(A, bi, bj)
19+
MAK.zero!(a)
20+
end
21+
return A
22+
end
23+
24+
function MAK.one!(A::BlockBlasMat)
25+
for bj in blockaxes(A, 2), bi in blockaxes(A, 1)
26+
a = view(A, bi, bj)
27+
bi == bj ? MAK.one!(a) : MAK.zero!(a)
28+
end
29+
return A
30+
end

src/linalg/factorizations.jl

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,6 @@ import MatrixAlgebraKit as MAK
44

55
# Type piracy for defining the MAK rules on BlockArrays!
66
# -----------------------------------------------------
7-
8-
const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T}
9-
10-
function MatrixAlgebraKit.zero!(A::BlockBlasMat)
11-
for bj in blockaxes(A, 2), bi in blockaxes(A, 1)
12-
a = view(A, bi, bj)
13-
MAK.zero!(a)
14-
end
15-
return A
16-
end
17-
18-
function MatrixAlgebraKit.one!(A::BlockBlasMat)
19-
for bj in blockaxes(A, 2), bi in blockaxes(A, 1)
20-
a = view(A, bi, bj)
21-
bi == bj ? MAK.one!(a) : MAK.zero!(a)
22-
end
23-
return A
24-
end
25-
26-
_full(A) = Array(A)
27-
287
for f in
298
[
309
:svd_compact, :svd_full, :svd_vals,
@@ -46,8 +25,7 @@ for f! in (
4625
)
4726
@eval function MAK.$f!(t::AbstractBlockTensorMap, F, alg::AbstractAlgorithm)
4827
TensorKit.foreachblock(t, F...) do _, (tblock, Fblocks...)
49-
full_block = _full(tblock)
50-
Fblocks′ = MAK.$f!(full_block, alg)
28+
Fblocks′ = MAK.$f!(copy_dense(tblock), alg)
5129
# deal with the case where the output is not in-place
5230
for (b′, b) in zip(Fblocks′, Fblocks)
5331
b === b′ || copy!(b, b′)
@@ -66,7 +44,7 @@ for f! in (
6644
)
6745
@eval function MAK.$f!(t::AbstractBlockTensorMap, N, alg::AbstractAlgorithm)
6846
TensorKit.foreachblock(t, N) do _, (tblock, Nblock)
69-
Nblock′ = MAK.$f!(_full(tblock), alg)
47+
Nblock′ = MAK.$f!(copy_dense(tblock), alg)
7048
# deal with the case where the output is not the same as the input
7149
Nblock === Nblock′ || copy!(Nblock, Nblock′)
7250
return nothing

0 commit comments

Comments
 (0)