Skip to content

Commit 629b569

Browse files
committed
Adding the cmake based build support for oneDNN BGGeMM
1 parent d3db72b commit 629b569

6 files changed

Lines changed: 46 additions & 15 deletions

File tree

CMakeLists.txt

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ set(CMAKE_CXX_STANDARD 17)
2222
set(CMAKE_CXX_STANDARD_REQUIRED ON)
2323
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
2424

25+
# Optional: OneDNN BRGeMM micro-kernel support (x86-64 only).
26+
# Enable with: cmake -DGEMMA_ONEDNN_BRGEMM=ON ...
27+
option(GEMMA_ONEDNN_BRGEMM "Enable OneDNN BRGeMM micro-kernel for MatMul (x86-64)" OFF)
28+
2529
if(EMSCRIPTEN)
2630
add_compile_options("-sMEMORY64")
2731
add_compile_options("-msimd128")
@@ -85,6 +89,23 @@ if(EMSCRIPTEN)
8589
target_compile_options(benchmark PRIVATE -Wno-c2y-extensions)
8690
endif()
8791

92+
# OneDNN BRGeMM micro-kernel support (optional, x86-64 only).
93+
if(GEMMA_ONEDNN_BRGEMM)
94+
set(DNNL_BUILD_TESTS OFF CACHE BOOL "" FORCE)
95+
set(DNNL_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
96+
set(DNNL_CPU_RUNTIME "SEQ" CACHE STRING "" FORCE)
97+
set(DNNL_GPU_RUNTIME "NONE" CACHE STRING "" FORCE)
98+
set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "" FORCE)
99+
set(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE)
100+
FetchContent_Declare(onednn
101+
GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN.git
102+
GIT_TAG v3.11
103+
EXCLUDE_FROM_ALL
104+
)
105+
FetchContent_MakeAvailable(onednn)
106+
message(STATUS "OneDNN BRGeMM micro-kernel support enabled")
107+
endif()
108+
88109
# Base source files
89110
set(SOURCES
90111
compression/compress-inl.h
@@ -141,6 +162,8 @@ set(SOURCES
141162
ops/matmul-inl.h
142163
ops/matmul.cc
143164
ops/matmul.h
165+
ops/brgemm.h
166+
ops/brgemm-inl.h
144167
ops/ops-inl.h
145168
ops/ops.h
146169
ops/sum-inl.h
@@ -191,6 +214,10 @@ target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static)
191214
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
192215
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
193216
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
217+
if(GEMMA_ONEDNN_BRGEMM)
218+
target_compile_definitions(libgemma PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL)
219+
target_link_libraries(libgemma dnnl)
220+
endif()
194221
install(TARGETS libgemma DESTINATION lib)
195222

196223
# Shared library target for C# interop
@@ -215,6 +242,10 @@ target_compile_definitions(gemma_shared
215242
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
216243
)
217244
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
245+
if(GEMMA_ONEDNN_BRGEMM)
246+
target_compile_definitions(gemma_shared PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL)
247+
target_link_libraries(gemma_shared PRIVATE dnnl)
248+
endif()
218249
install(TARGETS gemma_shared DESTINATION lib)
219250
install(FILES gemma/c_api.h DESTINATION include/gemma)
220251
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)

ops/bench_matmul.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
131131

132132
// Only record times after autotuning finished.
133133
bool done = per_key->autotune.Best();
134-
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
134+
#if GEMMA_ONEDNN_BRGEMM
135135
done = done || per_key->brgemm_autotune.Best();
136136
#endif
137137
if (done) times.push_back(elapsed);

ops/brgemm-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
// BRGeMM dispatch. Included from matmul-inl.h inside gcpp::HWY_NAMESPACE.
1717

18-
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
18+
#if GEMMA_ONEDNN_BRGEMM
1919

