|
| 1 | +# oneMKL FFT (DFT) high-level Julia interface |
| 2 | +# Inspired by AMDGPU ROCFFT interface style, adapted to oneMKL DFT C wrapper. |
| 3 | + |
| 4 | +module FFT |
| 5 | + |
| 6 | +using ..oneMKL |
| 7 | +using ..oneMKL: oneAPI, SYCL, syclQueue_t |
| 8 | +using ..Support |
| 9 | +using ..SYCL |
| 10 | +using LinearAlgebra |
| 11 | +using GPUArrays |
| 12 | +using AbstractFFTs |
| 13 | +import AbstractFFTs: complexfloat, realfloat |
| 14 | +import AbstractFFTs: plan_fft, plan_fft!, plan_bfft, plan_bfft! |
| 15 | +import AbstractFFTs: plan_rfft, plan_brfft, plan_inv, normalization |
| 16 | +import AbstractFFTs: fft, bfft, ifft, rfft, Plan, ScaledPlan |
| 17 | +export MKLFFTPlan |
| 18 | + |
| 19 | +# Low-level enums mirroring C API (subset) |
| 20 | +# (We can just re-use integer constants; C wrappers return 0 on success.) |
| 21 | +const DFT_PREC_SINGLE = 0 |
| 22 | +const DFT_PREC_DOUBLE = 1 |
| 23 | +const DFT_DOM_REAL = 0 |
| 24 | +const DFT_DOM_COMPLEX = 1 |
| 25 | + |
| 26 | +# Placement values |
| 27 | +const DFT_PARAM_DIMENSION = 1 |
| 28 | +const DFT_PARAM_LENGTHS = 2 |
| 29 | +const DFT_PARAM_PRECISION = 3 |
| 30 | +const DFT_PARAM_FORWARD_SCALE = 4 |
| 31 | +const DFT_PARAM_BACKWARD_SCALE = 5 |
| 32 | + |
| 33 | +# Opaque descriptor type alias to Ptr{Nothing} (generated wrapper not yet exposed) |
| 34 | +# We'll declare ccall prototypes manually until generator exposes them. |
| 35 | + |
| 36 | +# NOTE: The liboneapi_support.jl generated file currently doesn't have DFT entries; add manual ccalls. |
| 37 | +const lib = :liboneapi_support |
| 38 | + |
| 39 | +# Allow implicit conversion of SYCL queue object to raw handle when storing/passing |
| 40 | +Base.convert(::Type{syclQueue_t}, q::SYCL.syclQueue) = Base.unsafe_convert(syclQueue_t, q) |
| 41 | + |
| 42 | +# Creation / destruction |
| 43 | +ccall_create1d(desc_ref, prec::Int32, dom::Int32, length::Int64) = ccall((:onemklDftCreate1D, lib), Cint, (Ref{Ptr{Cvoid}}, Cint, Cint, Int64), desc_ref, prec, dom, length) |
| 44 | +ccall_creatend(desc_ref, prec::Int32, dom::Int32, dim::Int64, lengths::Ptr{Int64}) = ccall((:onemklDftCreateND, lib), Cint, (Ref{Ptr{Cvoid}}, Cint, Cint, Int64, Ptr{Int64}), desc_ref, prec, dom, dim, lengths) |
| 45 | +ccall_destroy(desc) = ccall((:onemklDftDestroy, lib), Cint, (Ptr{Cvoid},), desc) |
| 46 | +ccall_commit(desc, q) = ccall((:onemklDftCommit, lib), Cint, (Ptr{Cvoid}, syclQueue_t), desc, q) |
| 47 | +ccall_fwd(desc, ptr) = ccall((:onemklDftComputeForward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr) |
| 48 | +ccall_fwd_oop(desc, pin, pout) = ccall((:onemklDftComputeForwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout) |
| 49 | +ccall_bwd(desc, ptr) = ccall((:onemklDftComputeBackward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr) |
| 50 | +ccall_bwd_oop(desc, pin, pout) = ccall((:onemklDftComputeBackwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout) |
| 51 | +ccall_set_double(desc, param::Int32, value::Float64) = ccall((:onemklDftSetValueDouble, lib), Cint, (Ptr{Cvoid}, Cint, Float64), desc, param, value) |
| 52 | + |
| 53 | +abstract type MKLFFTPlan{T,K,inplace} <: AbstractFFTs.Plan{T} end |
| 54 | + |
| 55 | +Base.eltype(::MKLFFTPlan{T}) where T = T |
| 56 | +is_inplace(::MKLFFTPlan{<:Any,<:Any,inplace}) where inplace = inplace |
| 57 | + |
| 58 | +# Forward / inverse flags |
| 59 | +const MKLFFT_FORWARD = true |
| 60 | +const MKLFFT_INVERSE = false |
| 61 | + |
| 62 | +mutable struct cMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace} |
| 63 | + handle::Ptr{Cvoid} |
| 64 | + queue::syclQueue_t |
| 65 | + sz::NTuple{N,Int} |
| 66 | + osz::NTuple{N,Int} |
| 67 | + realdomain::Bool |
| 68 | + region::NTuple{R,Int} |
| 69 | + buffer::B |
| 70 | + pinv::Any |
| 71 | +end |
| 72 | + |
| 73 | +# Real transforms use separate struct (mirroring AMDGPU style) for buffer staging |
| 74 | +mutable struct rMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace} |
| 75 | + handle::Ptr{Cvoid} |
| 76 | + queue::syclQueue_t |
| 77 | + sz::NTuple{N,Int} |
| 78 | + osz::NTuple{N,Int} |
| 79 | + xtype::Symbol |
| 80 | + region::NTuple{R,Int} |
| 81 | + buffer::B |
| 82 | + pinv::Any |
| 83 | +end |
| 84 | + |
| 85 | +# Inverse plan constructors (derive from existing plan) |
| 86 | +function plan_inv(p::cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B} |
| 87 | + q = cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p) |
| 88 | + p.pinv = q |
| 89 | + q |
| 90 | +end |
| 91 | +function plan_inv(p::cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B} |
| 92 | + q = cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p) |
| 93 | + p.pinv = q |
| 94 | + q |
| 95 | +end |
| 96 | + |
| 97 | +function plan_inv(p::rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B} |
| 98 | + # forward real -> inverse complex->real (brfft) |
| 99 | + q = rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:brfft,p.region,p.buffer,p) |
| 100 | + p.pinv = q |
| 101 | + q |
| 102 | +end |
| 103 | +function plan_inv(p::rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B} |
| 104 | + # inverse real -> forward real (rfft) |
| 105 | + q = rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:rfft,p.region,p.buffer,p) |
| 106 | + p.pinv = q |
| 107 | + q |
| 108 | +end |
| 109 | + |
| 110 | +function Base.show(io::IO, p::MKLFFTPlan{T,K,inplace}) where {T,K,inplace} |
| 111 | + print(io, inplace ? "oneMKL FFT in-place " : "oneMKL FFT ", K ? "forward" : "inverse", " plan for ") |
| 112 | + if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end |
| 113 | + print(io, " oneArray of ", T) |
| 114 | +end |
| 115 | + |
| 116 | +# Plan constructors |
| 117 | +function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool; normalize=true) where N |
| 118 | + prec = T<:Float64 || T<:ComplexF64 ? DFT_PREC_DOUBLE : DFT_PREC_SINGLE |
| 119 | + dom = complex ? DFT_DOM_COMPLEX : DFT_DOM_REAL |
| 120 | + desc_ref = Ref{Ptr{Cvoid}}() |
| 121 | + lengths = collect(Int64, sz) |
| 122 | + iprec = Int32(prec); idom = Int32(dom) |
| 123 | + st = length(lengths) == 1 ? ccall_create1d(desc_ref, iprec, idom, lengths[1]) : ccall_creatend(desc_ref, iprec, idom, length(lengths), pointer(lengths)) |
| 124 | + st == 0 || error("onemkl DFT create failed (status $st)") |
| 125 | + desc = desc_ref[] |
| 126 | + # Set scaling so that forward is unscaled, inverse multiplies by 1/volume (AbstractFFTs convention) |
| 127 | + if normalize |
| 128 | + vol = prod(sz) |
| 129 | + stfs = ccall_set_double(desc, Int32(DFT_PARAM_FORWARD_SCALE), 1.0) |
| 130 | + stfs == 0 || error("set forward scale failed ($stfs)") |
| 131 | + stbs = ccall_set_double(desc, Int32(DFT_PARAM_BACKWARD_SCALE), 1.0/vol) |
| 132 | + stbs == 0 || error("set backward scale failed ($stbs)") |
| 133 | + end |
| 134 | + # Construct a SYCL queue from current Level Zero context/device (reuse global queue) |
| 135 | + ze_ctx = oneAPI.context(); ze_dev = oneAPI.device() |
| 136 | + sycl_dev = SYCL.syclDevice(SYCL.syclPlatform(oneAPI.driver()), ze_dev) |
| 137 | + sycl_ctx = SYCL.syclContext([sycl_dev], ze_ctx) |
| 138 | + q = SYCL.syclQueue(sycl_ctx, sycl_dev, oneAPI.global_queue(ze_ctx, ze_dev)) |
| 139 | + stc = ccall_commit(desc, q) |
| 140 | + stc == 0 || error("onemkl DFT commit failed (status $stc)") |
| 141 | + return desc, q |
| 142 | +end |
| 143 | + |
| 144 | +# Complex plans |
| 145 | +function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N} |
| 146 | + R = length(region); reg = NTuple{R,Int}(region) |
| 147 | + desc, q = _create_descriptor(size(X), T, true) |
| 148 | + return cMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing) |
| 149 | +end |
| 150 | +function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N} |
| 151 | + R = length(region); reg = NTuple{R,Int}(region) |
| 152 | + desc, q = _create_descriptor(size(X), T, true) |
| 153 | + return cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing) |
| 154 | +end |
| 155 | + |
| 156 | +# In-place (provide separate methods) |
| 157 | +function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N} |
| 158 | + R = length(region); reg = NTuple{R,Int}(region) |
| 159 | + desc,q = _create_descriptor(size(X),T,true) |
| 160 | + cMKLFFTPlan{T,MKLFFT_FORWARD,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing) |
| 161 | +end |
| 162 | +function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N} |
| 163 | + R = length(region); reg = NTuple{R,Int}(region) |
| 164 | + desc,q = _create_descriptor(size(X),T,true) |
| 165 | + cMKLFFTPlan{T,MKLFFT_INVERSE,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing) |
| 166 | +end |
| 167 | + |
| 168 | +# Real forward (out-of-place) |
| 169 | +function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N} |
| 170 | + R = length(region); reg = NTuple{R,Int}(region) |
| 171 | + desc,q = _create_descriptor(size(X),T,false) |
| 172 | + xdims = size(X) |
| 173 | + # output last region dim becomes N/2+1 (assume first region index) |
| 174 | + ax = reg[1] |
| 175 | + ydims = Base.setindex(xdims, div(xdims[ax],2)+1, ax) |
| 176 | + buffer = oneAPI.oneArray{Complex{T}}(undef, ydims) |
| 177 | + rMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:rfft,reg,buffer,nothing) |
| 178 | +end |
| 179 | + |
| 180 | +# Real inverse (complex->real) requires complex input shape |
| 181 | +function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union{ComplexF32,ComplexF64},N} |
| 182 | + R = length(region); reg = NTuple{R,Int}(region) |
| 183 | + # output real size 'd' along region[1] |
| 184 | + xdims = size(X) |
| 185 | + ydims = Base.setindex(xdims, d, reg[1]) |
| 186 | + # Extract underlying real type R from Complex{R} |
| 187 | + @assert T <: Complex |
| 188 | + RT = T.parameters[1] |
| 189 | + desc,q = _create_descriptor(ydims, RT, false) |
| 190 | + buffer = oneAPI.oneArray{T}(undef, xdims) # copy for safety |
| 191 | + rMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:brfft,reg,buffer,nothing) |
| 192 | +end |
| 193 | + |
| 194 | +# Convenience no-region methods use all dimensions in order |
| 195 | +plan_fft(X::oneAPI.oneArray) = plan_fft(X, ntuple(identity, ndims(X))) |
| 196 | +plan_bfft(X::oneAPI.oneArray) = plan_bfft(X, ntuple(identity, ndims(X))) |
| 197 | +plan_fft!(X::oneAPI.oneArray) = plan_fft!(X, ntuple(identity, ndims(X))) |
| 198 | +plan_bfft!(X::oneAPI.oneArray) = plan_bfft!(X, ntuple(identity, ndims(X))) |
| 199 | +plan_rfft(X::oneAPI.oneArray) = plan_rfft(X, (1,)) # default first dim like Base.rfft |
| 200 | +plan_brfft(X::oneAPI.oneArray, d::Integer) = plan_brfft(X, d, (1,)) |
| 201 | + |
| 202 | +# Alias names to mirror AMDGPU / AbstractFFTs style |
| 203 | +const plan_ifft = plan_bfft |
| 204 | +const plan_ifft! = plan_bfft! |
| 205 | +const plan_irfft = plan_brfft |
| 206 | + |
| 207 | +# Inversion |
| 208 | +Base.inv(p::MKLFFTPlan) = plan_inv(p) |
| 209 | + |
| 210 | +# High-level wrappers operating like CPU FFTW versions. |
| 211 | +function fft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}} |
| 212 | + (plan_fft(X) * X) |
| 213 | +end |
| 214 | +function ifft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}} |
| 215 | + (plan_bfft(X) * X) |
| 216 | +end |
| 217 | +function fft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}} |
| 218 | + (plan_fft!(X) * X; X) |
| 219 | +end |
| 220 | +function ifft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}} |
| 221 | + (plan_bfft!(X) * X; X) |
| 222 | +end |
| 223 | +function rfft(X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}} |
| 224 | + (plan_rfft(X) * X) |
| 225 | +end |
| 226 | +function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T<:Union{ComplexF32,ComplexF64}} |
| 227 | + (plan_brfft(X, d) * X) |
| 228 | +end |
| 229 | + |
| 230 | +# Execution helpers |
| 231 | +_rawptr(a::oneAPI.oneArray{T}) where T = reinterpret(Ptr{Cvoid}, pointer(a)) |
| 232 | + |
| 233 | +function _exec!(p::cMKLFFTPlan{T,MKLFFT_FORWARD,true}, X::oneAPI.oneArray{T}) where T |
| 234 | + st = ccall_fwd(p.handle, _rawptr(X)); st==0 || error("forward FFT failed ($st)"); X |
| 235 | +end |
| 236 | +function _exec!(p::cMKLFFTPlan{T,MKLFFT_INVERSE,true}, X::oneAPI.oneArray{T}) where T |
| 237 | + st = ccall_bwd(p.handle, _rawptr(X)); st==0 || error("inverse FFT failed ($st)"); X |
| 238 | +end |
| 239 | +function _exec!(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T,K} |
| 240 | + st = (K==MKLFFT_FORWARD ? ccall_fwd_oop : ccall_bwd_oop)(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("FFT failed ($st)"); Y |
| 241 | +end |
| 242 | + |
| 243 | +# Real forward |
| 244 | +function _exec!(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where T |
| 245 | + st = ccall_fwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("rfft failed ($st)"); Y |
| 246 | +end |
| 247 | +# Real inverse (complex -> real) |
| 248 | +function _exec!(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R,T<:Complex{R}} |
| 249 | + st = ccall_bwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("brfft failed ($st)"); Y |
| 250 | +end |
| 251 | + |
| 252 | +# Public API similar to AMDGPU |
| 253 | +function Base.:*(p::cMKLFFTPlan{T,K,true}, X::oneAPI.oneArray{T}) where {T,K} |
| 254 | + _exec!(p,X) |
| 255 | +end |
| 256 | +function Base.:*(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K} |
| 257 | + Y = oneAPI.oneArray{T}(undef, p.osz); _exec!(p,X,Y) |
| 258 | +end |
| 259 | +function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K} |
| 260 | + _exec!(p,X,Y) |
| 261 | +end |
| 262 | + |
| 263 | +# Real forward |
| 264 | +function Base.:*(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}} |
| 265 | + Y = oneAPI.oneArray{Complex{T}}(undef, p.osz); _exec!(p,X,Y) |
| 266 | +end |
| 267 | +function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}} |
| 268 | + _exec!(p,X,Y) |
| 269 | +end |
| 270 | +# Real inverse |
| 271 | +function Base.:*(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}} |
| 272 | + Y = oneAPI.oneArray{R}(undef, p.osz); _exec!(p,X,Y) |
| 273 | +end |
| 274 | +function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}} |
| 275 | + _exec!(p,X,Y) |
| 276 | +end |
| 277 | + |
| 278 | +end # module FFT |
0 commit comments