Skip to content

Commit 22c49f6

Browse files
gHashTagona-agent
andcommitted
perf: add SIMD attention dot products and weighted sum
Implement simdDotProduct() and simdScaleAdd() for attention: - 8-wide SIMD (AVX2 style) for Q@K^T dot products - 8-wide SIMD for attention @ V weighted sum Performance improvement: - Before: ~10 ms/layer, 303 ms/token, 3.3 tok/s - After: ~6.7 ms/layer, 187 ms/token, 4.9 tok/s - Speedup: 1.5x on attention, 1.6x overall Total speedup from baseline (17.4 ms/layer): - 17.4 ms → 6.7 ms = 2.6x speedup All 12 tests passing. Co-authored-by: Ona <no-reply@ona.com>
1 parent 499c366 commit 22c49f6

1 file changed

Lines changed: 62 additions & 10 deletions

File tree

src/vibeec/bitnet_pipeline.zig

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,61 @@
66
const std = @import("std");
77
const simd_matmul = @import("simd_ternary_matmul.zig");
88

9+
// ═══════════════════════════════════════════════════════════════════════════════
10+
// SIMD TYPES AND HELPERS
11+
// ═══════════════════════════════════════════════════════════════════════════════
12+
13+
const Vec8f32 = @Vector(8, f32);
14+
const Vec16f32 = @Vector(16, f32);
15+
16+
/// SIMD dot product for attention Q@K^T
17+
/// Processes 8 elements at a time for AVX2-style performance
18+
inline fn simdDotProduct(a: []const f32, b: []const f32) f32 {
19+
const len = @min(a.len, b.len);
20+
const aligned_len = len & ~@as(usize, 7); // Round down to multiple of 8
21+
22+
var sum_vec: Vec8f32 = @splat(0.0);
23+
var i: usize = 0;
24+
25+
// SIMD loop - 8 elements at a time
26+
while (i < aligned_len) : (i += 8) {
27+
const a_vec: Vec8f32 = a[i..][0..8].*;
28+
const b_vec: Vec8f32 = b[i..][0..8].*;
29+
sum_vec += a_vec * b_vec;
30+
}
31+
32+
// Reduce SIMD vector
33+
var sum: f32 = @reduce(.Add, sum_vec);
34+
35+
// Scalar tail
36+
while (i < len) : (i += 1) {
37+
sum += a[i] * b[i];
38+
}
39+
40+
return sum;
41+
}
42+
43+
/// SIMD scale-add for attention weighted sum: out += scale * vec
44+
inline fn simdScaleAdd(out: []f32, vec: []const f32, scale: f32) void {
45+
const len = @min(out.len, vec.len);
46+
const aligned_len = len & ~@as(usize, 7);
47+
48+
const scale_vec: Vec8f32 = @splat(scale);
49+
var i: usize = 0;
50+
51+
// SIMD loop
52+
while (i < aligned_len) : (i += 8) {
53+
const out_vec: Vec8f32 = out[i..][0..8].*;
54+
const v_vec: Vec8f32 = vec[i..][0..8].*;
55+
out[i..][0..8].* = out_vec + v_vec * scale_vec;
56+
}
57+
58+
// Scalar tail
59+
while (i < len) : (i += 1) {
60+
out[i] += scale * vec[i];
61+
}
62+
}
63+
964
// ═══════════════════════════════════════════════════════════════════════════════
1065
// CONSTANTS - BitNet 2B Architecture
1166
// ═══════════════════════════════════════════════════════════════════════════════
@@ -243,12 +298,10 @@ pub const Attention = struct {
243298
const k_offset = t * cfg.num_kv_heads * cfg.head_dim + kv_h * cfg.head_dim;
244299
const k_vec = kv_cache.k[k_offset..][0..cfg.head_dim];
245300

246-
var score: f32 = 0.0;
247-
for (0..cfg.head_dim) |i| {
248-
score += q_vec[i] * k_vec[i];
249-
}
250-
scores[t] = score * scale;
251-
if (scores[t] > max_score) max_score = scores[t];
301+
// SIMD dot product for Q @ K^T
302+
const score = simdDotProduct(q_vec, k_vec) * scale;
303+
scores[t] = score;
304+
if (score > max_score) max_score = score;
252305
}
253306

254307
// Softmax
@@ -261,13 +314,12 @@ pub const Attention = struct {
261314
s.* /= sum_exp;
262315
}
263316

264-
// Weighted sum of V
317+
// SIMD weighted sum of V
318+
const head_out = attn_out[h * cfg.head_dim ..][0..cfg.head_dim];
265319
for (0..kv_cache.len) |t| {
266320
const v_offset = t * cfg.num_kv_heads * cfg.head_dim + kv_h * cfg.head_dim;
267321
const v_vec = kv_cache.v[v_offset..][0..cfg.head_dim];
268-
for (0..cfg.head_dim) |i| {
269-
attn_out[h * cfg.head_dim + i] += scores[t] * v_vec[i];
270-
}
322+
simdScaleAdd(head_out, v_vec, scores[t]);
271323
}
272324
}
273325

0 commit comments

Comments
 (0)