@@ -102,7 +102,8 @@ pub fn ternaryMatVec(
102102 }
103103}
104104
105- /// SIMD-optimized ternary matmul (AVX2)
105+ /// SIMD-optimized ternary matmul (AVX2/AVX-512)
106+ /// Uses lookup tables and vectorized operations for maximum throughput
106107pub fn simdTernaryMatVec (
107108 output : []f32 ,
108109 weights : []const u8 ,
@@ -113,6 +114,10 @@ pub fn simdTernaryMatVec(
113114 const Vec8f32 = @Vector (8 , f32 );
114115 const cols_packed = (cols + 3 ) / 4 ;
115116
117+ // Precompute sign lookup: trit -> {-1, 0, +1}
118+ // 00 = 0, 01 = +1, 10 = -1
119+ const sign_lut = [4 ]f32 { 0.0 , 1.0 , -1.0 , 0.0 };
120+
116121 for (0.. rows ) | row | {
117122 var sum_vec : Vec8f32 = @splat (0.0 );
118123 var sum_scalar : f32 = 0.0 ;
@@ -121,56 +126,233 @@ pub fn simdTernaryMatVec(
121126 var col : usize = 0 ;
122127
123128 // Process 8 floats at a time with SIMD
124- while (col + 8 <= cols ) {
129+ while (col + 8 <= cols and row_start + col / 4 + 1 < weights . len ) {
125130 // Load 8 input values
126131 const in_vec : Vec8f32 = input [col .. ][0.. 8].* ;
127132
128133 // Load 2 bytes = 8 trits
129134 const byte0 = weights [row_start + col / 4 ];
130135 const byte1 = weights [row_start + col / 4 + 1 ];
131136
132- // Decode trits and create masks
133- var add_mask : Vec8f32 = @splat (0.0 );
134- var sub_mask : Vec8f32 = @splat (0.0 );
135-
136- inline for (0.. 4) | i | {
137- const trit0 = (byte0 >> @intCast (i * 2 )) & 0x3 ;
138- const trit1 = (byte1 >> @intCast (i * 2 )) & 0x3 ;
139-
140- if (trit0 == 0b01 ) add_mask [i ] = 1.0 ;
141- if (trit0 == 0b10 ) sub_mask [i ] = 1.0 ;
142- if (trit1 == 0b01 ) add_mask [4 + i ] = 1.0 ;
143- if (trit1 == 0b10 ) sub_mask [4 + i ] = 1.0 ;
144- }
145-
146- sum_vec += in_vec * add_mask ;
147- sum_vec -= in_vec * sub_mask ;
137+ // Decode trits using lookup table - vectorized
138+ const signs : Vec8f32 = .{
139+ sign_lut [(byte0 >> 0 ) & 0x3 ],
140+ sign_lut [(byte0 >> 2 ) & 0x3 ],
141+ sign_lut [(byte0 >> 4 ) & 0x3 ],
142+ sign_lut [(byte0 >> 6 ) & 0x3 ],
143+ sign_lut [(byte1 >> 0 ) & 0x3 ],
144+ sign_lut [(byte1 >> 2 ) & 0x3 ],
145+ sign_lut [(byte1 >> 4 ) & 0x3 ],
146+ sign_lut [(byte1 >> 6 ) & 0x3 ],
147+ };
148+
149+ // Multiply and accumulate: sum += input * sign
150+ // This is the key optimization: no branches, pure SIMD
151+ sum_vec += in_vec * signs ;
148152
149153 col += 8 ;
150154 }
151155
152- // Reduce SIMD vector
156+ // Reduce SIMD vector to scalar
153157 sum_scalar = @reduce (.Add , sum_vec );
154158
155- // Handle remaining elements
159+ // Handle remaining elements (scalar fallback)
156160 while (col < cols ) : (col += 1 ) {
157161 const byte_idx = row_start + col / 4 ;
158162 if (byte_idx >= weights .len ) break ;
159163
160164 const shift : u3 = @intCast ((col % 4 ) * 2 );
161165 const trit = (weights [byte_idx ] >> shift ) & 0x3 ;
162-
163- switch (trit ) {
164- 0b01 = > sum_scalar += input [col ],
165- 0b10 = > sum_scalar -= input [col ],
166- else = > {},
167- }
166+ sum_scalar += input [col ] * sign_lut [trit ];
167+ }
168+
169+ output [row ] = sum_scalar ;
170+ }
171+ }
172+
173+ /// Ultra-optimized SIMD ternary matmul with 16-wide vectors
174+ /// For AVX-512 capable CPUs
175+ pub fn simd16TernaryMatVec (
176+ output : []f32 ,
177+ weights : []const u8 ,
178+ input : []const f32 ,
179+ rows : usize ,
180+ cols : usize ,
181+ ) void {
182+ const Vec16f32 = @Vector (16 , f32 );
183+ const cols_packed = (cols + 3 ) / 4 ;
184+ const sign_lut = [4 ]f32 { 0.0 , 1.0 , -1.0 , 0.0 };
185+
186+ for (0.. rows ) | row | {
187+ var sum_vec : Vec16f32 = @splat (0.0 );
188+ var sum_scalar : f32 = 0.0 ;
189+ const row_start = row * cols_packed ;
190+
191+ var col : usize = 0 ;
192+
193+ // Process 16 floats at a time (4 bytes = 16 trits)
194+ while (col + 16 <= cols and row_start + col / 4 + 3 < weights .len ) {
195+ const in_vec : Vec16f32 = input [col .. ][0.. 16].* ;
196+
197+ // Load 4 bytes = 16 trits
198+ const b0 = weights [row_start + col / 4 ];
199+ const b1 = weights [row_start + col / 4 + 1 ];
200+ const b2 = weights [row_start + col / 4 + 2 ];
201+ const b3 = weights [row_start + col / 4 + 3 ];
202+
203+ const signs : Vec16f32 = .{
204+ sign_lut [(b0 >> 0 ) & 0x3 ], sign_lut [(b0 >> 2 ) & 0x3 ],
205+ sign_lut [(b0 >> 4 ) & 0x3 ], sign_lut [(b0 >> 6 ) & 0x3 ],
206+ sign_lut [(b1 >> 0 ) & 0x3 ], sign_lut [(b1 >> 2 ) & 0x3 ],
207+ sign_lut [(b1 >> 4 ) & 0x3 ], sign_lut [(b1 >> 6 ) & 0x3 ],
208+ sign_lut [(b2 >> 0 ) & 0x3 ], sign_lut [(b2 >> 2 ) & 0x3 ],
209+ sign_lut [(b2 >> 4 ) & 0x3 ], sign_lut [(b2 >> 6 ) & 0x3 ],
210+ sign_lut [(b3 >> 0 ) & 0x3 ], sign_lut [(b3 >> 2 ) & 0x3 ],
211+ sign_lut [(b3 >> 4 ) & 0x3 ], sign_lut [(b3 >> 6 ) & 0x3 ],
212+ };
213+
214+ sum_vec += in_vec * signs ;
215+ col += 16 ;
216+ }
217+
218+ sum_scalar = @reduce (.Add , sum_vec );
219+
220+ // Scalar fallback for remaining
221+ while (col < cols ) : (col += 1 ) {
222+ const byte_idx = row_start + col / 4 ;
223+ if (byte_idx >= weights .len ) break ;
224+ const shift : u3 = @intCast ((col % 4 ) * 2 );
225+ const trit = (weights [byte_idx ] >> shift ) & 0x3 ;
226+ sum_scalar += input [col ] * sign_lut [trit ];
168227 }
169228
170229 output [row ] = sum_scalar ;
171230 }
172231}
173232
233+ /// Batch ternary matmul - process multiple rows in parallel
234+ /// Best for large matrices
235+ pub fn batchTernaryMatVec (
236+ output : []f32 ,
237+ weights : []const u8 ,
238+ input : []const f32 ,
239+ rows : usize ,
240+ cols : usize ,
241+ ) void {
242+ const Vec8f32 = @Vector (8 , f32 );
243+ const cols_packed = (cols + 3 ) / 4 ;
244+ const sign_lut = [4 ]f32 { 0.0 , 1.0 , -1.0 , 0.0 };
245+
246+ var row : usize = 0 ;
247+
248+ // Process 4 rows at a time
249+ while (row + 4 <= rows ) {
250+ var sum0 : Vec8f32 = @splat (0.0 );
251+ var sum1 : Vec8f32 = @splat (0.0 );
252+ var sum2 : Vec8f32 = @splat (0.0 );
253+ var sum3 : Vec8f32 = @splat (0.0 );
254+
255+ var col : usize = 0 ;
256+ while (col + 8 <= cols ) {
257+ const in_vec : Vec8f32 = input [col .. ][0.. 8].* ;
258+ const col_byte = col / 4 ;
259+
260+ // Row 0
261+ const r0_start = row * cols_packed ;
262+ if (r0_start + col_byte + 1 < weights .len ) {
263+ const b0 = weights [r0_start + col_byte ];
264+ const b1 = weights [r0_start + col_byte + 1 ];
265+ const s0 : Vec8f32 = .{
266+ sign_lut [(b0 >> 0 ) & 0x3 ], sign_lut [(b0 >> 2 ) & 0x3 ],
267+ sign_lut [(b0 >> 4 ) & 0x3 ], sign_lut [(b0 >> 6 ) & 0x3 ],
268+ sign_lut [(b1 >> 0 ) & 0x3 ], sign_lut [(b1 >> 2 ) & 0x3 ],
269+ sign_lut [(b1 >> 4 ) & 0x3 ], sign_lut [(b1 >> 6 ) & 0x3 ],
270+ };
271+ sum0 += in_vec * s0 ;
272+ }
273+
274+ // Row 1
275+ const r1_start = (row + 1 ) * cols_packed ;
276+ if (r1_start + col_byte + 1 < weights .len ) {
277+ const b0 = weights [r1_start + col_byte ];
278+ const b1 = weights [r1_start + col_byte + 1 ];
279+ const s1 : Vec8f32 = .{
280+ sign_lut [(b0 >> 0 ) & 0x3 ], sign_lut [(b0 >> 2 ) & 0x3 ],
281+ sign_lut [(b0 >> 4 ) & 0x3 ], sign_lut [(b0 >> 6 ) & 0x3 ],
282+ sign_lut [(b1 >> 0 ) & 0x3 ], sign_lut [(b1 >> 2 ) & 0x3 ],
283+ sign_lut [(b1 >> 4 ) & 0x3 ], sign_lut [(b1 >> 6 ) & 0x3 ],
284+ };
285+ sum1 += in_vec * s1 ;
286+ }
287+
288+ // Row 2
289+ const r2_start = (row + 2 ) * cols_packed ;
290+ if (r2_start + col_byte + 1 < weights .len ) {
291+ const b0 = weights [r2_start + col_byte ];
292+ const b1 = weights [r2_start + col_byte + 1 ];
293+ const s2 : Vec8f32 = .{
294+ sign_lut [(b0 >> 0 ) & 0x3 ], sign_lut [(b0 >> 2 ) & 0x3 ],
295+ sign_lut [(b0 >> 4 ) & 0x3 ], sign_lut [(b0 >> 6 ) & 0x3 ],
296+ sign_lut [(b1 >> 0 ) & 0x3 ], sign_lut [(b1 >> 2 ) & 0x3 ],
297+ sign_lut [(b1 >> 4 ) & 0x3 ], sign_lut [(b1 >> 6 ) & 0x3 ],
298+ };
299+ sum2 += in_vec * s2 ;
300+ }
301+
302+ // Row 3
303+ const r3_start = (row + 3 ) * cols_packed ;
304+ if (r3_start + col_byte + 1 < weights .len ) {
305+ const b0 = weights [r3_start + col_byte ];
306+ const b1 = weights [r3_start + col_byte + 1 ];
307+ const s3 : Vec8f32 = .{
308+ sign_lut [(b0 >> 0 ) & 0x3 ], sign_lut [(b0 >> 2 ) & 0x3 ],
309+ sign_lut [(b0 >> 4 ) & 0x3 ], sign_lut [(b0 >> 6 ) & 0x3 ],
310+ sign_lut [(b1 >> 0 ) & 0x3 ], sign_lut [(b1 >> 2 ) & 0x3 ],
311+ sign_lut [(b1 >> 4 ) & 0x3 ], sign_lut [(b1 >> 6 ) & 0x3 ],
312+ };
313+ sum3 += in_vec * s3 ;
314+ }
315+
316+ col += 8 ;
317+ }
318+
319+ // Reduce and store
320+ output [row ] = @reduce (.Add , sum0 );
321+ output [row + 1 ] = @reduce (.Add , sum1 );
322+ output [row + 2 ] = @reduce (.Add , sum2 );
323+ output [row + 3 ] = @reduce (.Add , sum3 );
324+
325+ // Scalar remainder for columns
326+ while (col < cols ) : (col += 1 ) {
327+ for (0.. 4) | b | {
328+ const r_start = (row + b ) * cols_packed ;
329+ const byte_idx = r_start + col / 4 ;
330+ if (byte_idx >= weights .len ) continue ;
331+ const shift : u3 = @intCast ((col % 4 ) * 2 );
332+ const trit = (weights [byte_idx ] >> shift ) & 0x3 ;
333+ output [row + b ] += input [col ] * sign_lut [trit ];
334+ }
335+ }
336+
337+ row += 4 ;
338+ }
339+
340+ // Handle remaining rows
341+ while (row < rows ) : (row += 1 ) {
342+ var sum : f32 = 0.0 ;
343+ const row_start = row * cols_packed ;
344+
345+ for (0.. cols ) | col | {
346+ const byte_idx = row_start + col / 4 ;
347+ if (byte_idx >= weights .len ) break ;
348+ const shift : u3 = @intCast ((col % 4 ) * 2 );
349+ const trit = (weights [byte_idx ] >> shift ) & 0x3 ;
350+ sum += input [col ] * sign_lut [trit ];
351+ }
352+ output [row ] = sum ;
353+ }
354+ }
355+
174356// ═══════════════════════════════════════════════════════════════════════════════
175357// QUANTIZATION: Float -> Ternary
176358// ═══════════════════════════════════════════════════════════════════════════════
@@ -307,3 +489,99 @@ test "memory stats" {
307489 // Ternary: ~1.75 GB (16x smaller)
308490 try std .testing .expect (stats .ternary_bytes < 2_000_000_000 );
309491}
492+
493+ test "simd ternary matmul" {
494+ const allocator = std .testing .allocator ;
495+ _ = allocator ;
496+
497+ // 2x8 matrix for SIMD test
498+ const weights = [_ ]u8 {
499+ 0b01_00_10_01 , 0b00_01_10_01 , // Row 0: +1,-1,0,+1, +1,-1,+1,0
500+ 0b10_01_01_00 , 0b01_00_00_10 , // Row 1: 0,+1,+1,-1, -1,0,0,+1
501+ };
502+
503+ const input = [_ ]f32 { 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 };
504+ var output_scalar : [2 ]f32 = undefined ;
505+ var output_simd : [2 ]f32 = undefined ;
506+
507+ ternaryMatVec (& output_scalar , & weights , & input , 2 , 8 );
508+ simdTernaryMatVec (& output_simd , & weights , & input , 2 , 8 );
509+
510+ // Results should match
511+ try std .testing .expectApproxEqAbs (output_scalar [0 ], output_simd [0 ], 0.001 );
512+ try std .testing .expectApproxEqAbs (output_scalar [1 ], output_simd [1 ], 0.001 );
513+ }
514+
515+ // Benchmark function for comparing implementations
516+ pub fn main () void {
517+ // Run benchmarks when executed directly
518+ benchmarkTernaryMatVec (768 , 768 , 1000 ); // Small layer
519+ benchmarkTernaryMatVec (2048 , 2048 , 100 ); // Medium layer
520+ benchmarkTernaryMatVec (4096 , 4096 , 50 ); // Large layer
521+ }
522+
523+ pub fn benchmarkTernaryMatVec (rows : usize , cols : usize , iterations : usize ) void {
524+ const allocator = std .heap .page_allocator ;
525+
526+ // Allocate test data
527+ const weights = allocator .alloc (u8 , rows * ((cols + 3 ) / 4 )) catch return ;
528+ defer allocator .free (weights );
529+ const input = allocator .alloc (f32 , cols ) catch return ;
530+ defer allocator .free (input );
531+ const output = allocator .alloc (f32 , rows ) catch return ;
532+ defer allocator .free (output );
533+
534+ // Initialize with random-ish data
535+ for (weights , 0.. ) | * w , i | w .* = @truncate (i * 17 + 31 );
536+ for (input , 0.. ) | * v , i | v .* = @as (f32 , @floatFromInt (i % 100 )) / 100.0 ;
537+
538+ std .debug .print ("\n Ternary MatVec Benchmark ({d}x{d}, {d} iterations)\n " , .{rows , cols , iterations });
539+ std .debug .print ("=" ** 50 ++ "\n " , .{});
540+
541+ // Benchmark scalar
542+ var timer = std .time .Timer .start () catch return ;
543+ for (0.. iterations ) | _ | {
544+ ternaryMatVec (output , weights , input , rows , cols );
545+ }
546+ const scalar_time = timer .read ();
547+ std .debug .print ("Scalar: {d:.2} ms ({d:.2} GFLOPS)\n " , .{
548+ @as (f64 , @floatFromInt (scalar_time )) / 1e6 ,
549+ @as (f64 , @floatFromInt (rows * cols * iterations * 2 )) / @as (f64 , @floatFromInt (scalar_time )),
550+ });
551+
552+ // Benchmark SIMD 8-wide
553+ timer .reset ();
554+ for (0.. iterations ) | _ | {
555+ simdTernaryMatVec (output , weights , input , rows , cols );
556+ }
557+ const simd8_time = timer .read ();
558+ std .debug .print ("SIMD-8: {d:.2} ms ({d:.2} GFLOPS) - {d:.1}x speedup\n " , .{
559+ @as (f64 , @floatFromInt (simd8_time )) / 1e6 ,
560+ @as (f64 , @floatFromInt (rows * cols * iterations * 2 )) / @as (f64 , @floatFromInt (simd8_time )),
561+ @as (f64 , @floatFromInt (scalar_time )) / @as (f64 , @floatFromInt (simd8_time )),
562+ });
563+
564+ // Benchmark SIMD 16-wide
565+ timer .reset ();
566+ for (0.. iterations ) | _ | {
567+ simd16TernaryMatVec (output , weights , input , rows , cols );
568+ }
569+ const simd16_time = timer .read ();
570+ std .debug .print ("SIMD-16: {d:.2} ms ({d:.2} GFLOPS) - {d:.1}x speedup\n " , .{
571+ @as (f64 , @floatFromInt (simd16_time )) / 1e6 ,
572+ @as (f64 , @floatFromInt (rows * cols * iterations * 2 )) / @as (f64 , @floatFromInt (simd16_time )),
573+ @as (f64 , @floatFromInt (scalar_time )) / @as (f64 , @floatFromInt (simd16_time )),
574+ });
575+
576+ // Benchmark batch
577+ timer .reset ();
578+ for (0.. iterations ) | _ | {
579+ batchTernaryMatVec (output , weights , input , rows , cols );
580+ }
581+ const batch_time = timer .read ();
582+ std .debug .print ("Batch-4: {d:.2} ms ({d:.2} GFLOPS) - {d:.1}x speedup\n " , .{
583+ @as (f64 , @floatFromInt (batch_time )) / 1e6 ,
584+ @as (f64 , @floatFromInt (rows * cols * iterations * 2 )) / @as (f64 , @floatFromInt (batch_time )),
585+ @as (f64 , @floatFromInt (scalar_time )) / @as (f64 , @floatFromInt (batch_time )),
586+ });
587+ }
0 commit comments