Skip to content

Commit 0c58d9e

Browse files
committed
remove unnecessary auxiliary code
1 parent c89afa3 commit 0c58d9e

1 file changed

Lines changed: 0 additions & 24 deletions

File tree

ext/BlockTensorKitGPUArraysExt.jl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,9 @@ module BlockTensorKitGPUArraysExt
33
using BlockTensorKit, BlockArrays, GPUArrays, Strided
44
using Strided: StridedViews
55
using GPUArrays: KernelAbstractions
6-
import BlockTensorKit: _full
76

87
function KernelAbstractions.get_backend(BA::BlockArrays.BlockArray{T, N, A}) where {T, N, A <: AbstractArray{<:StridedView{T, N, <:AnyGPUArray}}}
98
return KernelAbstractions.get_backend(first(BA.blocks))
109
end
1110

12-
function BlockTensorKit._full(A::BM) where {T <: Number, TA <: AnyGPUMatrix{T}, BM <: BlockMatrix{T, Matrix{TA}}}
13-
arr = similar(first(A.blocks), size(A))
14-
# TODO -- should we use Threads here to parallelize these
15-
# transfers in streams if possible?
16-
for block_index in Iterators.product(blockaxes(A)...)
17-
indices = getindex.(axes(A), block_index)
18-
arr[indices...] = @view A[block_index...]
19-
end
20-
return arr
21-
end
22-
23-
# awful piracy but defined here as BlockArrays doesn't support this well
24-
function Base.copyto!(dest::BM, src::TA) where {T <: Number, TA <: AnyGPUMatrix{T}, BM <: BlockMatrix{T, Matrix{TA}}}
25-
# TODO -- should we use Threads here to parallelize these
26-
# transfers in streams if possible?
27-
for block_index in Iterators.product(blockaxes(dest)...)
28-
indices = getindex.(axes(dest), block_index)
29-
dest_view = @view dest[block_index...]
30-
dest_view = src[indices...]
31-
end
32-
return dest
33-
end
34-
3511
end

0 commit comments

Comments
 (0)