@@ -271,6 +271,135 @@ pub fn simdSoftmax(output: []f32, input: []const f32) void {
271271 simdScale (output , output , inv_sum );
272272}
273273
274+ // ═══════════════════════════════════════════════════════════════════════════════
275+ // PARALLEL MATRIX-VECTOR MULTIPLICATION
276+ // ═══════════════════════════════════════════════════════════════════════════════
277+
278+ /// Thread-local context for parallel matVec
279+ const ParallelMatVecContext = struct {
280+ output : []f32 ,
281+ mat : []const f32 ,
282+ vec : []const f32 ,
283+ cols : usize ,
284+ start_row : usize ,
285+ end_row : usize ,
286+ };
287+
288+ /// Worker function for parallel matVec
289+ fn parallelMatVecWorker (ctx : * ParallelMatVecContext , wg : * std.Thread.WaitGroup ) void {
290+ defer wg .finish ();
291+
292+ const aligned_cols = ctx .cols & ~ @as (usize , SIMD_WIDTH * 4 - 1 );
293+ const aligned_cols_single = ctx .cols & ~ @as (usize , SIMD_WIDTH - 1 );
294+
295+ for (ctx .start_row .. ctx .end_row ) | i | {
296+ var sum_vec0 : Vec8f = @splat (0.0 );
297+ var sum_vec1 : Vec8f = @splat (0.0 );
298+ var sum_vec2 : Vec8f = @splat (0.0 );
299+ var sum_vec3 : Vec8f = @splat (0.0 );
300+ var sum_scalar : f32 = 0.0 ;
301+ const row_offset = i * ctx .cols ;
302+
303+ var j : usize = 0 ;
304+ while (j < aligned_cols ) : (j += SIMD_WIDTH * 4 ) {
305+ const mat_vec0 : Vec8f = ctx .mat [row_offset + j .. ][0.. SIMD_WIDTH ].* ;
306+ const mat_vec1 : Vec8f = ctx .mat [row_offset + j + SIMD_WIDTH .. ][0.. SIMD_WIDTH ].* ;
307+ const mat_vec2 : Vec8f = ctx .mat [row_offset + j + SIMD_WIDTH * 2 .. ][0.. SIMD_WIDTH ].* ;
308+ const mat_vec3 : Vec8f = ctx .mat [row_offset + j + SIMD_WIDTH * 3 .. ][0.. SIMD_WIDTH ].* ;
309+ const vec_vec0 : Vec8f = ctx .vec [j .. ][0.. SIMD_WIDTH ].* ;
310+ const vec_vec1 : Vec8f = ctx .vec [j + SIMD_WIDTH .. ][0.. SIMD_WIDTH ].* ;
311+ const vec_vec2 : Vec8f = ctx .vec [j + SIMD_WIDTH * 2 .. ][0.. SIMD_WIDTH ].* ;
312+ const vec_vec3 : Vec8f = ctx .vec [j + SIMD_WIDTH * 3 .. ][0.. SIMD_WIDTH ].* ;
313+ sum_vec0 += mat_vec0 * vec_vec0 ;
314+ sum_vec1 += mat_vec1 * vec_vec1 ;
315+ sum_vec2 += mat_vec2 * vec_vec2 ;
316+ sum_vec3 += mat_vec3 * vec_vec3 ;
317+ }
318+
319+ sum_vec0 += sum_vec1 ;
320+ sum_vec2 += sum_vec3 ;
321+ sum_vec0 += sum_vec2 ;
322+
323+ while (j < aligned_cols_single ) : (j += SIMD_WIDTH ) {
324+ const mat_vec : Vec8f = ctx .mat [row_offset + j .. ][0.. SIMD_WIDTH ].* ;
325+ const vec_vec : Vec8f = ctx .vec [j .. ][0.. SIMD_WIDTH ].* ;
326+ sum_vec0 += mat_vec * vec_vec ;
327+ }
328+
329+ const sum_arr : [SIMD_WIDTH ]f32 = sum_vec0 ;
330+ inline for (sum_arr ) | v | {
331+ sum_scalar += v ;
332+ }
333+
334+ while (j < ctx .cols ) : (j += 1 ) {
335+ sum_scalar += ctx .mat [row_offset + j ] * ctx .vec [j ];
336+ }
337+
338+ ctx .output [i ] = sum_scalar ;
339+ }
340+ }
341+
342+ /// Global thread pool for parallel operations
343+ var global_pool : std.Thread.Pool = undefined ;
344+ var pool_initialized : bool = false ;
345+
346+ /// Initialize global thread pool
347+ pub fn initThreadPool (allocator : std.mem.Allocator ) ! void {
348+ if (! pool_initialized ) {
349+ try global_pool .init (.{ .allocator = allocator });
350+ pool_initialized = true ;
351+ }
352+ }
353+
354+ /// Deinitialize global thread pool
355+ pub fn deinitThreadPool () void {
356+ if (pool_initialized ) {
357+ global_pool .deinit ();
358+ pool_initialized = false ;
359+ }
360+ }
361+
362+ /// Parallel SIMD matrix-vector multiplication
363+ /// Uses thread pool for very large matrices only (rows > 10000)
364+ /// On 2-core systems, threading overhead often exceeds benefit
365+ pub fn parallelMatVec (output : []f32 , mat : []const f32 , vec : []const f32 , rows : usize , cols : usize ) void {
366+ // For most matrices, single-threaded SIMD is faster on 2 cores
367+ // Only use threading for vocab projection (32000 rows)
368+ if (rows < 10000 or ! pool_initialized ) {
369+ simdMatVec (output , mat , vec , rows , cols );
370+ return ;
371+ }
372+
373+ const num_threads : usize = 2 ; // Match CPU cores
374+ const rows_per_thread = rows / num_threads ;
375+
376+ var contexts : [2 ]ParallelMatVecContext = undefined ;
377+ var wg = std.Thread.WaitGroup {};
378+
379+ for (0.. num_threads ) | t | {
380+ const start = t * rows_per_thread ;
381+ const end = if (t == num_threads - 1 ) rows else (t + 1 ) * rows_per_thread ;
382+
383+ contexts [t ] = ParallelMatVecContext {
384+ .output = output ,
385+ .mat = mat ,
386+ .vec = vec ,
387+ .cols = cols ,
388+ .start_row = start ,
389+ .end_row = end ,
390+ };
391+
392+ wg .start ();
393+ global_pool .spawn (parallelMatVecWorker , .{& contexts [t ], & wg }) catch {
394+ // Fallback to single-threaded
395+ wg .finish ();
396+ simdMatVec (output [start .. end ], mat [start * cols .. ], vec , end - start , cols );
397+ };
398+ }
399+
400+ wg .wait ();
401+ }
402+
274403// ═══════════════════════════════════════════════════════════════════════════════
275404// TESTS
276405// ═══════════════════════════════════════════════════════════════════════════════
0 commit comments