Skip to content

Commit 09ddbf4

Browse files
committed
Tested and benchmarked OneDNN BRGeMM integration against dev branch
1 parent a29e2fc commit 09ddbf4

10 files changed

Lines changed: 1110 additions & 6 deletions

File tree

BUILD.bazel

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,14 @@ test_suite(
313313
cc_library(
314314
name = "matmul_env",
315315
srcs = ["ops/matmul.cc"],
316-
hdrs = ["ops/matmul.h"],
316+
hdrs = [
317+
"ops/brgemm.h",
318+
"ops/matmul.h",
319+
],
320+
defines = select({
321+
"@platforms//cpu:x86_64": ["GEMMA_ONEDNN=1", "DNNL_EXPERIMENTAL_UKERNEL"],
322+
"//conditions:default": [],
323+
}),
317324
deps = [
318325
":allocator",
319326
":basics",
@@ -324,14 +331,20 @@ cc_library(
324331
"@highway//:hwy",
325332
"@highway//:nanobenchmark",
326333
"@highway//:profiler",
327-
],
334+
] + select({
335+
"@platforms//cpu:x86_64": ["@onednn//:onednn"],
336+
"//conditions:default": [],
337+
}),
328338
)
329339

330340
cc_library(
331341
name = "matmul",
332342
# allow depending only on this target, without also matmul_env.
333343
hdrs = ["ops/matmul.h"],
334-
textual_hdrs = ["ops/matmul-inl.h"],
344+
textual_hdrs = [
345+
"ops/brgemm-inl.h",
346+
"ops/matmul-inl.h",
347+
],
335348
deps = [
336349
":allocator",
337350
":basics",
@@ -345,7 +358,10 @@ cc_library(
345358
"@highway//:hwy",
346359
"@highway//:nanobenchmark",
347360
"@highway//:profiler",
348-
],
361+
] + select({
362+
"@platforms//cpu:x86_64": ["@onednn//:onednn"],
363+
"//conditions:default": [],
364+
}),
349365
)
350366

351367
cc_library(
@@ -362,6 +378,7 @@ cc_library(
362378
"ops/matmul_static.h",
363379
],
364380
textual_hdrs = [
381+
"ops/brgemm-inl.h",
365382
"ops/matmul_static-inl.h",
366383
"ops/matmul-inl.h",
367384
],
@@ -378,7 +395,10 @@ cc_library(
378395
"@highway//:hwy",
379396
"@highway//:profiler",
380397
"@highway//:timer",
381-
],
398+
] + select({
399+
"@platforms//cpu:x86_64": ["@onednn//:onednn"],
400+
"//conditions:default": [],
401+
}),
382402
)
383403

