Skip to content

Commit f7b2d1d

Browse files
Merge pull request #903 from Intel-tensorflow:feature/onednn-brgemm
PiperOrigin-RevId: 925952741
2 parents 4013e52 + c3787f8 commit f7b2d1d

11 files changed

Lines changed: 1213 additions & 6 deletions

File tree

BUILD.bazel

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ exports_files([
2727
".github/workflows/build.yml",
2828
])
2929

30+
# To enable OneDNN BRGeMM support, build with:
31+
# bazel build --define gemma_onednn_brgemm=1 ...
32+
config_setting(
33+
name = "gemma_onednn_brgemm",
34+
define_values = {"gemma_onednn_brgemm": "1"},
35+
)
36+
3037
cc_library(
3138
name = "basics",
3239
srcs = ["util/basics.cc"],
@@ -314,7 +321,17 @@ test_suite(
314321
cc_library(
315322
name = "matmul_env",
316323
srcs = ["ops/matmul.cc"],
317-
hdrs = ["ops/matmul.h"],
324+
hdrs = [
325+
"ops/brgemm.h",
326+
"ops/matmul.h",
327+
],
328+
defines = select({
329+
":gemma_onednn_brgemm": [
330+
"GEMMA_ONEDNN_BRGEMM=1",
331+
"DNNL_EXPERIMENTAL_UKERNEL",
332+
],
333+
"//conditions:default": [],
334+
}),
318335
deps = [
319336
":allocator",
320337
":basics",
@@ -325,14 +342,20 @@ cc_library(
325342
"@highway//:hwy",
326343
"@highway//:nanobenchmark",
327344
"@highway//:profiler",
328-
],
345+
] + select({
346+
":gemma_onednn_brgemm": ["@onednn"],
347+
"//conditions:default": [],
348+
}),
329349
)
330350

331351
cc_library(
332352
name = "matmul",
333353
# allow depending only on this target, without also matmul_env.
334354
hdrs = ["ops/matmul.h"],
335-
textual_hdrs = ["ops/matmul-inl.h"],
355+
textual_hdrs = [
356+
"ops/brgemm-inl.h",
357+
"ops/matmul-inl.h",
358+
],
336359
deps = [
337360
":allocator",
338361
":basics",
@@ -346,7 +369,10 @@ cc_library(
346369
"@highway//:hwy",
347370
"@highway//:nanobenchmark",
348371
"@highway//:profiler",
349-
],
372+
] + select({
373+
":gemma_onednn_brgemm": ["@onednn"],
374+
"//conditions:default": [],
375+
}),
350376
)
351377

352378
cc_library(
@@ -363,6 +389,7 @@ cc_library(
363389
"ops/matmul_static.h",
364390
],
365391
textual_hdrs = [
392+
"ops/brgemm-inl.h",
366393
"ops/matmul_static-inl.h",
367394
"ops/matmul-inl.h",
368395
],
@@ -379,7 +406,10 @@ cc_library(
379406
"@highway//:hwy",
380407
"@highway//:profiler",
381408
"@highway//:timer",
382-
],
409+
] + select({
410+
":gemma_onednn_brgemm": ["@onednn"],
411+
"//conditions:default": [],
412+
}),
383413
)
384414

385415
cc_library(

CMakeLists.txt

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ set(CMAKE_CXX_STANDARD 17)
2222
set(CMAKE_CXX_STANDARD_REQUIRED ON)
2323
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
2424

25+
# Optional: OneDNN BRGeMM micro-kernel support (x86-64 only).
26+
# Enable with: cmake -DGEMMA_ONEDNN_BRGEMM=ON ...
27+
option(GEMMA_ONEDNN_BRGEMM "Enable OneDNN BRGeMM micro-kernel for MatMul (x86-64)" OFF)
28+
2529
if(EMSCRIPTEN)
2630
add_compile_options("-sMEMORY64")
2731
add_compile_options("-msimd128")
@@ -85,6 +89,23 @@ if(EMSCRIPTEN)
8589
target_compile_options(benchmark PRIVATE -Wno-c2y-extensions)
8690
endif()
8791

92+
# OneDNN BRGeMM micro-kernel support (optional, x86-64 only).
93+
if(GEMMA_ONEDNN_BRGEMM)
94+
set(DNNL_BUILD_TESTS OFF CACHE BOOL "" FORCE)
95+
set(DNNL_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
96+
set(DNNL_CPU_RUNTIME "SEQ" CACHE STRING "" FORCE)
97+
set(DNNL_GPU_RUNTIME "NONE" CACHE STRING "" FORCE)
98+
set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "" FORCE)
99+
set(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE)
100+
FetchContent_Declare(onednn
101+
GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN.git
102+
GIT_TAG v3.11
103+
EXCLUDE_FROM_ALL
104+
)
105+
FetchContent_MakeAvailable(onednn)
106+
message(STATUS "OneDNN BRGeMM micro-kernel support enabled")
107+
endif()
108+
88109
# Base source files
89110
set(SOURCES
90111
compression/compress-inl.h
@@ -143,6 +164,8 @@ set(SOURCES
143164
ops/matmul-inl.h
144165
ops/matmul.cc
145166
ops/matmul.h
167+
ops/brgemm.h
168+
ops/brgemm-inl.h
146169
ops/ops-inl.h
147170
ops/ops.h
148171
ops/sum-inl.h
@@ -195,6 +218,10 @@ target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static)
195218
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
196219
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
197220
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
221+
if(GEMMA_ONEDNN_BRGEMM)
222+
target_compile_definitions(libgemma PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL)
223+
target_link_libraries(libgemma dnnl)
224+
endif()
198225
install(TARGETS libgemma DESTINATION lib)
199226

200227
# Shared library target for C# interop
@@ -219,6 +246,10 @@ target_compile_definitions(gemma_shared
219246
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
220247
)
221248
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
249+
if(GEMMA_ONEDNN_BRGEMM)
250+
target_compile_definitions(gemma_shared PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL)
251+
target_link_libraries(gemma_shared PRIVATE dnnl)
252+
endif()
222253
install(TARGETS gemma_shared DESTINATION lib)
223254
install(FILES gemma/c_api.h DESTINATION include/gemma)
224255
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)

MODULE.bazel

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ git_override(
2525

2626
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
2727

28+
# OneDNN v3.11 for BRGeMM micro-kernel support (optional, x86-64 only).
29+
http_archive(
30+
name = "onednn",
31+
build_file = "@//bazel:onednn.BUILD",
32+
sha256 = "04df98b18300daf6c3aa7cc2d5e7ce8a8f430fed1787151daed0254d8dd4e64e",
33+
strip_prefix = "oneDNN-3.11",
34+
urls = [
35+
"https://github.com/uxlfoundation/oneDNN/archive/refs/tags/v3.11.tar.gz",
36+
],
37+
)
38+
2839
http_archive(
2940
name = "com_google_absl_py",
3041
sha256 = "8a3d0830e4eb4f66c4fa907c06edf6ce1c719ced811a12e26d9d3162f8471758",

bazel/onednn.BUILD

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
2+
3+
exports_files(["LICENSE"])
4+
5+
expand_template(
6+
name = "dnnl_config_h",
7+
out = "include/oneapi/dnnl/dnnl_config.h",
8+
substitutions = {
9+
"#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "#define DNNL_EXPERIMENTAL_UKERNEL 1",
10+
"#cmakedefine DNNL_SAFE_RBP": "#undef DNNL_SAFE_RBP",
11+
"#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_SEQ",
12+
"#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_SEQ",
13+
"#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#define DNNL_DISABLE_GPU_REF_KERNELS",
14+
"#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
15+
"#cmakedefine DNNL_GPU_VENDOR DNNL_VENDOR_${DNNL_GPU_VENDOR}": "#define DNNL_GPU_VENDOR DNNL_VENDOR_NONE",
16+
"#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "#undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE",
17+
"#cmakedefine DNNL_WITH_SYCL": "#undef DNNL_WITH_SYCL",
18+
"#cmakedefine DNNL_WITH_LEVEL_ZERO": "#undef DNNL_WITH_LEVEL_ZERO",
19+
"#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA",
20+
"#cmakedefine DNNL_SYCL_GENERIC": "#undef DNNL_SYCL_GENERIC",
21+
"#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP",
22+
"#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
23+
"#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH",
24+
"#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE",
25+
"#cmakedefine DNNL_EXPERIMENTAL_LOGGING": "#undef DNNL_EXPERIMENTAL_LOGGING",
26+
"#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING",
27+
"#cmakedefine DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER": "#undef DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER",
28+
"#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
29+
"#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
30+
"#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0",
31+
"#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1",
32+
"#cmakedefine01 BUILD_BATCH_NORMALIZATION": "#define BUILD_BATCH_NORMALIZATION 0",
33+
"#cmakedefine01 BUILD_BINARY": "#define BUILD_BINARY 0",
34+
"#cmakedefine01 BUILD_CONCAT": "#define BUILD_CONCAT 0",
35+
"#cmakedefine01 BUILD_CONVOLUTION": "#define BUILD_CONVOLUTION 0",
36+
"#cmakedefine01 BUILD_DECONVOLUTION": "#define BUILD_DECONVOLUTION 0",
37+
"#cmakedefine01 BUILD_ELTWISE": "#define BUILD_ELTWISE 0",
38+
"#cmakedefine01 BUILD_GEMM_KERNELS_ALL": "#define BUILD_GEMM_KERNELS_ALL 1",
39+
"#cmakedefine01 BUILD_GEMM_KERNELS_NONE": "#define BUILD_GEMM_KERNELS_NONE 0",
40+
"#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 1",
41+
"#cmakedefine01 BUILD_GEMM_AVX2": "#define BUILD_GEMM_AVX2 1",
42+
"#cmakedefine01 BUILD_GEMM_AVX512": "#define BUILD_GEMM_AVX512 1",
43+
"#cmakedefine01 BUILD_GROUP_NORMALIZATION": "#define BUILD_GROUP_NORMALIZATION 1",
44+
"#cmakedefine01 BUILD_INNER_PRODUCT": "#define BUILD_INNER_PRODUCT 0",
45+
"#cmakedefine01 BUILD_LAYER_NORMALIZATION": "#define BUILD_LAYER_NORMALIZATION 0",
46+
"#cmakedefine01 BUILD_LRN": "#define BUILD_LRN 0",
47+
"#cmakedefine01 BUILD_MATMUL": "#define BUILD_MATMUL 0",
48+
"#cmakedefine01 BUILD_POOLING": "#define BUILD_POOLING 0",
49+
"#cmakedefine01 BUILD_PRELU": "#define BUILD_PRELU 0",
50+
"#cmakedefine01 BUILD_REDUCTION": "#define BUILD_REDUCTION 0",
51+
"#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0",
52+
"#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0",
53+
"#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0",
54+
"#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0",
55+
"#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0",
56+
"#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0",
57+
"#cmakedefine01 BUILD_PRIMITIVE_CPU_ISA_ALL": "#define BUILD_PRIMITIVE_CPU_ISA_ALL 1",
58+
"#cmakedefine01 BUILD_SSE41": "#define BUILD_SSE41 0",
59+
"#cmakedefine01 BUILD_AVX2": "#define BUILD_AVX2 0",
60+
"#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0",
61+
"#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0",
62+
"#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 0",
63+
"#cmakedefine01 BUILD_XE2": "#define BUILD_XE2 0",
64+
"#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0",
65+
"#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0",
66+
"#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0",
67+
"#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0",
68+
"#cmakedefine01 BUILD_SDPA": "#define BUILD_SDPA 1",
69+
"#cmakedefine01 BUILD_XE3": "#define BUILD_XE3 0",
70+
},
71+
template = "include/oneapi/dnnl/dnnl_config.h.in",
72+
)
73+
74+
expand_template(
75+
name = "dnnl_version_h",
76+
out = "include/oneapi/dnnl/dnnl_version.h",
77+
substitutions = {
78+
"@DNNL_VERSION_MAJOR@": "3",
79+
"@DNNL_VERSION_MINOR@": "11",
80+
"@DNNL_VERSION_PATCH@": "0",
81+
},
82+
template = "include/oneapi/dnnl/dnnl_version.h.in",
83+
)
84+
85+
expand_template(
86+
name = "dnnl_version_hash_h",
87+
out = "include/oneapi/dnnl/dnnl_version_hash.h",
88+
substitutions = {
89+
"@DNNL_VERSION_HASH@": "fc6151651a4577beae5ffac5a4132e75d39e1409",
90+
},
91+
template = "include/oneapi/dnnl/dnnl_version_hash.h.in",
92+
)
93+
94+
cc_library(
95+
name = "onednn_autogen",
96+
srcs = glob(["src/cpu/x64/gemm/**/*_kern_autogen*.cpp"]),
97+
copts = [
98+
"-O1",
99+
"-U_FORTIFY_SOURCE",
100+
"-fexceptions",
101+
"-UUSE_MKL",
102+
"-UUSE_CBLAS",
103+
"-DDNNL_ENABLE_MAX_CPU_ISA",
104+
"-DDNNL_ENABLE_ITT_TASKS",
105+
"-DDNNL_ENABLE_GRAPH_DUMP",
106+
"-DDNNL_EXPERIMENTAL_UKERNEL",
107+
],
108+
includes = [
109+
"include",
110+
"src",
111+
"src/common",
112+
"src/cpu",
113+
"src/cpu/gemm",
114+
"src/graph",
115+
"third_party",
116+
"third_party/ittnotify",
117+
"third_party/xbyak",
118+
],
119+
textual_hdrs = glob([
120+
"include/**/*",
121+
"src/common/*.hpp",
122+
"src/cpu/*.hpp",
123+
"src/cpu/**/*.hpp",
124+
"src/cpu/jit_utils/**/*.hpp",
125+
"src/graph/interface/*.hpp",
126+
"src/graph/backend/*.hpp",
127+
"src/graph/backend/dnnl/*.hpp",
128+
"src/graph/backend/dnnl/executables/*.hpp",
129+
"src/graph/backend/fake/*.hpp",
130+
"src/graph/backend/dnnl/passes/*.hpp",
131+
"src/graph/backend/dnnl/patterns/*.hpp",
132+
"src/graph/backend/dnnl/kernels/*.hpp",
133+
"src/graph/utils/*.hpp",
134+
"src/graph/utils/pm/*.hpp",
135+
"third_party/ittnotify/**/*.h",
136+
"third_party/spdlog/**/*.h",
137+
"third_party/xbyak/*.h",
138+
]) + [
139+
":dnnl_config_h",
140+
":dnnl_version_h",
141+
":dnnl_version_hash_h",
142+
],
143+
visibility = ["//visibility:public"],
144+
)
145+
146+
cc_library(
147+
name = "onednn",
148+
srcs = glob(
149+
[
150+
"src/common/*.cpp",
151+
"src/cpu/*.cpp",
152+
"src/cpu/**/*.cpp",
153+
"src/cpu/jit_utils/**/*.cpp",
154+
"src/cpu/x64/**/*.cpp",
155+
"src/graph/interface/*.cpp",
156+
"src/graph/backend/*.cpp",
157+
"src/graph/backend/dnnl/*.cpp",
158+
"src/graph/backend/dnnl/executables/*.cpp",
159+
"src/graph/backend/fake/*.cpp",
160+
"src/graph/backend/dnnl/passes/*.cpp",
161+
"src/graph/backend/dnnl/patterns/*.cpp",
162+
"src/graph/backend/dnnl/kernels/*.cpp",
163+
"src/graph/utils/*.cpp",
164+
"src/graph/utils/pm/*.cpp",
165+
"third_party/ittnotify/*.c",
166+
],
167+
exclude = [
168+
"src/cpu/aarch64/**",
169+
"src/cpu/rv64/**",
170+
"src/cpu/ppc64/**",
171+
"src/cpu/s390x/**",
172+
"src/cpu/x64/gemm/**/*_kern_autogen.cpp",
173+
"src/cpu/sycl/**",
174+
],
175+
),
176+
copts = [
177+
"-fexceptions",
178+
"-UUSE_MKL",
179+
"-UUSE_CBLAS",
180+
"-DDNNL_ENABLE_MAX_CPU_ISA",
181+
"-DDNNL_ENABLE_ITT_TASKS",
182+
"-DDNNL_ENABLE_GRAPH_DUMP",
183+
"-DDNNL_EXPERIMENTAL_UKERNEL",
184+
],
185+
includes = [
186+
"include",
187+
"src",
188+
"src/common",
189+
"src/cpu",
190+
"src/cpu/gemm",
191+
"src/graph",
192+
"third_party",
193+
"third_party/ittnotify",
194+
"third_party/xbyak",
195+
],
196+
linkopts = [
197+
"-lrt",
198+
"-Wl,--allow-multiple-definition",
199+
],
200+
textual_hdrs = glob([
201+
"include/**/*",
202+
"src/common/*.hpp",
203+
"src/cpu/*.hpp",
204+
"src/cpu/**/*.hpp",
205+
"src/cpu/jit_utils/**/*.hpp",
206+
"src/graph/interface/*.hpp",
207+
"src/graph/backend/*.hpp",
208+
"src/graph/backend/dnnl/*.hpp",
209+
"src/graph/backend/fake/*.hpp",
210+
"src/graph/backend/dnnl/passes/*.hpp",
211+
"src/graph/backend/dnnl/patterns/*.hpp",
212+
"src/graph/backend/dnnl/kernels/*.hpp",
213+
"src/graph/utils/*.hpp",
214+
"src/graph/utils/pm/*.hpp",
215+
"third_party/ittnotify/**/*.h",
216+
"third_party/spdlog/**/*.h",
217+
"third_party/xbyak/*.h",
218+
]) + [
219+
":dnnl_config_h",
220+
":dnnl_version_h",
221+
":dnnl_version_hash_h",
222+
],
223+
visibility = ["//visibility:public"],
224+
deps = [
225+
":onednn_autogen",
226+
],
227+
)

0 commit comments

Comments
 (0)