Skip to content

Commit 53c46bf

Browse files
authored
Add some mul overrides too (#46)
* Add some mul overrides too * Cleanup imports * Fix and test for fill * Run the tests on GPU too
1 parent c4025a3 commit 53c46bf

6 files changed

Lines changed: 33 additions & 4 deletions

File tree

ext/StridedGPUArraysExt.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module StridedGPUArraysExt
22

33
using Strided, GPUArrays
44
using GPUArrays: Adapt, KernelAbstractions
5+
using GPUArrays.KernelAbstractions: @kernel, @index
56

67
ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}
78

@@ -19,4 +20,27 @@ function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS
1920
return dst
2021
end
2122

23+
# lifted from GPUArrays.jl
24+
function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractGPUArray{T}, F <: ALL_FS}
25+
isempty(A) && return A
26+
@kernel function fill_kernel!(a, val)
27+
idx = @index(Global, Linear)
28+
@inbounds a[idx] = val
29+
end
30+
# ndims check for 0D support
31+
kernel = fill_kernel!(KernelAbstractions.get_backend(A))
32+
f_x = F <: Union{typeof(conj), typeof(adjoint)} ? conj(x) : x
33+
kernel(A, f_x; ndrange = length(A))
34+
return A
35+
end
36+
37+
function Strided.__mul!(
38+
C::StridedView{TC, 2, <:AnyGPUArray{TC}},
39+
A::StridedView{TA, 2, <:AnyGPUArray{TA}},
40+
B::StridedView{TB, 2, <:AnyGPUArray{TB}},
41+
α::Number, β::Number
42+
) where {TC, TA, TB}
43+
return GPUArrays.generic_matmatmul!(C, A, B, α, β)
44+
end
45+
2246
end

ext/StridedJLArraysExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@ module StridedJLArraysExt
22

33
using Strided, StridedViews, JLArrays
44
using JLArrays: Adapt
5-
using JLArrays: GPUArrays
65

76
const ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}
87

98
function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TAD <: JLArray{TD}, FD <: ALL_FS, TS <: Number, NS, TAS <: JLArray{TS}, FS <: ALL_FS}
109
bc_style = Base.Broadcast.BroadcastStyle(TAS)
1110
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
12-
GPUArrays._copyto!(dst, bc)
11+
JLArrays.GPUArrays._copyto!(dst, bc)
1312
return dst
1413
end
1514

test/amd.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
1616
axes(f1(A1)) == axes(f2(A2)) || continue
1717
@test collect(ROCMatrix(copy!(f2(A2), f1(A1)))) == AMDGPU.Adapt.adapt(Vector{T}, copy!(B2, B1))
1818
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
19+
x = rand(T)
20+
@test f1(StridedView(AMDGPU.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == AMDGPU.Adapt.adapt(Vector{T}, fill!(B1, x))
1921
end
2022
end
2123
end

test/cuda.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
1212
axes(f1(A1)) == axes(f2(A2)) || continue
1313
@test collect(CuMatrix(copy!(f2(A2), f1(A1)))) == CUDA.Adapt.adapt(Vector{T}, copy!(B2, B1))
1414
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
15+
x = rand(T)
16+
@test f1(StridedView(CUDA.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == CUDA.Adapt.adapt(Vector{T}, fill!(B1, x))
1517
end
1618
end
1719
end

test/jlarrays.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
1+
@testset for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
22
@testset "Copy with JLArrayStridedView: $T, $f1, $f2" for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint)
33
for m1 in (0, 16, 32), m2 in (0, 16, 32)
44
A1 = JLArray(randn(T, (m1, m2)))
@@ -12,6 +12,8 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
1212
axes(f1(A1)) == axes(f2(A2)) || continue
1313
@test collect(Matrix(copy!(f2(A2), f1(A1)))) == JLArrays.Adapt.adapt(Vector{T}, copy!(B2, B1))
1414
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
15+
x = rand(T)
16+
@test f1(StridedView(JLArrays.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == JLArrays.Adapt.adapt(Vector{T}, fill!(B1, x))
1517
end
1618
end
1719
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Random.seed!(1234)
1111
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1212

1313
if !is_buildkite
14+
include("jlarrays.jl")
1415
println("Base.Threads.nthreads() = $(Base.Threads.nthreads())")
1516

1617
println("Running tests single-threaded:")
@@ -28,7 +29,6 @@ if !is_buildkite
2829
include("blasmultests.jl")
2930
Strided.disable_threaded_mul()
3031

31-
include("jlarrays.jl")
3232
Aqua.test_all(Strided; piracies = false)
3333
end
3434

0 commit comments

Comments
 (0)