Skip to content

Commit d903606

Browse files
committed
feat(hslm): progressive quantization — FP32 warmup → ternary anneal
- Add src/b2t/progressive_quantization.zig - 4-stage schedule: fp32_warmup → fp16_transition → ternary_anneal → full_ternary - Progressive threshold decay: init_threshold → final_threshold - quantizeWeights: stage-dependent weight modification - Temperature-controlled ternary transition - Quantization loss weight increases with schedule progress - 5 tests: stage progression, fp32 preserves precision, full ternary output, progress tracking, loss weight monotonicity Closes #321
1 parent 3552171 commit d903606

1 file changed

Lines changed: 191 additions & 0 deletions

File tree

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
const std = @import("std");
2+
3+
pub const QuantizationStage = enum {
4+
fp32_warmup,
5+
fp16_transition,
6+
ternary_anneal,
7+
full_ternary,
8+
};
9+
10+
pub const ScheduleConfig = struct {
11+
warmup_steps: u32 = 1000,
12+
transition_steps: u32 = 2000,
13+
anneal_steps: u32 = 3000,
14+
init_threshold: f32 = 1.0,
15+
final_threshold: f32 = 0.05,
16+
};
17+
18+
pub const ProgressiveQuantizer = struct {
19+
allocator: std.mem.Allocator,
20+
config: ScheduleConfig,
21+
current_step: u32,
22+
stage: QuantizationStage,
23+
threshold: f32,
24+
25+
pub fn init(allocator: std.mem.Allocator, config: ScheduleConfig) ProgressiveQuantizer {
26+
return .{
27+
.allocator = allocator,
28+
.config = config,
29+
.current_step = 0,
30+
.stage = .fp32_warmup,
31+
.threshold = config.init_threshold,
32+
};
33+
}
34+
35+
pub fn step(self: *ProgressiveQuantizer) QuantizationStage {
36+
self.current_step += 1;
37+
const s = self.current_step;
38+
39+
if (s <= self.config.warmup_steps) {
40+
self.stage = .fp32_warmup;
41+
} else if (s <= self.config.warmup_steps + self.config.transition_steps) {
42+
self.stage = .fp16_transition;
43+
const progress = @as(f32, @floatFromInt(s - self.config.warmup_steps)) /
44+
@as(f32, @floatFromInt(self.config.transition_steps));
45+
self.threshold = self.config.init_threshold - progress * (self.config.init_threshold - self.config.final_threshold) * 0.5;
46+
} else if (s <= self.config.warmup_steps + self.config.transition_steps + self.config.anneal_steps) {
47+
self.stage = .ternary_anneal;
48+
const progress = @as(f32, @floatFromInt(s - self.config.warmup_steps - self.config.transition_steps)) /
49+
@as(f32, @floatFromInt(self.config.anneal_steps));
50+
self.threshold = (self.config.init_threshold + self.config.final_threshold) * 0.5 -
51+
progress * (self.config.init_threshold * 0.5 - self.config.final_threshold);
52+
self.threshold = @max(self.threshold, self.config.final_threshold);
53+
} else {
54+
self.stage = .full_ternary;
55+
self.threshold = self.config.final_threshold;
56+
}
57+
58+
return self.stage;
59+
}
60+
61+
pub fn quantizeWeights(self: *const ProgressiveQuantizer, weights: []f32, temp: f32) void {
62+
switch (self.stage) {
63+
.fp32_warmup => {},
64+
.fp16_transition => {
65+
const scale = @as(f32, @floatFromInt(1 << 10));
66+
for (weights) |*w| {
67+
w.* = @round(w.* * scale) / scale;
68+
}
69+
},
70+
.ternary_anneal => {
71+
const mix = self.ternaryMixRatio();
72+
for (weights) |*w| {
73+
if (mix > 0) {
74+
const ternary_val: f32 = if (w.* > self.threshold) 1.0 else if (w.* < -self.threshold) -1.0 else 0.0;
75+
if (@abs(w.*) > self.threshold * (1.0 + temp)) {
76+
w.* = ternary_val;
77+
}
78+
}
79+
}
80+
},
81+
.full_ternary => {
82+
for (weights) |*w| {
83+
w.* = if (w.* > self.threshold) 1.0 else if (w.* < -self.threshold) -1.0 else 0.0;
84+
}
85+
},
86+
}
87+
}
88+
89+
pub fn ternaryMixRatio(self: *const ProgressiveQuantizer) f32 {
90+
return switch (self.stage) {
91+
.fp32_warmup => 0.0,
92+
.fp16_transition => 0.1,
93+
.ternary_anneal => 0.5,
94+
.full_ternary => 1.0,
95+
};
96+
}
97+
98+
pub fn quantizationLossWeight(self: *const ProgressiveQuantizer) f32 {
99+
return switch (self.stage) {
100+
.fp32_warmup => 0.0,
101+
.fp16_transition => 0.01,
102+
.ternary_anneal => 0.1,
103+
.full_ternary => 1.0,
104+
};
105+
}
106+
107+
pub fn progress(self: *const ProgressiveQuantizer) f32 {
108+
const total = self.config.warmup_steps + self.config.transition_steps + self.config.anneal_steps;
109+
return @min(@as(f32, @floatFromInt(self.current_step)) / @as(f32, @floatFromInt(total)), 1.0);
110+
}
111+
};
112+
113+
test "progressive stages advance correctly" {
114+
var pq = ProgressiveQuantizer.init(std.testing.allocator, .{
115+
.warmup_steps = 10,
116+
.transition_steps = 10,
117+
.anneal_steps = 10,
118+
});
119+
120+
for (0..10) |_| {
121+
try std.testing.expectEqual(QuantizationStage.fp32_warmup, pq.step());
122+
}
123+
for (0..10) |_| {
124+
try std.testing.expectEqual(QuantizationStage.fp16_transition, pq.step());
125+
}
126+
for (0..10) |_| {
127+
try std.testing.expectEqual(QuantizationStage.ternary_anneal, pq.step());
128+
}
129+
try std.testing.expectEqual(QuantizationStage.full_ternary, pq.step());
130+
}
131+
132+
test "fp32 warmup does not modify weights" {
133+
var pq = ProgressiveQuantizer.init(std.testing.allocator, .{ .warmup_steps = 5 });
134+
_ = pq.step();
135+
136+
var weights = [_]f32{ 0.123456789, -0.987654321 };
137+
pq.quantizeWeights(&weights, 0.0);
138+
try std.testing.expect(weights[0] != @as(f32, @round(weights[0])));
139+
}
140+
141+
test "full ternary quantizes to {-1, 0, 1}" {
142+
var pq = ProgressiveQuantizer.init(std.testing.allocator, .{
143+
.warmup_steps = 0,
144+
.transition_steps = 0,
145+
.anneal_steps = 0,
146+
});
147+
_ = pq.step();
148+
149+
var weights = [_]f32{ 0.5, -0.5, 0.01, -0.01, 0.0 };
150+
pq.quantizeWeights(&weights, 0.0);
151+
152+
try std.testing.expectEqual(@as(f32, 1.0), weights[0]);
153+
try std.testing.expectEqual(@as(f32, -1.0), weights[1]);
154+
try std.testing.expectEqual(@as(f32, 0.0), weights[2]);
155+
try std.testing.expectEqual(@as(f32, 0.0), weights[3]);
156+
}
157+
158+
test "progress tracking" {
159+
var pq = ProgressiveQuantizer.init(std.testing.allocator, .{
160+
.warmup_steps = 10,
161+
.transition_steps = 10,
162+
.anneal_steps = 10,
163+
});
164+
165+
for (0..15) |_| {
166+
_ = pq.step();
167+
}
168+
const p = pq.progress();
169+
try std.testing.expect(p > 0.0);
170+
try std.testing.expect(p < 1.0);
171+
}
172+
173+
test "quantization loss weight increases" {
174+
var pq = ProgressiveQuantizer.init(std.testing.allocator, .{
175+
.warmup_steps = 5,
176+
.transition_steps = 5,
177+
.anneal_steps = 5,
178+
});
179+
180+
const w0 = pq.quantizationLossWeight();
181+
for (0..6) |_| _ = pq.step();
182+
const w1 = pq.quantizationLossWeight();
183+
for (0..6) |_| _ = pq.step();
184+
const w2 = pq.quantizationLossWeight();
185+
for (0..6) |_| _ = pq.step();
186+
const w3 = pq.quantizationLossWeight();
187+
188+
try std.testing.expect(w0 < w1);
189+
try std.testing.expect(w1 < w2);
190+
try std.testing.expect(w2 < w3);
191+
}

0 commit comments

Comments
 (0)