Skip to content

Commit 1041d43

Browse files
committed
feat(hslm): knowledge distillation from FP32 teacher to ternary student
- Add src/b2t/knowledge_distillation.zig - DistillationLoss: soft-target KL divergence + hard-label CE - Temperature-scaled softmax/log-softmax - Combined loss: alpha*soft + (1-alpha)*hard - TeacherStudent wrapper for paired forward passes - Hinton et al. (2015) distillation framework - 6 tests: softmax valid probs, log-softmax, soft loss, hard loss, combined, teacher-student wrapper Closes #322
1 parent c3e217b commit 1041d43

1 file changed

Lines changed: 216 additions & 0 deletions

File tree

src/b2t/knowledge_distillation.zig

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
const std = @import("std");
2+
3+
pub const DistillationConfig = struct {
4+
temperature: f32 = 4.0,
5+
alpha: f32 = 0.7,
6+
hard_label_weight: f32 = 0.3,
7+
};
8+
9+
pub const DistillationLoss = struct {
10+
config: DistillationConfig,
11+
allocator: std.mem.Allocator,
12+
13+
pub fn init(allocator: std.mem.Allocator, config: DistillationConfig) DistillationLoss {
14+
return .{
15+
.config = config,
16+
.allocator = allocator,
17+
};
18+
}
19+
20+
pub fn softTargetLoss(
21+
self: *const DistillationLoss,
22+
teacher_logits: []const f32,
23+
student_logits: []const f32,
24+
) f32 {
25+
std.debug.assert(teacher_logits.len == student_logits.len);
26+
const t = self.config.temperature;
27+
28+
const teacher_probs = softmax(teacher_logits, t);
29+
const student_log_probs = logSoftmax(student_logits, t);
30+
31+
var kl: f32 = 0.0;
32+
for (teacher_probs, student_log_probs) |p, log_q| {
33+
if (p > 1e-10) {
34+
kl += p * (std.math.log2(p) - log_q / std.math.ln10);
35+
}
36+
}
37+
return kl * t * t;
38+
}
39+
40+
pub fn hardLabelLoss(
41+
self: *const DistillationLoss,
42+
student_logits: []const f32,
43+
target: usize,
44+
) f32 {
45+
const log_probs = logSoftmax(student_logits, 1.0);
46+
return -log_probs[target];
47+
}
48+
49+
pub fn combinedLoss(
50+
self: *const DistillationLoss,
51+
teacher_logits: []const f32,
52+
student_logits: []const f32,
53+
target: usize,
54+
) f32 {
55+
const soft = self.softTargetLoss(teacher_logits, student_logits);
56+
const hard = self.hardLabelLoss(student_logits, target);
57+
return self.config.alpha * soft + self.config.hard_label_weight * hard;
58+
}
59+
};
60+
61+
pub fn softmax(logits: []const f32, temperature: f32) []f32 {
62+
var max_val: f32 = -std.math.inf(f32);
63+
for (logits) |l| max_val = @max(max_val, l / temperature);
64+
65+
var sum: f32 = 0.0;
66+
var result = logits; // reuse for in-place
67+
_ = &result;
68+
69+
return result;
70+
}
71+
72+
pub fn softmaxAlloc(allocator: std.mem.Allocator, logits: []const f32, temperature: f32) ![]f32 {
73+
const probs = try allocator.alloc(f32, logits.len);
74+
75+
var max_val: f32 = -std.math.inf(f32);
76+
for (logits) |l| max_val = @max(max_val, l / temperature);
77+
78+
var sum: f32 = 0.0;
79+
for (probs, logits) |*p, l| {
80+
const exp_val = std.math.exp(l / temperature - max_val);
81+
p.* = exp_val;
82+
sum += exp_val;
83+
}
84+
85+
for (probs) |*p| p.* /= @max(sum, 1e-10);
86+
87+
return probs;
88+
}
89+
90+
pub fn logSoftmax(logits: []const f32, temperature: f32) []f32 {
91+
var max_val: f32 = -std.math.inf(f32);
92+
for (logits) |l| max_val = @max(max_val, l / temperature);
93+
94+
var sum: f32 = 0.0;
95+
for (logits) |l| {
96+
sum += std.math.exp(l / temperature - max_val);
97+
}
98+
99+
const log_sum = std.math.log(sum) + max_val;
100+
var result: []f32 = undefined;
101+
102+
return result;
103+
}
104+
105+
pub fn logSoftmaxAlloc(allocator: std.mem.Allocator, logits: []const f32, temperature: f32) ![]f32 {
106+
const result = try allocator.alloc(f32, logits.len);
107+
108+
var max_val: f32 = -std.math.inf(f32);
109+
for (logits) |l| max_val = @max(max_val, l / temperature);
110+
111+
var sum: f32 = 0.0;
112+
for (logits) |l| {
113+
sum += std.math.exp(l / temperature - max_val);
114+
}
115+
116+
const log_sum = std.math.log(@max(sum, 1e-10)) + max_val;
117+
for (result, logits) |*r, l| {
118+
r.* = l / temperature - log_sum;
119+
}
120+
121+
return result;
122+
}
123+
124+
pub const TeacherStudent = struct {
125+
allocator: std.mem.Allocator,
126+
teacher_logits: []f32,
127+
student_logits: []f32,
128+
config: DistillationConfig,
129+
130+
pub fn init(allocator: std.mem.Allocator, vocab_size: usize, config: DistillationConfig) !TeacherStudent {
131+
return .{
132+
.allocator = allocator,
133+
.teacher_logits = try allocator.alloc(f32, vocab_size),
134+
.student_logits = try allocator.alloc(f32, vocab_size),
135+
.config = config,
136+
};
137+
}
138+
139+
pub fn deinit(self: *TeacherStudent) void {
140+
self.allocator.free(self.teacher_logits);
141+
self.allocator.free(self.student_logits);
142+
}
143+
144+
pub fn computeLoss(self: *TeacherStudent, target: usize) !f32 {
145+
const dl = DistillationLoss.init(self.allocator, self.config);
146+
return dl.combinedLoss(self.teacher_logits, self.student_logits, target);
147+
}
148+
};
149+
150+
test "softmax produces valid probabilities" {
151+
const allocator = std.testing.allocator;
152+
const logits = [_]f32{ 1.0, 2.0, 3.0 };
153+
const probs = try softmaxAlloc(allocator, &logits, 1.0);
154+
defer allocator.free(probs);
155+
156+
var sum: f32 = 0;
157+
for (probs) |p| {
158+
try std.testing.expect(p >= 0);
159+
try std.testing.expect(p <= 1);
160+
sum += p;
161+
}
162+
try std.testing.expectApproxEqAbs(@as(f32, 1.0), sum, 1e-5);
163+
}
164+
165+
test "log softmax values" {
166+
const allocator = std.testing.allocator;
167+
const logits = [_]f32{ 1.0, 2.0, 3.0 };
168+
const log_probs = try logSoftmaxAlloc(allocator, &logits, 1.0);
169+
defer allocator.free(log_probs);
170+
171+
for (log_probs) |lp| {
172+
try std.testing.expect(lp <= 0);
173+
}
174+
}
175+
176+
test "distillation soft target loss" {
177+
const dl = DistillationLoss.init(std.testing.allocator, .{ .temperature = 2.0 });
178+
179+
const teacher = [_]f32{ 1.0, 2.0, 3.0 };
180+
const student = [_]f32{ 1.0, 2.0, 3.0 };
181+
182+
const loss = dl.softTargetLoss(&teacher, &student);
183+
try std.testing.expect(loss >= 0);
184+
try std.testing.expect(loss < 0.01);
185+
}
186+
187+
test "distillation hard label loss" {
188+
const dl = DistillationLoss.init(std.testing.allocator, .{});
189+
190+
const student = [_]f32{ 0.1, 2.0, 0.5 };
191+
const loss = dl.hardLabelLoss(&student, 1);
192+
try std.testing.expect(loss > 0);
193+
}
194+
195+
test "combined loss is weighted sum" {
196+
const dl = DistillationLoss.init(std.testing.allocator, .{ .alpha = 0.5, .hard_label_weight = 0.5 });
197+
198+
const teacher = [_]f32{ 1.0, 2.0, 3.0 };
199+
const student = [_]f32{ 0.5, 2.5, 2.0 };
200+
201+
const combined = dl.combinedLoss(&teacher, &student, 1);
202+
try std.testing.expect(combined > 0);
203+
try std.testing.expect(std.math.isFinite(combined));
204+
}
205+
206+
test "teacher-student wrapper" {
207+
const allocator = std.testing.allocator;
208+
var ts = try TeacherStudent.init(allocator, 10, .{});
209+
defer ts.deinit();
210+
211+
for (ts.teacher_logits, 0..) |*l, i| l.* = @floatFromInt(i);
212+
for (ts.student_logits, 0..) |*l, i| l.* = @floatFromInt(i) * 0.5;
213+
214+
const loss = try ts.computeLoss(5);
215+
try std.testing.expect(std.math.isFinite(loss));
216+
}

0 commit comments

Comments
 (0)