Skip to content

Commit 1dd3234

Browse files
committed
Support BlockArrays in KernelAbstractions bcasting with new extension
1 parent 32e2092 commit 1dd3234

4 files changed

Lines changed: 19 additions & 6 deletions

File tree

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,19 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1717

1818
[weakdeps]
1919
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
20+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
2021

2122
[extensions]
2223
BlockTensorKitAdaptExt = "Adapt"
24+
BlockTensorKitGPUArraysExt = "GPUArrays"
2325

2426
[compat]
2527
Adapt = "4"
2628
Aqua = "0.8"
2729
BlockArrays = "1"
2830
Combinatorics = "1"
2931
Compat = "4.13"
32+
GPUArrays = "11.4.1"
3033
JLArrays = "0.3"
3134
LinearAlgebra = "1"
3235
MatrixAlgebraKit = "0.6"
@@ -45,6 +48,7 @@ julia = "1.10"
4548
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
4649
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4750
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
51+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
4852
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
4953
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5054
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
@@ -55,4 +59,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
5559
test = ["Test", "TestExtras", "Random", "Combinatorics", "SafeTestsets", "Aqua", "Adapt", "JLArrays"]
5660

5761
[sources]
58-
Strided = {url = "https://github.com/QuantumKitHub/Strided.jl", rev = "ksh/jlarrays"}
62+
Strided = {path = "/home/kshyatt/.julia/dev/Strided"}

ext/BlockTensorKitGPUArraysExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module BlockTensorKitGPUArraysExt
2+
3+
using BlockTensorKit, BlockArrays, GPUArrays, Strided
4+
using Strided: StridedViews
5+
using GPUArrays: KernelAbstractions
6+
7+
function KernelAbstractions.get_backend(BA::BlockArrays.BlockArray{T, N, A}) where {T, N, A<:AbstractArray{<:StridedView{T, N, <:AnyGPUArray}}}
8+
return KernelAbstractions.get_backend(first(BA.blocks))
9+
end
10+
11+
end

src/tensors/blocktensor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ function BlockTensorMap(t::AbstractTensorMap, space::TensorMapSumSpace)
114114
TT = tensormaptype(spacetype(t), numout(t), numin(t), storagetype(t))
115115
tdst = BlockTensorMap{TT}(undef, space)
116116
for (f₁, f₂) in fusiontrees(tdst)
117-
tdst[f₁, f₂] .= t[f₁, f₂]
117+
copy!(tdst[f₁, f₂], t[f₁, f₂])
118118
end
119119
return tdst
120120
end

test/abstracttensor/blocktensor.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using BlockTensorKit
55
using Random
66
using Combinatorics
77
using Adapt
8-
using Strided, JLArrays
8+
using JLArrays
99

1010
Vtr = (
1111
SumSpace(ℂ^3),
@@ -87,9 +87,7 @@ end
8787
jl_bt1 = rand(JLVector{T}, W)
8888
TT = TensorKit.TensorMap{T, spacetype(t1′), numout(t1′), numin(t1′), JLVector{T}}
8989
jl_bt1′ = @constinferred convert(TT, jl_bt1)
90-
JLArrays.@allowscalar begin
91-
jl_bt1″ = @inferred BlockTensorMap(jl_bt1′, W)
92-
end # still need some logic for copying to a BlockArray of StridedViews
90+
jl_bt1″ = @inferred BlockTensorMap(jl_bt1′, W)
9391
@test jl_bt1 jl_bt1″
9492
end
9593
end

0 commit comments

Comments
 (0)