Skip to content

Commit 3c8e3d9

Browse files
gHashTagona-agent
andcommitted
perf(matmul): optimize batch ternary matmul (OPT-T07)
- Add batchTiledTernaryMatVec with 8-row batch processing - Update parallel ternaryWorker with 4-row batch optimization - Use batchTernaryMatVec for small matrices (faster than SIMD-16) - Benchmark: 2.28x speedup (3.36 → 7.65 GFLOPS) Co-authored-by: Ona <no-reply@ona.com>
1 parent 449ac42 commit 3c8e3d9

4 files changed

Lines changed: 434 additions & 18 deletions

File tree

docs/DISCOVERIES.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ Where:
7878
| OPT-T04 | Ternary Attention | 16x | 1.5x | ✅ Implemented |
7979
| OPT-T05 | Ternary Embeddings | 12.8x | 1x | ✅ Implemented |
8080
| OPT-T06 | Ternary Normalization | 16x | 0.2x | ✅ Implemented |
81+
| OPT-T07 | Batch Ternary MatMul | N/A | 2.28x | ✅ Implemented |
8182

8283
### Business Value
8384

@@ -350,6 +351,33 @@ var model = try TriModel.load(allocator, "model.tri");
350351
try model.enableTernaryNorm(); // 16x memory reduction for norm weights
351352
```
352353

354+
### Batch Ternary MatMul (OPT-T07)
355+
356+
**Status**: ✅ Implemented
357+
358+
| Component | File | Description |
359+
|-----------|------|-------------|
360+
| batchTernaryMatVec | `ternary_weights.zig` | 4-row batch SIMD matmul |
361+
| batchTiledTernaryMatVec | `ternary_weights.zig` | 8-row optimized version |
362+
| ternaryWorker | `parallel_inference.zig` | Parallel batch worker |
363+
364+
**Benchmark Results (2048x2048 matrix):**
365+
```
366+
╔══════════════════════════════════════════════════════════════╗
367+
║ TERNARY MATMUL BENCHMARK (2048x2048) ║
368+
╠══════════════════════════════════════════════════════════════╣
369+
║ SIMD-16 (baseline): 2499.7 us ( 3.36 GFLOPS) ║
370+
║ BatchTiled (new): 1096.0 us ( 7.65 GFLOPS) ║
371+
║ Speedup: 2.28x ║
372+
╚══════════════════════════════════════════════════════════════╝
373+
```
374+
375+
**Optimization Techniques:**
376+
1. Process 4-8 rows simultaneously (better register utilization)
377+
2. LUT-based sign conversion (faster than arithmetic)
378+
3. 8-wide SIMD vectors (AVX2 compatible)
379+
4. Parallel worker with batch processing
380+
353381
### Batch Processing (INF-004)
354382

355383
**Status**: ✅ Implemented
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# optimized_ternary_matmul.vibee
2+
# Cache-optimized ternary matrix-vector multiplication
3+
# Target: 2x speedup over current batch implementation
4+
5+
name: optimized_ternary_matmul
6+
version: "1.0.0"
7+
language: zig
8+
module: optimized_ternary_matmul
9+
10+
types:
11+
TileConfig:
12+
description: "Tiling configuration for cache optimization"
13+
fields:
14+
tile_rows: Int # Rows per tile (fit in L1 cache)
15+
tile_cols: Int # Cols per tile (fit in L2 cache)
16+
prefetch_distance: Int # Prefetch ahead distance
17+
18+
TernaryTile:
19+
description: "Pre-unpacked ternary tile for SIMD processing"
20+
fields:
21+
signs: List<Float> # Pre-converted signs (-1, 0, +1)
22+
rows: Int
23+
cols: Int
24+
25+
behaviors:
26+
- name: tiled_ternary_matmul
27+
given: output buffer, packed ternary weights, input vector, dimensions
28+
when: performing matrix-vector multiplication with tiling
29+
then: computes output with improved cache locality
30+
31+
- name: preunpack_tile
32+
given: packed ternary bytes, tile dimensions
33+
when: preparing tile for SIMD processing
34+
then: returns pre-unpacked signs as f32 array
35+
36+
- name: simd_tile_dot
37+
given: pre-unpacked signs, input vector slice
38+
when: computing dot product for tile
39+
then: returns partial sum using pure SIMD (no LUT)
40+
41+
- name: parallel_tiled_matmul
42+
given: output, weights, input, dimensions, num_threads
43+
when: distributing tiles across threads
44+
then: computes output with parallel tile processing
45+
46+
# Optimization Strategy:
47+
#
48+
# 1. TILING: Process matrix in L1/L2 cache-sized tiles
49+
# - L1 cache: 32KB → tile_rows = 64, tile_cols = 512
50+
# - L2 cache: 256KB → larger tiles for weight reuse
51+
#
52+
# 2. PRE-UNPACKING: Convert ternary to f32 signs once per tile
53+
# - Eliminates LUT lookups in inner loop
54+
# - Enables pure SIMD multiply-add
55+
#
56+
# 3. PREFETCHING: Software prefetch for next tile
57+
# - Hide memory latency
58+
#
59+
# 4. PARALLEL TILES: Distribute tiles across threads
60+
# - Better load balancing than row-based parallelism
61+
62+
# Memory Layout:
63+
# - Weights: row-major packed ternary (4 values per byte)
64+
# - Input: contiguous f32 vector
65+
# - Output: contiguous f32 vector
66+
# - Tile buffer: pre-unpacked f32 signs (reused per tile)
67+
68+
# Expected Performance:
69+
# - Current: 6.11 GFLOPS (batch-4)
70+
# - Target: 12+ GFLOPS (2x speedup)
71+
# - Theoretical max: ~50 GFLOPS (memory bandwidth limited)

src/vibeec/parallel_inference.zig

Lines changed: 92 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -164,49 +164,123 @@ fn ternaryWorker(ctx: *const ParallelTernaryContext, chunk: WorkChunk) void {
164164
const cols_packed = (cols + 3) / 4;
165165
const sign_lut = [4]f32{ 0.0, 1.0, -1.0, 0.0 };
166166

167-
for (chunk.start_row..chunk.end_row) |row| {
167+
const num_rows = chunk.end_row - chunk.start_row;
168+
var row = chunk.start_row;
169+
170+
// Process 4 rows at a time (batch optimization)
171+
while (row + 4 <= chunk.end_row) {
172+
var sum0: Vec8f = @splat(0.0);
173+
var sum1: Vec8f = @splat(0.0);
174+
var sum2: Vec8f = @splat(0.0);
175+
var sum3: Vec8f = @splat(0.0);
176+
177+
var col: usize = 0;
178+
while (col + 8 <= cols) {
179+
const in_vec: Vec8f = ctx.input[col..][0..8].*;
180+
const col_byte = col / 4;
181+
182+
// Row 0
183+
const r0_start = row * cols_packed;
184+
if (r0_start + col_byte + 1 < ctx.weights.len) {
185+
const b0 = ctx.weights[r0_start + col_byte];
186+
const b1 = ctx.weights[r0_start + col_byte + 1];
187+
const s0: Vec8f = .{
188+
sign_lut[(b0 >> 0) & 0x3], sign_lut[(b0 >> 2) & 0x3],
189+
sign_lut[(b0 >> 4) & 0x3], sign_lut[(b0 >> 6) & 0x3],
190+
sign_lut[(b1 >> 0) & 0x3], sign_lut[(b1 >> 2) & 0x3],
191+
sign_lut[(b1 >> 4) & 0x3], sign_lut[(b1 >> 6) & 0x3],
192+
};
193+
sum0 += in_vec * s0;
194+
}
195+
196+
// Row 1
197+
const r1_start = (row + 1) * cols_packed;
198+
if (r1_start + col_byte + 1 < ctx.weights.len) {
199+
const b0 = ctx.weights[r1_start + col_byte];
200+
const b1 = ctx.weights[r1_start + col_byte + 1];
201+
const s1: Vec8f = .{
202+
sign_lut[(b0 >> 0) & 0x3], sign_lut[(b0 >> 2) & 0x3],
203+
sign_lut[(b0 >> 4) & 0x3], sign_lut[(b0 >> 6) & 0x3],
204+
sign_lut[(b1 >> 0) & 0x3], sign_lut[(b1 >> 2) & 0x3],
205+
sign_lut[(b1 >> 4) & 0x3], sign_lut[(b1 >> 6) & 0x3],
206+
};
207+
sum1 += in_vec * s1;
208+
}
209+
210+
// Row 2
211+
const r2_start = (row + 2) * cols_packed;
212+
if (r2_start + col_byte + 1 < ctx.weights.len) {
213+
const b0 = ctx.weights[r2_start + col_byte];
214+
const b1 = ctx.weights[r2_start + col_byte + 1];
215+
const s2: Vec8f = .{
216+
sign_lut[(b0 >> 0) & 0x3], sign_lut[(b0 >> 2) & 0x3],
217+
sign_lut[(b0 >> 4) & 0x3], sign_lut[(b0 >> 6) & 0x3],
218+
sign_lut[(b1 >> 0) & 0x3], sign_lut[(b1 >> 2) & 0x3],
219+
sign_lut[(b1 >> 4) & 0x3], sign_lut[(b1 >> 6) & 0x3],
220+
};
221+
sum2 += in_vec * s2;
222+
}
223+
224+
// Row 3
225+
const r3_start = (row + 3) * cols_packed;
226+
if (r3_start + col_byte + 1 < ctx.weights.len) {
227+
const b0 = ctx.weights[r3_start + col_byte];
228+
const b1 = ctx.weights[r3_start + col_byte + 1];
229+
const s3: Vec8f = .{
230+
sign_lut[(b0 >> 0) & 0x3], sign_lut[(b0 >> 2) & 0x3],
231+
sign_lut[(b0 >> 4) & 0x3], sign_lut[(b0 >> 6) & 0x3],
232+
sign_lut[(b1 >> 0) & 0x3], sign_lut[(b1 >> 2) & 0x3],
233+
sign_lut[(b1 >> 4) & 0x3], sign_lut[(b1 >> 6) & 0x3],
234+
};
235+
sum3 += in_vec * s3;
236+
}
237+
238+
col += 8;
239+
}
240+
241+
ctx.output[row + 0] = @reduce(.Add, sum0) * ctx.scale;
242+
ctx.output[row + 1] = @reduce(.Add, sum1) * ctx.scale;
243+
ctx.output[row + 2] = @reduce(.Add, sum2) * ctx.scale;
244+
ctx.output[row + 3] = @reduce(.Add, sum3) * ctx.scale;
245+
246+
row += 4;
247+
}
248+
249+
// Handle remaining rows
250+
while (row < chunk.end_row) : (row += 1) {
168251
var sum_vec: Vec8f = @splat(0.0);
169252
var sum_scalar: f32 = 0.0;
170253
const row_start = row * cols_packed;
171254

172255
var col: usize = 0;
173-
174-
// SIMD loop: 8 floats at a time
175256
while (col + 8 <= cols and row_start + col / 4 + 1 < ctx.weights.len) {
176257
const in_vec: Vec8f = ctx.input[col..][0..8].*;
177-
178258
const byte0 = ctx.weights[row_start + col / 4];
179259
const byte1 = ctx.weights[row_start + col / 4 + 1];
180-
181260
const signs: Vec8f = .{
182-
sign_lut[(byte0 >> 0) & 0x3],
183-
sign_lut[(byte0 >> 2) & 0x3],
184-
sign_lut[(byte0 >> 4) & 0x3],
185-
sign_lut[(byte0 >> 6) & 0x3],
186-
sign_lut[(byte1 >> 0) & 0x3],
187-
sign_lut[(byte1 >> 2) & 0x3],
188-
sign_lut[(byte1 >> 4) & 0x3],
189-
sign_lut[(byte1 >> 6) & 0x3],
261+
sign_lut[(byte0 >> 0) & 0x3], sign_lut[(byte0 >> 2) & 0x3],
262+
sign_lut[(byte0 >> 4) & 0x3], sign_lut[(byte0 >> 6) & 0x3],
263+
sign_lut[(byte1 >> 0) & 0x3], sign_lut[(byte1 >> 2) & 0x3],
264+
sign_lut[(byte1 >> 4) & 0x3], sign_lut[(byte1 >> 6) & 0x3],
190265
};
191-
192266
sum_vec += in_vec * signs;
193267
col += 8;
194268
}
195269

196270
sum_scalar = @reduce(.Add, sum_vec);
197271

198-
// Scalar tail
199272
while (col < cols) : (col += 1) {
200273
const byte_idx = row_start + col / 4;
201274
if (byte_idx >= ctx.weights.len) break;
202-
203275
const shift: u3 = @intCast((col % 4) * 2);
204276
const trit = (ctx.weights[byte_idx] >> shift) & 0x3;
205277
sum_scalar += ctx.input[col] * sign_lut[trit];
206278
}
207279

208280
ctx.output[row] = sum_scalar * ctx.scale;
209281
}
282+
283+
_ = num_rows;
210284
}
211285

212286
/// Minimum rows to justify parallelization overhead
@@ -221,9 +295,9 @@ pub fn parallelTernaryMatmul(
221295
cols: usize,
222296
scale: f32,
223297
) void {
224-
// For small matrices, use single-threaded SIMD (faster due to no thread overhead)
298+
// For small matrices, use single-threaded batch SIMD (fastest)
225299
if (rows < MIN_PARALLEL_ROWS) {
226-
ternary.simd16TernaryMatVec(output, weights, input, rows, cols);
300+
ternary.batchTernaryMatVec(output, weights, input, rows, cols);
227301
for (output) |*o| o.* *= scale;
228302
return;
229303
}

0 commit comments

Comments
 (0)