@@ -164,49 +164,123 @@ fn ternaryWorker(ctx: *const ParallelTernaryContext, chunk: WorkChunk) void {
164164 const cols_packed = (cols + 3 ) / 4 ;
165165 const sign_lut = [4 ]f32 { 0.0 , 1.0 , -1.0 , 0.0 };
166166
167- for (chunk .start_row .. chunk .end_row ) | row | {
167+ const num_rows = chunk .end_row - chunk .start_row ;
168+ var row = chunk .start_row ;
169+
170+ // Process 4 rows at a time (batch optimization)
171+ while (row + 4 <= chunk .end_row ) {
172+ var sum0 : Vec8f = @splat (0.0 );
173+ var sum1 : Vec8f = @splat (0.0 );
174+ var sum2 : Vec8f = @splat (0.0 );
175+ var sum3 : Vec8f = @splat (0.0 );
176+
177+ var col : usize = 0 ;
178+ while (col + 8 <= cols ) {
179+ const in_vec : Vec8f = ctx .input [col .. ][0.. 8].* ;
180+ const col_byte = col / 4 ;
181+
182+ // Row 0
183+ const r0_start = row * cols_packed ;
184+ if (r0_start + col_byte + 1 < ctx .weights .len ) {
185+ const b0 = ctx .weights [r0_start + col_byte ];
186+ const b1 = ctx .weights [r0_start + col_byte + 1 ];
187+ const s0 : Vec8f = .{
188+ sign_lut [(b0 >> 0 ) & 0x3 ], sign_lut [(b0 >> 2 ) & 0x3 ],
189+ sign_lut [(b0 >> 4 ) & 0x3 ], sign_lut [(b0 >> 6 ) & 0x3 ],
190+ sign_lut [(b1 >> 0 ) & 0x3 ], sign_lut [(b1 >> 2 ) & 0x3 ],
191+ sign_lut [(b1 >> 4 ) & 0x3 ], sign_lut [(b1 >> 6 ) & 0x3 ],
192+ };
193+ sum0 += in_vec * s0 ;
194+ }
195+
196+ // Row 1
197+ const r1_start = (row + 1 ) * cols_packed ;
198+ if (r1_start + col_byte + 1 < ctx .weights .len ) {
199+ const b0 = ctx .weights [r1_start + col_byte ];
200+ const b1 = ctx .weights [r1_start + col_byte + 1 ];
201+ const s1 : Vec8f = .{
202+ sign_lut [(b0 >> 0 ) & 0x3 ], sign_lut [(b0 >> 2 ) & 0x3 ],
203+ sign_lut [(b0 >> 4 ) & 0x3 ], sign_lut [(b0 >> 6 ) & 0x3 ],
204+ sign_lut [(b1 >> 0 ) & 0x3 ], sign_lut [(b1 >> 2 ) & 0x3 ],
205+ sign_lut [(b1 >> 4 ) & 0x3 ], sign_lut [(b1 >> 6 ) & 0x3 ],
206+ };
207+ sum1 += in_vec * s1 ;
208+ }
209+
210+ // Row 2
211+ const r2_start = (row + 2 ) * cols_packed ;
212+ if (r2_start + col_byte + 1 < ctx .weights .len ) {
213+ const b0 = ctx .weights [r2_start + col_byte ];
214+ const b1 = ctx .weights [r2_start + col_byte + 1 ];
215+ const s2 : Vec8f = .{
216+ sign_lut [(b0 >> 0 ) & 0x3 ], sign_lut [(b0 >> 2 ) & 0x3 ],
217+ sign_lut [(b0 >> 4 ) & 0x3 ], sign_lut [(b0 >> 6 ) & 0x3 ],
218+ sign_lut [(b1 >> 0 ) & 0x3 ], sign_lut [(b1 >> 2 ) & 0x3 ],
219+ sign_lut [(b1 >> 4 ) & 0x3 ], sign_lut [(b1 >> 6 ) & 0x3 ],
220+ };
221+ sum2 += in_vec * s2 ;
222+ }
223+
224+ // Row 3
225+ const r3_start = (row + 3 ) * cols_packed ;
226+ if (r3_start + col_byte + 1 < ctx .weights .len ) {
227+ const b0 = ctx .weights [r3_start + col_byte ];
228+ const b1 = ctx .weights [r3_start + col_byte + 1 ];
229+ const s3 : Vec8f = .{
230+ sign_lut [(b0 >> 0 ) & 0x3 ], sign_lut [(b0 >> 2 ) & 0x3 ],
231+ sign_lut [(b0 >> 4 ) & 0x3 ], sign_lut [(b0 >> 6 ) & 0x3 ],
232+ sign_lut [(b1 >> 0 ) & 0x3 ], sign_lut [(b1 >> 2 ) & 0x3 ],
233+ sign_lut [(b1 >> 4 ) & 0x3 ], sign_lut [(b1 >> 6 ) & 0x3 ],
234+ };
235+ sum3 += in_vec * s3 ;
236+ }
237+
238+ col += 8 ;
239+ }
240+
241+ ctx .output [row + 0 ] = @reduce (.Add , sum0 ) * ctx .scale ;
242+ ctx .output [row + 1 ] = @reduce (.Add , sum1 ) * ctx .scale ;
243+ ctx .output [row + 2 ] = @reduce (.Add , sum2 ) * ctx .scale ;
244+ ctx .output [row + 3 ] = @reduce (.Add , sum3 ) * ctx .scale ;
245+
246+ row += 4 ;
247+ }
248+
249+ // Handle remaining rows
250+ while (row < chunk .end_row ) : (row += 1 ) {
168251 var sum_vec : Vec8f = @splat (0.0 );
169252 var sum_scalar : f32 = 0.0 ;
170253 const row_start = row * cols_packed ;
171254
172255 var col : usize = 0 ;
173-
174- // SIMD loop: 8 floats at a time
175256 while (col + 8 <= cols and row_start + col / 4 + 1 < ctx .weights .len ) {
176257 const in_vec : Vec8f = ctx .input [col .. ][0.. 8].* ;
177-
178258 const byte0 = ctx .weights [row_start + col / 4 ];
179259 const byte1 = ctx .weights [row_start + col / 4 + 1 ];
180-
181260 const signs : Vec8f = .{
182- sign_lut [(byte0 >> 0 ) & 0x3 ],
183- sign_lut [(byte0 >> 2 ) & 0x3 ],
184- sign_lut [(byte0 >> 4 ) & 0x3 ],
185- sign_lut [(byte0 >> 6 ) & 0x3 ],
186- sign_lut [(byte1 >> 0 ) & 0x3 ],
187- sign_lut [(byte1 >> 2 ) & 0x3 ],
188- sign_lut [(byte1 >> 4 ) & 0x3 ],
189- sign_lut [(byte1 >> 6 ) & 0x3 ],
261+ sign_lut [(byte0 >> 0 ) & 0x3 ], sign_lut [(byte0 >> 2 ) & 0x3 ],
262+ sign_lut [(byte0 >> 4 ) & 0x3 ], sign_lut [(byte0 >> 6 ) & 0x3 ],
263+ sign_lut [(byte1 >> 0 ) & 0x3 ], sign_lut [(byte1 >> 2 ) & 0x3 ],
264+ sign_lut [(byte1 >> 4 ) & 0x3 ], sign_lut [(byte1 >> 6 ) & 0x3 ],
190265 };
191-
192266 sum_vec += in_vec * signs ;
193267 col += 8 ;
194268 }
195269
196270 sum_scalar = @reduce (.Add , sum_vec );
197271
198- // Scalar tail
199272 while (col < cols ) : (col += 1 ) {
200273 const byte_idx = row_start + col / 4 ;
201274 if (byte_idx >= ctx .weights .len ) break ;
202-
203275 const shift : u3 = @intCast ((col % 4 ) * 2 );
204276 const trit = (ctx .weights [byte_idx ] >> shift ) & 0x3 ;
205277 sum_scalar += ctx .input [col ] * sign_lut [trit ];
206278 }
207279
208280 ctx .output [row ] = sum_scalar * ctx .scale ;
209281 }
282+
283+ _ = num_rows ;
210284}
211285
212286/// Minimum rows to justify parallelization overhead
@@ -221,9 +295,9 @@ pub fn parallelTernaryMatmul(
221295 cols : usize ,
222296 scale : f32 ,
223297) void {
224- // For small matrices, use single-threaded SIMD (faster due to no thread overhead )
298+ // For small matrices, use single-threaded batch SIMD (fastest )
225299 if (rows < MIN_PARALLEL_ROWS ) {
226- ternary .simd16TernaryMatVec (output , weights , input , rows , cols );
300+ ternary .batchTernaryMatVec (output , weights , input , rows , cols );
227301 for (output ) | * o | o .* *= scale ;
228302 return ;
229303 }
0 commit comments