@@ -19,6 +19,7 @@ fn main() -> Result<(), String> {
1919 println ! ( "cargo::rustc-check-cfg=cfg(kernel_support, values(\" avx512\" ))" ) ;
2020
2121 println ! ( "cargo:rerun-if-changed=src/simd/f16.c" ) ;
22+ println ! ( "cargo:rerun-if-changed=src/simd/bf16.c" ) ;
2223 println ! ( "cargo:rerun-if-changed=src/simd/dist_table.c" ) ;
2324
2425 // Important: we don't use `cfg!(target_arch)` here because that is the target_arch
@@ -37,13 +38,16 @@ fn main() -> Result<(), String> {
3738 if target_arch == "aarch64" && target_os == "macos" {
3839 // Build a version with NEON
3940 build_f16_with_flags ( "neon" , & [ "-mtune=apple-m1" ] ) . unwrap ( ) ;
41+ build_bf16_with_flags ( "neon" , & [ "-mtune=apple-m1" ] ) . unwrap ( ) ;
4042 } else if target_arch == "aarch64" && target_os == "ios" {
4143 // Build version with NEON
4244 // A13 bionic is the earliest supported iOS SOC
4345 build_f16_with_flags ( "neon" , & [ "-mtune=apple-a13" ] ) . unwrap ( ) ;
46+ build_bf16_with_flags ( "neon" , & [ "-mtune=apple-a13" ] ) . unwrap ( ) ;
4447 } else if target_arch == "aarch64" && ( target_os == "linux" || target_os == "android" ) {
4548 // Build a version with NEON
4649 build_f16_with_flags ( "neon" , & [ "-march=armv8.2-a+fp16" ] ) . unwrap ( ) ;
50+ build_bf16_with_flags ( "neon" , & [ "-march=armv8.2-a+fp16" ] ) . unwrap ( ) ;
4751 } else if target_arch == "x86_64" {
4852 // Build a version with AVX512
4953 if let Err ( err) = build_f16_with_flags ( "avx512" , & [ "-march=sapphirerapids" , "-mavx512fp16" ] )
@@ -59,6 +63,17 @@ fn main() -> Result<(), String> {
5963 // generated the AVX512 version of the f16 kernels.
6064 println ! ( "cargo:rustc-cfg=kernel_support=\" avx512\" " ) ;
6165 } ;
66+ // Build AVX-512 bf16 kernels (sapphirerapids has native vdpbf16ps)
67+ if let Err ( err) =
68+ build_bf16_with_flags ( "avx512" , & [ "-march=sapphirerapids" , "-mavx512fp16" ] )
69+ {
70+ println ! (
71+ "cargo:warning=Skipping build of AVX-512 bf16 kernels. Error: {}" ,
72+ err
73+ ) ;
74+ } else {
75+ println ! ( "cargo:rustc-cfg=kernel_support=\" avx512\" " ) ;
76+ } ;
6277 if let Err ( err) = build_dist_table_with_flags ( "avx512" , & [ "-march=native" ] ) {
6378 println ! (
6479 "cargo:warning=Skipping build of AVX-512 dist_table. Error: {}" ,
@@ -77,11 +92,20 @@ fn main() -> Result<(), String> {
7792 err
7893 ) ) ;
7994 } ;
95+ // Build AVX2 bf16 kernels (bf16-to-f32 is just a shift, auto-vectorizes well)
96+ if let Err ( err) = build_bf16_with_flags ( "avx2" , & [ "-march=haswell" ] ) {
97+ return Err ( format ! (
98+ "Unable to build AVX2 bf16 kernels. Received error: {}" ,
99+ err
100+ ) ) ;
101+ } ;
80102 // There is no SSE instruction set for f16 -> f32 float conversion
81103 } else if target_arch == "loongarch64" {
82104 // Build a version with LSX and LASX
83105 build_f16_with_flags ( "lsx" , & [ "-mlsx" ] ) . unwrap ( ) ;
84106 build_f16_with_flags ( "lasx" , & [ "-mlasx" ] ) . unwrap ( ) ;
107+ build_bf16_with_flags ( "lsx" , & [ "-mlsx" ] ) . unwrap ( ) ;
108+ build_bf16_with_flags ( "lasx" , & [ "-mlasx" ] ) . unwrap ( ) ;
85109 } else {
86110 // Only error if fp16kernels was explicitly requested on unsupported platform.
87111 // This allows builds on iOS, Android, etc. when the feature is disabled.
@@ -128,6 +152,32 @@ fn build_f16_with_flags(suffix: &str, flags: &[&str]) -> Result<(), cc::Error> {
128152 builder. try_compile ( & format ! ( "f16_{}" , suffix) )
129153}
130154
155+ fn build_bf16_with_flags ( suffix : & str , flags : & [ & str ] ) -> Result < ( ) , cc:: Error > {
156+ if cfg ! ( not( feature = "fp16kernels" ) ) {
157+ println ! (
158+ "cargo:warning=fp16kernels feature is not enabled, skipping build of bf16 kernels"
159+ ) ;
160+ return Ok ( ( ) ) ;
161+ }
162+
163+ let mut builder = cc:: Build :: new ( ) ;
164+ builder
165+ . std ( "c17" )
166+ . file ( "src/simd/bf16.c" )
167+ . flag ( "-ffast-math" )
168+ . flag ( "-funroll-loops" )
169+ . flag ( "-O3" )
170+ . flag ( "-Wall" )
171+ . flag ( "-Wextra" )
172+ . flag ( format ! ( "-DSUFFIX=_{}" , suffix) . as_str ( ) ) ;
173+
174+ for flag in flags {
175+ builder. flag ( flag) ;
176+ }
177+
178+ builder. try_compile ( & format ! ( "bf16_{}" , suffix) )
179+ }
180+
131181fn build_dist_table_with_flags ( suffix : & str , flags : & [ & str ] ) -> Result < ( ) , cc:: Error > {
132182 let mut builder = cc:: Build :: new ( ) ;
133183 builder
0 commit comments