384404
cc_library(

MODULE.bazel

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ bazel_dep(name = "bazel_skylib", version = "1.8.1")
88
bazel_dep(name = "googletest", version = "1.17.0")
99
bazel_dep(name = "highway", version = "1.1.0")
1010
bazel_dep(name = "nlohmann_json", version = "3.11.3")
11+
bazel_dep(name = "onetbb", version = "2021.13.0")
1112
bazel_dep(name = "protobuf", version = "33.4")
1213
bazel_dep(name = "platforms", version = "1.0.0")
1314
bazel_dep(name = "pybind11_bazel", version = "2.13.6")
@@ -25,6 +26,17 @@ git_override(
2526

2627
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
2728

29+
# OneDNN v3.11 for BRGeMM micro-kernel support (optional, x86-64 only).
30+
http_archive(
31+
name = "onednn",
32+
build_file = "@//bazel:onednn.BUILD",
33+
sha256 = "04df98b18300daf6c3aa7cc2d5e7ce8a8f430fed1787151daed0254d8dd4e64e",
34+
strip_prefix = "oneDNN-3.11",
35+
urls = [
36+
"https://github.com/uxlfoundation/oneDNN/archive/refs/tags/v3.11.tar.gz",
37+
],
38+
)
39+
2840
http_archive(
2941
name = "com_google_absl_py",
3042
sha256 = "8a3d0830e4eb4f66c4fa907c06edf6ce1c719ced811a12e26d9d3162f8471758",

bazel/onednn.BUILD

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
2+
3+
exports_files(["LICENSE"])
4+
5+
expand_template(
6+
name = "dnnl_config_h",
7+
out = "include/oneapi/dnnl/dnnl_config.h",
8+
substitutions = {
9+
"#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "#define DNNL_EXPERIMENTAL_UKERNEL 1",
10+
"#cmakedefine DNNL_SAFE_RBP": "#undef DNNL_SAFE_RBP",
11+
"#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_TBB",
12+
"#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_TBB",
13+
"#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#define DNNL_DISABLE_GPU_REF_KERNELS",
14+
"#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
15+
"#cmakedefine DNNL_GPU_VENDOR DNNL_VENDOR_${DNNL_GPU_VENDOR}": "#define DNNL_GPU_VENDOR DNNL_VENDOR_NONE",
16+
"#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "#undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE",
17+
"#cmakedefine DNNL_WITH_SYCL": "#undef DNNL_WITH_SYCL",
18+
"#cmakedefine DNNL_WITH_LEVEL_ZERO": "#undef DNNL_WITH_LEVEL_ZERO",
19+
"#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA",
20+
"#cmakedefine DNNL_SYCL_GENERIC": "#undef DNNL_SYCL_GENERIC",
21+
"#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP",
22+
"#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
23+
"#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH",
24+
"#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE",
25+
"#cmakedefine DNNL_EXPERIMENTAL_LOGGING": "#undef DNNL_EXPERIMENTAL_LOGGING",
26+
"#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING",
27+
"#cmakedefine DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER": "#undef DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER",
28+
"#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
29+
"#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
30+
"#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0",
31+
"#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1",
32+
"#cmakedefine01 BUILD_BATCH_NORMALIZATION": "#define BUILD_BATCH_NORMALIZATION 0",
33+
"#cmakedefine01 BUILD_BINARY": "#define BUILD_BINARY 0",
34+
"#cmakedefine01 BUILD_CONCAT": "#define BUILD_CONCAT 0",
35+
"#cmakedefine01 BUILD_CONVOLUTION": "#define BUILD_CONVOLUTION 0",
36+
"#cmakedefine01 BUILD_DECONVOLUTION": "#define BUILD_DECONVOLUTION 0",
37+
"#cmakedefine01 BUILD_ELTWISE": "#define BUILD_ELTWISE 0",
38+
"#cmakedefine01 BUILD_GEMM_KERNELS_ALL": "#define BUILD_GEMM_KERNELS_ALL 1",
39+
"#cmakedefine01 BUILD_GEMM_KERNELS_NONE": "#define BUILD_GEMM_KERNELS_NONE 0",
40+
"#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 1",
41+
"#cmakedefine01 BUILD_GEMM_AVX2": "#define BUILD_GEMM_AVX2 1",
42+
"#cmakedefine01 BUILD_GEMM_AVX512": "#define BUILD_GEMM_AVX512 1",
43+
"#cmakedefine01 BUILD_GROUP_NORMALIZATION": "#define BUILD_GROUP_NORMALIZATION 1",
44+
"#cmakedefine01 BUILD_INNER_PRODUCT": "#define BUILD_INNER_PRODUCT 0",
45+
"#cmakedefine01 BUILD_LAYER_NORMALIZATION": "#define BUILD_LAYER_NORMALIZATION 0",
46+
"#cmakedefine01 BUILD_LRN": "#define BUILD_LRN 0",
47+
"#cmakedefine01 BUILD_MATMUL": "#define BUILD_MATMUL 0",
48+
"#cmakedefine01 BUILD_POOLING": "#define BUILD_POOLING 0",
49+
"#cmakedefine01 BUILD_PRELU": "#define BUILD_PRELU 0",
50+
"#cmakedefine01 BUILD_REDUCTION": "#define BUILD_REDUCTION 0",
51+
"#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0",
52+
"#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0",
53+
"#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0",
54+
"#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0",
55+
"#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0",
56+
"#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0",
57+
"#cmakedefine01 BUILD_PRIMITIVE_CPU_ISA_ALL": "#define BUILD_PRIMITIVE_CPU_ISA_ALL 1",
58+
"#cmakedefine01 BUILD_SSE41": "#define BUILD_SSE41 0",
59+
"#cmakedefine01 BUILD_AVX2": "#define BUILD_AVX2 0",
60+
"#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0",
61+
"#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0",
62+
"#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 0",
63+
"#cmakedefine01 BUILD_XE2": "#define BUILD_XE2 0",
64+
"#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0",
65+
"#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0",
66+
"#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0",
67+
"#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0",
68+
"#cmakedefine01 BUILD_SDPA": "#define BUILD_SDPA 1",
69+
"#cmakedefine01 BUILD_XE3": "#define BUILD_XE3 0",
70+
},
71+
template = "include/oneapi/dnnl/dnnl_config.h.in",
72+
)
73+
74+
expand_template(
75+
name = "dnnl_version_h",
76+
out = "include/oneapi/dnnl/dnnl_version.h",
77+
substitutions = {
78+
"@DNNL_VERSION_MAJOR@": "3",
79+
"@DNNL_VERSION_MINOR@": "11",
80+
"@DNNL_VERSION_PATCH@": "0",
81+
},
82+
template = "include/oneapi/dnnl/dnnl_version.h.in",
83+
)
84+
85+
expand_template(
86+
name = "dnnl_version_hash_h",
87+
out = "include/oneapi/dnnl/dnnl_version_hash.h",
88+
substitutions = {
89+
"@DNNL_VERSION_HASH@": "fc6151651a4577beae5ffac5a4132e75d39e1409",
90+
},
91+
template = "include/oneapi/dnnl/dnnl_version_hash.h.in",
92+
)
93+
94+
cc_library(
95+
name = "onednn_autogen",
96+
srcs = glob(["src/cpu/x64/gemm/**/*_kern_autogen*.cpp"]),
97+
copts = [
98+
"-O1",
99+
"-U_FORTIFY_SOURCE",
100+
"-fexceptions",
101+
"-UUSE_MKL",
102+
"-UUSE_CBLAS",
103+
"-DDNNL_ENABLE_MAX_CPU_ISA",
104+
"-DDNNL_ENABLE_ITT_TASKS",
105+
"-DDNNL_ENABLE_GRAPH_DUMP",
106+
"-DDNNL_EXPERIMENTAL_UKERNEL",
107+
],
108+
includes = [
109+
"include",
110+
"src",
111+
"src/common",
112+
"src/cpu",
113+
"src/cpu/gemm",
114+
"src/graph",
115+
"third_party",
116+
"third_party/ittnotify",
117+
"third_party/xbyak",
118+
],
119+
textual_hdrs = glob([
120+
"include/**/*",
121+
"src/common/*.hpp",
122+
"src/cpu/*.hpp",
123+
"src/cpu/**/*.hpp",
124+
"src/cpu/jit_utils/**/*.hpp",
125+
"src/graph/interface/*.hpp",
126+
"src/graph/backend/*.hpp",
127+
"src/graph/backend/dnnl/*.hpp",
128+
"src/graph/backend/dnnl/executables/*.hpp",
129+
"src/graph/backend/fake/*.hpp",
130+
"src/graph/backend/dnnl/passes/*.hpp",
131+
"src/graph/backend/dnnl/patterns/*.hpp",
132+
"src/graph/backend/dnnl/kernels/*.hpp",
133+
"src/graph/utils/*.hpp",
134+
"src/graph/utils/pm/*.hpp",
135+
"third_party/ittnotify/**/*.h",
136+
"third_party/spdlog/**/*.h",
137+
"third_party/xbyak/*.h",
138+
]) + [
139+
":dnnl_config_h",
140+
":dnnl_version_h",
141+
":dnnl_version_hash_h",
142+
],
143+
visibility = ["//visibility:public"],
144+
)
145+
146+
cc_library(
147+
name = "onednn",
148+
srcs = glob(
149+
[
150+
"src/common/*.cpp",
151+
"src/cpu/*.cpp",
152+
"src/cpu/**/*.cpp",
153+
"src/cpu/jit_utils/**/*.cpp",
154+
"src/cpu/x64/**/*.cpp",
155+
"src/graph/interface/*.cpp",
156+
"src/graph/backend/*.cpp",
157+
"src/graph/backend/dnnl/*.cpp",
158+
"src/graph/backend/dnnl/executables/*.cpp",
159+
"src/graph/backend/fake/*.cpp",
160+
"src/graph/backend/dnnl/passes/*.cpp",
161+
"src/graph/backend/dnnl/patterns/*.cpp",
162+
"src/graph/backend/dnnl/kernels/*.cpp",
163+
"src/graph/utils/*.cpp",
164+
"src/graph/utils/pm/*.cpp",
165+
"third_party/ittnotify/*.c",
166+
],
167+
exclude = [
168+
"src/cpu/aarch64/**",
169+
"src/cpu/rv64/**",
170+
"src/cpu/ppc64/**",
171+
"src/cpu/s390x/**",
172+
"src/cpu/x64/gemm/**/*_kern_autogen.cpp",
173+
"src/cpu/sycl/**",
174+
],
175+
),
176+
copts = [
177+
"-fexceptions",
178+
"-UUSE_MKL",
179+
"-UUSE_CBLAS",
180+
"-DDNNL_ENABLE_MAX_CPU_ISA",
181+
"-DDNNL_ENABLE_ITT_TASKS",
182+
"-DDNNL_ENABLE_GRAPH_DUMP",
183+
"-DDNNL_EXPERIMENTAL_UKERNEL",
184+
],
185+
includes = [
186+
"include",
187+
"src",
188+
"src/common",
189+
"src/cpu",
190+
"src/cpu/gemm",
191+
"src/graph",
192+
"third_party",
193+
"third_party/ittnotify",
194+
"third_party/xbyak",
195+
],
196+
linkopts = [
197+
"-lrt",
198+
"-Wl,--allow-multiple-definition",
199+
],
200+
textual_hdrs = glob([
201+
"include/**/*",
202+
"src/common/*.hpp",
203+
"src/cpu/*.hpp",
204+
"src/cpu/**/*.hpp",
205+
"src/cpu/jit_utils/**/*.hpp",
206+
"src/graph/interface/*.hpp",
207+
"src/graph/backend/*.hpp",
208+
"src/graph/backend/dnnl/*.hpp",
209+
"src/graph/backend/fake/*.hpp",
210+
"src/graph/backend/dnnl/passes/*.hpp",
211+
"src/graph/backend/dnnl/patterns/*.hpp",
212+
"src/graph/backend/dnnl/kernels/*.hpp",
213+
"src/graph/utils/*.hpp",
214+
"src/graph/utils/pm/*.hpp",
215+
"third_party/ittnotify/**/*.h",
216+
"third_party/spdlog/**/*.h",
217+
"third_party/xbyak/*.h",
218+
]) + [
219+
":dnnl_config_h",
220+
":dnnl_version_h",
221+
":dnnl_version_hash_h",
222+
],
223+
visibility = ["//visibility:public"],
224+
deps = [
225+
":onednn_autogen",
226+
"@onetbb//:tbb",
227+
],
228+
)

ops/bench_matmul.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,11 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
130130
keep += hwy::ConvertScalarTo<double>(C.Row(0)[hwy::Unpredictable1()]);
131131

132132
// Only record times after autotuning finished.
133-
if (per_key->autotune.Best()) times.push_back(elapsed);
133+
bool done = per_key->autotune.Best();
134+
#if GEMMA_ONEDNN
135+
done = done || per_key->brgemm_autotune.Best();
136+
#endif
137+
if (done) times.push_back(elapsed);
134138
}
135139
hwy::PreventElision(keep);
136140
env.ctx.pools.MaybeStopSpinning(use_spinning);

0 commit comments

Comments
 (0)