@@ -3,33 +3,9 @@ module BlockTensorKitGPUArraysExt
33using BlockTensorKit, BlockArrays, GPUArrays, Strided
44using Strided: StridedViews
55using GPUArrays: KernelAbstractions
6- import BlockTensorKit: _full
76
87function KernelAbstractions. get_backend (BA:: BlockArrays.BlockArray{T, N, A} ) where {T, N, A <: AbstractArray{<:StridedView{T, N, <:AnyGPUArray}} }
98 return KernelAbstractions. get_backend (first (BA. blocks))
109end
1110
12- function BlockTensorKit. _full (A:: BM ) where {T <: Number , TA <: AnyGPUMatrix{T} , BM <: BlockMatrix{T, Matrix{TA}} }
13- arr = similar (first (A. blocks), size (A))
14- # TODO -- should we use Threads here to parallelize these
15- # transfers in streams if possible?
16- for block_index in Iterators. product (blockaxes (A)... )
17- indices = getindex .(axes (A), block_index)
18- arr[indices... ] = @view A[block_index... ]
19- end
20- return arr
21- end
22-
23- # awful piracy but defined here as BlockArrays doesn't support this well
24- function Base. copyto! (dest:: BM , src:: TA ) where {T <: Number , TA <: AnyGPUMatrix{T} , BM <: BlockMatrix{T, Matrix{TA}} }
25- # TODO -- should we use Threads here to parallelize these
26- # transfers in streams if possible?
27- for block_index in Iterators. product (blockaxes (dest)... )
28- indices = getindex .(axes (dest), block_index)
29- dest_view = @view dest[block_index... ]
30- dest_view = src[indices... ]
31- end
32- return dest
33- end
34-
3511end
0 commit comments