Skip to content

Commit 227a8b0

Browse files
committed
FFT Support
1 parent 717bc89 commit 227a8b0

5 files changed

Lines changed: 392 additions & 0 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ authors = ["Tim Besard <tim.besard@gmail.com>"]
44
version = "2.0.3"
55

66
[deps]
7+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
910
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
11+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1012
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1113
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
1214
GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
@@ -29,6 +31,7 @@ oneAPI_Level_Zero_Loader_jll = "13eca655-d68d-5b81-8367-6d99d727ab01"
2931
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"
3032

3133
[compat]
34+
AbstractFFTs = "1.5.0"
3235
Adapt = "4"
3336
CEnum = "0.4, 0.5"
3437
ExprTools = "0.1"

deps/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ add_library(oneapi_support SHARED
1818

1919
target_link_libraries(oneapi_support
2020
mkl_sycl
21+
# DFT component libraries needed for oneMKL DFT template instantiations
22+
mkl_sycl_dft
23+
mkl_cdft_core
2124
mkl_intel_ilp64
2225
mkl_sequential
2326
mkl_core

lib/mkl/fft.jl

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# oneMKL FFT (DFT) high-level Julia interface
2+
# Inspired by AMDGPU ROCFFT interface style, adapted to oneMKL DFT C wrapper.
3+
4+
module FFT
5+
6+
using ..oneMKL
7+
using ..oneMKL: oneAPI, SYCL, syclQueue_t
8+
using ..Support
9+
using ..SYCL
10+
using LinearAlgebra
11+
using GPUArrays
12+
using AbstractFFTs
13+
import AbstractFFTs: complexfloat, realfloat
14+
import AbstractFFTs: plan_fft, plan_fft!, plan_bfft, plan_bfft!
15+
import AbstractFFTs: plan_rfft, plan_brfft, plan_inv, normalization
16+
import AbstractFFTs: fft, bfft, ifft, rfft, Plan, ScaledPlan
17+
export MKLFFTPlan
18+
19+
# Low-level enums mirroring C API (subset)
20+
# (We can just re-use integer constants; C wrappers return 0 on success.)
21+
const DFT_PREC_SINGLE = 0
22+
const DFT_PREC_DOUBLE = 1
23+
const DFT_DOM_REAL = 0
24+
const DFT_DOM_COMPLEX = 1
25+
26+
# Placement values
27+
const DFT_PARAM_DIMENSION = 1
28+
const DFT_PARAM_LENGTHS = 2
29+
const DFT_PARAM_PRECISION = 3
30+
const DFT_PARAM_FORWARD_SCALE = 4
31+
const DFT_PARAM_BACKWARD_SCALE = 5
32+
33+
# Opaque descriptor type alias to Ptr{Nothing} (generated wrapper not yet exposed)
34+
# We'll declare ccall prototypes manually until generator exposes them.
35+
36+
# NOTE: The liboneapi_support.jl generated file currently doesn't have DFT entries; add manual ccalls.
37+
const lib = :liboneapi_support
38+
39+
# Allow implicit conversion of SYCL queue object to raw handle when storing/passing
40+
Base.convert(::Type{syclQueue_t}, q::SYCL.syclQueue) = Base.unsafe_convert(syclQueue_t, q)
41+
42+
# Creation / destruction
43+
ccall_create1d(desc_ref, prec::Int32, dom::Int32, length::Int64) = ccall((:onemklDftCreate1D, lib), Cint, (Ref{Ptr{Cvoid}}, Cint, Cint, Int64), desc_ref, prec, dom, length)
44+
ccall_creatend(desc_ref, prec::Int32, dom::Int32, dim::Int64, lengths::Ptr{Int64}) = ccall((:onemklDftCreateND, lib), Cint, (Ref{Ptr{Cvoid}}, Cint, Cint, Int64, Ptr{Int64}), desc_ref, prec, dom, dim, lengths)
45+
ccall_destroy(desc) = ccall((:onemklDftDestroy, lib), Cint, (Ptr{Cvoid},), desc)
46+
ccall_commit(desc, q) = ccall((:onemklDftCommit, lib), Cint, (Ptr{Cvoid}, syclQueue_t), desc, q)
47+
ccall_fwd(desc, ptr) = ccall((:onemklDftComputeForward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr)
48+
ccall_fwd_oop(desc, pin, pout) = ccall((:onemklDftComputeForwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout)
49+
ccall_bwd(desc, ptr) = ccall((:onemklDftComputeBackward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr)
50+
ccall_bwd_oop(desc, pin, pout) = ccall((:onemklDftComputeBackwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout)
51+
ccall_set_double(desc, param::Int32, value::Float64) = ccall((:onemklDftSetValueDouble, lib), Cint, (Ptr{Cvoid}, Cint, Float64), desc, param, value)
52+
53+
abstract type MKLFFTPlan{T,K,inplace} <: AbstractFFTs.Plan{T} end
54+
55+
Base.eltype(::MKLFFTPlan{T}) where T = T
56+
is_inplace(::MKLFFTPlan{<:Any,<:Any,inplace}) where inplace = inplace
57+
58+
# Forward / inverse flags
59+
const MKLFFT_FORWARD = true
60+
const MKLFFT_INVERSE = false
61+
62+
mutable struct cMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
63+
handle::Ptr{Cvoid}
64+
queue::syclQueue_t
65+
sz::NTuple{N,Int}
66+
osz::NTuple{N,Int}
67+
realdomain::Bool
68+
region::NTuple{R,Int}
69+
buffer::B
70+
pinv::Any
71+
end
72+
73+
# Real transforms use separate struct (mirroring AMDGPU style) for buffer staging
74+
mutable struct rMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
75+
handle::Ptr{Cvoid}
76+
queue::syclQueue_t
77+
sz::NTuple{N,Int}
78+
osz::NTuple{N,Int}
79+
xtype::Symbol
80+
region::NTuple{R,Int}
81+
buffer::B
82+
pinv::Any
83+
end
84+
85+
# Inverse plan constructors (derive from existing plan)
86+
function plan_inv(p::cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
87+
q = cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
88+
p.pinv = q
89+
q
90+
end
91+
function plan_inv(p::cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
92+
q = cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
93+
p.pinv = q
94+
q
95+
end
96+
97+
function plan_inv(p::rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
98+
# forward real -> inverse complex->real (brfft)
99+
q = rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:brfft,p.region,p.buffer,p)
100+
p.pinv = q
101+
q
102+
end
103+
function plan_inv(p::rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
104+
# inverse real -> forward real (rfft)
105+
q = rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:rfft,p.region,p.buffer,p)
106+
p.pinv = q
107+
q
108+
end
109+
110+
function Base.show(io::IO, p::MKLFFTPlan{T,K,inplace}) where {T,K,inplace}
111+
print(io, inplace ? "oneMKL FFT in-place " : "oneMKL FFT ", K ? "forward" : "inverse", " plan for ")
112+
if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
113+
print(io, " oneArray of ", T)
114+
end
115+
116+
# Plan constructors
117+
function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool; normalize=true) where N
118+
prec = T<:Float64 || T<:ComplexF64 ? DFT_PREC_DOUBLE : DFT_PREC_SINGLE
119+
dom = complex ? DFT_DOM_COMPLEX : DFT_DOM_REAL
120+
desc_ref = Ref{Ptr{Cvoid}}()
121+
lengths = collect(Int64, sz)
122+
iprec = Int32(prec); idom = Int32(dom)
123+
st = length(lengths) == 1 ? ccall_create1d(desc_ref, iprec, idom, lengths[1]) : ccall_creatend(desc_ref, iprec, idom, length(lengths), pointer(lengths))
124+
st == 0 || error("onemkl DFT create failed (status $st)")
125+
desc = desc_ref[]
126+
# Set scaling so that forward is unscaled, inverse multiplies by 1/volume (AbstractFFTs convention)
127+
if normalize
128+
vol = prod(sz)
129+
stfs = ccall_set_double(desc, Int32(DFT_PARAM_FORWARD_SCALE), 1.0)
130+
stfs == 0 || error("set forward scale failed ($stfs)")
131+
stbs = ccall_set_double(desc, Int32(DFT_PARAM_BACKWARD_SCALE), 1.0/vol)
132+
stbs == 0 || error("set backward scale failed ($stbs)")
133+
end
134+
# Construct a SYCL queue from current Level Zero context/device (reuse global queue)
135+
ze_ctx = oneAPI.context(); ze_dev = oneAPI.device()
136+
sycl_dev = SYCL.syclDevice(SYCL.syclPlatform(oneAPI.driver()), ze_dev)
137+
sycl_ctx = SYCL.syclContext([sycl_dev], ze_ctx)
138+
q = SYCL.syclQueue(sycl_ctx, sycl_dev, oneAPI.global_queue(ze_ctx, ze_dev))
139+
stc = ccall_commit(desc, q)
140+
stc == 0 || error("onemkl DFT commit failed (status $stc)")
141+
return desc, q
142+
end
143+
144+
# Complex plans
145+
function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
146+
R = length(region); reg = NTuple{R,Int}(region)
147+
desc, q = _create_descriptor(size(X), T, true)
148+
return cMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
149+
end
150+
function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
151+
R = length(region); reg = NTuple{R,Int}(region)
152+
desc, q = _create_descriptor(size(X), T, true)
153+
return cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
154+
end
155+
156+
# In-place (provide separate methods)
157+
function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
158+
R = length(region); reg = NTuple{R,Int}(region)
159+
desc,q = _create_descriptor(size(X),T,true)
160+
cMKLFFTPlan{T,MKLFFT_FORWARD,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
161+
end
162+
function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
163+
R = length(region); reg = NTuple{R,Int}(region)
164+
desc,q = _create_descriptor(size(X),T,true)
165+
cMKLFFTPlan{T,MKLFFT_INVERSE,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
166+
end
167+
168+
# Real forward (out-of-place)
169+
function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
170+
R = length(region); reg = NTuple{R,Int}(region)
171+
desc,q = _create_descriptor(size(X),T,false)
172+
xdims = size(X)
173+
# output last region dim becomes N/2+1 (assume first region index)
174+
ax = reg[1]
175+
ydims = Base.setindex(xdims, div(xdims[ax],2)+1, ax)
176+
buffer = oneAPI.oneArray{Complex{T}}(undef, ydims)
177+
rMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:rfft,reg,buffer,nothing)
178+
end
179+
180+
# Real inverse (complex->real) requires complex input shape
181+
function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union{ComplexF32,ComplexF64},N}
182+
R = length(region); reg = NTuple{R,Int}(region)
183+
# output real size 'd' along region[1]
184+
xdims = size(X)
185+
ydims = Base.setindex(xdims, d, reg[1])
186+
# Extract underlying real type R from Complex{R}
187+
@assert T <: Complex
188+
RT = T.parameters[1]
189+
desc,q = _create_descriptor(ydims, RT, false)
190+
buffer = oneAPI.oneArray{T}(undef, xdims) # copy for safety
191+
rMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:brfft,reg,buffer,nothing)
192+
end
193+
194+
# Convenience no-region methods use all dimensions in order
195+
plan_fft(X::oneAPI.oneArray) = plan_fft(X, ntuple(identity, ndims(X)))
196+
plan_bfft(X::oneAPI.oneArray) = plan_bfft(X, ntuple(identity, ndims(X)))
197+
plan_fft!(X::oneAPI.oneArray) = plan_fft!(X, ntuple(identity, ndims(X)))
198+
plan_bfft!(X::oneAPI.oneArray) = plan_bfft!(X, ntuple(identity, ndims(X)))
199+
plan_rfft(X::oneAPI.oneArray) = plan_rfft(X, (1,)) # default first dim like Base.rfft
200+
plan_brfft(X::oneAPI.oneArray, d::Integer) = plan_brfft(X, d, (1,))
201+
202+
# Alias names to mirror AMDGPU / AbstractFFTs style
203+
const plan_ifft = plan_bfft
204+
const plan_ifft! = plan_bfft!
205+
const plan_irfft = plan_brfft
206+
207+
# Inversion
208+
Base.inv(p::MKLFFTPlan) = plan_inv(p)
209+
210+
# High-level wrappers operating like CPU FFTW versions.
211+
function fft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
212+
(plan_fft(X) * X)
213+
end
214+
function ifft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
215+
(plan_bfft(X) * X)
216+
end
217+
function fft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
218+
(plan_fft!(X) * X; X)
219+
end
220+
function ifft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
221+
(plan_bfft!(X) * X; X)
222+
end
223+
function rfft(X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
224+
(plan_rfft(X) * X)
225+
end
226+
function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T<:Union{ComplexF32,ComplexF64}}
227+
(plan_brfft(X, d) * X)
228+
end
229+
230+
# Execution helpers
231+
_rawptr(a::oneAPI.oneArray{T}) where T = reinterpret(Ptr{Cvoid}, pointer(a))
232+
233+
function _exec!(p::cMKLFFTPlan{T,MKLFFT_FORWARD,true}, X::oneAPI.oneArray{T}) where T
234+
st = ccall_fwd(p.handle, _rawptr(X)); st==0 || error("forward FFT failed ($st)"); X
235+
end
236+
function _exec!(p::cMKLFFTPlan{T,MKLFFT_INVERSE,true}, X::oneAPI.oneArray{T}) where T
237+
st = ccall_bwd(p.handle, _rawptr(X)); st==0 || error("inverse FFT failed ($st)"); X
238+
end
239+
function _exec!(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T,K}
240+
st = (K==MKLFFT_FORWARD ? ccall_fwd_oop : ccall_bwd_oop)(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("FFT failed ($st)"); Y
241+
end
242+
243+
# Real forward
244+
function _exec!(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where T
245+
st = ccall_fwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("rfft failed ($st)"); Y
246+
end
247+
# Real inverse (complex -> real)
248+
function _exec!(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R,T<:Complex{R}}
249+
st = ccall_bwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("brfft failed ($st)"); Y
250+
end
251+
252+
# Public API similar to AMDGPU
253+
function Base.:*(p::cMKLFFTPlan{T,K,true}, X::oneAPI.oneArray{T}) where {T,K}
254+
_exec!(p,X)
255+
end
256+
function Base.:*(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
257+
Y = oneAPI.oneArray{T}(undef, p.osz); _exec!(p,X,Y)
258+
end
259+
function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
260+
_exec!(p,X,Y)
261+
end
262+
263+
# Real forward
264+
function Base.:*(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
265+
Y = oneAPI.oneArray{Complex{T}}(undef, p.osz); _exec!(p,X,Y)
266+
end
267+
function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
268+
_exec!(p,X,Y)
269+
end
270+
# Real inverse
271+
function Base.:*(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
272+
Y = oneAPI.oneArray{R}(undef, p.osz); _exec!(p,X,Y)
273+
end
274+
function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
275+
_exec!(p,X,Y)
276+
end
277+
278+
end # module FFT

lib/mkl/oneMKL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ include("wrappers_lapack.jl")
2929
include("wrappers_sparse.jl")
3030
include("linalg.jl")
3131
include("interfaces.jl")
32+
include("fft.jl")
3233

3334
function band(A::StridedArray, kl, ku)
3435
m, n = size(A)

0 commit comments

Comments
 (0)