Skip to content

Commit af32694

Browse files
authored
feat(hslm): sparse ternary matmul — skip zero weights (#561)
- Add src/b2t/sparse_ternary.zig - CSR-like sparse storage: only non-zero entries stored - SparseTernaryMatrix: init from dense, sparsity tracking - matmul: iterate nnz entries only (O(nnz) vs O(rows*cols)) - matmulBatch: batched version for sequences - SparseStats: nnz, sparsity%, compression ratio - 5 tests: matches dense, skips zeros, all-nonzero, stats, batch Closes #316
1 parent c128444 commit af32694

1 file changed

Lines changed: 203 additions & 0 deletions

File tree

src/b2t/sparse_ternary.zig

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
const std = @import("std");
2+
3+
pub const Trit = enum(i8) { P = 1, Z = 0, N = -1 };
4+
5+
pub const SparseEntry = packed struct {
6+
row: u16,
7+
col: u16,
8+
value: Trit,
9+
};
10+
11+
pub const SparseTernaryMatrix = struct {
12+
entries: []SparseEntry,
13+
rows: usize,
14+
cols: usize,
15+
nnz: usize,
16+
17+
pub fn init(allocator: std.mem.Allocator, dense_rows: usize, dense_cols: usize, weights: []const Trit) !SparseTernaryMatrix {
18+
var nnz_count: usize = 0;
19+
for (weights) |w| {
20+
if (w != .Z) nnz_count += 1;
21+
}
22+
23+
const entries = try allocator.alloc(SparseEntry, nnz_count);
24+
var idx: usize = 0;
25+
for (weights, 0..) |w, flat| {
26+
if (w == .Z) continue;
27+
const r = flat / dense_cols;
28+
const c = flat % dense_cols;
29+
entries[idx] = SparseEntry{
30+
.row = @intCast(r),
31+
.col = @intCast(c),
32+
.value = w,
33+
};
34+
idx += 1;
35+
}
36+
37+
return .{
38+
.entries = entries,
39+
.rows = dense_rows,
40+
.cols = dense_cols,
41+
.nnz = nnz_count,
42+
};
43+
}
44+
45+
pub fn deinit(self: *SparseTernaryMatrix, allocator: std.mem.Allocator) void {
46+
allocator.free(self.entries);
47+
}
48+
49+
pub fn sparsity(self: *const SparseTernaryMatrix) f32 {
50+
const total: f32 = @floatFromInt(self.rows * self.cols);
51+
return 1.0 - @as(f32, @floatFromInt(self.nnz)) / total;
52+
}
53+
54+
pub fn matmul(self: *const SparseTernaryMatrix, input: []const f32, output: []f32) void {
55+
std.debug.assert(input.len >= self.cols);
56+
std.debug.assert(output.len >= self.rows);
57+
58+
@memset(output[0..self.rows], 0);
59+
60+
for (self.entries[0..self.nnz]) |entry| {
61+
const val: f32 = switch (entry.value) {
62+
.P => input[entry.col],
63+
.N => -input[entry.col],
64+
.Z => 0.0,
65+
};
66+
output[entry.row] += val;
67+
}
68+
}
69+
70+
pub fn matmulBatch(self: *const SparseTernaryMatrix, inputs: []const f32, outputs: []f32, batch_size: usize, seq_len: usize) void {
71+
for (0..batch_size * seq_len) |b| {
72+
const in_offset = b * self.cols;
73+
const out_offset = b * self.rows;
74+
self.matmul(inputs[in_offset..][0..self.cols], outputs[out_offset..][0..self.rows]);
75+
}
76+
}
77+
};
78+
79+
pub const SparseStats = struct {
80+
nnz: usize,
81+
total: usize,
82+
sparsity: f32,
83+
compression_ratio: f32,
84+
85+
pub fn format(self: SparseStats, writer: anytype) !void {
86+
try writer.print("SparseStats(nnz={}, total={}, sparsity={d:.1}%, compression={d:.1}x)\n", .{
87+
self.nnz,
88+
self.total,
89+
self.sparsity * 100.0,
90+
self.compression_ratio,
91+
});
92+
}
93+
};
94+
95+
pub fn computeStats(matrix: *const SparseTernaryMatrix) SparseStats {
96+
const total = matrix.rows * matrix.cols;
97+
const ratio: f32 = if (matrix.nnz > 0)
98+
@as(f32, @floatFromInt(total)) / @as(f32, @floatFromInt(matrix.nnz))
99+
else
100+
0.0;
101+
return .{
102+
.nnz = matrix.nnz,
103+
.total = total,
104+
.sparsity = matrix.sparsity(),
105+
.compression_ratio = ratio,
106+
};
107+
}
108+
109+
test "sparse matmul matches dense" {
110+
const allocator = std.testing.allocator;
111+
112+
const rows: usize = 3;
113+
const cols: usize = 4;
114+
const weights = [_]Trit{ .P, .Z, .N, .P, .Z, .P, .Z, .N, .N, .P, .Z, .Z };
115+
116+
var sparse = try SparseTernaryMatrix.init(allocator, rows, cols, &weights);
117+
defer sparse.deinit(allocator);
118+
119+
const input = [_]f32{ 1.0, 2.0, 3.0, 4.0 };
120+
var sparse_out: [3]f32 = undefined;
121+
sparse.matmul(&input, &sparse_out);
122+
123+
var dense_out: [3]f32 = [_]f32{0} ** 3;
124+
for (0..rows) |r| {
125+
for (0..cols) |c| {
126+
const w: f32 = switch (weights[r * cols + c]) {
127+
.P => 1.0,
128+
.N => -1.0,
129+
.Z => 0.0,
130+
};
131+
dense_out[r] += w * input[c];
132+
}
133+
}
134+
135+
for (0..rows) |i| {
136+
try std.testing.expectApproxEqAbs(dense_out[i], sparse_out[i], 1e-6);
137+
}
138+
}
139+
140+
test "sparse matmul skips zeros" {
141+
const allocator = std.testing.allocator;
142+
143+
const weights = [_]Trit{ .Z, .Z, .Z, .Z };
144+
var sparse = try SparseTernaryMatrix.init(allocator, 2, 2, &weights);
145+
defer sparse.deinit(allocator);
146+
147+
try std.testing.expectEqual(@as(usize, 0), sparse.nnz);
148+
try std.testing.expectEqual(@as(f32, 1.0), sparse.sparsity());
149+
150+
const input = [_]f32{ 1.0, 2.0 };
151+
var output: [2]f32 = undefined;
152+
sparse.matmul(&input, &output);
153+
for (output) |v| try std.testing.expect(v == 0.0);
154+
}
155+
156+
test "sparse matmul all non-zero" {
157+
const allocator = std.testing.allocator;
158+
159+
const weights = [_]Trit{ .P, .N, .N, .P };
160+
var sparse = try SparseTernaryMatrix.init(allocator, 2, 2, &weights);
161+
defer sparse.deinit(allocator);
162+
163+
try std.testing.expectEqual(@as(usize, 4), sparse.nnz);
164+
try std.testing.expectEqual(@as(f32, 0.0), sparse.sparsity());
165+
166+
const input = [_]f32{ 3.0, 5.0 };
167+
var output: [2]f32 = undefined;
168+
sparse.matmul(&input, &output);
169+
170+
try std.testing.expectApproxEqAbs(@as(f32, -2.0), output[0], 1e-6);
171+
try std.testing.expectApproxEqAbs(@as(f32, 2.0), output[1], 1e-6);
172+
}
173+
174+
test "sparse stats" {
175+
const allocator = std.testing.allocator;
176+
177+
const weights = [_]Trit{ .P, .Z, .Z, .P, .Z, .N, .P, .Z, .N };
178+
var sparse = try SparseTernaryMatrix.init(allocator, 3, 3, &weights);
179+
defer sparse.deinit(allocator);
180+
181+
const stats = computeStats(&sparse);
182+
try std.testing.expectEqual(@as(usize, 5), stats.nnz);
183+
try std.testing.expectEqual(@as(usize, 9), stats.total);
184+
try std.testing.expect(stats.sparsity > 0.4);
185+
try std.testing.expect(stats.compression_ratio > 1.0);
186+
}
187+
188+
test "batch sparse matmul" {
189+
const allocator = std.testing.allocator;
190+
191+
const weights = [_]Trit{ .P, .N, .Z, .P };
192+
var sparse = try SparseTernaryMatrix.init(allocator, 2, 2, &weights);
193+
defer sparse.deinit(allocator);
194+
195+
const inputs = [_]f32{ 1.0, 2.0, 3.0, 4.0 };
196+
var outputs: [4]f32 = undefined;
197+
sparse.matmulBatch(&inputs, &outputs, 2, 1);
198+
199+
try std.testing.expectApproxEqAbs(@as(f32, -1.0), outputs[0], 1e-6);
200+
try std.testing.expectApproxEqAbs(@as(f32, 2.0), outputs[1], 1e-6);
201+
try std.testing.expectApproxEqAbs(@as(f32, -1.0), outputs[2], 1e-6);
202+
try std.testing.expectApproxEqAbs(@as(f32, 4.0), outputs[3], 1e-6);
203+
}

0 commit comments

Comments
 (0)