Skip to content

Commit d3db72b

Browse files
committed
Fixed the compile time flag to designate BRGEMM path
1 parent cd9436d commit d3db72b

6 files changed

Lines changed: 27 additions & 29 deletions

File tree

BUILD.bazel

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ exports_files([
2727
".github/workflows/build.yml",
2828
])
2929

30+
# To enable OneDNN BRGeMM support, build with:
31+
# bazel build --define gemma_onednn_brgemm=1 ...
32+
config_setting(
33+
name = "gemma_onednn_brgemm",
34+
define_values = {"gemma_onednn_brgemm": "1"},
35+
)
36+
3037
cc_library(
3138
name = "basics",
3239
srcs = ["util/basics.cc"],
@@ -318,7 +325,7 @@ cc_library(
318325
"ops/matmul.h",
319326
],
320327
defines = select({
321-
"@platforms//cpu:x86_64": ["GEMMA_ONEDNN=1", "DNNL_EXPERIMENTAL_UKERNEL"],
328+
":gemma_onednn_brgemm": ["GEMMA_ONEDNN_BRGEMM=1", "DNNL_EXPERIMENTAL_UKERNEL"],
322329
"//conditions:default": [],
323330
}),
324331
deps = [
@@ -332,7 +339,7 @@ cc_library(
332339
"@highway//:nanobenchmark",
333340
"@highway//:profiler",
334341
] + select({
335-
"@platforms//cpu:x86_64": ["@onednn//:onednn"],
342+
":gemma_onednn_brgemm": ["@onednn//:onednn"],
336343
"//conditions:default": [],
337344
}),
338345
)
@@ -359,7 +366,7 @@ cc_library(
359366
"@highway//:nanobenchmark",
360367
"@highway//:profiler",
361368
] + select({
362-
"@platforms//cpu:x86_64": ["@onednn//:onednn"],
369+
":gemma_onednn_brgemm": ["@onednn//:onednn"],
363370
"//conditions:default": [],
364371
}),
365372
)
@@ -396,7 +403,7 @@ cc_library(
396403
"@highway//:profiler",
397404
"@highway//:timer",
398405
] + select({
399-
"@platforms//cpu:x86_64": ["@onednn//:onednn"],
406+
":gemma_onednn_brgemm": ["@onednn//:onednn"],
400407
"//conditions:default": [],
401408
}),
402409
)

ops/bench_matmul.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
131131

132132
// Only record times after autotuning finished.
133133
bool done = per_key->autotune.Best();
134-
#if GEMMA_ONEDNN
134+
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
135135
done = done || per_key->brgemm_autotune.Best();
136136
#endif
137137
if (done) times.push_back(elapsed);

ops/brgemm-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
// BRGeMM dispatch. Included from matmul-inl.h inside gcpp::HWY_NAMESPACE.
1717

18-
#if GEMMA_ONEDNN
18+
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
1919

2020
static bool MakeBrgemm(dnnl::ukernel::brgemm& brg, int64_t m, int64_t n,
2121
int64_t k, int64_t batch, int64_t lda, int64_t ldb,
@@ -489,4 +489,4 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
489489
main_bufs.hw_ctx_kernel = nullptr;
490490
}
491491

492-
#endif // GEMMA_ONEDNN
492+
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM

ops/brgemm.h

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
// limitations under the License.
1515

1616
// OneDNN BRGeMM micro-kernel integration for MatMul on Intel AMX/AVX-512.
17-
// Enabled at runtime via GEMMA_USE_ONEDNN_BRGEMM=1.
17+
// Enabled at compile time via GEMMA_ONEDNN_BRGEMM_BRGEMM=1 (Bazel: --define gemma_onednn_brgemm=1).
1818

1919
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_
2020
#define THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_
@@ -23,30 +23,21 @@
2323
#include <stdint.h>
2424

2525
#include <algorithm>
26-
#include <cstdlib>
2726
#include <unordered_map>
2827
#include <utility>
2928
#include <vector>
3029

