@@ -168,6 +168,267 @@ fn spread_32_to_64(val: u32) -> u64 {
168168 out
169169}
170170
171+ // ============================================================================
172+ // Multi-versioned attend kernel: AVX-512 → AVX2 → scalar.
173+ // ============================================================================
174+
175+ /// Return type for attend kernel: (best_idx, distance, scores, fires).
176+ type AttendFn = unsafe fn ( & [ u64 ; 64 ] , u64 , u8 ) -> ( u8 , u8 , [ u8 ; 64 ] , u64 ) ;
177+
178+ #[ cfg( target_arch = "x86_64" ) ]
179+ #[ target_feature( enable = "avx512f" ) ]
180+ unsafe fn attend_avx512 ( rows : & [ u64 ; 64 ] , query : u64 , gamma : u8 ) -> ( u8 , u8 , [ u8 ; 64 ] , u64 ) {
181+ use std:: arch:: x86_64:: * ;
182+ let mut best_idx = 0u8 ;
183+ let mut best_score = 0u8 ;
184+ let mut scores = [ 0u8 ; 64 ] ;
185+ let mut fires = 0u64 ;
186+
187+ let q = _mm512_set1_epi64 ( query as i64 ) ;
188+ // Process 8 rows per chunk, 8 chunks = 64 rows
189+ for chunk in 0 ..8 {
190+ let base = chunk * 8 ;
191+ // SAFETY: rows is [u64; 64], base..base+8 is in bounds, Palette64 is 64-byte aligned.
192+ let r = _mm512_loadu_si512 ( rows[ base..] . as_ptr ( ) as * const __m512i ) ;
193+ let anded = _mm512_and_si512 ( r, q) ;
194+ // Extract 8 u64s and scalar popcount (no VPOPCNTDQ dependency)
195+ let vals: [ u64 ; 8 ] = std:: mem:: transmute ( anded) ;
196+ for j in 0 ..8 {
197+ let score = vals[ j] . count_ones ( ) as u8 ;
198+ let idx = base + j;
199+ scores[ idx] = score;
200+ if score > best_score {
201+ best_score = score;
202+ best_idx = idx as u8 ;
203+ }
204+ if score >= gamma {
205+ fires |= 1u64 << idx;
206+ }
207+ }
208+ }
209+ ( best_idx, 64 - best_score, scores, fires)
210+ }
211+
212+ #[ cfg( target_arch = "x86_64" ) ]
213+ #[ target_feature( enable = "avx2" ) ]
214+ unsafe fn attend_avx2 ( rows : & [ u64 ; 64 ] , query : u64 , gamma : u8 ) -> ( u8 , u8 , [ u8 ; 64 ] , u64 ) {
215+ use std:: arch:: x86_64:: * ;
216+ let mut best_idx = 0u8 ;
217+ let mut best_score = 0u8 ;
218+ let mut scores = [ 0u8 ; 64 ] ;
219+ let mut fires = 0u64 ;
220+
221+ let q = _mm256_set1_epi64x ( query as i64 ) ;
222+ // Process 4 rows per chunk, 16 chunks = 64 rows
223+ for chunk in 0 ..16 {
224+ let base = chunk * 4 ;
225+ // SAFETY: rows is [u64; 64], base..base+4 is in bounds.
226+ let r = _mm256_loadu_si256 ( rows[ base..] . as_ptr ( ) as * const __m256i ) ;
227+ let anded = _mm256_and_si256 ( r, q) ;
228+ let vals: [ u64 ; 4 ] = std:: mem:: transmute ( anded) ;
229+ for j in 0 ..4 {
230+ let score = vals[ j] . count_ones ( ) as u8 ;
231+ let idx = base + j;
232+ scores[ idx] = score;
233+ if score > best_score {
234+ best_score = score;
235+ best_idx = idx as u8 ;
236+ }
237+ if score >= gamma {
238+ fires |= 1u64 << idx;
239+ }
240+ }
241+ }
242+ ( best_idx, 64 - best_score, scores, fires)
243+ }
244+
245+ fn attend_scalar ( rows : & [ u64 ; 64 ] , query : u64 , gamma : u8 ) -> ( u8 , u8 , [ u8 ; 64 ] , u64 ) {
246+ let mut best_idx = 0u8 ;
247+ let mut best_score = 0u8 ;
248+ let mut scores = [ 0u8 ; 64 ] ;
249+ let mut fires = 0u64 ;
250+ for i in 0 ..64 {
251+ let score = ( query & rows[ i] ) . count_ones ( ) as u8 ;
252+ scores[ i] = score;
253+ if score > best_score {
254+ best_score = score;
255+ best_idx = i as u8 ;
256+ }
257+ if score >= gamma {
258+ fires |= 1u64 << i;
259+ }
260+ }
261+ ( best_idx, 64 - best_score, scores, fires)
262+ }
263+
264+ static ATTEND_KERNEL : std:: sync:: LazyLock < AttendFn > = std:: sync:: LazyLock :: new ( || {
265+ #[ cfg( target_arch = "x86_64" ) ]
266+ {
267+ if is_x86_feature_detected ! ( "avx512f" ) {
268+ return attend_avx512 as AttendFn ;
269+ }
270+ if is_x86_feature_detected ! ( "avx2" ) {
271+ return attend_avx2 as AttendFn ;
272+ }
273+ }
274+ attend_scalar as AttendFn
275+ } ) ;
276+
277+ // ============================================================================
278+ // Multi-versioned nearest_k kernel: AVX-512 → AVX2 → scalar.
279+ // ============================================================================
280+
281+ /// Compute all 64 Hamming distances in one pass.
282+ type NearestKFn = unsafe fn ( & [ u64 ; 64 ] , u64 ) -> [ u8 ; 64 ] ;
283+
284+ #[ cfg( target_arch = "x86_64" ) ]
285+ #[ target_feature( enable = "avx512f" ) ]
286+ unsafe fn nearest_k_avx512 ( rows : & [ u64 ; 64 ] , query : u64 ) -> [ u8 ; 64 ] {
287+ use std:: arch:: x86_64:: * ;
288+ let mut dists = [ 0u8 ; 64 ] ;
289+ let q = _mm512_set1_epi64 ( query as i64 ) ;
290+ for chunk in 0 ..8 {
291+ let base = chunk * 8 ;
292+ // SAFETY: rows is [u64; 64], base..base+8 is in bounds.
293+ let r = _mm512_loadu_si512 ( rows[ base..] . as_ptr ( ) as * const __m512i ) ;
294+ let xored = _mm512_xor_si512 ( r, q) ;
295+ let vals: [ u64 ; 8 ] = std:: mem:: transmute ( xored) ;
296+ for j in 0 ..8 {
297+ dists[ base + j] = vals[ j] . count_ones ( ) as u8 ;
298+ }
299+ }
300+ dists
301+ }
302+
303+ #[ cfg( target_arch = "x86_64" ) ]
304+ #[ target_feature( enable = "avx2" ) ]
305+ unsafe fn nearest_k_avx2 ( rows : & [ u64 ; 64 ] , query : u64 ) -> [ u8 ; 64 ] {
306+ use std:: arch:: x86_64:: * ;
307+ let mut dists = [ 0u8 ; 64 ] ;
308+ let q = _mm256_set1_epi64x ( query as i64 ) ;
309+ for chunk in 0 ..16 {
310+ let base = chunk * 4 ;
311+ // SAFETY: rows is [u64; 64], base..base+4 is in bounds.
312+ let r = _mm256_loadu_si256 ( rows[ base..] . as_ptr ( ) as * const __m256i ) ;
313+ let xored = _mm256_xor_si256 ( r, q) ;
314+ let vals: [ u64 ; 4 ] = std:: mem:: transmute ( xored) ;
315+ for j in 0 ..4 {
316+ dists[ base + j] = vals[ j] . count_ones ( ) as u8 ;
317+ }
318+ }
319+ dists
320+ }
321+
322+ fn nearest_k_scalar ( rows : & [ u64 ; 64 ] , query : u64 ) -> [ u8 ; 64 ] {
323+ let mut dists = [ 0u8 ; 64 ] ;
324+ for i in 0 ..64 {
325+ dists[ i] = ( query ^ rows[ i] ) . count_ones ( ) as u8 ;
326+ }
327+ dists
328+ }
329+
330+ static NEAREST_K_KERNEL : std:: sync:: LazyLock < NearestKFn > = std:: sync:: LazyLock :: new ( || {
331+ #[ cfg( target_arch = "x86_64" ) ]
332+ {
333+ if is_x86_feature_detected ! ( "avx512f" ) {
334+ return nearest_k_avx512 as NearestKFn ;
335+ }
336+ if is_x86_feature_detected ! ( "avx2" ) {
337+ return nearest_k_avx2 as NearestKFn ;
338+ }
339+ }
340+ nearest_k_scalar as NearestKFn
341+ } ) ;
342+
343+ // ============================================================================
344+ // Multi-versioned moe_gate kernel: AVX-512 → AVX2 → scalar.
345+ // ============================================================================
346+
347+ /// Return type: (active_mask, strength[8], combined).
348+ type MoeGateFn = unsafe fn ( & [ u64 ; 8 ] , u64 , u8 ) -> ( u8 , [ u8 ; 8 ] , u64 ) ;
349+
350+ #[ cfg( target_arch = "x86_64" ) ]
351+ #[ target_feature( enable = "avx512f" ) ]
352+ unsafe fn moe_gate_avx512 ( planes : & [ u64 ; 8 ] , query : u64 , threshold : u8 ) -> ( u8 , [ u8 ; 8 ] , u64 ) {
353+ use std:: arch:: x86_64:: * ;
354+ // Load all 8 planes into one zmm register, AND with broadcast query
355+ // SAFETY: planes is [u64; 8] = 64 bytes, fits in one zmm.
356+ let p = _mm512_loadu_si512 ( planes. as_ptr ( ) as * const __m512i ) ;
357+ let q = _mm512_set1_epi64 ( query as i64 ) ;
358+ let anded = _mm512_and_si512 ( p, q) ;
359+ let vals: [ u64 ; 8 ] = std:: mem:: transmute ( anded) ;
360+
361+ let mut active = 0u8 ;
362+ let mut strength = [ 0u8 ; 8 ] ;
363+ let mut combined = 0u64 ;
364+ for i in 0 ..8 {
365+ let score = vals[ i] . count_ones ( ) as u8 ;
366+ strength[ i] = score;
367+ if score >= threshold {
368+ active |= 1 << i;
369+ combined |= planes[ i] ;
370+ }
371+ }
372+ ( active, strength, combined)
373+ }
374+
375+ #[ cfg( target_arch = "x86_64" ) ]
376+ #[ target_feature( enable = "avx2" ) ]
377+ unsafe fn moe_gate_avx2 ( planes : & [ u64 ; 8 ] , query : u64 , threshold : u8 ) -> ( u8 , [ u8 ; 8 ] , u64 ) {
378+ use std:: arch:: x86_64:: * ;
379+ let q = _mm256_set1_epi64x ( query as i64 ) ;
380+ let mut active = 0u8 ;
381+ let mut strength = [ 0u8 ; 8 ] ;
382+ let mut combined = 0u64 ;
383+
384+ // Process 4 planes at a time, 2 chunks = 8 planes
385+ for chunk in 0 ..2 {
386+ let base = chunk * 4 ;
387+ // SAFETY: planes is [u64; 8], base..base+4 is in bounds.
388+ let p = _mm256_loadu_si256 ( planes[ base..] . as_ptr ( ) as * const __m256i ) ;
389+ let anded = _mm256_and_si256 ( p, q) ;
390+ let vals: [ u64 ; 4 ] = std:: mem:: transmute ( anded) ;
391+ for j in 0 ..4 {
392+ let score = vals[ j] . count_ones ( ) as u8 ;
393+ let idx = base + j;
394+ strength[ idx] = score;
395+ if score >= threshold {
396+ active |= 1 << idx;
397+ combined |= planes[ idx] ;
398+ }
399+ }
400+ }
401+ ( active, strength, combined)
402+ }
403+
404+ fn moe_gate_scalar ( planes : & [ u64 ; 8 ] , query : u64 , threshold : u8 ) -> ( u8 , [ u8 ; 8 ] , u64 ) {
405+ let mut active = 0u8 ;
406+ let mut strength = [ 0u8 ; 8 ] ;
407+ let mut combined = 0u64 ;
408+ for i in 0 ..8 {
409+ let score = ( query & planes[ i] ) . count_ones ( ) as u8 ;
410+ strength[ i] = score;
411+ if score >= threshold {
412+ active |= 1 << i;
413+ combined |= planes[ i] ;
414+ }
415+ }
416+ ( active, strength, combined)
417+ }
418+
419+ static MOE_GATE_KERNEL : std:: sync:: LazyLock < MoeGateFn > = std:: sync:: LazyLock :: new ( || {
420+ #[ cfg( target_arch = "x86_64" ) ]
421+ {
422+ if is_x86_feature_detected ! ( "avx512f" ) {
423+ return moe_gate_avx512 as MoeGateFn ;
424+ }
425+ if is_x86_feature_detected ! ( "avx2" ) {
426+ return moe_gate_avx2 as MoeGateFn ;
427+ }
428+ }
429+ moe_gate_scalar as MoeGateFn
430+ } ) ;
431+
171432// ============================================================================
172433// BNN Attention
173434// ============================================================================
@@ -183,30 +444,16 @@ impl Palette64 {
183444 /// Score = popcount(query AND row[i]).
184445 /// Higher score = more bits in common = better match.
185446 /// Gamma threshold: rows below this score don't "fire."
447+ ///
448+ /// Runtime dispatch via LazyLock: AVX-512 → AVX2 → scalar.
186449 #[ inline]
187450 pub fn attend ( & self , query : u64 , gamma : u8 ) -> AttentionResult {
188- let mut scores = [ 0u8 ; 64 ] ;
189- let mut best_idx = 0u8 ;
190- let mut best_score = 0u8 ;
191- let mut fires = 0u64 ;
192-
193- for i in 0 ..64 {
194- let score = ( query & self . rows [ i] ) . count_ones ( ) as u8 ;
195- scores[ i] = score;
196-
197- if score > best_score {
198- best_score = score;
199- best_idx = i as u8 ;
200- }
201-
202- if score >= gamma {
203- fires |= 1u64 << i;
204- }
205- }
206-
451+ // SAFETY: LazyLock guarantees the selected kernel matches CPU features.
452+ let ( best_idx, distance, scores, fires) =
453+ unsafe { ATTEND_KERNEL ( & self . rows , query, gamma) } ;
207454 AttentionResult {
208455 best_idx,
209- distance : 64 - best_score ,
456+ distance,
210457 scores,
211458 fires,
212459 }
@@ -228,16 +475,15 @@ impl Palette64 {
228475 /// Palette lookup: find the K nearest rows by Hamming distance.
229476 ///
230477 /// Returns (row_index, hamming_distance) sorted ascending.
478+ ///
479+ /// Runtime dispatch via LazyLock: AVX-512 → AVX2 → scalar.
231480 pub fn nearest_k ( & self , query : u64 , k : usize ) -> Vec < ( u8 , u8 ) > {
232- let mut dists: Vec < ( u8 , u8 ) > = ( 0 ..64 )
233- . map ( |i| {
234- let dist = ( query ^ self . rows [ i] ) . count_ones ( ) as u8 ;
235- ( i as u8 , dist)
236- } )
237- . collect ( ) ;
238- dists. sort_by_key ( |& ( _, d) | d) ;
239- dists. truncate ( k) ;
240- dists
481+ // SAFETY: LazyLock guarantees the selected kernel matches CPU features.
482+ let dists = unsafe { NEAREST_K_KERNEL ( & self . rows , query) } ;
483+ let mut pairs: Vec < ( u8 , u8 ) > = ( 0 ..64u8 ) . map ( |i| ( i, dists[ i as usize ] ) ) . collect ( ) ;
484+ pairs. sort_by_key ( |& ( _, d) | d) ;
485+ pairs. truncate ( k) ;
486+ pairs
241487 }
242488
243489 /// Row density: popcount of each row. Sparse rows = abstract; dense = concrete.
@@ -281,22 +527,13 @@ impl HeelPlanes {
281527 ///
282528 /// Each HEEL plane is an expert. The query's match against each expert
283529 /// determines which experts activate and with what strength.
530+ ///
531+ /// Runtime dispatch via LazyLock: AVX-512 → AVX2 → scalar.
284532 #[ inline]
285533 pub fn moe_gate ( & self , query : u64 , threshold : u8 ) -> MoeGate {
286- let mut active = 0u8 ;
287- let mut strength = [ 0u8 ; 8 ] ;
288- let mut combined = 0u64 ;
289-
290- for i in 0 ..8 {
291- let score = ( query & self . planes [ i] ) . count_ones ( ) as u8 ;
292- strength[ i] = score;
293-
294- if score >= threshold {
295- active |= 1 << i;
296- combined |= self . planes [ i] ;
297- }
298- }
299-
534+ // SAFETY: LazyLock guarantees the selected kernel matches CPU features.
535+ let ( active, strength, combined) =
536+ unsafe { MOE_GATE_KERNEL ( & self . planes , query, threshold) } ;
300537 MoeGate {
301538 active,
302539 strength,
0 commit comments