void cpu_gemm_blas_bf16(
const bfloat16* A, const bfloat16* B,
size_t M, size_t K, size_t N,
bfloat16* C,
bool transA, bool transB)
{
for (int i = 0; i < M * K; i++) {
printf("A[%d] = %.6f\n", i, static_cast<float>(A[i]));
}
for (int i = 0; i < K * N; i++) {
printf("B[%d] = %.6f\n", i, static_cast<float>(B[i]));
}
std::vector<float> C_float(M * N, 0.0f);
blasint lda = transA ? M : K;
blasint ldb = transB ? K : N;
blasint ldc = N;
cblas_sbgemm(
CblasRowMajor,
transA ? CblasTrans : CblasNoTrans,
transB ? CblasTrans : CblasNoTrans,
M,
N,
K,
1.0f,
A,
lda,
B,
ldb,
0.0f,
C_float.data(),
ldc
);
for (size_t i = 0; i < M * N; ++i) {
printf("C_float[%zu] = %.6f\n", i, C_float[i]);
}
for (size_t i = 0; i < M * N; ++i) {
C[i] = static_cast<bfloat16>(C_float[i]);
}
}
void test_bf16() {
const int M = 2, K = 3, N = 2;
std::vector<bfloat16> A(M * K);
std::vector<bfloat16> B(N * K);
std::vector<bfloat16> C(M * N);
float A_values[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
for (int i = 0; i < M * K; ++i) A[i] = static_cast<bfloat16>(A_values[i]);
float B_values[] = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
for (int i = 0; i < N * K; ++i) B[i] = static_cast<bfloat16>(B_values[i]);
std::vector<float> C_float(M * N);
cpu_gemm_blas_bf16(
A.data(), B.data(),
M, K, N,
C.data(),
false, true
);
printf("BF16 GEMM Result C (2x2):\n");
for (int i = 0; i < M * N; ++i) {
printf("%.2f ", static_cast<float>(C[i]));
}
printf("\n");
}
A[0] = 1.000000
A[1] = 2.000000
A[2] = 3.000000
A[3] = 4.000000
A[4] = 5.000000
A[5] = 6.000000
B[0] = 7.000000
B[1] = 8.000000
B[2] = 9.000000
B[3] = 10.000000
B[4] = 11.000000
B[5] = 12.000000
C_float[0] = 0.000000
C_float[1] = 0.000000
C_float[2] = 0.000000
C_float[3] = 0.000000
BF16 GEMM Result C (2x2):
0.00 0.00 0.00 0.00
The compilation options added
BUILD_BFLOAT16=1. Checking/proc/cpuinfoalso supports the bf16 format, but the final output matrix is all zeros.And the result is below:
I have two questions:
bf16, what caused the result to be all zeros?bf16, does OpenBLAS support performingfp16matrix multiplication calculations onARM64chips?