Skip to content

Commit bcb35e9

Browse files
committed
Don't duplicate Adapt methods
1 parent ddccde8 commit bcb35e9

3 files changed

Lines changed: 4 additions & 11 deletions

File tree

ext/StridedAMDGPUExt.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,8 @@ using AMDGPU: GPUArrays
66

77
const ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}
88

9-
function Adapt.adapt_storage(to::AMDGPU.Runtime.Adaptor, xs::StridedView{T,N,TA,F}) where {T,N,TA<:ROCArray{T},F <: ALL_FS}
10-
return StridedView(Adapt.adapt(to, parent(xs)), xs.size, xs.strides, xs.offset, xs.op)
11-
end
12-
139
function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS, TAS, FS}) where {TD, ND, TAD <: ROCArray{TD}, FD <: ALL_FS, TS, NS, TAS <: ROCArray{TS}, FS <: ALL_FS}
14-
bc_style = Base.Broadcast.BroadcastStyle(TAS)
10+
bc_style = Base.Broadcast.BroadcastStyle(TAS)
1511
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
1612
GPUArrays._copyto!(dst, bc)
1713
return dst

ext/StridedCUDAExt.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
module StridedCUDAExt
22

33
using Strided, CUDA
4+
using Strided: StridedViews
45
using CUDA: Adapt, KernelAdaptor
56
using CUDA: GPUArrays
67

78
const ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}
89

9-
function Adapt.adapt_storage(to::KernelAdaptor, xs::StridedView{T,N,TA,F}) where {T,N,TA<:CuArray{T},F <: ALL_FS}
10-
return StridedView(Adapt.adapt(to, parent(xs)), xs.size, xs.strides, xs.offset, xs.op)
11-
end
12-
1310
function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS, TAS, FS}) where {TD, ND, TAD <: CuArray{TD}, FD <: ALL_FS, TS, NS, TAS <: CuArray{TS}, FS <: ALL_FS}
14-
bc_style = Base.Broadcast.BroadcastStyle(TAS)
11+
bc_style = Base.Broadcast.BroadcastStyle(TAS)
1512
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
1613
GPUArrays._copyto!(dst, bc)
1714
return dst

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ if !is_buildkite
2828
include("blasmultests.jl")
2929
Strided.disable_threaded_mul()
3030

31-
Aqua.test_all(Strided; piracies=false)
31+
Aqua.test_all(Strided; piracies = false)
3232
end
3333

3434
if CUDA.functional()

0 commit comments

Comments
 (0)