diff --git a/ext/AMDGPUExt.jl b/ext/AMDGPUExt.jl index 912525bff..a79af33f7 100644 --- a/ext/AMDGPUExt.jl +++ b/ext/AMDGPUExt.jl @@ -2,22 +2,10 @@ module AMDGPUExt import MPI isdefined(Base, :get_extension) ? (import AMDGPU) : (import ..AMDGPU) -import MPI: MPIPtr, Buffer, Datatype +import MPI: MPIPtr, Buffer, Datatype, CConvWrapper -function Base.cconvert(::Type{MPIPtr}, A::AMDGPU.ROCArray{T}) where T - A -end - -function Base.unsafe_convert(::Type{MPIPtr}, X::AMDGPU.ROCArray{T}) where T - reinterpret(MPIPtr, Base.unsafe_convert(Ptr{T}, X)) -end - -# only need to define this for strided arrays: all others can be handled by generic machinery -function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:AMDGPU.ROCArray,I} - X = parent(V) - pX = Base.unsafe_convert(Ptr{T}, X) - pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T) - return reinterpret(MPIPtr, pV) +function Base.cconvert(::Type{MPIPtr}, x::AMDGPU.ROCArray{T}) where T + CConvWrapper(Ptr{T}, x) end function Buffer(arr::AMDGPU.ROCArray) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index f86dcd18a..ffafdfc60 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -2,22 +2,14 @@ module CUDAExt import MPI isdefined(Base, :get_extension) ? (import CUDA) : (import ..CUDA) -import MPI: MPIPtr, Buffer, Datatype +import MPI: MPIPtr, Buffer, Datatype, CConvWrapper function Base.cconvert(::Type{MPIPtr}, buf::CUDA.CuArray{T}) where T - Base.cconvert(CUDA.CuPtr{T}, buf) # returns DeviceBuffer + CConvWrapper(CUDA.CuPtr{T}, buf) end -function Base.unsafe_convert(::Type{MPIPtr}, X::CUDA.CuArray{T}) where T - reinterpret(MPIPtr, Base.unsafe_convert(CUDA.CuPtr{T}, X)) -end - -# only need to define this for strided arrays: all others can be handled by generic machinery -function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:CUDA.CuArray,I} - X = parent(V) - pX = Base.unsafe_convert(CUDA.CuPtr{T}, X) - pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T) - return reinterpret(MPIPtr, pV) +function Base.cconvert(::Type{MPIPtr}, buf::SubArray{T,N,P,I,true}) where {T,N,P<:CUDA.CuArray,I} + CConvWrapper(CUDA.CuPtr{T}, buf) end function Buffer(arr::CUDA.CuArray) diff --git a/src/api/api.jl b/src/api/api.jl index 5b731c3b6..6fccd3f1f 100644 --- a/src/api/api.jl +++ b/src/api/api.jl @@ -76,9 +76,7 @@ end primitive type MPIPtr Sys.WORD_SIZE end @assert sizeof(MPIPtr) == sizeof(Ptr{Cvoid}) -Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = x -Base.unsafe_convert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x) - +Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x) # Initialize the ref constants from the library. # This is not `API.__init__`, as it should be called _after_ diff --git a/src/buffers.jl b/src/buffers.jl index 57d848d6c..071c3f2d0 100644 --- a/src/buffers.jl +++ b/src/buffers.jl @@ -1,16 +1,73 @@ MPIBuffertype{T} = Union{Ptr{T}, Array{T}, SubArray{T}, Ref{T}} MPIBuffertypeOrConst{T} = Union{MPIBuffertype{T}, SentinelPtr} -Base.cconvert(::Type{MPIPtr}, x::Union{Ptr{T}, Array{T}, Ref{T}}) where T = Base.cconvert(Ptr{T}, x) -Base.cconvert(::Type{MPIPtr}, x::SubArray{T}) where T = Base.cconvert(Ptr{T}, x) -function Base.unsafe_convert(::Type{MPIPtr}, x::MPIBuffertype{T}) where T - ptr = Base.unsafe_convert(Ptr{T}, x) +# CConvWrapper: GC-safe adapter for converting Julia objects to MPIPtr in ccall. +# +# Background: ccall's argument conversion protocol works in two steps: +# 1. cconvert(T, x) — called before the ccall. Its return value is GC-rooted +# by ccall for the duration of the foreign call, keeping the underlying +# Julia object alive while a pointer to it is in use. +# 2. unsafe_convert(T, result_of_cconvert) — called on the GC-rooted result +# to extract the raw pointer. Crucially, dispatch is on the *return type* +# of cconvert, not the original argument type. +# +# Problem: because unsafe_convert dispatches on the cconvert return type, the +# unsafe_convert(::Type{MPIPtr}, ...) method must match whatever cconvert +# returned. If cconvert delegates to e.g. Base.cconvert(Ptr{T}, x), the return +# type depends on the Base implementation, so an unsafe_convert method written +# for the original type will never be called. +# +# Solution: CConvWrapper provides a single, predictable return type from +# cconvert(MPIPtr, x). The conversion proceeds as: +# +# ccall argument x::Array{Float64} +# │ +# ▼ +# cconvert(MPIPtr, x) +# calls Base.cconvert(Ptr{Float64}, x) — returns the Array (kept alive) +# wraps it in CConvWrapper{Ptr{Float64}}(array) +# ◄── ccall GC-roots this CConvWrapper, which holds the Array +# │ +# ▼ +# unsafe_convert(MPIPtr, wrapper::CConvWrapper{Ptr{Float64}}) +# calls Base.unsafe_convert(Ptr{Float64}, wrapper.cconv) — extracts raw ptr +# reinterprets to MPIPtr +# ◄── only called while ccall holds the GC root on the wrapper +# +# Types that don't need GC protection (Ptr, Nothing, InPlace, SentinelPtr) skip +# the wrapper and return an MPIPtr directly from cconvert, since they are plain +# bit types with no GC-managed backing memory. +struct CConvWrapper{T, C} + # T: the intermediate pointer type (e.g. Ptr{Float64}, CuPtr{Float64}) + # C: the type of the GC-rooted cconvert result (e.g. Array{Float64,1}) + cconv::C # the GC-rooted object — kept alive by ccall holding the wrapper +end +function CConvWrapper(::Type{T}, x) where T + # Delegate to Base.cconvert(T, x) to get the GC-rootable object, then wrap + # it so unsafe_convert dispatch is predictable. + cconv = Base.cconvert(T, x) + CConvWrapper{T, typeof(cconv)}(cconv) +end + +function Base.unsafe_convert(::Type{MPIPtr}, x::CConvWrapper{T}) where T + # Called by ccall while x (and thus x.cconv) is GC-rooted. + # Delegate to the Base pointer extraction, then reinterpret to MPIPtr. + ptr = Base.unsafe_convert(T, x.cconv) reinterpret(MPIPtr, ptr) end +# --- cconvert methods for types with GC-managed memory (use CConvWrapper) --- + +function Base.cconvert(::Type{MPIPtr}, x::Union{Array{T}, SubArray{T}, Ref{T}}) where T + CConvWrapper(Ptr{T}, x) +end +function Base.cconvert(::Type{MPIPtr}, x::String) + CConvWrapper(Ptr{UInt8}, x) +end + +# --- cconvert methods for plain bit types (no GC protection needed) --- -Base.cconvert(::Type{MPIPtr}, x::String) = x -Base.unsafe_convert(::Type{MPIPtr}, x::String) = reinterpret(MPIPtr, pointer(x)) +Base.cconvert(::Type{MPIPtr}, ptr::Ptr) = reinterpret(MPIPtr, ptr) Base.cconvert(::Type{MPIPtr}, ::Nothing) = reinterpret(MPIPtr, C_NULL) @@ -45,7 +102,7 @@ MPIPtr struct InPlace end -Base.cconvert(::Type{MPIPtr}, ::InPlace) = API.MPI_IN_PLACE[] +Base.cconvert(::Type{MPIPtr}, ::InPlace) = reinterpret(MPIPtr, API.MPI_IN_PLACE[]) """