@@ -61,6 +61,66 @@ inline fn simdScaleAdd(out: []f32, vec: []const f32, scale: f32) void {
6161 }
6262}
6363
64+ // ═══════════════════════════════════════════════════════════════════════════════
65+ // PARALLEL ATTENTION - Multi-threaded head processing
66+ // ═══════════════════════════════════════════════════════════════════════════════
67+
68+ /// Context for parallel attention head computation
69+ const AttentionHeadContext = struct {
70+ head_idx : usize ,
71+ q : []const f32 ,
72+ k_cache : []const f32 ,
73+ v_cache : []const f32 ,
74+ attn_out : []f32 ,
75+ scores_buf : []f32 ,
76+ head_dim : usize ,
77+ num_kv_heads : usize ,
78+ kv_group_size : usize ,
79+ cache_len : usize ,
80+ scale : f32 ,
81+ };
82+
83+ /// Process single attention head (called from thread)
84+ fn processAttentionHead (ctx : * AttentionHeadContext ) void {
85+ const h = ctx .head_idx ;
86+ const kv_h = h / ctx .kv_group_size ;
87+ const q_vec = ctx .q [h * ctx .head_dim .. ][0.. ctx .head_dim ];
88+ const scores = ctx .scores_buf ;
89+
90+ // Compute attention scores with SIMD dot product
91+ var max_score : f32 = - std .math .inf (f32 );
92+ for (0.. ctx .cache_len ) | t | {
93+ const k_offset = t * ctx .num_kv_heads * ctx .head_dim + kv_h * ctx .head_dim ;
94+ const k_vec = ctx .k_cache [k_offset .. ][0.. ctx .head_dim ];
95+ const score = simdDotProduct (q_vec , k_vec ) * ctx .scale ;
96+ scores [t ] = score ;
97+ if (score > max_score ) max_score = score ;
98+ }
99+
100+ // Softmax
101+ var sum_exp : f32 = 0.0 ;
102+ for (scores [0.. ctx .cache_len ]) | * s | {
103+ s .* = @exp (s .* - max_score );
104+ sum_exp += s .* ;
105+ }
106+ for (scores [0.. ctx .cache_len ]) | * s | {
107+ s .* /= sum_exp ;
108+ }
109+
110+ // SIMD weighted sum of V
111+ const head_out = ctx .attn_out [h * ctx .head_dim .. ][0.. ctx .head_dim ];
112+ @memset (head_out , 0.0 );
113+ for (0.. ctx .cache_len ) | t | {
114+ const v_offset = t * ctx .num_kv_heads * ctx .head_dim + kv_h * ctx .head_dim ;
115+ const v_vec = ctx .v_cache [v_offset .. ][0.. ctx .head_dim ];
116+ simdScaleAdd (head_out , v_vec , scores [t ]);
117+ }
118+ }
119+
120+ /// Number of threads for parallel attention (configurable)
121+ /// Set to 2 for environments with limited cores
122+ pub const NUM_ATTENTION_THREADS : usize = 2 ;
123+
64124// ═══════════════════════════════════════════════════════════════════════════════
65125// CONSTANTS - BitNet 2B Architecture
66126// ═══════════════════════════════════════════════════════════════════════════════
@@ -277,49 +337,81 @@ pub const Attention = struct {
277337 // Append K, V to cache
278338 kv_cache .append (k , v );
279339
280- // Compute attention for each head
340+ // Compute attention for each head (parallel when possible)
281341 const attn_out = try allocator .alloc (f32 , q_size );
282342 defer allocator .free (attn_out );
283343 @memset (attn_out , 0.0 );
284344
285345 const scale = 1.0 / @sqrt (@as (f32 , @floatFromInt (cfg .head_dim )));
286346 const kv_group_size = cfg .num_heads / cfg .num_kv_heads ;
287347
288- for (0.. cfg .num_heads ) | h | {
289- const kv_h = h / kv_group_size ;
290- const q_vec = q [h * cfg .head_dim .. ][0.. cfg .head_dim ];
291-
292- // Compute attention scores
293- const scores = try allocator .alloc (f32 , kv_cache .len );
294- defer allocator .free (scores );
348+ // Allocate score buffers for all heads
349+ const max_cache_len = @max (kv_cache .len , 1 );
350+ const scores_bufs = try allocator .alloc (f32 , cfg .num_heads * max_cache_len );
351+ defer allocator .free (scores_bufs );
352+
353+ // Determine number of threads (min of available and heads)
354+ const num_threads = @min (NUM_ATTENTION_THREADS , cfg .num_heads );
355+
356+ if (num_threads > 1 and cfg .num_heads >= 2 ) {
357+ // Parallel execution with threads
358+ var contexts : [8 ]AttentionHeadContext = undefined ;
359+ var threads : [8 ]std.Thread = undefined ;
360+ var active_threads : usize = 0 ;
295361
296- var max_score : f32 = - std . math . inf ( f32 ) ;
297- for (0 .. kv_cache . len ) | t | {
298- const k_offset = t * cfg . num_kv_heads * cfg . head_dim + kv_h * cfg . head_dim ;
299- const k_vec = kv_cache . k [ k_offset .. ][0 .. cfg .head_dim ] ;
362+ var h : usize = 0 ;
363+ while ( h < cfg . num_heads ) {
364+ // Launch batch of threads
365+ const batch_size = @min ( num_threads , cfg .num_heads - h ) ;
300366
301- // SIMD dot product for Q @ K^T
302- const score = simdDotProduct (q_vec , k_vec ) * scale ;
303- scores [t ] = score ;
304- if (score > max_score ) max_score = score ;
305- }
306-
307- // Softmax
308- var sum_exp : f32 = 0.0 ;
309- for (scores ) | * s | {
310- s .* = @exp (s .* - max_score );
311- sum_exp += s .* ;
312- }
313- for (scores ) | * s | {
314- s .* /= sum_exp ;
367+ for (0.. batch_size ) | t | {
368+ const head_idx = h + t ;
369+ contexts [t ] = AttentionHeadContext {
370+ .head_idx = head_idx ,
371+ .q = q ,
372+ .k_cache = kv_cache .k ,
373+ .v_cache = kv_cache .v ,
374+ .attn_out = attn_out ,
375+ .scores_buf = scores_bufs [head_idx * max_cache_len .. ][0.. max_cache_len ],
376+ .head_dim = cfg .head_dim ,
377+ .num_kv_heads = cfg .num_kv_heads ,
378+ .kv_group_size = kv_group_size ,
379+ .cache_len = kv_cache .len ,
380+ .scale = scale ,
381+ };
382+ threads [t ] = std .Thread .spawn (.{}, processAttentionHead , .{& contexts [t ]}) catch {
383+ // Fallback to sequential if spawn fails
384+ processAttentionHead (& contexts [t ]);
385+ continue ;
386+ };
387+ active_threads += 1 ;
388+ }
389+
390+ // Wait for batch to complete
391+ for (0.. active_threads ) | t | {
392+ threads [t ].join ();
393+ }
394+ active_threads = 0 ;
395+
396+ h += batch_size ;
315397 }
316-
317- // SIMD weighted sum of V
318- const head_out = attn_out [h * cfg .head_dim .. ][0.. cfg .head_dim ];
319- for (0.. kv_cache .len ) | t | {
320- const v_offset = t * cfg .num_kv_heads * cfg .head_dim + kv_h * cfg .head_dim ;
321- const v_vec = kv_cache .v [v_offset .. ][0.. cfg .head_dim ];
322- simdScaleAdd (head_out , v_vec , scores [t ]);
398+ } else {
399+ // Sequential fallback for single head or single thread
400+ for (0.. cfg .num_heads ) | h | {
401+ var ctx = AttentionHeadContext {
402+ .head_idx = h ,
403+ .q = q ,
404+ .k_cache = kv_cache .k ,
405+ .v_cache = kv_cache .v ,
406+ .attn_out = attn_out ,
407+ .scores_buf = scores_bufs [h * max_cache_len .. ][0.. max_cache_len ],
408+ .head_dim = cfg .head_dim ,
409+ .num_kv_heads = cfg .num_kv_heads ,
410+ .kv_group_size = kv_group_size ,
411+ .cache_len = kv_cache .len ,
412+ .scale = scale ,
413+ };
414+ processAttentionHead (& ctx );
323415 }
324416 }
325417
0 commit comments