Skip to content

Commit 717bc89

Browse files
committed
Add oneMKL DFT
1 parent a0ee705 commit 717bc89

5 files changed

Lines changed: 586 additions & 5 deletions

File tree

deps/CMakeLists.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,15 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
66

77
project(oneAPISupport)
88

9-
add_library(oneapi_support SHARED src/sycl.h src/sycl.hpp src/sycl.cpp src/onemkl.h src/onemkl.cpp)
9+
add_library(oneapi_support SHARED
10+
src/sycl.h
11+
src/sycl.hpp
12+
src/sycl.cpp
13+
src/onemkl.h
14+
src/onemkl.cpp
15+
src/onemkl_dft.h
16+
src/onemkl_dft.cpp
17+
)
1018

1119
target_link_libraries(oneapi_support
1220
mkl_sycl

deps/src/onemkl_dft.cpp

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
#include "onemkl_dft.h"
2+
#include "sycl.hpp" // internal struct definitions
3+
4+
#include <oneapi/mkl/dft.hpp>
5+
#include <vector>
6+
#include <complex>
7+
#include <new>
8+
#include <exception>
9+
#include <cstring>
10+
11+
using namespace oneapi::mkl::dft;
12+
13+
struct onemklDftDescriptor_st {
14+
precision prec;
15+
domain dom;
16+
void *ptr; // pointer to concrete descriptor<prec, dom>
17+
};
18+
19+
static inline precision to_prec(onemklDftPrecision p) {
20+
return (p == ONEMKL_DFT_PRECISION_DOUBLE) ? precision::DOUBLE : precision::SINGLE;
21+
}
22+
23+
static inline domain to_dom(onemklDftDomain d) {
24+
return (d == ONEMKL_DFT_DOMAIN_COMPLEX) ? domain::COMPLEX : domain::REAL;
25+
}
26+
27+
// Helper to allocate descriptor depending on precision/domain
28+
static int allocate_descriptor(onemklDftDescriptor_t *out, precision p, domain d, const std::vector<int64_t> &lengths) {
29+
try {
30+
auto *desc = new onemklDftDescriptor_st();
31+
desc->prec = p;
32+
desc->dom = d;
33+
if (p == precision::SINGLE && d == domain::REAL) {
34+
desc->ptr = new descriptor<precision::SINGLE, domain::REAL>(lengths);
35+
} else if (p == precision::SINGLE && d == domain::COMPLEX) {
36+
desc->ptr = new descriptor<precision::SINGLE, domain::COMPLEX>(lengths);
37+
} else if (p == precision::DOUBLE && d == domain::REAL) {
38+
desc->ptr = new descriptor<precision::DOUBLE, domain::REAL>(lengths);
39+
} else { // DOUBLE COMPLEX
40+
desc->ptr = new descriptor<precision::DOUBLE, domain::COMPLEX>(lengths);
41+
}
42+
*out = desc;
43+
return 0;
44+
} catch (...) {
45+
return -1;
46+
}
47+
}
48+
49+
int onemklDftCreate1D(onemklDftDescriptor_t *desc,
50+
onemklDftPrecision precision,
51+
onemklDftDomain domain,
52+
int64_t length) {
53+
std::vector<int64_t> dims{length};
54+
return allocate_descriptor(desc, to_prec(precision), to_dom(domain), dims);
55+
}
56+
57+
int onemklDftCreateND(onemklDftDescriptor_t *desc,
58+
onemklDftPrecision precision,
59+
onemklDftDomain domain,
60+
int64_t dim,
61+
const int64_t *lengths) {
62+
if (dim <= 0 || lengths == nullptr) return -2;
63+
std::vector<int64_t> dims(lengths, lengths + dim);
64+
return allocate_descriptor(desc, to_prec(precision), to_dom(domain), dims);
65+
}
66+
67+
int onemklDftDestroy(onemklDftDescriptor_t desc) {
68+
if (!desc) return 0;
69+
try {
70+
if (desc->prec == precision::SINGLE && desc->dom == domain::REAL) {
71+
delete static_cast< descriptor<precision::SINGLE, domain::REAL>* >(desc->ptr);
72+
} else if (desc->prec == precision::SINGLE && desc->dom == domain::COMPLEX) {
73+
delete static_cast< descriptor<precision::SINGLE, domain::COMPLEX>* >(desc->ptr);
74+
} else if (desc->prec == precision::DOUBLE && desc->dom == domain::REAL) {
75+
delete static_cast< descriptor<precision::DOUBLE, domain::REAL>* >(desc->ptr);
76+
} else {
77+
delete static_cast< descriptor<precision::DOUBLE, domain::COMPLEX>* >(desc->ptr);
78+
}
79+
delete desc;
80+
return 0;
81+
} catch (...) {
82+
return -1;
83+
}
84+
}
85+
86+
int onemklDftCommit(onemklDftDescriptor_t desc, syclQueue_t queue) {
87+
if (!desc || !queue) return -2;
88+
try {
89+
if (desc->prec == precision::SINGLE && desc->dom == domain::REAL) {
90+
static_cast< descriptor<precision::SINGLE, domain::REAL>* >(desc->ptr)->commit(queue->val);
91+
} else if (desc->prec == precision::SINGLE && desc->dom == domain::COMPLEX) {
92+
static_cast< descriptor<precision::SINGLE, domain::COMPLEX>* >(desc->ptr)->commit(queue->val);
93+
} else if (desc->prec == precision::DOUBLE && desc->dom == domain::REAL) {
94+
static_cast< descriptor<precision::DOUBLE, domain::REAL>* >(desc->ptr)->commit(queue->val);
95+
} else {
96+
static_cast< descriptor<precision::DOUBLE, domain::COMPLEX>* >(desc->ptr)->commit(queue->val);
97+
}
98+
return 0;
99+
} catch (...) {
100+
return -1;
101+
}
102+
}
103+
104+
// Internal mapping helpers for config params/values; rely on enum ordering matching header.
105+
static inline config_param to_param(onemklDftConfigParam p) { return static_cast<config_param>(p); }
106+
static inline config_value to_cvalue(onemklDftConfigValue v) { return static_cast<config_value>(v); }
107+
108+
// Dispatch macro re-used for configuration
109+
#define ONEMKL_DFT_DISPATCH_CFG(desc_expr, CALL) \
110+
do { \
111+
if (desc->prec == precision::SINGLE && desc->dom == domain::REAL) { \
112+
auto *d = static_cast< descriptor<precision::SINGLE, domain::REAL>* >(desc_expr); \
113+
CALL; \
114+
} else if (desc->prec == precision::SINGLE && desc->dom == domain::COMPLEX) { \
115+
auto *d = static_cast< descriptor<precision::SINGLE, domain::COMPLEX>* >(desc_expr); \
116+
CALL; \
117+
} else if (desc->prec == precision::DOUBLE && desc->dom == domain::REAL) { \
118+
auto *d = static_cast< descriptor<precision::DOUBLE, domain::REAL>* >(desc_expr); \
119+
CALL; \
120+
} else { \
121+
auto *d = static_cast< descriptor<precision::DOUBLE, domain::COMPLEX>* >(desc_expr); \
122+
CALL; \
123+
} \
124+
} while (0)
125+
126+
int onemklDftSetValueInt64(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t value) {
127+
if (!desc) return -2;
128+
try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->set_value(to_param(param), value)); return 0; } catch (...) { return -1; }
129+
}
130+
131+
int onemklDftSetValueDouble(onemklDftDescriptor_t desc, onemklDftConfigParam param, double value) {
132+
if (!desc) return -2;
133+
try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->set_value(to_param(param), value)); return 0; } catch (...) { return -1; }
134+
}
135+
136+
int onemklDftSetValueInt64Array(onemklDftDescriptor_t desc, onemklDftConfigParam param, const int64_t *values, int64_t n) {
137+
if (!desc || !values || n < 0) return -2;
138+
try { std::vector<int64_t> v(values, values + n); ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->set_value(to_param(param), v)); return 0; } catch (...) { return -1; }
139+
}
140+
141+
int onemklDftSetValueConfigValue(onemklDftDescriptor_t desc, onemklDftConfigParam param, onemklDftConfigValue value) {
142+
if (!desc) return -2;
143+
try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->set_value(to_param(param), to_cvalue(value))); return 0; } catch (...) { return -1; }
144+
}
145+
146+
int onemklDftGetValueInt64(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t *value) {
147+
if (!desc || !value) return -2;
148+
try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->get_value(to_param(param), value)); return 0; } catch (...) { return -1; }
149+
}
150+
151+
int onemklDftGetValueDouble(onemklDftDescriptor_t desc, onemklDftConfigParam param, double *value) {
152+
if (!desc || !value) return -2;
153+
try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->get_value(to_param(param), value)); return 0; } catch (...) { return -1; }
154+
}
155+
156+
int onemklDftGetValueInt64Array(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t *values, int64_t *n) {
157+
if (!desc || !values || !n || *n <= 0) return -2;
158+
try {
159+
std::vector<int64_t> v; ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->get_value(to_param(param), &v));
160+
int64_t to_copy = (*n < (int64_t)v.size()) ? *n : (int64_t)v.size();
161+
std::memcpy(values, v.data(), sizeof(int64_t)*to_copy);
162+
*n = to_copy; return 0;
163+
} catch (...) { return -1; }
164+
}
165+
166+
int onemklDftGetValueConfigValue(onemklDftDescriptor_t desc, onemklDftConfigParam param, onemklDftConfigValue *value) {
167+
if (!desc || !value) return -2;
168+
try { config_value cv; ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->get_value(to_param(param), &cv)); *value = static_cast<onemklDftConfigValue>(cv); return 0; } catch (...) { return -1; }
169+
}
170+
171+
// Helper macro to dispatch compute operations
172+
#define ONEMKL_DFT_DISPATCH(desc_expr, CALL) \
173+
do { \
174+
if (desc->prec == precision::SINGLE && desc->dom == domain::REAL) { \
175+
auto *d = static_cast< descriptor<precision::SINGLE, domain::REAL>* >(desc_expr); \
176+
CALL; \
177+
} else if (desc->prec == precision::SINGLE && desc->dom == domain::COMPLEX) { \
178+
auto *d = static_cast< descriptor<precision::SINGLE, domain::COMPLEX>* >(desc_expr); \
179+
CALL; \
180+
} else if (desc->prec == precision::DOUBLE && desc->dom == domain::REAL) { \
181+
auto *d = static_cast< descriptor<precision::DOUBLE, domain::REAL>* >(desc_expr); \
182+
CALL; \
183+
} else { \
184+
auto *d = static_cast< descriptor<precision::DOUBLE, domain::COMPLEX>* >(desc_expr); \
185+
CALL; \
186+
} \
187+
} while (0)
188+
189+
int onemklDftComputeForward(onemklDftDescriptor_t desc, void *inout) {
190+
if (!desc || !inout) return -2;
191+
try {
192+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, inout));
193+
return 0;
194+
} catch (...) {
195+
return -1;
196+
}
197+
}
198+
199+
int onemklDftComputeForwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void *out) {
200+
if (!desc || !in || !out) return -2;
201+
try {
202+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, in, out));
203+
return 0;
204+
} catch (...) {
205+
return -1;
206+
}
207+
}
208+
209+
int onemklDftComputeBackward(onemklDftDescriptor_t desc, void *inout) {
210+
if (!desc || !inout) return -2;
211+
try {
212+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, inout));
213+
return 0;
214+
} catch (...) {
215+
return -1;
216+
}
217+
}
218+
219+
int onemklDftComputeBackwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void *out) {
220+
if (!desc || !in || !out) return -2;
221+
try {
222+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, in, out));
223+
return 0;
224+
} catch (...) {
225+
return -1;
226+
}
227+
}
228+
229+
// Keep dispatch macros defined for buffer variants below; undef at end of file.
230+
231+
// Buffer API helpers: create temporary buffers referencing host memory.
232+
// NOTE: This assumes the memory is accessible and sized appropriately.
233+
template <typename T>
234+
static inline sycl::buffer<T,1> make_buffer(T *ptr, int64_t n) {
235+
return sycl::buffer<T,1>(ptr, sycl::range<1>(static_cast<size_t>(n)));
236+
}
237+
238+
// Query total element count from LENGTHS config (product of lengths).
239+
static int64_t get_element_count(onemklDftDescriptor_t desc) {
240+
int64_t n = 0; int64_t dims = 0; if (onemklDftGetValueInt64(desc, ONEMKL_DFT_PARAM_DIMENSION, &dims) != 0) return -1; if (dims <= 0 || dims > 8) return -1; int64_t lens[16]; int64_t want = dims; if (onemklDftGetValueInt64Array(desc, ONEMKL_DFT_PARAM_LENGTHS, lens, &want) != 0) return -1; if (want != dims) return -1; int64_t total = 1; for (int i=0;i<dims;i++){ if (lens[i]<=0) return -1; total *= lens[i]; } return total; }
241+
242+
// Select real/complex element size variant for pointers.
243+
int onemklDftComputeForwardBuffer(onemklDftDescriptor_t desc, void *inout) {
244+
if (!desc || !inout) return -2; int64_t n = get_element_count(desc); if (n <= 0) return -3; try {
245+
if (desc->dom == domain::REAL) {
246+
if (desc->prec == precision::SINGLE) { auto buf = make_buffer((float*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, buf)); }
247+
else { auto buf = make_buffer((double*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, buf)); }
248+
} else { // COMPLEX
249+
if (desc->prec == precision::SINGLE) { auto buf = make_buffer((std::complex<float>*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, buf)); }
250+
else { auto buf = make_buffer((std::complex<double>*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, buf)); }
251+
}
252+
return 0; } catch (...) { return -1; }
253+
}
254+
255+
int onemklDftComputeForwardOutOfPlaceBuffer(onemklDftDescriptor_t desc, void *in, void *out) {
256+
if (!desc || !in || !out) return -2; int64_t n = get_element_count(desc); if (n <= 0) return -3; try {
257+
if (desc->dom == domain::REAL) {
258+
if (desc->prec == precision::SINGLE) { auto bufi = make_buffer((float*)in, n); auto bufo = make_buffer((float*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, bufi, bufo)); }
259+
else { auto bufi = make_buffer((double*)in, n); auto bufo = make_buffer((double*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, bufi, bufo)); }
260+
} else {
261+
if (desc->prec == precision::SINGLE) { auto bufi = make_buffer((std::complex<float>*)in, n); auto bufo = make_buffer((std::complex<float>*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, bufi, bufo)); }
262+
else { auto bufi = make_buffer((std::complex<double>*)in, n); auto bufo = make_buffer((std::complex<double>*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, bufi, bufo)); }
263+
}
264+
return 0; } catch (...) { return -1; }
265+
}
266+
267+
int onemklDftComputeBackwardBuffer(onemklDftDescriptor_t desc, void *inout) {
268+
if (!desc || !inout) return -2; int64_t n = get_element_count(desc); if (n <= 0) return -3; try {
269+
if (desc->dom == domain::REAL) {
270+
if (desc->prec == precision::SINGLE) { auto buf = make_buffer((float*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, buf)); }
271+
else { auto buf = make_buffer((double*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, buf)); }
272+
} else {
273+
if (desc->prec == precision::SINGLE) { auto buf = make_buffer((std::complex<float>*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, buf)); }
274+
else { auto buf = make_buffer((std::complex<double>*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, buf)); }
275+
}
276+
return 0; } catch (...) { return -1; }
277+
}
278+
279+
int onemklDftComputeBackwardOutOfPlaceBuffer(onemklDftDescriptor_t desc, void *in, void *out) {
280+
if (!desc || !in || !out) return -2; int64_t n = get_element_count(desc); if (n <= 0) return -3; try {
281+
if (desc->dom == domain::REAL) {
282+
if (desc->prec == precision::SINGLE) { auto bufi = make_buffer((float*)in, n); auto bufo = make_buffer((float*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, bufi, bufo)); }
283+
else { auto bufi = make_buffer((double*)in, n); auto bufo = make_buffer((double*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, bufi, bufo)); }
284+
} else {
285+
if (desc->prec == precision::SINGLE) { auto bufi = make_buffer((std::complex<float>*)in, n); auto bufo = make_buffer((std::complex<float>*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, bufi, bufo)); }
286+
else { auto bufi = make_buffer((std::complex<double>*)in, n); auto bufo = make_buffer((std::complex<double>*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, bufi, bufo)); }
287+
}
288+
return 0; } catch (...) { return -1; }
289+
}
290+
291+
#undef ONEMKL_DFT_DISPATCH
292+
#undef ONEMKL_DFT_DISPATCH_CFG

0 commit comments

Comments
 (0)