2020
static bool MakeBrgemm(dnnl::ukernel::brgemm& brg, int64_t m, int64_t n,
2121
int64_t k, int64_t batch, int64_t lda, int64_t ldb,
@@ -489,4 +489,4 @@ static HWY_NOINLINE void DoMatMul_BRGeMM(
489489
main_bufs.hw_ctx_kernel = nullptr;
490490
}
491491

492-
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
492+
#endif // GEMMA_ONEDNN_BRGEMM

ops/brgemm.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
// limitations under the License.
1515

1616
// OneDNN BRGeMM micro-kernel integration for MatMul on Intel AMX/AVX-512.
17-
// Enabled at compile time via GEMMA_ONEDNN_BRGEMM_BRGEMM=1 (Bazel: --define gemma_onednn_brgemm=1).
17+
// Enabled at compile time via GEMMA_ONEDNN_BRGEMM=1 (Bazel: --define gemma_onednn_brgemm=1).
1818

1919
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_
2020
#define THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_
@@ -29,12 +29,12 @@
2929

3030
#include "hwy/base.h"
3131

32-
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
32+
#if GEMMA_ONEDNN_BRGEMM
3333
#include <sys/mman.h>
3434

3535
#include "oneapi/dnnl/dnnl.hpp"
3636
#include "oneapi/dnnl/dnnl_ukernel.hpp"
37-
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
37+
#endif // GEMMA_ONEDNN_BRGEMM
3838

3939
namespace gcpp {
4040

@@ -46,7 +46,7 @@ struct BRGeMMConfig {
4646
int64_t par_m;
4747
};
4848

49-
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
49+
#if GEMMA_ONEDNN_BRGEMM
5050

5151
// Generates autotuning candidates. Fixed: N_blk=32, K_blk=32 (AMX BF16).
5252
// Tunable: M_blk in {32,64}, batch_size in {16,32,64,128,256}.
@@ -281,7 +281,7 @@ inline auto& GetBRGeMMPackedBCache() {
281281
return cache;
282282
}
283283

284-
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
284+
#endif // GEMMA_ONEDNN_BRGEMM
285285

286286
} // namespace gcpp
287287

ops/matmul-inl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ namespace gcpp {
4747
namespace HWY_NAMESPACE {
4848
namespace hn = hwy::HWY_NAMESPACE;
4949

50-
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
50+
#if GEMMA_ONEDNN_BRGEMM
5151
#include "ops/brgemm-inl.h" // DoMatMul_BRGeMM
52-
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
52+
#endif // GEMMA_ONEDNN_BRGEMM
5353

5454
// Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register.
5555
template <class DF, class DBF = hn::Repartition<BF16, DF>>
@@ -1081,7 +1081,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
10811081
MMPerKey& per_key = MMImpl::FindOrAddPerKey(
10821082
M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]);
10831083

1084-
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
1084+
#if GEMMA_ONEDNN_BRGEMM
10851085
// BRGeMM path for BF16×BF16 on Intel AMX/AVX-512.
10861086
// Requires M,N,K >= 32 and K % 32 == 0 (AMX tile constraint).
10871087
if constexpr (IsBF16<TA>() && IsBF16<TB>()) {
@@ -1119,7 +1119,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
11191119
return &per_key;
11201120
}
11211121
} // if constexpr BF16/float
1122-
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
1122+
#endif // GEMMA_ONEDNN_BRGEMM
11231123

11241124
// (Also auto-tunes, hence outside the timed section to prevent interference.)
11251125
const StridedViewBF A_view =

ops/matmul.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
#include "hwy/base.h"
3333
#include "hwy/bit_set.h"
3434
#include "hwy/profiler.h"
35-
#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN_BRGEMM_BRGEMM
35+
#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN_BRGEMM
3636
// IWYU pragma: end_exports
3737

3838
namespace gcpp {
@@ -640,9 +640,9 @@ class MMKeys {
640640
struct MMPerKey {
641641
MMAutoTune<MMConfig> autotune;
642642
MMAutoTune<MMParA> autotune_par_a;
643-
#if GEMMA_ONEDNN_BRGEMM_BRGEMM
643+
#if GEMMA_ONEDNN_BRGEMM
644644
MMAutoTune<BRGeMMConfig> brgemm_autotune;
645-
#endif // GEMMA_ONEDNN_BRGEMM_BRGEMM
645+
#endif // GEMMA_ONEDNN_BRGEMM
646646
};
647647

648648
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive

0 commit comments

Comments
 (0)