Skip to content

Commit c82c506

Browse files
authored
Create JLArrays extension and tests for GPU mocking (#44)
* Create JLArrays extension and tests for GPU mocking * No sources * Add an AbstractArray option too
1 parent bd9ece4 commit c82c506

4 files changed

Lines changed: 44 additions & 2 deletions

File tree

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,21 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1010

1111
[weakdeps]
1212
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
13+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
1314
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1415
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1516

1617
[extensions]
1718
StridedAMDGPUExt = "AMDGPU"
19+
StridedJLArraysExt = "JLArrays"
1820
StridedGPUArraysExt = "GPUArrays"
1921
StridedCUDAExt = "CUDA"
2022

2123
[compat]
2224
AMDGPU = "2"
2325
Aqua = "0.8"
2426
CUDA = "5"
27+
JLArrays = "0.3.1"
2528
GPUArrays = "11.4.1"
2629
LinearAlgebra = "1.6"
2730
Random = "1.6"
@@ -34,9 +37,10 @@ julia = "1.6"
3437
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3538
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3639
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
40+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
3741
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
3842
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3943
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4044

4145
[targets]
42-
test = ["Test", "Random", "Aqua", "AMDGPU", "CUDA", "GPUArrays"]
46+
test = ["Test", "Random", "Aqua", "AMDGPU", "CUDA", "GPUArrays", "JLArrays"]

ext/StridedJLArraysExt.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module StridedJLArraysExt
2+
3+
using Strided, StridedViews, JLArrays
4+
using JLArrays: Adapt
5+
using JLArrays: GPUArrays
6+
7+
const ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}
8+
9+
function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TAD <: JLArray{TD}, FD <: ALL_FS, TS <: Number, NS, TAS <: JLArray{TS}, FS <: ALL_FS}
10+
bc_style = Base.Broadcast.BroadcastStyle(TAS)
11+
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
12+
GPUArrays._copyto!(dst, bc)
13+
return dst
14+
end
15+
16+
function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TS <: Number, NS, TAS <: JLArray{TS}, FS <: ALL_FS}
17+
bc_style = Base.Broadcast.BroadcastStyle(TAS)
18+
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
19+
GPUArrays._copyto!(dst, bc)
20+
return dst
21+
end
22+
23+
end

test/jlarrays.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
2+
@testset "Copy with JLArrayStridedView: $T, $f1, $f2" for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint)
3+
for m1 in (0, 16, 32), m2 in (0, 16, 32)
4+
A1 = JLArray(randn(T, (m1, m2)))
5+
A2 = similar(A1)
6+
A1c = copy(A1)
7+
A2c = copy(A2)
8+
B1 = f1(StridedView(A1c))
9+
B2 = f2(StridedView(A2c))
10+
axes(f1(A1)) == axes(f2(A2)) || continue
11+
@test collect(Matrix(copy!(f2(A2), f1(A1)))) == JLArrays.Adapt.adapt(Vector{T}, copy!(B2, B1))
12+
end
13+
end
14+
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Random
44
using Strided
55
using Strided: StridedView
66
using Aqua
7-
using AMDGPU, CUDA, GPUArrays
7+
using JLArrays, AMDGPU, CUDA, GPUArrays
88

99
Random.seed!(1234)
1010

@@ -28,6 +28,7 @@ if !is_buildkite
2828
include("blasmultests.jl")
2929
Strided.disable_threaded_mul()
3030

31+
include("jlarrays.jl")
3132
Aqua.test_all(Strided; piracies = false)
3233
end
3334

0 commit comments

Comments
 (0)