4646#include " src/externals/service_math.h"
4747#include " src/services/service_profiler.h"
4848
49+ #include < mutex>
4950#if defined(DAAL_INTEL_CPP_COMPILER)
5051 #include " immintrin.h"
5152#endif
@@ -131,7 +132,7 @@ class EuclideanDistances : public PairwiseDistances<FPType, cpu>
131132 : _a(a), _b(b), _squared(squared), _isSqrtNorm(isSqrtNorm)
132133 {}
133134
134- ~EuclideanDistances () override {}
135+ virtual ~EuclideanDistances () override {}
135136
136137 PairwiseDistanceType getType () override { return PairwiseDistanceType::euclidean; }
137138
@@ -292,8 +293,32 @@ class EuclideanDistances : public PairwiseDistances<FPType, cpu>
292293
293294 return safeStat.detach ();
294295 }
295-
296- // compute (A x B')
296+ #ifndef DAAL_REF
297+ // AMX-BF16 capability check (MKL builds only; cross-platform)
298+ static bool knn_has_amx_bf16 ()
299+ {
300+ static bool v = []() {
301+ #if defined(_MSC_VER)
302+ int info[4 ];
303+ __cpuidex (info, 7 , 0 );
304+ if (!((info[3 ] >> 22 ) & 1 )) return false ;
305+ unsigned long long xcr = _xgetbv (0 );
306+ return ((xcr >> 17 ) & 3 ) == 3 ;
307+ #else
308+ unsigned int a = 0 , b = 0 , c = 0 , d = 0 ;
309+ __asm__ volatile (" cpuid" : " =a" (a), " =b" (b), " =c" (c), " =d" (d) : " a" (7 ), " c" (0 ));
310+ if (!((d >> 22 ) & 1u )) return false ;
311+ unsigned int lo = 0 , hi = 0 ;
312+ __asm__ volatile (" xgetbv" : " =a" (lo), " =d" (hi) : " c" (0 ));
313+ return ((lo >> 17 ) & 3u ) == 3u ;
314+ #endif
315+ }();
316+ return v;
317+ }
318+ #endif // !DAAL_REF
319+
320+ // compute (A x B') -- EuclideanDistances inner GEMM
321+ // GEMM call: out = B * A^T (col-major: M=nRowsB, N=nRowsA, K=nColsA)
297322 void computeABt (const FPType * const a, const FPType * const b, const size_t nRowsA, const size_t nColsA, const size_t nRowsB, FPType * const out)
298323 {
299324 const char transa = ' t' ;
@@ -306,7 +331,45 @@ class EuclideanDistances : public PairwiseDistances<FPType, cpu>
306331 const DAAL_INT ldy = nColsA;
307332 const FPType beta = 0.0 ;
308333 const DAAL_INT ldaty = nRowsB;
309-
334+ #ifndef DAAL_REF
335+ // AMX BF16 path: float only, all dims >= 64, MKL builds only
336+ if constexpr (std::is_same<FPType, float >::value)
337+ {
338+ if (knn_has_amx_bf16 () && _m >= 64 && _n >= 64 && _k >= 64 )
339+ {
340+ union
341+ {
342+ float f;
343+ unsigned int u;
344+ } cv;
345+ const size_t szA = (size_t )_n * (size_t )_k;
346+ const size_t szB = (size_t )_m * (size_t )_k;
347+ MKL_BF16 * a16 = (MKL_BF16 *)mkl_malloc (szA * sizeof (MKL_BF16), 64 );
348+ MKL_BF16 * b16 = (MKL_BF16 *)mkl_malloc (szB * sizeof (MKL_BF16), 64 );
349+ if (a16 && b16)
350+ {
351+ for (size_t i = 0 ; i < szA; i++)
352+ {
353+ cv.f = a[i];
354+ a16[i] = (MKL_BF16)(cv.u >> 16 );
355+ }
356+ for (size_t i = 0 ; i < szB; i++)
357+ {
358+ cv.f = b[i];
359+ b16[i] = (MKL_BF16)(cv.u >> 16 );
360+ }
361+ // col-major: C = B16 * A16^T => cblas: NoTrans B, Trans A
362+ cblas_gemm_bf16bf16f32 (CblasColMajor, CblasNoTrans, CblasTrans, (MKL_INT)_m, (MKL_INT)_n, (MKL_INT)_k, 1 .0f , b16, (MKL_INT)_m,
363+ a16, (MKL_INT)_n, 0 .0f , out, (MKL_INT)_m);
364+ mkl_free (a16);
365+ mkl_free (b16);
366+ return ;
367+ }
368+ if (a16) mkl_free (a16);
369+ if (b16) mkl_free (b16);
370+ }
371+ }
372+ #endif // !DAAL_REF
310373 BlasInst<FPType, cpu>::xxgemm (&transa, &transb, &_m, &_n, &_k, &alpha, b, &lda, a, &ldy, &beta, out, &ldaty);
311374 }
312375
@@ -329,7 +392,7 @@ class CosineDistances : public EuclideanDistances<FPType, cpu>
329392public:
330393 CosineDistances (const NumericTable & a, const NumericTable & b) : super(a, b, true , true ) {}
331394
332- ~CosineDistances () override {}
395+ virtual ~CosineDistances () override {}
333396
334397 PairwiseDistanceType getType () override { return PairwiseDistanceType::cosine; }
335398
@@ -372,7 +435,7 @@ class MinkowskiDistances : public PairwiseDistances<FPType, cpu>
372435 : _a(a), _b(b), _powered(powered), _p(p)
373436 {}
374437
375- ~MinkowskiDistances () override {}
438+ virtual ~MinkowskiDistances () override {}
376439
377440 PairwiseDistanceType getType () override { return PairwiseDistanceType::minkowski; }
378441
@@ -471,7 +534,7 @@ class ChebyshevDistances : public PairwiseDistances<FPType, cpu>
471534public:
472535 ChebyshevDistances (const NumericTable & a, const NumericTable & b) : _a(a), _b(b) {}
473536
474- ~ChebyshevDistances () override {}
537+ virtual ~ChebyshevDistances () override {}
475538
476539 PairwiseDistanceType getType () override { return PairwiseDistanceType::chebyshev; }
477540
0 commit comments