|
| 1 | +// @origin(spec:sparse_simd.tri) @regen(manual-impl) |
| 2 | +// Sparse Ternary SIMD — Zero-Weight Skipping for 30-50% Speedup |
| 3 | +// ~66% of ternary weights are zero → skip entire chunks via @reduce(.Or) |
| 4 | +// |
| 5 | +// Key insight: if all 16 weights in a chunk are zero, skip compute entirely |
| 6 | +// Uses f16 for activations (2× memory bandwidth), f32 for accumulate (precision) |
| 7 | +// |
| 8 | +// φ² + 1/φ² = 3 | TRINITY |
| 9 | + |
| 10 | +const std = @import("std"); |
| 11 | +const f16_utils = @import("f16_utils.zig"); |
| 12 | + |
| 13 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 14 | +// TYPES |
| 15 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 16 | + |
| 17 | +const Vec16i8 = @Vector(16, i8); |
| 18 | +const Vec16f16 = @Vector(16, f16); |
| 19 | +const Vec16f32 = @Vector(16, f32); |
| 20 | + |
| 21 | +const zero_vec_i8: Vec16i8 = @splat(0); |
| 22 | +const zero_vec_f16: Vec16f16 = @splat(@as(f16, 0.0)); |
| 23 | + |
| 24 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 25 | +// SPARSE DOT PRODUCT — Skip zero chunks |
| 26 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 27 | + |
| 28 | +/// Sparse ternary dot product with 16-wide zero-chunk skipping. |
| 29 | +/// Returns f64 for precision. ~30-50% faster on sparse data (66% zeros). |
| 30 | +pub fn sparseTernaryDot(weights: []const i8, activations: []const f16) f64 { |
| 31 | + std.debug.assert(weights.len == activations.len); |
| 32 | + |
| 33 | + var acc: f64 = 0; |
| 34 | + const VEC_SIZE = 16; |
| 35 | + const num_chunks = weights.len / VEC_SIZE; |
| 36 | + |
| 37 | + var i: usize = 0; |
| 38 | + while (i < num_chunks * VEC_SIZE) : (i += VEC_SIZE) { |
| 39 | + // Load 16 weights |
| 40 | + const w_vec: Vec16i8 = weights[i..][0..VEC_SIZE].*; |
| 41 | + |
| 42 | + // Check if any non-zero exists in this chunk |
| 43 | + const any_nonzero = @reduce(.Or, w_vec != zero_vec_i8); |
| 44 | + |
| 45 | + // Skip entire chunk if all zeros |
| 46 | + if (!any_nonzero) continue; |
| 47 | + |
| 48 | + // Load activations and compute |
| 49 | + const a_vec: Vec16f16 = activations[i..][0..VEC_SIZE].*; |
| 50 | + const a_wide: Vec16f32 = @floatCast(a_vec); |
| 51 | + const w_wide: Vec16f32 = @floatFromInt(w_vec); |
| 52 | + |
| 53 | + const prod = a_wide * w_wide; |
| 54 | + var chunk_sum: f32 = 0; |
| 55 | + inline for (0..VEC_SIZE) |j| { |
| 56 | + chunk_sum += prod[j]; |
| 57 | + } |
| 58 | + acc += @as(f64, chunk_sum); |
| 59 | + } |
| 60 | + |
| 61 | + // Handle scalar tail |
| 62 | + while (i < weights.len) : (i += 1) { |
| 63 | + if (weights[i] == 0) continue; |
| 64 | + const a_f32: f32 = @floatCast(activations[i]); |
| 65 | + const w_f32: f32 = @floatFromInt(weights[i]); |
| 66 | + acc += @as(f64, a_f32 * w_f32); |
| 67 | + } |
| 68 | + |
| 69 | + return acc; |
| 70 | +} |
| 71 | + |
| 72 | +/// Dense ternary dot product (baseline for comparison). |
| 73 | +/// Always computes all elements — no skipping. |
| 74 | +pub fn denseTernaryDot(weights: []const i8, activations: []const f16) f64 { |
| 75 | + return f16_utils.dotProductF16(activations, @as([]const f16, @ptrCast(weights))); |
| 76 | +} |
| 77 | + |
| 78 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 79 | +// SPARSE MATRIX-VECTOR — Skip zero rows/chunks |
| 80 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 81 | + |
| 82 | +/// Sparse ternary matrix-vector multiplication. |
| 83 | +/// weights: [out_dim][in_dim] row-major i8 ternary matrix |
| 84 | +/// activations: [in_dim] f16 input vector |
| 85 | +/// output: [out_dim] f16 result (caller-allocated) |
| 86 | +pub fn sparseTernaryMatvec( |
| 87 | + weights: []const i8, |
| 88 | + activations: []const f16, |
| 89 | + output: []f16, |
| 90 | + out_dim: usize, |
| 91 | + in_dim: usize, |
| 92 | +) void { |
| 93 | + std.debug.assert(weights.len == out_dim * in_dim); |
| 94 | + std.debug.assert(activations.len == in_dim); |
| 95 | + std.debug.assert(output.len == out_dim); |
| 96 | + |
| 97 | + const VEC_SIZE = 16; |
| 98 | + |
| 99 | + // Process each output dimension (row) |
| 100 | + for (0..out_dim) |row| { |
| 101 | + const row_start = row * in_dim; |
| 102 | + var acc: f64 = 0; |
| 103 | + |
| 104 | + // Process 16 elements at a time |
| 105 | + const num_chunks = in_dim / VEC_SIZE; |
| 106 | + var col: usize = 0; |
| 107 | + |
| 108 | + while (col < num_chunks * VEC_SIZE) : (col += VEC_SIZE) { |
| 109 | + const w_vec: Vec16i8 = weights[row_start + col..][0..VEC_SIZE].*; |
| 110 | + const any_nonzero = @reduce(.Or, w_vec != zero_vec_i8); |
| 111 | + |
| 112 | + if (!any_nonzero) { |
| 113 | + col += VEC_SIZE; |
| 114 | + continue; |
| 115 | + } |
| 116 | + |
| 117 | + const a_vec: Vec16f16 = activations[col..][0..VEC_SIZE].*; |
| 118 | + const a_wide: Vec16f32 = @floatCast(a_vec); |
| 119 | + const w_wide: Vec16f32 = @floatFromInt(w_vec); |
| 120 | + |
| 121 | + const prod = a_wide * w_wide; |
| 122 | + var chunk_sum: f32 = 0; |
| 123 | + inline for (0..VEC_SIZE) |j| { |
| 124 | + chunk_sum += prod[j]; |
| 125 | + } |
| 126 | + acc += @as(f64, chunk_sum); |
| 127 | + } |
| 128 | + |
| 129 | + // Handle scalar tail |
| 130 | + while (col < in_dim) : (col += 1) { |
| 131 | + const w = weights[row_start + col]; |
| 132 | + if (w == 0) continue; |
| 133 | + const a_f32: f32 = @floatCast(activations[col]); |
| 134 | + acc += @as(f64, a_f32 * @as(f64, @floatFromInt(w))); |
| 135 | + } |
| 136 | + |
| 137 | + output[row] = @floatCast(acc); |
| 138 | + } |
| 139 | +} |
| 140 | + |
| 141 | +/// Dense ternary matrix-vector multiplication (baseline). |
| 142 | +pub fn denseTernaryMatvec( |
| 143 | + weights: []const i8, |
| 144 | + activations: []const f16, |
| 145 | + output: []f16, |
| 146 | + out_dim: usize, |
| 147 | + in_dim: usize, |
| 148 | +) void { |
| 149 | + std.debug.assert(weights.len == out_dim * in_dim); |
| 150 | + std.debug.assert(activations.len == in_dim); |
| 151 | + std.debug.assert(output.len == out_dim); |
| 152 | + |
| 153 | + for (0..out_dim) |row| { |
| 154 | + const row_start = row * in_dim; |
| 155 | + var dot: f64 = 0; |
| 156 | + |
| 157 | + for (0..in_dim) |col| { |
| 158 | + const w = weights[row_start + col]; |
| 159 | + if (w == 0) continue; |
| 160 | + const a_f32: f32 = @floatCast(activations[col]); |
| 161 | + dot += @as(f64, a_f32 * @as(f64, @floatFromInt(w))); |
| 162 | + } |
| 163 | + |
| 164 | + output[row] = @floatCast(dot); |
| 165 | + } |
| 166 | +} |
| 167 | + |
| 168 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 169 | +// SPARSITY ANALYSIS |
| 170 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 171 | + |
| 172 | +/// Count zero chunks in a slice (16-element granularity). |
| 173 | +pub fn countZeroChunks(data: []const i8) usize { |
| 174 | + const VEC_SIZE = 16; |
| 175 | + const num_chunks = data.len / VEC_SIZE; |
| 176 | + var zero_count: usize = 0; |
| 177 | + |
| 178 | + var i: usize = 0; |
| 179 | + while (i < num_chunks * VEC_SIZE) : (i += VEC_SIZE) { |
| 180 | + const vec: Vec16i8 = data[i..][0..VEC_SIZE].*; |
| 181 | + const all_zero = @reduce(.And, vec == zero_vec_i8); |
| 182 | + if (all_zero) zero_count += 1; |
| 183 | + } |
| 184 | + |
| 185 | + return zero_count; |
| 186 | +} |
| 187 | + |
| 188 | +/// Calculate sparsity ratio (fraction of zeros). |
| 189 | +pub fn sparsityRatio(data: []const i8) f64 { |
| 190 | + if (data.len == 0) return 0; |
| 191 | + |
| 192 | + var zero_count: usize = 0; |
| 193 | + for (data) |v| { |
| 194 | + if (v == 0) zero_count += 1; |
| 195 | + } |
| 196 | + |
| 197 | + return @as(f64, @floatFromInt(zero_count)) / @as(f64, @floatFromInt(data.len)); |
| 198 | +} |
| 199 | + |
| 200 | +/// Estimate speedup factor for sparse vs dense. |
| 201 | +/// Returns 1.0 + (zero_chunk_ratio * 0.5) as rough estimate. |
| 202 | +pub fn estimateSpeedup(weights: []const i8) f64 { |
| 203 | + const total_chunks = weights.len / 16; |
| 204 | + if (total_chunks == 0) return 1.0; |
| 205 | + |
| 206 | + const zero_chunks = countZeroChunks(weights); |
| 207 | + const zero_chunk_ratio = @as(f64, @floatFromInt(zero_chunks)) / @as(f64, @floatFromInt(total_chunks)); |
| 208 | + |
| 209 | + // Each skipped chunk saves ~50% of work |
| 210 | + return 1.0 + zero_chunk_ratio * 0.5; |
| 211 | +} |
| 212 | + |
| 213 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 214 | +// TESTS |
| 215 | +// ═══════════════════════════════════════════════════════════════════════════════ |
| 216 | + |
| 217 | +test "sparse dot product matches dense" { |
| 218 | + const weights = [_]i8{ 1, 0, -1, 0, 1, 0, -1, 0, 1, 0, -1, 0, 1, 0, -1, 0 }; |
| 219 | + const activations = [_]f16{ 0.5, 0.3, 0.7, 0.2, 0.5, 0.3, 0.7, 0.2, 0.5, 0.3, 0.7, 0.2, 0.5, 0.3, 0.7, 0.2 }; |
| 220 | + |
| 221 | + const sparse_result = sparseTernaryDot(&weights, &activations); |
| 222 | + |
| 223 | + // Compute expected manually |
| 224 | + var expected: f64 = 0; |
| 225 | + for (weights, activations) |w, a| { |
| 226 | + const a_f32: f32 = @floatCast(a); |
| 227 | + expected += @as(f64, a_f32 * @as(f64, @floatFromInt(w))); |
| 228 | + } |
| 229 | + |
| 230 | + try std.testing.expectApproxEqAbs(expected, sparse_result, 0.001); |
| 231 | +} |
| 232 | + |
| 233 | +test "sparse dot product all zeros" { |
| 234 | + const weights = [_]i8{0} ** 16; |
| 235 | + const activations = [_]f16{0.5} ** 16; |
| 236 | + |
| 237 | + const result = sparseTernaryDot(&weights, &activations); |
| 238 | + try std.testing.expectEqual(@as(f64, 0), result); |
| 239 | +} |
| 240 | + |
| 241 | +test "sparse dot product all nonzeros" { |
| 242 | + const weights = [_]i8{1} ** 16; |
| 243 | + const activations = [_]f16{0.5} ** 16; |
| 244 | + |
| 245 | + const result = sparseTernaryDot(&weights, &activations); |
| 246 | + const expected: f64 = 16 * 0.5; |
| 247 | + try std.testing.expectApproxEqAbs(expected, result, 0.001); |
| 248 | +} |
| 249 | + |
| 250 | +test "sparse dot product 50% sparse" { |
| 251 | + // Alternating zero/nonzero pattern |
| 252 | + var weights: [16]i8 = undefined; |
| 253 | + var activations: [16]f16 = undefined; |
| 254 | + for (0..16) |i| { |
| 255 | + weights[i] = if (i % 2 == 0) 1 else 0; |
| 256 | + activations[i] = @floatCast(@as(f32, @floatFromInt(i))); |
| 257 | + } |
| 258 | + |
| 259 | + const result = sparseTernaryDot(&weights, &activations); |
| 260 | + |
| 261 | + // Compute expected: only even indices contribute |
| 262 | + var expected: f64 = 0; |
| 263 | + for (0..16) |i| { |
| 264 | + if (i % 2 == 0) { |
| 265 | + const a_f32: f32 = @floatCast(activations[i]); |
| 266 | + expected += @as(f64, a_f32); |
| 267 | + } |
| 268 | + } |
| 269 | + |
| 270 | + try std.testing.expectApproxEqAbs(expected, result, 0.01); |
| 271 | +} |
| 272 | + |
| 273 | +test "sparse matvec matches dense" { |
| 274 | + const out_dim: usize = 4; |
| 275 | + const in_dim: usize = 8; |
| 276 | + |
| 277 | + // Create weights with some zero rows |
| 278 | + var weights: [out_dim * in_dim]i8 = undefined; |
| 279 | + for (0..out_dim) |row| { |
| 280 | + for (0..in_dim) |col| { |
| 281 | + const idx = row * in_dim + col; |
| 282 | + // Every other row is all zeros |
| 283 | + weights[idx] = if (row % 2 == 0) @as(i8, 1) else 0; |
| 284 | + } |
| 285 | + } |
| 286 | + |
| 287 | + const activations = [_]f16{0.1} ** in_dim; |
| 288 | + |
| 289 | + var sparse_output: [out_dim]f16 = undefined; |
| 290 | + var dense_output: [out_dim]f16 = undefined; |
| 291 | + |
| 292 | + sparseTernaryMatvec(&weights, &activations, &sparse_output, out_dim, in_dim); |
| 293 | + denseTernaryMatvec(&weights, &activations, &dense_output, out_dim, in_dim); |
| 294 | + |
| 295 | + for (sparse_output, dense_output) |s, d| { |
| 296 | + try std.testing.expectApproxEqAbs(@as(f64, @floatCast(d)), @as(f64, @floatCast(s)), 0.001); |
| 297 | + } |
| 298 | +} |
| 299 | + |
| 300 | +test "count zero chunks" { |
| 301 | + const all_zeros = [_]i8{0} ** 32; |
| 302 | + try std.testing.expectEqual(@as(usize, 2), countZeroChunks(&all_zeros)); |
| 303 | + |
| 304 | + const all_ones = [_]i8{1} ** 32; |
| 305 | + try std.testing.expectEqual(@as(usize, 0), countZeroChunks(&all_ones)); |
| 306 | + |
| 307 | + const half_zeros: [32]i8 = .{0} ** 16 ++ .{1} ** 16; |
| 308 | + try std.testing.expectEqual(@as(usize, 1), countZeroChunks(&half_zeros)); |
| 309 | +} |
| 310 | + |
| 311 | +test "sparsity ratio" { |
| 312 | + const all_zeros = [_]i8{0} ** 10; |
| 313 | + try std.testing.expectApproxEqAbs(@as(f64, 1.0), sparsityRatio(&all_zeros), 0.01); |
| 314 | + |
| 315 | + const all_ones = [_]i8{1} ** 10; |
| 316 | + try std.testing.expectApproxEqAbs(@as(f64, 0.0), sparsityRatio(&all_ones), 0.01); |
| 317 | + |
| 318 | + const half_zeros = [_]i8{0} ** 5 ++ [_]i8{1} ** 5; |
| 319 | + try std.testing.expectApproxEqAbs(@as(f64, 0.5), sparsityRatio(&half_zeros), 0.01); |
| 320 | +} |
| 321 | + |
| 322 | +test "estimate speedup" { |
| 323 | + const all_zeros = [_]i8{0} ** 32; |
| 324 | + const speedup_all_zeros = estimateSpeedup(&all_zeros); |
| 325 | + try std.testing.expect(speedup_all_zeros >= 1.5); // At least 1.5× if all chunks skipped |
| 326 | + |
| 327 | + const all_ones = [_]i8{1} ** 32; |
| 328 | + const speedup_all_ones = estimateSpeedup(&all_ones); |
| 329 | + try std.testing.expectApproxEqAbs(@as(f64, 1.0), speedup_all_ones, 0.1); // No speedup if dense |
| 330 | +} |
| 331 | + |
| 332 | +test "sparse dot product non-aligned length" { |
| 333 | + const weights = [_]i8{ 1, 0, -1, 0, 1, 0, -1, 0, 1, 0, -1 }; |
| 334 | + const activations = [_]f16{ 0.5, 0.3, 0.7, 0.2, 0.5, 0.3, 0.7, 0.2, 0.5, 0.3, 0.7 }; |
| 335 | + |
| 336 | + // Should not crash, should produce correct result |
| 337 | + const result = sparseTernaryDot(&weights, &activations); |
| 338 | + try std.testing.expect(std.math.isFinite(result)); |
| 339 | +} |
| 340 | + |
| 341 | +test "sparse matvec single row" { |
| 342 | + const weights = [_]i8{1, 0, -1, 1}; |
| 343 | + const activations = [_]f16{ 0.5, 0.3, -0.7, 0.2 }; |
| 344 | + |
| 345 | + var output: [1]f16 = undefined; |
| 346 | + sparseTernaryMatvec(&weights, &activations, &output, 1, 4); |
| 347 | + |
| 348 | + const expected: f64 = 0.5 + 0 + 0.7 + 0.2; // 1*0.5 + 0*0.3 + (-1)*(-0.7) + 1*0.2 |
| 349 | + try std.testing.expectApproxEqAbs(expected, @as(f64, @floatCast(output[0])), 0.01); |
| 350 | +} |
| 351 | + |
| 352 | +// φ² + 1/φ² = 3 | TRINITY |
0 commit comments