Skip to content

Commit ee5045b

Browse files
committed
Fix
1 parent 227a8b0 commit ee5045b

1 file changed

Lines changed: 83 additions & 16 deletions

File tree

deps/src/onemkl_dft.cpp

Lines changed: 83 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -186,44 +186,111 @@ int onemklDftGetValueConfigValue(onemklDftDescriptor_t desc, onemklDftConfigPara
186186
} \
187187
} while (0)
188188

189+
// Pointer (USM) dispatch with proper element typing rather than using void* directly.
190+
// Using void* caused instantiation of compute_forward/backward with <void> template
191+
// parameters on some oneMKL versions, leading to unresolved symbols at runtime.
189192
int onemklDftComputeForward(onemklDftDescriptor_t desc, void *inout) {
190193
if (!desc || !inout) return -2;
191194
try {
192-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, inout));
195+
if (desc->dom == domain::REAL) {
196+
if (desc->prec == precision::SINGLE) {
197+
auto *p = static_cast<float*>(inout);
198+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, p));
199+
} else {
200+
auto *p = static_cast<double*>(inout);
201+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, p));
202+
}
203+
} else { // COMPLEX
204+
if (desc->prec == precision::SINGLE) {
205+
auto *p = static_cast<std::complex<float>*>(inout);
206+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, p));
207+
} else {
208+
auto *p = static_cast<std::complex<double>*>(inout);
209+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, p));
210+
}
211+
}
193212
return 0;
194-
} catch (...) {
195-
return -1;
196-
}
213+
} catch (...) { return -1; }
197214
}
198215

199216
int onemklDftComputeForwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void *out) {
200217
if (!desc || !in || !out) return -2;
201218
try {
202-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, in, out));
219+
if (desc->dom == domain::REAL) {
220+
if (desc->prec == precision::SINGLE) {
221+
auto *pi = static_cast<float*>(in);
222+
auto *po = static_cast<float*>(out);
223+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, pi, po));
224+
} else {
225+
auto *pi = static_cast<double*>(in);
226+
auto *po = static_cast<double*>(out);
227+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, pi, po));
228+
}
229+
} else { // COMPLEX
230+
if (desc->prec == precision::SINGLE) {
231+
auto *pi = static_cast<std::complex<float>*>(in);
232+
auto *po = static_cast<std::complex<float>*>(out);
233+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, pi, po));
234+
} else {
235+
auto *pi = static_cast<std::complex<double>*>(in);
236+
auto *po = static_cast<std::complex<double>*>(out);
237+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_forward(*d, pi, po));
238+
}
239+
}
203240
return 0;
204-
} catch (...) {
205-
return -1;
206-
}
241+
} catch (...) { return -1; }
207242
}
208243

209244
int onemklDftComputeBackward(onemklDftDescriptor_t desc, void *inout) {
210245
if (!desc || !inout) return -2;
211246
try {
212-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, inout));
247+
if (desc->dom == domain::REAL) {
248+
if (desc->prec == precision::SINGLE) {
249+
auto *p = static_cast<float*>(inout);
250+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, p));
251+
} else {
252+
auto *p = static_cast<double*>(inout);
253+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, p));
254+
}
255+
} else { // COMPLEX
256+
if (desc->prec == precision::SINGLE) {
257+
auto *p = static_cast<std::complex<float>*>(inout);
258+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, p));
259+
} else {
260+
auto *p = static_cast<std::complex<double>*>(inout);
261+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, p));
262+
}
263+
}
213264
return 0;
214-
} catch (...) {
215-
return -1;
216-
}
265+
} catch (...) { return -1; }
217266
}
218267

219268
int onemklDftComputeBackwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void *out) {
220269
if (!desc || !in || !out) return -2;
221270
try {
222-
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, in, out));
271+
if (desc->dom == domain::REAL) {
272+
if (desc->prec == precision::SINGLE) {
273+
auto *pi = static_cast<float*>(in);
274+
auto *po = static_cast<float*>(out);
275+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, pi, po));
276+
} else {
277+
auto *pi = static_cast<double*>(in);
278+
auto *po = static_cast<double*>(out);
279+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, pi, po));
280+
}
281+
} else { // COMPLEX
282+
if (desc->prec == precision::SINGLE) {
283+
auto *pi = static_cast<std::complex<float>*>(in);
284+
auto *po = static_cast<std::complex<float>*>(out);
285+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, pi, po));
286+
} else {
287+
auto *pi = static_cast<std::complex<double>*>(in);
288+
auto *po = static_cast<std::complex<double>*>(out);
289+
ONEMKL_DFT_DISPATCH(desc->ptr, compute_backward(*d, pi, po));
290+
}
291+
}
223292
return 0;
224-
} catch (...) {
225-
return -1;
226-
}
293+
} catch (...) { return -1; }
227294
}
228295

229296
// Keep dispatch macros defined for buffer variants below; undef at end of file.

0 commit comments

Comments
 (0)