Skip to content

Commit acf7592

Browse files
committed
Added HWY_WARN and fallback instead of exiting
1 parent 45708ea commit acf7592

2 files changed

Lines changed: 45 additions & 23 deletions

File tree

ops/brgemm-inl.h

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,25 @@ static bool MakeBrgemm(dnnl::ukernel::brgemm& brg, int64_t m, int64_t n,
5454
try {
5555
brg = dnnl::ukernel::brgemm(m, n, k, batch, lda, ldb, ldc, a_dt, b_dt,
5656
c_dt, true);
57-
if (!brg) return false;
57+
if (!brg) {
58+
HWY_WARN("BRGeMM: kernel creation failed m=%lld n=%lld k=%lld.",
59+
static_cast<long long>(m), static_cast<long long>(n),
60+
static_cast<long long>(k));
61+
return false;
62+
}
5863
brg.set_add_C(add_C);
59-
if (!brg.finalize()) return false;
64+
if (!brg.finalize()) {
65+
HWY_WARN("BRGeMM: kernel finalize failed m=%lld n=%lld k=%lld.",
66+
static_cast<long long>(m), static_cast<long long>(n),
67+
static_cast<long long>(k));
68+
return false;
69+
}
6070
brg.generate();
6171
return true;
6272
} catch (...) {
73+
HWY_WARN("BRGeMM: kernel JIT exception m=%lld n=%lld k=%lld.",
74+
static_cast<long long>(m), static_cast<long long>(n),
75+
static_cast<long long>(k));
6376
return false;
6477
}
6578
}
@@ -295,7 +308,7 @@ static HWY_NOINLINE bool InitBRGeMMKernels(
295308
}
296309

297310
template <typename TA, typename TB, typename TC>
298-
static HWY_NOINLINE void DoMatMul_BRGeMM(
311+
static HWY_NOINLINE bool DoMatMul_BRGeMM(
299312
const MatPtrT<TA>& A, const MatPtrT<TB>& B, RowPtrs<TC> C, size_t M,
300313
size_t K, size_t N, float scale, const float* HWY_RESTRICT add,
301314
const BRGeMMConfig& cfg, ThreadingContext& ctx, size_t cluster_idx) {
@@ -310,7 +323,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
310323
if (kern_it == kern_cache.end()) {
311324
BRGeMMKernelEntry ke;
312325
if (!InitBRGeMMKernels(cfg, M, K, N, A.Stride(), B.Stride(), ke)) {
313-
return;
326+
return false;
314327
}
315328
kern_it = kern_cache.emplace(kern_key, std::move(ke)).first;
316329
}
@@ -344,7 +357,10 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
344357

345358
pe.B_packed_buf.Resize(total_packed);
346359
uint8_t* B_packed = pe.B_packed_buf.data();
347-
if (!B_packed) return;
360+
if (!B_packed) {
361+
HWY_WARN("BRGeMM: packed B allocation failed.");
362+
return false;
363+
}
348364

349365
for (size_t nt = 0; nt < ke.N_total_tiles; ++nt) {
350366
const int ni = (nt < ke.N_full_tiles) ? 0 : 1;
@@ -366,7 +382,8 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
366382
B_packed + pe.B_ktail_offset[nt]);
367383
}
368384
} catch (...) {
369-
return;
385+
HWY_WARN("BRGeMM: B-packing execution failed.");
386+
return false;
370387
}
371388
}
372389
}
@@ -548,6 +565,7 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
548565
dnnl::ukernel::brgemm::release_hw_context();
549566
auto& main_bufs = GetBRGeMMThreadBufs();
550567
main_bufs.hw_ctx_kernel = nullptr;
568+
return true;
551569
}
552570

553571
#endif // GEMMA_ONEDNN_BRGEMM

ops/matmul-inl.h

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,9 +1089,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
10891089
MMAutoTune<BRGeMMConfig>& brg_tuner = per_key.brgemm_autotune;
10901090

10911091
if (HWY_LIKELY(brg_tuner.Best())) {
1092-
DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, *brg_tuner.Best(),
1093-
env.ctx, cluster_idx);
1094-
return &per_key;
1092+
if (DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add,
1093+
*brg_tuner.Best(), env.ctx, cluster_idx)) {
1094+
return &per_key;
1095+
}
1096+
// BRGeMM failed; fall through to standard matmul.
10951097
}
10961098

10971099
if (HWY_UNLIKELY(!brg_tuner.HasCandidates())) {
@@ -1100,21 +1102,23 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
11001102

11011103
const BRGeMMConfig& cfg = brg_tuner.NextConfig();
11021104
const uint64_t t0 = hwy::timer::Start();
1103-
DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, cfg, env.ctx,
1104-
cluster_idx);
1105-
const uint64_t t1 =
1106-
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
1107-
brg_tuner.NotifyTicks(t1 - t0);
1108-
1109-
if (HWY_UNLIKELY(env.print_best && brg_tuner.Best())) {
1110-
const BRGeMMConfig& best = *brg_tuner.Best();
1111-
fprintf(stderr,
1112-
"BRGeMM best: %zux%zux%zu M_blk=%zu N_blk=%zu K_blk=%zu "
1113-
"batch=%zu\n",
1114-
M, K, N, best.M_blk, best.N_blk, best.K_blk,
1115-
best.batch_size);
1105+
if (DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, cfg, env.ctx,
1106+
cluster_idx)) {
1107+
const uint64_t t1 =
1108+
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
1109+
brg_tuner.NotifyTicks(t1 - t0);
1110+
1111+
if (HWY_UNLIKELY(env.print_best && brg_tuner.Best())) {
1112+
const BRGeMMConfig& best = *brg_tuner.Best();
1113+
fprintf(stderr,
1114+
"BRGeMM best: %zux%zux%zu M_blk=%zu N_blk=%zu K_blk=%zu "
1115+
"batch=%zu\n",
1116+
M, K, N, best.M_blk, best.N_blk, best.K_blk,
1117+
best.batch_size);
1118+
}
1119+
return &per_key;
11161120
}
1117-
return &per_key;
1121+
// BRGeMM failed; fall through to standard matmul.
11181122
}
11191123
} // if constexpr BF16/float
11201124
#endif // GEMMA_ONEDNN_BRGEMM

0 commit comments

Comments
 (0)