@@ -122,6 +122,87 @@ fn processAttentionHead(ctx: *AttentionHeadContext) void {
122122/// Set to 2 for environments with limited cores
123123pub const NUM_ATTENTION_THREADS : usize = 2 ;
124124
125+ // ═══════════════════════════════════════════════════════════════════════════════
126+ // PERSISTENT THREAD POOL - Reusable across forward passes
127+ // ═══════════════════════════════════════════════════════════════════════════════
128+
129+ /// Global thread pool state
130+ pub const ThreadPoolState = struct {
131+ initialized : bool = false ,
132+ num_threads : usize = 0 ,
133+ };
134+
135+ var global_pool_state : ThreadPoolState = .{};
136+
137+ /// Initialize the global thread pool
138+ /// Call once at application startup
139+ pub fn initThreadPool (num_threads : usize ) void {
140+ if (global_pool_state .initialized ) return ;
141+ global_pool_state .num_threads = if (num_threads == 0 )
142+ (std .Thread .getCpuCount () catch 2 )
143+ else
144+ num_threads ;
145+ global_pool_state .initialized = true ;
146+ }
147+
148+ /// Deinitialize the global thread pool
149+ /// Call at application shutdown
150+ pub fn deinitThreadPool () void {
151+ global_pool_state .initialized = false ;
152+ global_pool_state .num_threads = 0 ;
153+ }
154+
155+ /// Get number of available threads
156+ pub fn getPoolThreadCount () usize {
157+ if (! global_pool_state .initialized ) {
158+ initThreadPool (0 );
159+ }
160+ return global_pool_state .num_threads ;
161+ }
162+
163+ /// Check if pool is initialized
164+ pub fn isPoolInitialized () bool {
165+ return global_pool_state .initialized ;
166+ }
167+
168+ /// Work queue for dynamic load balancing
169+ pub const WorkQueue = struct {
170+ next_item : std .atomic .Value (usize ),
171+ total_items : usize ,
172+
173+ pub fn init (total : usize ) WorkQueue {
174+ return .{
175+ .next_item = std .atomic .Value (usize ).init (0 ),
176+ .total_items = total ,
177+ };
178+ }
179+
180+ /// Get next work item (returns null if no more work)
181+ pub fn getNext (self : * WorkQueue ) ? usize {
182+ while (true ) {
183+ const current = self .next_item .load (.acquire );
184+ if (current >= self .total_items ) return null ;
185+
186+ if (self .next_item .cmpxchgWeak (
187+ current ,
188+ current + 1 ,
189+ .release ,
190+ .monotonic ,
191+ )) | _ | {
192+ // CAS failed, retry
193+ continue ;
194+ } else {
195+ return current ;
196+ }
197+ }
198+ }
199+
200+ /// Reset queue for reuse
201+ pub fn reset (self : * WorkQueue ) void {
202+ self .next_item .store (0 , .release );
203+ }
204+ };
205+
125206// ═══════════════════════════════════════════════════════════════════════════════
126207// CONSTANTS - BitNet 2B Architecture
127208// ═══════════════════════════════════════════════════════════════════════════════
@@ -351,7 +432,7 @@ pub const Attention = struct {
351432 const scores_bufs = try allocator .alloc (f32 , cfg .num_heads * max_cache_len );
352433 defer allocator .free (scores_bufs );
353434
354- // Determine number of threads (min of available and heads )
435+ // Determine number of threads (use constant for minimal overhead )
355436 const num_threads = @min (NUM_ATTENTION_THREADS , cfg .num_heads );
356437
357438 if (num_threads > 1 and cfg .num_heads >= 2 ) {
@@ -1334,6 +1415,174 @@ test "KV cache grows correctly" {
13341415 std .debug .print ("\n ✅ KV cache grows correctly!\n " , .{});
13351416}
13361417
1418+ test "thread pool initialization" {
1419+ // Test pool init/deinit
1420+ initThreadPool (4 );
1421+ try std .testing .expect (isPoolInitialized ());
1422+ try std .testing .expectEqual (@as (usize , 4 ), getPoolThreadCount ());
1423+
1424+ deinitThreadPool ();
1425+ try std .testing .expect (! isPoolInitialized ());
1426+
1427+ // Re-init with auto-detect
1428+ initThreadPool (0 );
1429+ try std .testing .expect (isPoolInitialized ());
1430+ try std .testing .expect (getPoolThreadCount () >= 1 );
1431+
1432+ std .debug .print ("\n ✅ Thread pool init/deinit works!\n " , .{});
1433+ }
1434+
1435+ test "work queue atomic operations" {
1436+ var queue = WorkQueue .init (10 );
1437+
1438+ // Get items sequentially
1439+ var count : usize = 0 ;
1440+ while (queue .getNext ()) | _ | {
1441+ count += 1 ;
1442+ }
1443+ try std .testing .expectEqual (@as (usize , 10 ), count );
1444+
1445+ // Queue should be exhausted
1446+ try std .testing .expect (queue .getNext () == null );
1447+
1448+ // Reset and try again
1449+ queue .reset ();
1450+ try std .testing .expect (queue .getNext () != null );
1451+
1452+ std .debug .print ("\n ✅ Work queue atomic operations work!\n " , .{});
1453+ }
1454+
1455+ test "long running generation (100 tokens)" {
1456+ const allocator = std .testing .allocator ;
1457+
1458+ // Mini config for long-running test
1459+ const config = Config {
1460+ .hidden_size = 32 ,
1461+ .intermediate_size = 64 ,
1462+ .num_layers = 1 ,
1463+ .num_heads = 4 ,
1464+ .num_kv_heads = 2 ,
1465+ .head_dim = 8 ,
1466+ .vocab_size = 100 ,
1467+ .max_seq_len = 128 ,
1468+ };
1469+
1470+ // Create dummy weights
1471+ const q_size = config .num_heads * config .head_dim * config .hidden_size / 4 ;
1472+ const kv_size = config .num_kv_heads * config .head_dim * config .hidden_size / 4 ;
1473+ const o_size = config .hidden_size * config .num_heads * config .head_dim / 4 ;
1474+ const gate_size = config .intermediate_size * config .hidden_size / 4 ;
1475+ const down_size = config .hidden_size * config .intermediate_size / 4 ;
1476+ const lm_head_size = config .vocab_size * config .hidden_size / 4 ;
1477+ const embed_size = config .vocab_size * config .hidden_size ;
1478+
1479+ const w_q = try allocator .alloc (u8 , q_size );
1480+ defer allocator .free (w_q );
1481+ @memset (w_q , 0x55 );
1482+
1483+ const w_k = try allocator .alloc (u8 , kv_size );
1484+ defer allocator .free (w_k );
1485+ @memset (w_k , 0x55 );
1486+
1487+ const w_v = try allocator .alloc (u8 , kv_size );
1488+ defer allocator .free (w_v );
1489+ @memset (w_v , 0x55 );
1490+
1491+ const w_o = try allocator .alloc (u8 , o_size );
1492+ defer allocator .free (w_o );
1493+ @memset (w_o , 0x55 );
1494+
1495+ const w_gate = try allocator .alloc (u8 , gate_size );
1496+ defer allocator .free (w_gate );
1497+ @memset (w_gate , 0x55 );
1498+
1499+ const w_up = try allocator .alloc (u8 , gate_size );
1500+ defer allocator .free (w_up );
1501+ @memset (w_up , 0x55 );
1502+
1503+ const w_down = try allocator .alloc (u8 , down_size );
1504+ defer allocator .free (w_down );
1505+ @memset (w_down , 0x55 );
1506+
1507+ const lm_head = try allocator .alloc (u8 , lm_head_size );
1508+ defer allocator .free (lm_head );
1509+ @memset (lm_head , 0x55 );
1510+
1511+ const embed = try allocator .alloc (f32 , embed_size );
1512+ defer allocator .free (embed );
1513+ for (embed , 0.. ) | * e , i | e .* = @as (f32 , @floatFromInt (i % 100 )) * 0.01 ;
1514+
1515+ const norm_weight = try allocator .alloc (f32 , config .hidden_size );
1516+ defer allocator .free (norm_weight );
1517+ for (norm_weight ) | * w | w .* = 1.0 ;
1518+
1519+ // Create layers
1520+ const layers = try allocator .alloc (BitNetLayer , config .num_layers );
1521+ defer allocator .free (layers );
1522+
1523+ for (layers ) | * layer | {
1524+ layer .* = BitNetLayer {
1525+ .attention = Attention {
1526+ .config = config ,
1527+ .w_q = w_q ,
1528+ .w_k = w_k ,
1529+ .w_v = w_v ,
1530+ .w_o = w_o ,
1531+ },
1532+ .mlp = MLP {
1533+ .config = config ,
1534+ .w_gate = w_gate ,
1535+ .w_up = w_up ,
1536+ .w_down = w_down ,
1537+ },
1538+ .input_norm = RMSNorm { .weight = norm_weight , .eps = config .rms_norm_eps },
1539+ .post_attn_norm = RMSNorm { .weight = norm_weight , .eps = config .rms_norm_eps },
1540+ };
1541+ }
1542+
1543+ // Create KV caches
1544+ const kv_caches = try allocator .alloc (KVCache , config .num_layers );
1545+ defer {
1546+ for (kv_caches ) | * cache | cache .deinit (allocator );
1547+ allocator .free (kv_caches );
1548+ }
1549+ for (kv_caches ) | * cache | {
1550+ cache .* = try KVCache .init (allocator , config );
1551+ }
1552+
1553+ // Create RoPE
1554+ var rope = try RoPE .init (allocator , config .head_dim , config .max_seq_len , config .rope_theta );
1555+ defer rope .deinit (allocator );
1556+
1557+ // Create model
1558+ var model = BitNetModel {
1559+ .config = config ,
1560+ .allocator = allocator ,
1561+ .embed = embed ,
1562+ .layers = layers ,
1563+ .final_norm = RMSNorm { .weight = norm_weight , .eps = config .rms_norm_eps },
1564+ .lm_head = lm_head ,
1565+ .rope = rope ,
1566+ .kv_caches = kv_caches ,
1567+ };
1568+
1569+ // Generate 100 tokens (long-running test)
1570+ const prompt = [_ ]u32 { 1 , 5 , 10 };
1571+ const generated = try model .generate (& prompt , 100 , 1.0 , 0.9 );
1572+ defer allocator .free (generated );
1573+
1574+ // Verify we got expected number of tokens
1575+ try std .testing .expect (generated .len >= prompt .len + 50 ); // At least 50 new tokens
1576+
1577+ // All tokens should be valid
1578+ for (generated ) | token | {
1579+ try std .testing .expect (token < config .vocab_size );
1580+ }
1581+
1582+ std .debug .print ("\n ✅ Long-running generation (100 tokens) completed with no leaks!\n " , .{});
1583+ std .debug .print (" Generated {d} tokens total\n " , .{generated .len });
1584+ }
1585+
13371586test "RoPE rotates vectors" {
13381587 const allocator = std .testing .allocator ;
13391588
0 commit comments