Skip to content

Commit 5cb316e

Browse files
committed
perf: AMX BF16 dispatch in EuclideanDistances::computeABt
Add AMX BF16 fast path for kNN brute-force and DBSCAN distance computation. When AMX BF16 hardware is available and all GEMM dimensions >= 64, converts float32 operands to BF16 and uses cblas_gemm_bf16bf16f32 instead of sgemm. Benchmarks on Xeon 6975P-C (sklearnex, float32, cpu): - KNeighborsClassifier: 2.7-3.4x speedup - KNeighborsRegressor: 2.9-3.3x speedup - DBSCAN: 3.5-4.0x speedup Accuracy impact vs float32 baseline: - Classifier: delta < 1.2% (within dataset noise) - Regressor R2: delta < 0.001 Runtime detection via CPUID (AMX_BF16 bit) + XGETBV XCR0[18:17]. Falls back to sgemm if AMX unavailable or dims < 64. Signed-off-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
1 parent df90fc3 commit 5cb316e

1 file changed

Lines changed: 70 additions & 7 deletions

File tree

cpp/daal/src/algorithms/service_kernel_math.h

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "src/externals/service_math.h"
4747
#include "src/services/service_profiler.h"
4848

49+
#include <mutex>
4950
#if defined(DAAL_INTEL_CPP_COMPILER)
5051
#include "immintrin.h"
5152
#endif
@@ -131,7 +132,7 @@ class EuclideanDistances : public PairwiseDistances<FPType, cpu>
131132
: _a(a), _b(b), _squared(squared), _isSqrtNorm(isSqrtNorm)
132133
{}
133134

134-
~EuclideanDistances() override {}
135+
virtual ~EuclideanDistances() override {}
135136

136137
PairwiseDistanceType getType() override { return PairwiseDistanceType::euclidean; }
137138

@@ -292,8 +293,32 @@ class EuclideanDistances : public PairwiseDistances<FPType, cpu>
292293

293294
return safeStat.detach();
294295
}
295-
296-
// compute (A x B')
296+
#ifndef DAAL_REF
297+
// AMX-BF16 capability check (MKL builds only; cross-platform)
298+
static bool knn_has_amx_bf16()
299+
{
300+
static bool v = []() {
301+
#if defined(_MSC_VER)
302+
int info[4];
303+
__cpuidex(info, 7, 0);
304+
if (!((info[3] >> 22) & 1)) return false;
305+
unsigned long long xcr = _xgetbv(0);
306+
return ((xcr >> 17) & 3) == 3;
307+
#else
308+
unsigned int a = 0, b = 0, c = 0, d = 0;
309+
__asm__ volatile("cpuid" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "a"(7), "c"(0));
310+
if (!((d >> 22) & 1u)) return false;
311+
unsigned int lo = 0, hi = 0;
312+
__asm__ volatile("xgetbv" : "=a"(lo), "=d"(hi) : "c"(0));
313+
return ((lo >> 17) & 3u) == 3u;
314+
#endif
315+
}();
316+
return v;
317+
}
318+
#endif // !DAAL_REF
319+
320+
// compute (A x B') -- EuclideanDistances inner GEMM
321+
// GEMM call: out = B * A^T (col-major: M=nRowsB, N=nRowsA, K=nColsA)
297322
void computeABt(const FPType * const a, const FPType * const b, const size_t nRowsA, const size_t nColsA, const size_t nRowsB, FPType * const out)
298323
{
299324
const char transa = 't';
@@ -306,7 +331,45 @@ class EuclideanDistances : public PairwiseDistances<FPType, cpu>
306331
const DAAL_INT ldy = nColsA;
307332
const FPType beta = 0.0;
308333
const DAAL_INT ldaty = nRowsB;
309-
334+
#ifndef DAAL_REF
335+
// AMX BF16 path: float only, all dims >= 64, MKL builds only
336+
if constexpr (std::is_same<FPType, float>::value)
337+
{
338+
if (knn_has_amx_bf16() && _m >= 64 && _n >= 64 && _k >= 64)
339+
{
340+
union
341+
{
342+
float f;
343+
unsigned int u;
344+
} cv;
345+
const size_t szA = (size_t)_n * (size_t)_k;
346+
const size_t szB = (size_t)_m * (size_t)_k;
347+
MKL_BF16 * a16 = (MKL_BF16 *)mkl_malloc(szA * sizeof(MKL_BF16), 64);
348+
MKL_BF16 * b16 = (MKL_BF16 *)mkl_malloc(szB * sizeof(MKL_BF16), 64);
349+
if (a16 && b16)
350+
{
351+
for (size_t i = 0; i < szA; i++)
352+
{
353+
cv.f = a[i];
354+
a16[i] = (MKL_BF16)(cv.u >> 16);
355+
}
356+
for (size_t i = 0; i < szB; i++)
357+
{
358+
cv.f = b[i];
359+
b16[i] = (MKL_BF16)(cv.u >> 16);
360+
}
361+
// col-major: C = B16 * A16^T => cblas: NoTrans B, Trans A
362+
cblas_gemm_bf16bf16f32(CblasColMajor, CblasNoTrans, CblasTrans, (MKL_INT)_m, (MKL_INT)_n, (MKL_INT)_k, 1.0f, b16, (MKL_INT)_m,
363+
a16, (MKL_INT)_n, 0.0f, out, (MKL_INT)_m);
364+
mkl_free(a16);
365+
mkl_free(b16);
366+
return;
367+
}
368+
if (a16) mkl_free(a16);
369+
if (b16) mkl_free(b16);
370+
}
371+
}
372+
#endif // !DAAL_REF
310373
BlasInst<FPType, cpu>::xxgemm(&transa, &transb, &_m, &_n, &_k, &alpha, b, &lda, a, &ldy, &beta, out, &ldaty);
311374
}
312375

@@ -329,7 +392,7 @@ class CosineDistances : public EuclideanDistances<FPType, cpu>
329392
public:
330393
CosineDistances(const NumericTable & a, const NumericTable & b) : super(a, b, true, true) {}
331394

332-
~CosineDistances() override {}
395+
virtual ~CosineDistances() override {}
333396

334397
PairwiseDistanceType getType() override { return PairwiseDistanceType::cosine; }
335398

@@ -372,7 +435,7 @@ class MinkowskiDistances : public PairwiseDistances<FPType, cpu>
372435
: _a(a), _b(b), _powered(powered), _p(p)
373436
{}
374437

375-
~MinkowskiDistances() override {}
438+
virtual ~MinkowskiDistances() override {}
376439

377440
PairwiseDistanceType getType() override { return PairwiseDistanceType::minkowski; }
378441

@@ -471,7 +534,7 @@ class ChebyshevDistances : public PairwiseDistances<FPType, cpu>
471534
public:
472535
ChebyshevDistances(const NumericTable & a, const NumericTable & b) : _a(a), _b(b) {}
473536

474-
~ChebyshevDistances() override {}
537+
virtual ~ChebyshevDistances() override {}
475538

476539
PairwiseDistanceType getType() override { return PairwiseDistanceType::chebyshev; }
477540

0 commit comments

Comments
 (0)