Skip to content

Commit 27a045b

Browse files
authored
feat(hslm): OHEM — online hard example mining for fast convergence (#564)
- Add src/b2t/ohem.zig - OhemMiner: select top-K% hardest examples by loss - Sort by loss descending, keep hard_ratio fraction - Configurable: hard_ratio (0.7), min_keep, loss_threshold - LossStats: mean, max, min, variance, above-mean count - 5 tests: hardest selection, min_keep, empty input, equal losses, loss stats Closes #318
1 parent cba9118 commit 27a045b

1 file changed

Lines changed: 186 additions & 0 deletions

File tree

src/b2t/ohem.zig

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
const std = @import("std");
2+
3+
pub const OhemConfig = struct {
4+
hard_ratio: f32 = 0.7,
5+
min_keep: usize = 1,
6+
loss_threshold: f32 = std.math.inf(f32),
7+
};
8+
9+
pub const OhemResult = struct {
10+
selected_indices: []usize,
11+
selected_losses: []f32,
12+
avg_loss: f32,
13+
hard_count: usize,
14+
total_count: usize,
15+
};
16+
17+
pub const OhemMiner = struct {
18+
allocator: std.mem.Allocator,
19+
config: OhemConfig,
20+
21+
pub fn init(allocator: std.mem.Allocator, config: OhemConfig) OhemMiner {
22+
return .{
23+
.allocator = allocator,
24+
.config = config,
25+
};
26+
}
27+
28+
pub fn mine(self: *OhemMiner, losses: []const f32) !OhemResult {
29+
const n = losses.len;
30+
if (n == 0) {
31+
return .{
32+
.selected_indices = &[_]usize{},
33+
.selected_losses = &[_]f32{},
34+
.avg_loss = 0.0,
35+
.hard_count = 0,
36+
.total_count = 0,
37+
};
38+
}
39+
40+
const indices = try self.allocator.alloc(usize, n);
41+
errdefer self.allocator.free(indices);
42+
for (indices, 0..) |*idx, i| idx.* = i;
43+
44+
const sorted = try self.allocator.dupe(usize, indices);
45+
errdefer self.allocator.free(sorted);
46+
47+
const SortCtx = struct {
48+
losses: []const f32,
49+
pub fn lessThan(ctx: @This(), a: usize, b: usize) bool {
50+
return ctx.losses[a] > ctx.losses[b];
51+
}
52+
};
53+
std.mem.sort(usize, sorted, SortCtx{ .losses = losses }, SortCtx.lessThan);
54+
55+
const keep_count = @max(
56+
@as(usize, @intFromFloat(@as(f32, @floatFromInt(n)) * self.config.hard_ratio)),
57+
self.config.min_keep,
58+
);
59+
const final_keep = @min(keep_count, n);
60+
61+
const selected = try self.allocator.alloc(usize, final_keep);
62+
const sel_losses = try self.allocator.alloc(f32, final_keep);
63+
64+
var sum: f32 = 0.0;
65+
for (0..final_keep) |i| {
66+
selected[i] = sorted[i];
67+
sel_losses[i] = losses[sorted[i]];
68+
sum += sel_losses[i];
69+
}
70+
71+
self.allocator.free(indices);
72+
self.allocator.free(sorted);
73+
74+
return .{
75+
.selected_indices = selected,
76+
.selected_losses = sel_losses,
77+
.avg_loss = if (final_keep > 0) sum / @as(f32, @floatFromInt(final_keep)) else 0.0,
78+
.hard_count = final_keep,
79+
.total_count = n,
80+
};
81+
}
82+
83+
pub fn deinitResult(self: *OhemMiner, result: *OhemResult) void {
84+
self.allocator.free(result.selected_indices);
85+
self.allocator.free(result.selected_losses);
86+
}
87+
88+
pub fn hardRatio(self: *const OhemMiner, result: *const OhemResult) f32 {
89+
if (result.total_count == 0) return 0.0;
90+
return @as(f32, @floatFromInt(result.hard_count)) / @as(f32, @floatFromInt(result.total_count));
91+
}
92+
};
93+
94+
pub const LossStats = struct {
95+
mean: f32,
96+
max: f32,
97+
min: f32,
98+
variance: f32,
99+
above_mean_count: usize,
100+
total: usize,
101+
102+
pub fn compute(losses: []const f32) LossStats {
103+
if (losses.len == 0) return .{ .mean = 0, .max = 0, .min = 0, .variance = 0, .above_mean_count = 0, .total = 0 };
104+
105+
var sum: f32 = 0;
106+
var mx: f32 = -std.math.inf(f32);
107+
var mn: f32 = std.math.inf(f32);
108+
for (losses) |l| {
109+
sum += l;
110+
mx = @max(mx, l);
111+
mn = @min(mn, l);
112+
}
113+
const mean = sum / @as(f32, @floatFromInt(losses.len));
114+
115+
var var_sum: f32 = 0;
116+
var above: usize = 0;
117+
for (losses) |l| {
118+
const d = l - mean;
119+
var_sum += d * d;
120+
if (l > mean) above += 1;
121+
}
122+
123+
return .{
124+
.mean = mean,
125+
.max = mx,
126+
.min = mn,
127+
.variance = var_sum / @as(f32, @floatFromInt(losses.len)),
128+
.above_mean_count = above,
129+
.total = losses.len,
130+
};
131+
}
132+
};
133+
134+
test "ohem selects hardest examples" {
135+
const allocator = std.testing.allocator;
136+
var miner = OhemMiner.init(allocator, .{ .hard_ratio = 0.5, .min_keep = 1 });
137+
138+
const losses = [_]f32{ 0.1, 0.9, 0.2, 0.8, 0.3 };
139+
var result = try miner.mine(&losses);
140+
defer miner.deinitResult(&result);
141+
142+
try std.testing.expectEqual(@as(usize, 2), result.hard_count);
143+
try std.testing.expect(result.selected_losses[0] >= result.selected_losses[1]);
144+
try std.testing.expect(result.avg_loss > 0.5);
145+
}
146+
147+
test "ohem respects min_keep" {
148+
const allocator = std.testing.allocator;
149+
var miner = OhemMiner.init(allocator, .{ .hard_ratio = 0.01, .min_keep = 3 });
150+
151+
const losses = [_]f32{ 0.1, 0.2, 0.3, 0.4, 0.5 };
152+
var result = try miner.mine(&losses);
153+
defer miner.deinitResult(&result);
154+
155+
try std.testing.expectEqual(@as(usize, 3), result.hard_count);
156+
}
157+
158+
test "ohem handles empty losses" {
159+
const allocator = std.testing.allocator;
160+
var miner = OhemMiner.init(allocator, .{});
161+
var result = try miner.mine(&[_]f32{});
162+
try std.testing.expectEqual(@as(usize, 0), result.hard_count);
163+
}
164+
165+
test "ohem all examples equally hard" {
166+
const allocator = std.testing.allocator;
167+
var miner = OhemMiner.init(allocator, .{ .hard_ratio = 0.5 });
168+
169+
const losses = [_]f32{ 0.5, 0.5, 0.5, 0.5 };
170+
var result = try miner.mine(&losses);
171+
defer miner.deinitResult(&result);
172+
173+
try std.testing.expectEqual(@as(usize, 2), result.hard_count);
174+
try std.testing.expectApproxEqAbs(@as(f32, 0.5), result.avg_loss, 1e-6);
175+
}
176+
177+
test "loss stats compute" {
178+
const losses = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0 };
179+
const stats = LossStats.compute(&losses);
180+
181+
try std.testing.expectApproxEqAbs(@as(f32, 3.0), stats.mean, 1e-6);
182+
try std.testing.expectEqual(@as(f32, 5.0), stats.max);
183+
try std.testing.expectEqual(@as(f32, 1.0), stats.min);
184+
try std.testing.expect(stats.variance > 0);
185+
try std.testing.expectEqual(@as(usize, 2), stats.above_mean_count);
186+
}

0 commit comments

Comments
 (0)