|
4 | 4 | // ═══════════════════════════════════════════════════════════════════════════════ |
5 | 5 |
|
6 | 6 | const std = @import("std"); |
| 7 | +const simd_matmul = @import("simd_ternary_matmul.zig"); |
7 | 8 |
|
8 | 9 | // ═══════════════════════════════════════════════════════════════════════════════ |
9 | 10 | // CONSTANTS - BitNet 2B Architecture |
@@ -139,41 +140,28 @@ pub const KVCache = struct { |
139 | 140 | }; |
140 | 141 |
|
141 | 142 | // ═══════════════════════════════════════════════════════════════════════════════ |
142 | | -// TERNARY MATMUL - Using packed 2-bit weights |
| 143 | +// TERNARY MATMUL - Using optimized SIMD from simd_ternary_matmul.zig |
143 | 144 | // ═══════════════════════════════════════════════════════════════════════════════ |
144 | 145 |
|
| 146 | +/// SIMD-optimized ternary matmul: output = weights @ input |
| 147 | +/// Uses simdTernaryMatmulOpt16 for 16-wide SIMD (AVX-512 style) - fastest option |
| 148 | +pub fn ternaryMatmul(output: []f32, weights: []const u8, input: []const f32, rows: usize, cols: usize) void { |
| 149 | + // Use optimized 16-wide SIMD implementation for best performance |
| 150 | + simd_matmul.simdTernaryMatmulOpt16(output, weights, input, rows, cols); |
| 151 | +} |
| 152 | + |
| 153 | +// Keep local LUT for tests that don't use the full SIMD path |
145 | 154 | const SIGN_LUT: [4]f32 = .{ 0.0, 1.0, -1.0, 0.0 }; |
146 | 155 |
|
147 | | -/// SIMD ternary matmul: output = weights @ input |
148 | | -pub fn ternaryMatmul(output: []f32, weights: []const u8, input: []const f32, rows: usize, cols: usize) void { |
| 156 | +/// Scalar fallback for testing (not used in production) |
| 157 | +fn ternaryMatmulScalar(output: []f32, weights: []const u8, input: []const f32, rows: usize, cols: usize) void { |
149 | 158 | const cols_packed = (cols + 3) / 4; |
150 | | - const Vec8 = @Vector(8, f32); |
151 | 159 |
|
152 | 160 | for (0..rows) |row| { |
153 | 161 | var sum: f32 = 0.0; |
154 | 162 | const row_start = row * cols_packed; |
155 | | - var col: usize = 0; |
156 | | - |
157 | | - // SIMD loop |
158 | | - while (col + 8 <= cols) { |
159 | | - const byte_idx = row_start + col / 4; |
160 | | - if (byte_idx + 1 >= weights.len) break; |
161 | | - |
162 | | - const in_vec: Vec8 = input[col..][0..8].*; |
163 | | - const b0 = weights[byte_idx]; |
164 | | - const b1 = weights[byte_idx + 1]; |
165 | | - const signs: Vec8 = .{ |
166 | | - SIGN_LUT[(b0 >> 0) & 0x3], SIGN_LUT[(b0 >> 2) & 0x3], |
167 | | - SIGN_LUT[(b0 >> 4) & 0x3], SIGN_LUT[(b0 >> 6) & 0x3], |
168 | | - SIGN_LUT[(b1 >> 0) & 0x3], SIGN_LUT[(b1 >> 2) & 0x3], |
169 | | - SIGN_LUT[(b1 >> 4) & 0x3], SIGN_LUT[(b1 >> 6) & 0x3], |
170 | | - }; |
171 | | - sum += @reduce(.Add, in_vec * signs); |
172 | | - col += 8; |
173 | | - } |
174 | 163 |
|
175 | | - // Scalar tail |
176 | | - while (col < cols) : (col += 1) { |
| 164 | + for (0..cols) |col| { |
177 | 165 | const byte_idx = row_start + col / 4; |
178 | 166 | if (byte_idx >= weights.len) break; |
179 | 167 | const shift: u3 = @intCast((col % 4) * 2); |
|
0 commit comments