Skip to content

Commit 72da1f7

Browse files
committed
Complex32 works
1 parent ee5045b commit 72da1f7

2 files changed

Lines changed: 37 additions & 24 deletions

File tree

deps/src/onemkl_dft.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,10 @@ int onemklDftComputeForward(onemklDftDescriptor_t desc, void *inout) {
195195
if (desc->dom == domain::REAL) {
196196
if (desc->prec == precision::SINGLE) {
197197
auto *p = static_cast<float*>(inout);
198-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, p));
198+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, p).wait());
199199
} else {
200200
auto *p = static_cast<double*>(inout);
201-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, p));
201+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, p).wait());
202202
}
203203
} else { // COMPLEX
204204
if (desc->prec == precision::SINGLE) {
@@ -220,11 +220,11 @@ int onemklDftComputeForwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void
220220
if (desc->prec == precision::SINGLE) {
221221
auto *pi = static_cast<float*>(in);
222222
auto *po = static_cast<float*>(out);
223-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, pi, po));
223+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, pi, po).wait());
224224
} else {
225225
auto *pi = static_cast<double*>(in);
226226
auto *po = static_cast<double*>(out);
227-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, pi, po));
227+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, pi, po).wait());
228228
}
229229
} else { // COMPLEX
230230
if (desc->prec == precision::SINGLE) {
@@ -247,10 +247,10 @@ int onemklDftComputeBackward(onemklDftDescriptor_t desc, void *inout) {
247247
if (desc->dom == domain::REAL) {
248248
if (desc->prec == precision::SINGLE) {
249249
auto *p = static_cast<float*>(inout);
250-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, p));
250+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, p).wait());
251251
} else {
252252
auto *p = static_cast<double*>(inout);
253-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, p));
253+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, p).wait());
254254
}
255255
} else { // COMPLEX
256256
if (desc->prec == precision::SINGLE) {
@@ -272,11 +272,11 @@ int onemklDftComputeBackwardOutOfPlace(onemklDftDescriptor_t desc, void *in, voi
272272
if (desc->prec == precision::SINGLE) {
273273
auto *pi = static_cast<float*>(in);
274274
auto *po = static_cast<float*>(out);
275-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, pi, po));
275+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, pi, po).wait());
276276
} else {
277277
auto *pi = static_cast<double*>(in);
278278
auto *po = static_cast<double*>(out);
279-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, pi, po));
279+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, pi, po).wait());
280280
}
281281
} else { // COMPLEX
282282
if (desc->prec == precision::SINGLE) {

lib/mkl/fft.jl

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ const DFT_PARAM_LENGTHS = 2
2929
const DFT_PARAM_PRECISION = 3
3030
const DFT_PARAM_FORWARD_SCALE = 4
3131
const DFT_PARAM_BACKWARD_SCALE = 5
32+
const DFT_PARAM_PLACEMENT = 11
33+
# oneMKL config_value enum (subset) indices (not raw DFTI_* numbers)
34+
const DFT_CFG_INPLACE = 4
35+
const DFT_CFG_NOT_INPLACE = 5
3236

3337
# Opaque descriptor type alias to Ptr{Nothing} (generated wrapper not yet exposed)
3438
# We'll declare ccall prototypes manually until generator exposes them.
@@ -49,6 +53,8 @@ ccall_fwd_oop(desc, pin, pout) = ccall((:onemklDftComputeForwardOutOfPlace, lib)
4953
ccall_bwd(desc, ptr) = ccall((:onemklDftComputeBackward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr)
5054
ccall_bwd_oop(desc, pin, pout) = ccall((:onemklDftComputeBackwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout)
5155
ccall_set_double(desc, param::Int32, value::Float64) = ccall((:onemklDftSetValueDouble, lib), Cint, (Ptr{Cvoid}, Cint, Float64), desc, param, value)
56+
ccall_set_int(desc, param::Int32, value::Int64) = ccall((:onemklDftSetValueInt64, lib), Cint, (Ptr{Cvoid}, Cint, Int64), desc, param, value)
57+
ccall_set_cfg(desc, param::Int32, value::Int32) = ccall((:onemklDftSetValueConfigValue, lib), Cint, (Ptr{Cvoid}, Cint, Cint), desc, param, value)
5258

5359
abstract type MKLFFTPlan{T,K,inplace} <: AbstractFFTs.Plan{T} end
5460

@@ -83,28 +89,31 @@ mutable struct rMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
8389
end
8490

8591
# Inverse plan constructors (derive from existing plan)
92+
function normalization_factor(sz, region)
93+
# AbstractFFTs expects inverse to scale by 1/prod(lengths along region)
94+
prod(ntuple(i-> sz[region[i]], length(region)))
95+
end
96+
8697
function plan_inv(p::cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
8798
q = cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
8899
p.pinv = q
89-
q
100+
ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
90101
end
91102
function plan_inv(p::cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
92103
q = cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
93104
p.pinv = q
94-
q
105+
ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
95106
end
96107

97108
function plan_inv(p::rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
98-
# forward real -> inverse complex->real (brfft)
99109
q = rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:brfft,p.region,p.buffer,p)
100110
p.pinv = q
101-
q
111+
ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
102112
end
103113
function plan_inv(p::rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
104-
# inverse real -> forward real (rfft)
105114
q = rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:rfft,p.region,p.buffer,p)
106115
p.pinv = q
107-
q
116+
ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
108117
end
109118

110119
function Base.show(io::IO, p::MKLFFTPlan{T,K,inplace}) where {T,K,inplace}
@@ -123,45 +132,45 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool; normalize
123132
st = length(lengths) == 1 ? ccall_create1d(desc_ref, iprec, idom, lengths[1]) : ccall_creatend(desc_ref, iprec, idom, length(lengths), pointer(lengths))
124133
st == 0 || error("onemkl DFT create failed (status $st)")
125134
desc = desc_ref[]
126-
# Set scaling so that forward is unscaled, inverse multiplies by 1/volume (AbstractFFTs convention)
127-
if normalize
128-
vol = prod(sz)
129-
stfs = ccall_set_double(desc, Int32(DFT_PARAM_FORWARD_SCALE), 1.0)
130-
stfs == 0 || error("set forward scale failed ($stfs)")
131-
stbs = ccall_set_double(desc, Int32(DFT_PARAM_BACKWARD_SCALE), 1.0/vol)
132-
stbs == 0 || error("set backward scale failed ($stbs)")
133-
end
135+
# Do not program descriptor scaling; we'll perform inverse normalization manually.
136+
# Set placement explicitly based on plan type later
134137
# Construct a SYCL queue from current Level Zero context/device (reuse global queue)
135138
ze_ctx = oneAPI.context(); ze_dev = oneAPI.device()
136139
sycl_dev = SYCL.syclDevice(SYCL.syclPlatform(oneAPI.driver()), ze_dev)
137140
sycl_ctx = SYCL.syclContext([sycl_dev], ze_ctx)
138141
q = SYCL.syclQueue(sycl_ctx, sycl_dev, oneAPI.global_queue(ze_ctx, ze_dev))
139-
stc = ccall_commit(desc, q)
140-
stc == 0 || error("onemkl DFT commit failed (status $stc)")
141142
return desc, q
142143
end
143144

144145
# Complex plans
145146
function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
146147
R = length(region); reg = NTuple{R,Int}(region)
147148
desc, q = _create_descriptor(size(X), T, true)
149+
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_NOT_INPLACE))
150+
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
148151
return cMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
149152
end
150153
function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
151154
R = length(region); reg = NTuple{R,Int}(region)
152155
desc, q = _create_descriptor(size(X), T, true)
156+
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_NOT_INPLACE))
157+
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
153158
return cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
154159
end
155160

156161
# In-place (provide separate methods)
157162
function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
158163
R = length(region); reg = NTuple{R,Int}(region)
159164
desc,q = _create_descriptor(size(X),T,true)
165+
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_INPLACE))
166+
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
160167
cMKLFFTPlan{T,MKLFFT_FORWARD,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
161168
end
162169
function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
163170
R = length(region); reg = NTuple{R,Int}(region)
164171
desc,q = _create_descriptor(size(X),T,true)
172+
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_INPLACE))
173+
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
165174
cMKLFFTPlan{T,MKLFFT_INVERSE,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
166175
end
167176

@@ -174,6 +183,8 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
174183
ax = reg[1]
175184
ydims = Base.setindex(xdims, div(xdims[ax],2)+1, ax)
176185
buffer = oneAPI.oneArray{Complex{T}}(undef, ydims)
186+
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_NOT_INPLACE))
187+
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
177188
rMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:rfft,reg,buffer,nothing)
178189
end
179190

@@ -188,6 +199,8 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
188199
RT = T.parameters[1]
189200
desc,q = _create_descriptor(ydims, RT, false)
190201
buffer = oneAPI.oneArray{T}(undef, xdims) # copy for safety
202+
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_NOT_INPLACE))
203+
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
191204
rMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:brfft,reg,buffer,nothing)
192205
end
193206

0 commit comments

Comments
 (0)