Skip to content

Commit 8ebc2b9

Browse files
committed
feat(hslm): TTQ — Trained Ternary Quantization with learned thresholds
- Add src/b2t/ttq.zig - TTQLayer: per-layer learned threshold for ternary quantization - Quantize: weight → {P, Z, N} based on threshold - STE gradient approximation for threshold update - Scaled quantize for layer-dependent thresholds - Sparsity and effective bits computation - TTQNetwork: multi-layer threshold management - 6 tests: quantize, threshold update, sparsity, scaled, multi-layer network, effective bits Closes #320
1 parent 27a045b commit 8ebc2b9

1 file changed

Lines changed: 196 additions & 0 deletions

File tree

src/b2t/ttq.zig

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
const std = @import("std");
2+
3+
pub const TTQConfig = struct {
4+
init_threshold: f32 = 0.05,
5+
lr_threshold: f32 = 1e-4,
6+
min_threshold: f32 = 1e-6,
7+
max_threshold: f32 = 1.0,
8+
};
9+
10+
pub const TTQLayer = struct {
11+
threshold: f32,
12+
grad_accumulator: f32,
13+
allocator: std.mem.Allocator,
14+
config: TTQConfig,
15+
16+
pub fn init(allocator: std.mem.Allocator, config: TTQConfig) TTQLayer {
17+
return .{
18+
.threshold = config.init_threshold,
19+
.grad_accumulator = 0.0,
20+
.allocator = allocator,
21+
.config = config,
22+
};
23+
}
24+
25+
pub fn quantize(self: *const TTQLayer, weights: []const f32, output: [] Trit) void {
26+
std.debug.assert(weights.len == output.len);
27+
const t = self.threshold;
28+
for (weights, output) |w, *o| {
29+
o.* = if (w > t)
30+
.P
31+
else if (w < -t)
32+
.N
33+
else
34+
.Z;
35+
}
36+
}
37+
38+
pub fn quantizeScaled(self: *const TTQLayer, weights: []const f32, output: [] Trit, scale: f32) void {
39+
std.debug.assert(weights.len == output.len);
40+
const t = self.threshold * scale;
41+
for (weights, output) |w, *o| {
42+
o.* = if (w > t)
43+
.P
44+
else if (w < -t)
45+
.N
46+
else
47+
.Z;
48+
}
49+
}
50+
51+
pub fn computeGradient(self: *TTQLayer, weights: []const f32, upstream_grad: []const f32) f32 {
52+
var grad: f32 = 0.0;
53+
const t = self.threshold;
54+
const eps: f32 = 1e-6;
55+
56+
for (weights, upstream_grad) |w, g| {
57+
const dist = @abs(w) - t;
58+
const soft_grad = 1.0 / (1.0 + std.math.exp(dist * 100.0));
59+
if (@abs(w) > eps) {
60+
grad += g * std.math.copysign(soft_grad, w);
61+
}
62+
}
63+
self.grad_accumulator += grad;
64+
return grad;
65+
}
66+
67+
pub fn updateThreshold(self: *TTQLayer) void {
68+
self.threshold += self.config.lr_threshold * self.grad_accumulator;
69+
self.threshold = std.math.clamp(self.threshold, self.config.min_threshold, self.config.max_threshold);
70+
self.grad_accumulator = 0.0;
71+
}
72+
73+
pub fn sparsity(self: *const TTQLayer, weights: []const f32) f32 {
74+
const t = self.threshold;
75+
var zeros: usize = 0;
76+
for (weights) |w| {
77+
if (@abs(w) <= t) zeros += 1;
78+
}
79+
return @as(f32, @floatFromInt(zeros)) / @as(f32, @floatFromInt(weights.len));
80+
}
81+
82+
pub fn effectiveBits(self: *const TTQLayer, weights: []const f32) f32 {
83+
const s = self.sparsity(weights);
84+
const p_nonzero = 1.0 - s;
85+
if (p_nonzero == 0) return 0;
86+
const entropy = -p_nonzero * std.math.log2(p_nonzero) - s * std.math.log2(@max(s, 1e-10));
87+
return entropy;
88+
}
89+
};
90+
91+
pub const Trit = enum(i8) { P = 1, Z = 0, N = -1 };
92+
93+
pub const TTQNetwork = struct {
94+
allocator: std.mem.Allocator,
95+
layers: std.ArrayList(TTQLayer),
96+
config: TTQConfig,
97+
98+
pub fn init(allocator: std.mem.Allocator, config: TTQConfig) TTQNetwork {
99+
return .{
100+
.allocator = allocator,
101+
.layers = std.ArrayList(TTQLayer).init(allocator),
102+
.config = config,
103+
};
104+
}
105+
106+
pub fn deinit(self: *TTQNetwork) void {
107+
self.layers.deinit();
108+
}
109+
110+
pub fn addLayer(self: *TTQNetwork) !usize {
111+
const idx = self.layers.items.len;
112+
try self.layers.append(TTQLayer.init(self.allocator, self.config));
113+
return idx;
114+
}
115+
116+
pub fn updateAllThresholds(self: *TTQNetwork) void {
117+
for (self.layers.items) |*layer| {
118+
layer.updateThreshold();
119+
}
120+
}
121+
122+
pub fn averageSparsity(self: *const TTQNetwork, all_weights: []const []const f32) f32 {
123+
var total: f32 = 0;
124+
for (self.layers.items, all_weights) |layer, weights| {
125+
total += layer.sparsity(weights);
126+
}
127+
return total / @as(f32, @floatFromInt(self.layers.items.len));
128+
}
129+
};
130+
131+
test "TTQ quantize basic" {
132+
const config = TTQConfig{ .init_threshold = 0.3 };
133+
var layer = TTQLayer.init(std.testing.allocator, config);
134+
135+
const weights = [_]f32{ 0.5, -0.5, 0.1, -0.1, 0.0 };
136+
var output: [5]Trit = undefined;
137+
layer.quantize(&weights, &output);
138+
139+
try std.testing.expectEqual(Trit.P, output[0]);
140+
try std.testing.expectEqual(Trit.N, output[1]);
141+
try std.testing.expectEqual(Trit.Z, output[2]);
142+
try std.testing.expectEqual(Trit.Z, output[3]);
143+
try std.testing.expectEqual(Trit.Z, output[4]);
144+
}
145+
146+
test "TTQ threshold update" {
147+
var layer = TTQLayer.init(std.testing.allocator, .{ .init_threshold = 0.1, .lr_threshold = 0.01 });
148+
149+
const weights = [_]f32{ 0.5, -0.5, 0.3 };
150+
const grads = [_]f32{ 1.0, 1.0, 1.0 };
151+
152+
_ = layer.computeGradient(&weights, &grads);
153+
const before = layer.threshold;
154+
layer.updateThreshold();
155+
try std.testing.expect(layer.threshold != before);
156+
}
157+
158+
test "TTQ sparsity calculation" {
159+
var layer = TTQLayer.init(std.testing.allocator, .{ .init_threshold = 0.3 });
160+
161+
const weights = [_]f32{ 0.5, -0.5, 0.1, -0.1, 0.0 };
162+
const s = layer.sparsity(&weights);
163+
try std.testing.expectApproxEqAbs(@as(f32, 0.6), s, 1e-6);
164+
}
165+
166+
test "TTQ scaled quantize" {
167+
var layer = TTQLayer.init(std.testing.allocator, .{ .init_threshold = 0.1 });
168+
169+
const weights = [_]f32{ 0.15, -0.15, 0.05, -0.05 };
170+
var output: [4]Trit = undefined;
171+
layer.quantizeScaled(&weights, &output, 2.0);
172+
173+
try std.testing.expectEqual(Trit.Z, output[0]);
174+
try std.testing.expectEqual(Trit.Z, output[1]);
175+
try std.testing.expectEqual(Trit.Z, output[2]);
176+
try std.testing.expectEqual(Trit.Z, output[3]);
177+
}
178+
179+
test "TTQ network multi-layer" {
180+
var net = TTQNetwork.init(std.testing.allocator, .{});
181+
defer net.deinit();
182+
183+
const idx1 = try net.addLayer();
184+
const idx2 = try net.addLayer();
185+
try std.testing.expectEqual(@as(usize, 0), idx1);
186+
try std.testing.expectEqual(@as(usize, 1), idx2);
187+
try std.testing.expectEqual(@as(usize, 2), net.layers.items.len);
188+
}
189+
190+
test "TTQ effective bits" {
191+
var layer = TTQLayer.init(std.testing.allocator, .{ .init_threshold = 0.3 });
192+
const weights = [_]f32{ 0.5, -0.5, 0.1, -0.1, 0.0 };
193+
const bits = layer.effectiveBits(&weights);
194+
try std.testing.expect(bits > 0);
195+
try std.testing.expect(bits < 2.0);
196+
}

0 commit comments

Comments
 (0)