@@ -29,6 +29,10 @@ const DFT_PARAM_LENGTHS = 2
2929const DFT_PARAM_PRECISION = 3
3030const DFT_PARAM_FORWARD_SCALE = 4
3131const DFT_PARAM_BACKWARD_SCALE = 5
32+ const DFT_PARAM_PLACEMENT = 11
33+ # oneMKL config_value enum (subset) indices (not raw DFTI_* numbers)
34+ const DFT_CFG_INPLACE = 4
35+ const DFT_CFG_NOT_INPLACE = 5
3236
3337# Opaque descriptor type alias to Ptr{Nothing} (generated wrapper not yet exposed)
3438# We'll declare ccall prototypes manually until generator exposes them.
@@ -49,6 +53,8 @@ ccall_fwd_oop(desc, pin, pout) = ccall((:onemklDftComputeForwardOutOfPlace, lib)
4953ccall_bwd (desc, ptr) = ccall ((:onemklDftComputeBackward , lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr)
5054ccall_bwd_oop (desc, pin, pout) = ccall ((:onemklDftComputeBackwardOutOfPlace , lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout)
5155ccall_set_double (desc, param:: Int32 , value:: Float64 ) = ccall ((:onemklDftSetValueDouble , lib), Cint, (Ptr{Cvoid}, Cint, Float64), desc, param, value)
56+ ccall_set_int (desc, param:: Int32 , value:: Int64 ) = ccall ((:onemklDftSetValueInt64 , lib), Cint, (Ptr{Cvoid}, Cint, Int64), desc, param, value)
57+ ccall_set_cfg (desc, param:: Int32 , value:: Int32 ) = ccall ((:onemklDftSetValueConfigValue , lib), Cint, (Ptr{Cvoid}, Cint, Cint), desc, param, value)
5258
5359abstract type MKLFFTPlan{T,K,inplace} <: AbstractFFTs.Plan{T} end
5460
@@ -83,28 +89,31 @@ mutable struct rMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
8389end
8490
8591# Inverse plan constructors (derive from existing plan)
92+ function normalization_factor (sz, region)
93+ # AbstractFFTs expects inverse to scale by 1/prod(lengths along region)
94+ prod (ntuple (i-> sz[region[i]], length (region)))
95+ end
96+
8697function plan_inv (p:: cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B} ) where {T,inplace,N,R,B}
8798 q = cMKLFFTPlan {T,MKLFFT_INVERSE,inplace,N,R,B} (p. handle,p. queue,p. sz,p. osz,p. realdomain,p. region,p. buffer,p)
8899 p. pinv = q
89- q
100+ ScaledPlan (q, 1 / normalization_factor (p . sz, p . region))
90101end
91102function plan_inv (p:: cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B} ) where {T,inplace,N,R,B}
92103 q = cMKLFFTPlan {T,MKLFFT_FORWARD,inplace,N,R,B} (p. handle,p. queue,p. sz,p. osz,p. realdomain,p. region,p. buffer,p)
93104 p. pinv = q
94- q
105+ ScaledPlan (q, 1 / normalization_factor (p . sz, p . region))
95106end
96107
97108function plan_inv (p:: rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B} ) where {T,inplace,N,R,B}
98- # forward real -> inverse complex->real (brfft)
99109 q = rMKLFFTPlan {T,MKLFFT_INVERSE,inplace,N,R,B} (p. handle,p. queue,p. sz,p. osz,:brfft ,p. region,p. buffer,p)
100110 p. pinv = q
101- q
111+ ScaledPlan (q, 1 / normalization_factor (p . sz, p . region))
102112end
103113function plan_inv (p:: rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B} ) where {T,inplace,N,R,B}
104- # inverse real -> forward real (rfft)
105114 q = rMKLFFTPlan {T,MKLFFT_FORWARD,inplace,N,R,B} (p. handle,p. queue,p. sz,p. osz,:rfft ,p. region,p. buffer,p)
106115 p. pinv = q
107- q
116+ ScaledPlan (q, 1 / normalization_factor (p . sz, p . region))
108117end
109118
110119function Base. show (io:: IO , p:: MKLFFTPlan{T,K,inplace} ) where {T,K,inplace}
@@ -123,45 +132,45 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool; normalize
123132 st = length (lengths) == 1 ? ccall_create1d (desc_ref, iprec, idom, lengths[1 ]) : ccall_creatend (desc_ref, iprec, idom, length (lengths), pointer (lengths))
124133 st == 0 || error (" onemkl DFT create failed (status $st )" )
125134 desc = desc_ref[]
126- # Set scaling so that forward is unscaled, inverse multiplies by 1/volume (AbstractFFTs convention)
127- if normalize
128- vol = prod (sz)
129- stfs = ccall_set_double (desc, Int32 (DFT_PARAM_FORWARD_SCALE), 1.0 )
130- stfs == 0 || error (" set forward scale failed ($stfs )" )
131- stbs = ccall_set_double (desc, Int32 (DFT_PARAM_BACKWARD_SCALE), 1.0 / vol)
132- stbs == 0 || error (" set backward scale failed ($stbs )" )
133- end
135+ # Do not program descriptor scaling; we'll perform inverse normalization manually.
136+ # Set placement explicitly based on plan type later
134137 # Construct a SYCL queue from current Level Zero context/device (reuse global queue)
135138 ze_ctx = oneAPI. context (); ze_dev = oneAPI. device ()
136139 sycl_dev = SYCL. syclDevice (SYCL. syclPlatform (oneAPI. driver ()), ze_dev)
137140 sycl_ctx = SYCL. syclContext ([sycl_dev], ze_ctx)
138141 q = SYCL. syclQueue (sycl_ctx, sycl_dev, oneAPI. global_queue (ze_ctx, ze_dev))
139- stc = ccall_commit (desc, q)
140- stc == 0 || error (" onemkl DFT commit failed (status $stc )" )
141142 return desc, q
142143end
143144
144145# Complex plans
145146function plan_fft (X:: oneAPI.oneArray{T,N} , region) where {T<: Union{ComplexF32,ComplexF64} ,N}
146147 R = length (region); reg = NTuple {R,Int} (region)
147148 desc, q = _create_descriptor (size (X), T, true )
149+ ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_NOT_INPLACE))
150+ stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
148151 return cMKLFFTPlan {T,MKLFFT_FORWARD,false,N,R,Nothing} (desc,q,size (X),size (X),false ,reg,nothing ,nothing )
149152end
150153function plan_bfft (X:: oneAPI.oneArray{T,N} , region) where {T<: Union{ComplexF32,ComplexF64} ,N}
151154 R = length (region); reg = NTuple {R,Int} (region)
152155 desc, q = _create_descriptor (size (X), T, true )
156+ ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_NOT_INPLACE))
157+ stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
153158 return cMKLFFTPlan {T,MKLFFT_INVERSE,false,N,R,Nothing} (desc,q,size (X),size (X),false ,reg,nothing ,nothing )
154159end
155160
156161# In-place (provide separate methods)
157162function plan_fft! (X:: oneAPI.oneArray{T,N} , region) where {T<: Union{ComplexF32,ComplexF64} ,N}
158163 R = length (region); reg = NTuple {R,Int} (region)
159164 desc,q = _create_descriptor (size (X),T,true )
165+ ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_INPLACE))
166+ stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
160167 cMKLFFTPlan {T,MKLFFT_FORWARD,true,N,R,Nothing} (desc,q,size (X),size (X),false ,reg,nothing ,nothing )
161168end
162169function plan_bfft! (X:: oneAPI.oneArray{T,N} , region) where {T<: Union{ComplexF32,ComplexF64} ,N}
163170 R = length (region); reg = NTuple {R,Int} (region)
164171 desc,q = _create_descriptor (size (X),T,true )
172+ ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_INPLACE))
173+ stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
165174 cMKLFFTPlan {T,MKLFFT_INVERSE,true,N,R,Nothing} (desc,q,size (X),size (X),false ,reg,nothing ,nothing )
166175end
167176
@@ -174,6 +183,8 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
174183 ax = reg[1 ]
175184 ydims = Base. setindex (xdims, div (xdims[ax],2 )+ 1 , ax)
176185 buffer = oneAPI. oneArray {Complex{T}} (undef, ydims)
186+ ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_NOT_INPLACE))
187+ stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
177188 rMKLFFTPlan {T,MKLFFT_FORWARD,false,N,R,typeof(buffer)} (desc,q,xdims,ydims,:rfft ,reg,buffer,nothing )
178189end
179190
@@ -188,6 +199,8 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
188199 RT = T. parameters[1 ]
189200 desc,q = _create_descriptor (ydims, RT, false )
190201 buffer = oneAPI. oneArray {T} (undef, xdims) # copy for safety
202+ ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_NOT_INPLACE))
203+ stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
191204 rMKLFFTPlan {T,MKLFFT_INVERSE,false,N,R,typeof(buffer)} (desc,q,xdims,ydims,:brfft ,reg,buffer,nothing )
192205end
193206
0 commit comments