|
| 1 | +#include "onemkl_dft.h" |
| 2 | +#include "sycl.hpp" // internal struct definitions |
| 3 | + |
| 4 | +#include <oneapi/mkl/dft.hpp> |
| 5 | +#include <vector> |
| 6 | +#include <complex> |
| 7 | +#include <new> |
| 8 | +#include <exception> |
| 9 | +#include <cstring> |
| 10 | + |
| 11 | +using namespace oneapi::mkl::dft; |
| 12 | + |
| 13 | +struct onemklDftDescriptor_st { |
| 14 | + precision prec; |
| 15 | + domain dom; |
| 16 | + void *ptr; // pointer to concrete descriptor<prec, dom> |
| 17 | +}; |
| 18 | + |
| 19 | +static inline precision to_prec(onemklDftPrecision p) { |
| 20 | + return (p == ONEMKL_DFT_PRECISION_DOUBLE) ? precision::DOUBLE : precision::SINGLE; |
| 21 | +} |
| 22 | + |
| 23 | +static inline domain to_dom(onemklDftDomain d) { |
| 24 | + return (d == ONEMKL_DFT_DOMAIN_COMPLEX) ? domain::COMPLEX : domain::REAL; |
| 25 | +} |
| 26 | + |
| 27 | +// Helper to allocate descriptor depending on precision/domain |
| 28 | +static int allocate_descriptor(onemklDftDescriptor_t *out, precision p, domain d, const std::vector<int64_t> &lengths) { |
| 29 | + try { |
| 30 | + auto *desc = new onemklDftDescriptor_st(); |
| 31 | + desc->prec = p; |
| 32 | + desc->dom = d; |
| 33 | + if (p == precision::SINGLE && d == domain::REAL) { |
| 34 | + desc->ptr = new descriptor<precision::SINGLE, domain::REAL>(lengths); |
| 35 | + } else if (p == precision::SINGLE && d == domain::COMPLEX) { |
| 36 | + desc->ptr = new descriptor<precision::SINGLE, domain::COMPLEX>(lengths); |
| 37 | + } else if (p == precision::DOUBLE && d == domain::REAL) { |
| 38 | + desc->ptr = new descriptor<precision::DOUBLE, domain::REAL>(lengths); |
| 39 | + } else { // DOUBLE COMPLEX |
| 40 | + desc->ptr = new descriptor<precision::DOUBLE, domain::COMPLEX>(lengths); |
| 41 | + } |
| 42 | + *out = desc; |
| 43 | + return 0; |
| 44 | + } catch (...) { |
| 45 | + return -1; |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +int onemklDftCreate1D(onemklDftDescriptor_t *desc, |
| 50 | + onemklDftPrecision precision, |
| 51 | + onemklDftDomain domain, |
| 52 | + int64_t length) { |
| 53 | + std::vector<int64_t> dims{length}; |
| 54 | + return allocate_descriptor(desc, to_prec(precision), to_dom(domain), dims); |
| 55 | +} |
| 56 | + |
| 57 | +int onemklDftCreateND(onemklDftDescriptor_t *desc, |
| 58 | + onemklDftPrecision precision, |
| 59 | + onemklDftDomain domain, |
| 60 | + int64_t dim, |
| 61 | + const int64_t *lengths) { |
| 62 | + if (dim <= 0 || lengths == nullptr) return -2; |
| 63 | + std::vector<int64_t> dims(lengths, lengths + dim); |
| 64 | + return allocate_descriptor(desc, to_prec(precision), to_dom(domain), dims); |
| 65 | +} |
| 66 | + |
| 67 | +int onemklDftDestroy(onemklDftDescriptor_t desc) { |
| 68 | + if (!desc) return 0; |
| 69 | + try { |
| 70 | + if (desc->prec == precision::SINGLE && desc->dom == domain::REAL) { |
| 71 | + delete static_cast< descriptor<precision::SINGLE, domain::REAL>* >(desc->ptr); |
| 72 | + } else if (desc->prec == precision::SINGLE && desc->dom == domain::COMPLEX) { |
| 73 | + delete static_cast< descriptor<precision::SINGLE, domain::COMPLEX>* >(desc->ptr); |
| 74 | + } else if (desc->prec == precision::DOUBLE && desc->dom == domain::REAL) { |
| 75 | + delete static_cast< descriptor<precision::DOUBLE, domain::REAL>* >(desc->ptr); |
| 76 | + } else { |
| 77 | + delete static_cast< descriptor<precision::DOUBLE, domain::COMPLEX>* >(desc->ptr); |
| 78 | + } |
| 79 | + delete desc; |
| 80 | + return 0; |
| 81 | + } catch (...) { |
| 82 | + return -1; |
| 83 | + } |
| 84 | +} |
| 85 | + |
| 86 | +int onemklDftCommit(onemklDftDescriptor_t desc, syclQueue_t queue) { |
| 87 | + if (!desc || !queue) return -2; |
| 88 | + try { |
| 89 | + if (desc->prec == precision::SINGLE && desc->dom == domain::REAL) { |
| 90 | + static_cast< descriptor<precision::SINGLE, domain::REAL>* >(desc->ptr)->commit(queue->val); |
| 91 | + } else if (desc->prec == precision::SINGLE && desc->dom == domain::COMPLEX) { |
| 92 | + static_cast< descriptor<precision::SINGLE, domain::COMPLEX>* >(desc->ptr)->commit(queue->val); |
| 93 | + } else if (desc->prec == precision::DOUBLE && desc->dom == domain::REAL) { |
| 94 | + static_cast< descriptor<precision::DOUBLE, domain::REAL>* >(desc->ptr)->commit(queue->val); |
| 95 | + } else { |
| 96 | + static_cast< descriptor<precision::DOUBLE, domain::COMPLEX>* >(desc->ptr)->commit(queue->val); |
| 97 | + } |
| 98 | + return 0; |
| 99 | + } catch (...) { |
| 100 | + return -1; |
| 101 | + } |
| 102 | +} |
| 103 | + |
| 104 | +// Internal mapping helpers for config params/values; rely on enum ordering matching header. |
| 105 | +static inline config_param to_param(onemklDftConfigParam p) { return static_cast<config_param>(p); } |
| 106 | +static inline config_value to_cvalue(onemklDftConfigValue v) { return static_cast<config_value>(v); } |
| 107 | + |
| 108 | +// Dispatch macro re-used for configuration |
| 109 | +#define ONEMKL_DFT_DISPATCH_CFG(desc_expr, CALL) \ |
| 110 | + do { \ |
| 111 | + if (desc->prec == precision::SINGLE && desc->dom == domain::REAL) { \ |
| 112 | + auto *d = static_cast< descriptor<precision::SINGLE, domain::REAL>* >(desc_expr); \ |
| 113 | + CALL; \ |
| 114 | + } else if (desc->prec == precision::SINGLE && desc->dom == domain::COMPLEX) { \ |
| 115 | + auto *d = static_cast< descriptor<precision::SINGLE, domain::COMPLEX>* >(desc_expr); \ |
| 116 | + CALL; \ |
| 117 | + } else if (desc->prec == precision::DOUBLE && desc->dom == domain::REAL) { \ |
| 118 | + auto *d = static_cast< descriptor<precision::DOUBLE, domain::REAL>* >(desc_expr); \ |
| 119 | + CALL; \ |
| 120 | + } else { \ |
| 121 | + auto *d = static_cast< descriptor<precision::DOUBLE, domain::COMPLEX>* >(desc_expr); \ |
| 122 | + CALL; \ |
| 123 | + } \ |
| 124 | + } while (0) |
| 125 | + |
| 126 | +int onemklDftSetValueInt64(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t value) { |
| 127 | + if (!desc) return -2; |
| 128 | + try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->set_value(to_param(param), value)); return 0; } catch (...) { return -1; } |
| 129 | +} |
| 130 | + |
| 131 | +int onemklDftSetValueDouble(onemklDftDescriptor_t desc, onemklDftConfigParam param, double value) { |
| 132 | + if (!desc) return -2; |
| 133 | + try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->set_value(to_param(param), value)); return 0; } catch (...) { return -1; } |
| 134 | +} |
| 135 | + |
| 136 | +int onemklDftSetValueInt64Array(onemklDftDescriptor_t desc, onemklDftConfigParam param, const int64_t *values, int64_t n) { |
| 137 | + if (!desc || !values || n < 0) return -2; |
| 138 | + try { std::vector<int64_t> v(values, values + n); ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->set_value(to_param(param), v)); return 0; } catch (...) { return -1; } |
| 139 | +} |
| 140 | + |
| 141 | +int onemklDftSetValueConfigValue(onemklDftDescriptor_t desc, onemklDftConfigParam param, onemklDftConfigValue value) { |
| 142 | + if (!desc) return -2; |
| 143 | + try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->set_value(to_param(param), to_cvalue(value))); return 0; } catch (...) { return -1; } |
| 144 | +} |
| 145 | + |
| 146 | +int onemklDftGetValueInt64(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t *value) { |
| 147 | + if (!desc || !value) return -2; |
| 148 | + try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->get_value(to_param(param), value)); return 0; } catch (...) { return -1; } |
| 149 | +} |
| 150 | + |
| 151 | +int onemklDftGetValueDouble(onemklDftDescriptor_t desc, onemklDftConfigParam param, double *value) { |
| 152 | + if (!desc || !value) return -2; |
| 153 | + try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->get_value(to_param(param), value)); return 0; } catch (...) { return -1; } |
| 154 | +} |
| 155 | + |
| 156 | +int onemklDftGetValueInt64Array(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t *values, int64_t *n) { |
| 157 | + if (!desc || !values || !n || *n <= 0) return -2; |
| 158 | + try { |
| 159 | + std::vector<int64_t> v; ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->get_value(to_param(param), &v)); |
| 160 | + int64_t to_copy = (*n < (int64_t)v.size()) ? *n : (int64_t)v.size(); |
| 161 | + std::memcpy(values, v.data(), sizeof(int64_t)*to_copy); |
| 162 | + *n = to_copy; return 0; |
| 163 | + } catch (...) { return -1; } |
| 164 | +} |
| 165 | + |
| 166 | +int onemklDftGetValueConfigValue(onemklDftDescriptor_t desc, onemklDftConfigParam param, onemklDftConfigValue *value) { |
| 167 | + if (!desc || !value) return -2; |
| 168 | + try { config_value cv; ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->get_value(to_param(param), &cv)); *value = static_cast<onemklDftConfigValue>(cv); return 0; } catch (...) { return -1; } |
| 169 | +} |
| 170 | + |
| 171 | +// Helper macro to dispatch compute operations |
| 172 | +#define ONEMKL_DFT_DISPATCH(desc_expr, CALL) \ |
| 173 | + do { \ |
| 174 | + if (desc->prec == precision::SINGLE && desc->dom == domain::REAL) { \ |
| 175 | + auto *d = static_cast< descriptor<precision::SINGLE, domain::REAL>* >(desc_expr); \ |
| 176 | + CALL; \ |
| 177 | + } else if (desc->prec == precision::SINGLE && desc->dom == domain::COMPLEX) { \ |
| 178 | + auto *d = static_cast< descriptor<precision::SINGLE, domain::COMPLEX>* >(desc_expr); \ |
| 179 | + CALL; \ |
| 180 | + } else if (desc->prec == precision::DOUBLE && desc->dom == domain::REAL) { \ |
| 181 | + auto *d = static_cast< descriptor<precision::DOUBLE, domain::REAL>* >(desc_expr); \ |
| 182 | + CALL; \ |
| 183 | + } else { \ |
| 184 | + auto *d = static_cast< descriptor<precision::DOUBLE, domain::COMPLEX>* >(desc_expr); \ |
| 185 | + CALL; \ |
| 186 | + } \ |
| 187 | + } while (0) |
| 188 | + |
| 189 | +int onemklDftComputeForward(onemklDftDescriptor_t desc, void *inout) { |
| 190 | + if (!desc || !inout) return -2; |
| 191 | + try { |
| 192 | + ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, inout)); |
| 193 | + return 0; |
| 194 | + } catch (...) { |
| 195 | + return -1; |
| 196 | + } |
| 197 | +} |
| 198 | + |
| 199 | +int onemklDftComputeForwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void *out) { |
| 200 | + if (!desc || !in || !out) return -2; |
| 201 | + try { |
| 202 | + ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, in, out)); |
| 203 | + return 0; |
| 204 | + } catch (...) { |
| 205 | + return -1; |
| 206 | + } |
| 207 | +} |
| 208 | + |
| 209 | +int onemklDftComputeBackward(onemklDftDescriptor_t desc, void *inout) { |
| 210 | + if (!desc || !inout) return -2; |
| 211 | + try { |
| 212 | + ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, inout)); |
| 213 | + return 0; |
| 214 | + } catch (...) { |
| 215 | + return -1; |
| 216 | + } |
| 217 | +} |
| 218 | + |
| 219 | +int onemklDftComputeBackwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void *out) { |
| 220 | + if (!desc || !in || !out) return -2; |
| 221 | + try { |
| 222 | + ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, in, out)); |
| 223 | + return 0; |
| 224 | + } catch (...) { |
| 225 | + return -1; |
| 226 | + } |
| 227 | +} |
| 228 | + |
| 229 | +// Keep dispatch macros defined for buffer variants below; undef at end of file. |
| 230 | + |
| 231 | +// Buffer API helpers: create temporary buffers referencing host memory. |
| 232 | +// NOTE: This assumes the memory is accessible and sized appropriately. |
| 233 | +template <typename T> |
| 234 | +static inline sycl::buffer<T,1> make_buffer(T *ptr, int64_t n) { |
| 235 | + return sycl::buffer<T,1>(ptr, sycl::range<1>(static_cast<size_t>(n))); |
| 236 | +} |
| 237 | + |
| 238 | +// Query total element count from LENGTHS config (product of lengths). |
| 239 | +static int64_t get_element_count(onemklDftDescriptor_t desc) { |
| 240 | + int64_t n = 0; int64_t dims = 0; if (onemklDftGetValueInt64(desc, ONEMKL_DFT_PARAM_DIMENSION, &dims) != 0) return -1; if (dims <= 0 || dims > 8) return -1; int64_t lens[16]; int64_t want = dims; if (onemklDftGetValueInt64Array(desc, ONEMKL_DFT_PARAM_LENGTHS, lens, &want) != 0) return -1; if (want != dims) return -1; int64_t total = 1; for (int i=0;i<dims;i++){ if (lens[i]<=0) return -1; total *= lens[i]; } return total; } |
| 241 | + |
| 242 | +// Select real/complex element size variant for pointers. |
| 243 | +int onemklDftComputeForwardBuffer(onemklDftDescriptor_t desc, void *inout) { |
| 244 | + if (!desc || !inout) return -2; int64_t n = get_element_count(desc); if (n <= 0) return -3; try { |
| 245 | + if (desc->dom == domain::REAL) { |
| 246 | + if (desc->prec == precision::SINGLE) { auto buf = make_buffer((float*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, buf)); } |
| 247 | + else { auto buf = make_buffer((double*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, buf)); } |
| 248 | + } else { // COMPLEX |
| 249 | + if (desc->prec == precision::SINGLE) { auto buf = make_buffer((std::complex<float>*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, buf)); } |
| 250 | + else { auto buf = make_buffer((std::complex<double>*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, buf)); } |
| 251 | + } |
| 252 | + return 0; } catch (...) { return -1; } |
| 253 | +} |
| 254 | + |
| 255 | +int onemklDftComputeForwardOutOfPlaceBuffer(onemklDftDescriptor_t desc, void *in, void *out) { |
| 256 | + if (!desc || !in || !out) return -2; int64_t n = get_element_count(desc); if (n <= 0) return -3; try { |
| 257 | + if (desc->dom == domain::REAL) { |
| 258 | + if (desc->prec == precision::SINGLE) { auto bufi = make_buffer((float*)in, n); auto bufo = make_buffer((float*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, bufi, bufo)); } |
| 259 | + else { auto bufi = make_buffer((double*)in, n); auto bufo = make_buffer((double*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, bufi, bufo)); } |
| 260 | + } else { |
| 261 | + if (desc->prec == precision::SINGLE) { auto bufi = make_buffer((std::complex<float>*)in, n); auto bufo = make_buffer((std::complex<float>*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, bufi, bufo)); } |
| 262 | + else { auto bufi = make_buffer((std::complex<double>*)in, n); auto bufo = make_buffer((std::complex<double>*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, bufi, bufo)); } |
| 263 | + } |
| 264 | + return 0; } catch (...) { return -1; } |
| 265 | +} |
| 266 | + |
| 267 | +int onemklDftComputeBackwardBuffer(onemklDftDescriptor_t desc, void *inout) { |
| 268 | + if (!desc || !inout) return -2; int64_t n = get_element_count(desc); if (n <= 0) return -3; try { |
| 269 | + if (desc->dom == domain::REAL) { |
| 270 | + if (desc->prec == precision::SINGLE) { auto buf = make_buffer((float*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, buf)); } |
| 271 | + else { auto buf = make_buffer((double*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, buf)); } |
| 272 | + } else { |
| 273 | + if (desc->prec == precision::SINGLE) { auto buf = make_buffer((std::complex<float>*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, buf)); } |
| 274 | + else { auto buf = make_buffer((std::complex<double>*)inout, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, buf)); } |
| 275 | + } |
| 276 | + return 0; } catch (...) { return -1; } |
| 277 | +} |
| 278 | + |
| 279 | +int onemklDftComputeBackwardOutOfPlaceBuffer(onemklDftDescriptor_t desc, void *in, void *out) { |
| 280 | + if (!desc || !in || !out) return -2; int64_t n = get_element_count(desc); if (n <= 0) return -3; try { |
| 281 | + if (desc->dom == domain::REAL) { |
| 282 | + if (desc->prec == precision::SINGLE) { auto bufi = make_buffer((float*)in, n); auto bufo = make_buffer((float*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, bufi, bufo)); } |
| 283 | + else { auto bufi = make_buffer((double*)in, n); auto bufo = make_buffer((double*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, bufi, bufo)); } |
| 284 | + } else { |
| 285 | + if (desc->prec == precision::SINGLE) { auto bufi = make_buffer((std::complex<float>*)in, n); auto bufo = make_buffer((std::complex<float>*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, bufi, bufo)); } |
| 286 | + else { auto bufi = make_buffer((std::complex<double>*)in, n); auto bufo = make_buffer((std::complex<double>*)out, n); ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, bufi, bufo)); } |
| 287 | + } |
| 288 | + return 0; } catch (...) { return -1; } |
| 289 | +} |
| 290 | + |
| 291 | +#undef ONEMKL_DFT_DISPATCH |
| 292 | +#undef ONEMKL_DFT_DISPATCH_CFG |
0 commit comments