@@ -239,25 +239,29 @@ mod reduce_sum_of_xy {
239239
240240 #[ inline]
241241 #[ cfg( target_arch = "x86_64" ) ]
242- #[ crate :: target_cpu( enable = "v4fp16" ) ]
243- pub fn reduce_sum_of_xy_v4fp16 ( lhs : & [ f16 ] , rhs : & [ f16 ] ) -> f32 {
242+ #[ crate :: target_cpu( enable = "v4" ) ]
243+ #[ target_feature( enable = "avx512fp16" ) ]
244+ pub fn reduce_sum_of_xy_v4_avx512fp16 ( lhs : & [ f16 ] , rhs : & [ f16 ] ) -> f32 {
244245 assert ! ( lhs. len( ) == rhs. len( ) ) ;
245246 unsafe {
246247 unsafe extern "C" {
247- unsafe fn fp16_reduce_sum_of_xy_v4fp16 ( a : * const ( ) , b : * const ( ) , n : usize )
248- -> f32 ;
248+ unsafe fn fp16_reduce_sum_of_xy_v4_avx512fp16 (
249+ a : * const ( ) ,
250+ b : * const ( ) ,
251+ n : usize ,
252+ ) -> f32 ;
249253 }
250- fp16_reduce_sum_of_xy_v4fp16 ( lhs. as_ptr ( ) . cast ( ) , rhs. as_ptr ( ) . cast ( ) , lhs. len ( ) )
254+ fp16_reduce_sum_of_xy_v4_avx512fp16 ( lhs. as_ptr ( ) . cast ( ) , rhs. as_ptr ( ) . cast ( ) , lhs. len ( ) )
251255 }
252256 }
253257
254258 #[ cfg( all( target_arch = "x86_64" , test, not( miri) ) ) ]
255259 #[ test]
256- fn reduce_sum_of_xy_v4fp16_test ( ) {
260+ fn reduce_sum_of_xy_v4_avx512fp16_test ( ) {
257261 use rand:: Rng ;
258262 const EPSILON : f32 = 2.0 ;
259- if !crate :: is_cpu_detected!( "v4fp16 " ) {
260- println ! ( "test {} ... skipped (v4fp16 )" , module_path!( ) ) ;
263+ if !crate :: is_cpu_detected!( "v4" ) || ! crate :: is_feature_detected! ( "avx512fp16 ") {
264+ println ! ( "test {} ... skipped (v4:avx512fp16 )" , module_path!( ) ) ;
261265 return ;
262266 }
263267 let mut rng = rand:: rng ( ) ;
@@ -272,7 +276,7 @@ mod reduce_sum_of_xy {
272276 for z in 3984 ..4016 {
273277 let lhs = & lhs[ ..z] ;
274278 let rhs = & rhs[ ..z] ;
275- let specialized = unsafe { reduce_sum_of_xy_v4fp16 ( lhs, rhs) } ;
279+ let specialized = unsafe { reduce_sum_of_xy_v4_avx512fp16 ( lhs, rhs) } ;
276280 let fallback = fallback ( lhs, rhs) ;
277281 assert ! (
278282 ( specialized - fallback) . abs( ) < EPSILON ,
@@ -494,7 +498,7 @@ mod reduce_sum_of_xy {
494498 }
495499 }
496500
497- #[ crate :: multiversion( @"v4fp16 " , @"v4" , @"v3" , @"a3.512" , @"a2:fp16" ) ]
501+ #[ crate :: multiversion( @"v4:avx512fp16 " , @"v4" , @"v3" , @"a3.512" , @"a2:fp16" ) ]
498502 pub fn reduce_sum_of_xy ( lhs : & [ f16 ] , rhs : & [ f16 ] ) -> f32 {
499503 assert ! ( lhs. len( ) == rhs. len( ) ) ;
500504 let n = lhs. len ( ) ;
@@ -511,25 +515,29 @@ mod reduce_sum_of_d2 {
511515
512516 #[ inline]
513517 #[ cfg( target_arch = "x86_64" ) ]
514- #[ crate :: target_cpu( enable = "v4fp16" ) ]
515- pub fn reduce_sum_of_d2_v4fp16 ( lhs : & [ f16 ] , rhs : & [ f16 ] ) -> f32 {
518+ #[ crate :: target_cpu( enable = "v4" ) ]
519+ #[ target_feature( enable = "avx512fp16" ) ]
520+ pub fn reduce_sum_of_d2_v4_avx512fp16 ( lhs : & [ f16 ] , rhs : & [ f16 ] ) -> f32 {
516521 assert ! ( lhs. len( ) == rhs. len( ) ) ;
517522 unsafe {
518523 unsafe extern "C" {
519- unsafe fn fp16_reduce_sum_of_d2_v4fp16 ( a : * const ( ) , b : * const ( ) , n : usize )
520- -> f32 ;
524+ unsafe fn fp16_reduce_sum_of_d2_v4_avx512fp16 (
525+ a : * const ( ) ,
526+ b : * const ( ) ,
527+ n : usize ,
528+ ) -> f32 ;
521529 }
522- fp16_reduce_sum_of_d2_v4fp16 ( lhs. as_ptr ( ) . cast ( ) , rhs. as_ptr ( ) . cast ( ) , lhs. len ( ) )
530+ fp16_reduce_sum_of_d2_v4_avx512fp16 ( lhs. as_ptr ( ) . cast ( ) , rhs. as_ptr ( ) . cast ( ) , lhs. len ( ) )
523531 }
524532 }
525533
526534 #[ cfg( all( target_arch = "x86_64" , test, not( miri) ) ) ]
527535 #[ test]
528- fn reduce_sum_of_d2_v4fp16_test ( ) {
536+ fn reduce_sum_of_d2_v4_avx512fp16_test ( ) {
529537 use rand:: Rng ;
530538 const EPSILON : f32 = 6.4 ;
531- if !crate :: is_cpu_detected!( "v4fp16 " ) {
532- println ! ( "test {} ... skipped (v4fp16 )" , module_path!( ) ) ;
539+ if !crate :: is_cpu_detected!( "v4" ) || ! crate :: is_feature_detected! ( "avx512fp16 ") {
540+ println ! ( "test {} ... skipped (v4:avx512fp16 )" , module_path!( ) ) ;
533541 return ;
534542 }
535543 let mut rng = rand:: rng ( ) ;
@@ -544,7 +552,7 @@ mod reduce_sum_of_d2 {
544552 for z in 3984 ..4016 {
545553 let lhs = & lhs[ ..z] ;
546554 let rhs = & rhs[ ..z] ;
547- let specialized = unsafe { reduce_sum_of_d2_v4fp16 ( lhs, rhs) } ;
555+ let specialized = unsafe { reduce_sum_of_d2_v4_avx512fp16 ( lhs, rhs) } ;
548556 let fallback = fallback ( lhs, rhs) ;
549557 assert ! (
550558 ( specialized - fallback) . abs( ) < EPSILON ,
@@ -774,7 +782,7 @@ mod reduce_sum_of_d2 {
774782 }
775783 }
776784
777- #[ crate :: multiversion( @"v4fp16 " , @"v4" , @"v3" , @"a3.512" , @"a2:fp16" ) ]
785+ #[ crate :: multiversion( @"v4:avx512fp16 " , @"v4" , @"v3" , @"a3.512" , @"a2:fp16" ) ]
778786 pub fn reduce_sum_of_d2 ( lhs : & [ f16 ] , rhs : & [ f16 ] ) -> f32 {
779787 assert ! ( lhs. len( ) == rhs. len( ) ) ;
780788 let n = lhs. len ( ) ;
0 commit comments