@@ -271,6 +271,30 @@ target_compile_features(bitsandbytes PUBLIC cxx_std_17)
271271target_include_directories (bitsandbytes PUBLIC csrc include )
272272
273273if (BUILD_CPU)
274+ include (CheckCXXSourceRuns )
275+ set (AVX512F_TEST_CODE "
276+ #include <immintrin.h>
277+ int main() {
278+ __m512 a = _mm512_setzero_ps();
279+ __m512 b = _mm512_add_ps(a, a);
280+ return 0;
281+ }
282+ " )
283+ set (AVX512BF16_TEST_CODE "
284+ #include <immintrin.h>
285+ int main() {
286+ __m512 a = _mm512_setzero_ps();
287+ __m256bh b = _mm512_cvtneps_pbh(a);
288+ return 0;
289+ }
290+ " )
291+ set (CMAKE_REQUIRED_FLAGS "-mavx512f" )
292+ check_cxx_source_runs ("${AVX512F_TEST_CODE} " HOST_HAS_AVX512F )
293+ unset (CMAKE_REQUIRED_FLAGS)
294+ set (CMAKE_REQUIRED_FLAGS "-mavx512bf16" )
295+ check_cxx_source_runs ("${AVX512BF16_TEST_CODE} " HOST_HAS_AVX512BF16 )
296+ unset (CMAKE_REQUIRED_FLAGS)
297+
274298 if (OpenMP_CXX_FOUND)
275299 target_link_libraries (bitsandbytes PRIVATE OpenMP::OpenMP_CXX )
276300 add_definitions (-DHAS_OPENMP )
@@ -280,13 +304,13 @@ if (BUILD_CPU)
280304 include (CheckCXXCompilerFlag )
281305 check_cxx_compiler_flag (-mavx512f HAS_AVX512F_FLAG )
282306 check_cxx_compiler_flag (-mavx512bf16 HAS_AVX512BF16_FLAG )
283- if (HAS_AVX512F_FLAG)
307+ if (HAS_AVX512F_FLAG AND HOST_HAS_AVX512F )
284308 target_compile_options (bitsandbytes PRIVATE -mavx512f )
285309 target_compile_options (bitsandbytes PRIVATE -mavx512dq )
286310 target_compile_options (bitsandbytes PRIVATE -mavx512bw )
287311 target_compile_options (bitsandbytes PRIVATE -mavx512vl )
288312 endif ()
289- if (HAS_AVX512BF16_FLAG)
313+ if (HAS_AVX512BF16_FLAG AND HOST_HAS_AVX512BF16 )
290314 target_compile_options (bitsandbytes PRIVATE -mavx512bf16 )
291315 endif ()
292316 target_compile_options (
0 commit comments