@@ -23,15 +23,26 @@ const DFT_PREC_DOUBLE = 1
2323const DFT_DOM_REAL = 0
2424const DFT_DOM_COMPLEX = 1
2525
26- # Placement values
27- const DFT_PARAM_DIMENSION = 1
28- const DFT_PARAM_LENGTHS = 2
29- const DFT_PARAM_PRECISION = 3
30- const DFT_PARAM_FORWARD_SCALE = 4
31- const DFT_PARAM_BACKWARD_SCALE = 5
32- const DFT_PARAM_PLACEMENT = 11
33- # oneMKL config_value enum (subset) indices (not raw DFTI_* numbers)
34- const DFT_CFG_INPLACE = 4
26+ # Configuration parameter indices (must match onemkl_dft.h enum ordering)
27+ const DFT_PARAM_DIMENSION = 1
28+ const DFT_PARAM_LENGTHS = 2
29+ const DFT_PARAM_PRECISION = 3
30+ const DFT_PARAM_FORWARD_SCALE = 4
31+ const DFT_PARAM_BACKWARD_SCALE = 5
32+ const DFT_PARAM_NUMBER_OF_TRANSFORMS = 6
33+ const DFT_PARAM_COMPLEX_STORAGE = 7
34+ const DFT_PARAM_PLACEMENT = 8
35+ const DFT_PARAM_INPUT_STRIDES = 9
36+ const DFT_PARAM_OUTPUT_STRIDES = 10
37+ const DFT_PARAM_FWD_DISTANCE = 11
38+ const DFT_PARAM_BWD_DISTANCE = 12
39+ const DFT_PARAM_WORKSPACE = 13
40+ const DFT_PARAM_WORKSPACE_ESTIMATE_BYTES = 14
41+ const DFT_PARAM_WORKSPACE_BYTES = 15
42+ const DFT_PARAM_FWD_STRIDES = 16
43+ const DFT_PARAM_BWD_STRIDES = 17
44+ # Config value logical indices (ordering per onemkl_dft.h)
45+ const DFT_CFG_INPLACE = 4
3546const DFT_CFG_NOT_INPLACE = 5
3647
3748# Opaque descriptor type alias to Ptr{Nothing} (generated wrapper not yet exposed)
@@ -54,6 +65,7 @@ ccall_bwd(desc, ptr) = ccall((:onemklDftComputeBackward, lib), Cint, (Ptr{Cvoid}
5465ccall_bwd_oop (desc, pin, pout) = ccall ((:onemklDftComputeBackwardOutOfPlace , lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout)
5566ccall_set_double (desc, param:: Int32 , value:: Float64 ) = ccall ((:onemklDftSetValueDouble , lib), Cint, (Ptr{Cvoid}, Cint, Float64), desc, param, value)
5667ccall_set_int (desc, param:: Int32 , value:: Int64 ) = ccall ((:onemklDftSetValueInt64 , lib), Cint, (Ptr{Cvoid}, Cint, Int64), desc, param, value)
68+ ccall_set_int64_array (desc, param:: Int32 , values:: Vector{Int64} ) = ccall ((:onemklDftSetValueInt64Array , lib), Cint, (Ptr{Cvoid}, Cint, Ptr{Int64}, Int64), desc, param, pointer (values), length (values))
5769ccall_set_cfg (desc, param:: Int32 , value:: Int32 ) = ccall ((:onemklDftSetValueConfigValue , lib), Cint, (Ptr{Cvoid}, Cint, Cint), desc, param, value)
5870
5971abstract type MKLFFTPlan{T,K,inplace} <: AbstractFFTs.Plan{T} end
@@ -127,6 +139,9 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool; normalize
127139 prec = T<: Float64 || T<: ComplexF64 ? DFT_PREC_DOUBLE : DFT_PREC_SINGLE
128140 dom = complex ? DFT_DOM_COMPLEX : DFT_DOM_REAL
129141 desc_ref = Ref {Ptr{Cvoid}} ()
142+ # Provide lengths in Julia's native (column-major) order; we'll explicitly
143+ # set column-major strides via FWD/BWD_STRIDES when N>1 to compensate for
144+ # oneMKL's row-major interpretation.
130145 lengths = collect (Int64, sz)
131146 iprec = Int32 (prec); idom = Int32 (dom)
132147 st = length (lengths) == 1 ? ccall_create1d (desc_ref, iprec, idom, lengths[1 ]) : ccall_creatend (desc_ref, iprec, idom, length (lengths), pointer (lengths))
@@ -147,13 +162,32 @@ function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,Co
147162 R = length (region); reg = NTuple {R,Int} (region)
148163 desc, q = _create_descriptor (size (X), T, true )
149164 ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_NOT_INPLACE))
165+ if N > 1
166+ # Column-major strides: stride along dimension i is product of sizes of previous dims
167+ strides = Vector {Int64} (undef, N+ 1 ); strides[1 ]= 0
168+ prod = 1
169+ @inbounds for i in 1 : N
170+ strides[i+ 1 ] = prod
171+ prod *= size (X,i)
172+ end
173+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_FWD_STRIDES), strides)
174+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_BWD_STRIDES), strides)
175+ end
150176 stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
151177 return cMKLFFTPlan {T,MKLFFT_FORWARD,false,N,R,Nothing} (desc,q,size (X),size (X),false ,reg,nothing ,nothing )
152178end
153179function plan_bfft (X:: oneAPI.oneArray{T,N} , region) where {T<: Union{ComplexF32,ComplexF64} ,N}
154180 R = length (region); reg = NTuple {R,Int} (region)
155181 desc, q = _create_descriptor (size (X), T, true )
156182 ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_NOT_INPLACE))
183+ if N > 1
184+ strides = Vector {Int64} (undef, N+ 1 ); strides[1 ]= 0 ; prod= 1
185+ @inbounds for i in 1 : N
186+ strides[i+ 1 ]= prod; prod*= size (X,i)
187+ end
188+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_FWD_STRIDES), strides)
189+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_BWD_STRIDES), strides)
190+ end
157191 stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
158192 return cMKLFFTPlan {T,MKLFFT_INVERSE,false,N,R,Nothing} (desc,q,size (X),size (X),false ,reg,nothing ,nothing )
159193end
@@ -163,13 +197,29 @@ function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,C
163197 R = length (region); reg = NTuple {R,Int} (region)
164198 desc,q = _create_descriptor (size (X),T,true )
165199 ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_INPLACE))
200+ if N > 1
201+ strides = Vector {Int64} (undef, N+ 1 ); strides[1 ]= 0 ; prod= 1
202+ @inbounds for i in 1 : N
203+ strides[i+ 1 ]= prod; prod*= size (X,i)
204+ end
205+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_FWD_STRIDES), strides)
206+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_BWD_STRIDES), strides)
207+ end
166208 stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
167209 cMKLFFTPlan {T,MKLFFT_FORWARD,true,N,R,Nothing} (desc,q,size (X),size (X),false ,reg,nothing ,nothing )
168210end
169211function plan_bfft! (X:: oneAPI.oneArray{T,N} , region) where {T<: Union{ComplexF32,ComplexF64} ,N}
170212 R = length (region); reg = NTuple {R,Int} (region)
171213 desc,q = _create_descriptor (size (X),T,true )
172214 ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_INPLACE))
215+ if N > 1
216+ strides = Vector {Int64} (undef, N+ 1 ); strides[1 ]= 0 ; prod= 1
217+ @inbounds for i in N: - 1 : 1
218+ strides[i+ 1 ]= prod; prod*= size (X,i)
219+ end
220+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_FWD_STRIDES), strides)
221+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_BWD_STRIDES), strides)
222+ end
173223 stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
174224 cMKLFFTPlan {T,MKLFFT_INVERSE,true,N,R,Nothing} (desc,q,size (X),size (X),false ,reg,nothing ,nothing )
175225end
@@ -184,6 +234,14 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
184234 ydims = Base. setindex (xdims, div (xdims[ax],2 )+ 1 , ax)
185235 buffer = oneAPI. oneArray {Complex{T}} (undef, ydims)
186236 ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_NOT_INPLACE))
237+ if N > 1
238+ strides = Vector {Int64} (undef, N+ 1 ); strides[1 ]= 0 ; prod= 1
239+ @inbounds for i in 1 : N
240+ strides[i+ 1 ]= prod; prod*= xdims[i]
241+ end
242+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_FWD_STRIDES), strides)
243+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_BWD_STRIDES), strides)
244+ end
187245 stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
188246 rMKLFFTPlan {T,MKLFFT_FORWARD,false,N,R,typeof(buffer)} (desc,q,xdims,ydims,:rfft ,reg,buffer,nothing )
189247end
@@ -200,6 +258,14 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
200258 desc,q = _create_descriptor (ydims, RT, false )
201259 buffer = oneAPI. oneArray {T} (undef, xdims) # copy for safety
202260 ccall_set_cfg (desc, Int32 (DFT_PARAM_PLACEMENT), Int32 (DFT_CFG_NOT_INPLACE))
261+ if N > 1
262+ strides = Vector {Int64} (undef, N+ 1 ); strides[1 ]= 0 ; prod= 1
263+ @inbounds for i in 1 : N
264+ strides[i+ 1 ]= prod; prod*= xdims[i]
265+ end
266+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_FWD_STRIDES), strides)
267+ ccall_set_int64_array (desc, Int32 (DFT_PARAM_BWD_STRIDES), strides)
268+ end
203269 stc = ccall_commit (desc, q); stc == 0 || error (" commit failed ($stc )" )
204270 rMKLFFTPlan {T,MKLFFT_INVERSE,false,N,R,typeof(buffer)} (desc,q,xdims,ydims,:brfft ,reg,buffer,nothing )
205271end
0 commit comments