66const std = @import ("std" );
77const simd_matmul = @import ("simd_ternary_matmul.zig" );
88
9+ // ═══════════════════════════════════════════════════════════════════════════════
10+ // SIMD TYPES AND HELPERS
11+ // ═══════════════════════════════════════════════════════════════════════════════
12+
13+ const Vec8f32 = @Vector (8 , f32 );
14+ const Vec16f32 = @Vector (16 , f32 );
15+
16+ /// SIMD dot product for attention Q@K^T
17+ /// Processes 8 elements at a time for AVX2-style performance
18+ inline fn simdDotProduct (a : []const f32 , b : []const f32 ) f32 {
19+ const len = @min (a .len , b .len );
20+ const aligned_len = len & ~ @as (usize , 7 ); // Round down to multiple of 8
21+
22+ var sum_vec : Vec8f32 = @splat (0.0 );
23+ var i : usize = 0 ;
24+
25+ // SIMD loop - 8 elements at a time
26+ while (i < aligned_len ) : (i += 8 ) {
27+ const a_vec : Vec8f32 = a [i .. ][0.. 8].* ;
28+ const b_vec : Vec8f32 = b [i .. ][0.. 8].* ;
29+ sum_vec += a_vec * b_vec ;
30+ }
31+
32+ // Reduce SIMD vector
33+ var sum : f32 = @reduce (.Add , sum_vec );
34+
35+ // Scalar tail
36+ while (i < len ) : (i += 1 ) {
37+ sum += a [i ] * b [i ];
38+ }
39+
40+ return sum ;
41+ }
42+
43+ /// SIMD scale-add for attention weighted sum: out += scale * vec
44+ inline fn simdScaleAdd (out : []f32 , vec : []const f32 , scale : f32 ) void {
45+ const len = @min (out .len , vec .len );
46+ const aligned_len = len & ~ @as (usize , 7 );
47+
48+ const scale_vec : Vec8f32 = @splat (scale );
49+ var i : usize = 0 ;
50+
51+ // SIMD loop
52+ while (i < aligned_len ) : (i += 8 ) {
53+ const out_vec : Vec8f32 = out [i .. ][0.. 8].* ;
54+ const v_vec : Vec8f32 = vec [i .. ][0.. 8].* ;
55+ out [i .. ][0.. 8].* = out_vec + v_vec * scale_vec ;
56+ }
57+
58+ // Scalar tail
59+ while (i < len ) : (i += 1 ) {
60+ out [i ] += scale * vec [i ];
61+ }
62+ }
63+
964// ═══════════════════════════════════════════════════════════════════════════════
1065// CONSTANTS - BitNet 2B Architecture
1166// ═══════════════════════════════════════════════════════════════════════════════
@@ -243,12 +298,10 @@ pub const Attention = struct {
243298 const k_offset = t * cfg .num_kv_heads * cfg .head_dim + kv_h * cfg .head_dim ;
244299 const k_vec = kv_cache .k [k_offset .. ][0.. cfg .head_dim ];
245300
246- var score : f32 = 0.0 ;
247- for (0.. cfg .head_dim ) | i | {
248- score += q_vec [i ] * k_vec [i ];
249- }
250- scores [t ] = score * scale ;
251- if (scores [t ] > max_score ) max_score = scores [t ];
301+ // SIMD dot product for Q @ K^T
302+ const score = simdDotProduct (q_vec , k_vec ) * scale ;
303+ scores [t ] = score ;
304+ if (score > max_score ) max_score = score ;
252305 }
253306
254307 // Softmax
@@ -261,13 +314,12 @@ pub const Attention = struct {
261314 s .* /= sum_exp ;
262315 }
263316
264- // Weighted sum of V
317+ // SIMD weighted sum of V
318+ const head_out = attn_out [h * cfg .head_dim .. ][0.. cfg .head_dim ];
265319 for (0.. kv_cache .len ) | t | {
266320 const v_offset = t * cfg .num_kv_heads * cfg .head_dim + kv_h * cfg .head_dim ;
267321 const v_vec = kv_cache .v [v_offset .. ][0.. cfg .head_dim ];
268- for (0.. cfg .head_dim ) | i | {
269- attn_out [h * cfg .head_dim + i ] += scores [t ] * v_vec [i ];
270- }
322+ simdScaleAdd (head_out , v_vec , scores [t ]);
271323 }
272324 }
273325
0 commit comments