3130
#include "hwy/base.h"
3231

33-
#if GEMMA_ONEDNN
32+
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
3433
#include <sys/mman.h>
3534

3635
#include "oneapi/dnnl/dnnl.hpp"
3736
#include "oneapi/dnnl/dnnl_ukernel.hpp"
38-
#endif // GEMMA_ONEDNN
37+
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
3938

4039
namespace gcpp {
4140

42-
inline bool UseOneDnnBrgemm() {
43-
static const bool enabled = [] {
44-
const char* env = std::getenv("GEMMA_USE_ONEDNN_BRGEMM");
45-
return env != nullptr && env[0] == '1' && env[1] == '\0';
46-
}();
47-
return enabled;
48-
}
49-
5041
struct BRGeMMConfig {
5142
int64_t M_blk;
5243
int64_t N_blk;
@@ -55,7 +46,7 @@ struct BRGeMMConfig {
5546
int64_t par_m;
5647
};
5748

58-
#if GEMMA_ONEDNN
49+
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
5950

6051
// Generates autotuning candidates. Fixed: N_blk=32, K_blk=32 (AMX BF16).
6152
// Tunable: M_blk in {32,64}, batch_size in {16,32,64,128,256}.
@@ -290,7 +281,7 @@ inline auto& GetBRGeMMPackedBCache() {
290281
return cache;
291282
}
292283

293-
#endif // GEMMA_ONEDNN
284+
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
294285

295286
} // namespace gcpp
296287

ops/matmul-inl.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ namespace gcpp {
4747
namespace HWY_NAMESPACE {
4848
namespace hn = hwy::HWY_NAMESPACE;
4949

50-
#if GEMMA_ONEDNN
50+
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
5151
#include "ops/brgemm-inl.h" // DoMatMul_BRGeMM
52-
#endif // GEMMA_ONEDNN
52+
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
5353

5454
// Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register.
5555
template <class DF, class DBF = hn::Repartition<BF16, DF>>
@@ -1081,11 +1081,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
10811081
MMPerKey& per_key = MMImpl::FindOrAddPerKey(
10821082
M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]);
10831083

1084-
#if GEMMA_ONEDNN
1084+
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
10851085
// BRGeMM path for BF16×BF16 on Intel AMX/AVX-512.
10861086
// Requires M,N,K >= 32 and K % 32 == 0 (AMX tile constraint).
10871087
if constexpr (IsBF16<TA>() && IsBF16<TB>()) {
1088-
if (UseOneDnnBrgemm() && M >= 32 && N >= 32 && K >= 32 && (K % 32) == 0) {
1088+
if (M >= 32 && N >= 32 && K >= 32 && (K % 32) == 0) {
10891089
const float scale = A.Scale() * B.Scale();
10901090
MMAutoTune<BRGeMMConfig>& brg_tuner = per_key.brgemm_autotune;
10911091

@@ -1119,7 +1119,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
11191119
return &per_key;
11201120
}
11211121
} // if constexpr BF16/float
1122-
#endif // GEMMA_ONEDNN
1122+
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
11231123

11241124
// (Also auto-tunes, hence outside the timed section to prevent interference.)
11251125
const StridedViewBF A_view =

ops/matmul.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
#include "hwy/base.h"
3333
#include "hwy/bit_set.h"
3434
#include "hwy/profiler.h"
35-
#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN
35+
#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN_BRGEMM_BRGEMM
3636
// IWYU pragma: end_exports
3737

3838
namespace gcpp {
@@ -640,9 +640,9 @@ class MMKeys {
640640
struct MMPerKey {
641641
MMAutoTune<MMConfig> autotune;
642642
MMAutoTune<MMParA> autotune_par_a;
643-
#if GEMMA_ONEDNN
643+
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
644644
MMAutoTune<BRGeMMConfig> brgemm_autotune;
645-
#endif // GEMMA_ONEDNN
645+
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
646646
};
647647

648648
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive

0 commit comments

Comments
 (0)