|
| 1 | +const std = @import("std"); |
| 2 | + |
| 3 | +pub const OhemConfig = struct { |
| 4 | + hard_ratio: f32 = 0.7, |
| 5 | + min_keep: usize = 1, |
| 6 | + loss_threshold: f32 = std.math.inf(f32), |
| 7 | +}; |
| 8 | + |
| 9 | +pub const OhemResult = struct { |
| 10 | + selected_indices: []usize, |
| 11 | + selected_losses: []f32, |
| 12 | + avg_loss: f32, |
| 13 | + hard_count: usize, |
| 14 | + total_count: usize, |
| 15 | +}; |
| 16 | + |
| 17 | +pub const OhemMiner = struct { |
| 18 | + allocator: std.mem.Allocator, |
| 19 | + config: OhemConfig, |
| 20 | + |
| 21 | + pub fn init(allocator: std.mem.Allocator, config: OhemConfig) OhemMiner { |
| 22 | + return .{ |
| 23 | + .allocator = allocator, |
| 24 | + .config = config, |
| 25 | + }; |
| 26 | + } |
| 27 | + |
| 28 | + pub fn mine(self: *OhemMiner, losses: []const f32) !OhemResult { |
| 29 | + const n = losses.len; |
| 30 | + if (n == 0) { |
| 31 | + return .{ |
| 32 | + .selected_indices = &[_]usize{}, |
| 33 | + .selected_losses = &[_]f32{}, |
| 34 | + .avg_loss = 0.0, |
| 35 | + .hard_count = 0, |
| 36 | + .total_count = 0, |
| 37 | + }; |
| 38 | + } |
| 39 | + |
| 40 | + const indices = try self.allocator.alloc(usize, n); |
| 41 | + errdefer self.allocator.free(indices); |
| 42 | + for (indices, 0..) |*idx, i| idx.* = i; |
| 43 | + |
| 44 | + const sorted = try self.allocator.dupe(usize, indices); |
| 45 | + errdefer self.allocator.free(sorted); |
| 46 | + |
| 47 | + const SortCtx = struct { |
| 48 | + losses: []const f32, |
| 49 | + pub fn lessThan(ctx: @This(), a: usize, b: usize) bool { |
| 50 | + return ctx.losses[a] > ctx.losses[b]; |
| 51 | + } |
| 52 | + }; |
| 53 | + std.mem.sort(usize, sorted, SortCtx{ .losses = losses }, SortCtx.lessThan); |
| 54 | + |
| 55 | + const keep_count = @max( |
| 56 | + @as(usize, @intFromFloat(@as(f32, @floatFromInt(n)) * self.config.hard_ratio)), |
| 57 | + self.config.min_keep, |
| 58 | + ); |
| 59 | + const final_keep = @min(keep_count, n); |
| 60 | + |
| 61 | + const selected = try self.allocator.alloc(usize, final_keep); |
| 62 | + const sel_losses = try self.allocator.alloc(f32, final_keep); |
| 63 | + |
| 64 | + var sum: f32 = 0.0; |
| 65 | + for (0..final_keep) |i| { |
| 66 | + selected[i] = sorted[i]; |
| 67 | + sel_losses[i] = losses[sorted[i]]; |
| 68 | + sum += sel_losses[i]; |
| 69 | + } |
| 70 | + |
| 71 | + self.allocator.free(indices); |
| 72 | + self.allocator.free(sorted); |
| 73 | + |
| 74 | + return .{ |
| 75 | + .selected_indices = selected, |
| 76 | + .selected_losses = sel_losses, |
| 77 | + .avg_loss = if (final_keep > 0) sum / @as(f32, @floatFromInt(final_keep)) else 0.0, |
| 78 | + .hard_count = final_keep, |
| 79 | + .total_count = n, |
| 80 | + }; |
| 81 | + } |
| 82 | + |
| 83 | + pub fn deinitResult(self: *OhemMiner, result: *OhemResult) void { |
| 84 | + self.allocator.free(result.selected_indices); |
| 85 | + self.allocator.free(result.selected_losses); |
| 86 | + } |
| 87 | + |
| 88 | + pub fn hardRatio(self: *const OhemMiner, result: *const OhemResult) f32 { |
| 89 | + if (result.total_count == 0) return 0.0; |
| 90 | + return @as(f32, @floatFromInt(result.hard_count)) / @as(f32, @floatFromInt(result.total_count)); |
| 91 | + } |
| 92 | +}; |
| 93 | + |
| 94 | +pub const LossStats = struct { |
| 95 | + mean: f32, |
| 96 | + max: f32, |
| 97 | + min: f32, |
| 98 | + variance: f32, |
| 99 | + above_mean_count: usize, |
| 100 | + total: usize, |
| 101 | + |
| 102 | + pub fn compute(losses: []const f32) LossStats { |
| 103 | + if (losses.len == 0) return .{ .mean = 0, .max = 0, .min = 0, .variance = 0, .above_mean_count = 0, .total = 0 }; |
| 104 | + |
| 105 | + var sum: f32 = 0; |
| 106 | + var mx: f32 = -std.math.inf(f32); |
| 107 | + var mn: f32 = std.math.inf(f32); |
| 108 | + for (losses) |l| { |
| 109 | + sum += l; |
| 110 | + mx = @max(mx, l); |
| 111 | + mn = @min(mn, l); |
| 112 | + } |
| 113 | + const mean = sum / @as(f32, @floatFromInt(losses.len)); |
| 114 | + |
| 115 | + var var_sum: f32 = 0; |
| 116 | + var above: usize = 0; |
| 117 | + for (losses) |l| { |
| 118 | + const d = l - mean; |
| 119 | + var_sum += d * d; |
| 120 | + if (l > mean) above += 1; |
| 121 | + } |
| 122 | + |
| 123 | + return .{ |
| 124 | + .mean = mean, |
| 125 | + .max = mx, |
| 126 | + .min = mn, |
| 127 | + .variance = var_sum / @as(f32, @floatFromInt(losses.len)), |
| 128 | + .above_mean_count = above, |
| 129 | + .total = losses.len, |
| 130 | + }; |
| 131 | + } |
| 132 | +}; |
| 133 | + |
| 134 | +test "ohem selects hardest examples" { |
| 135 | + const allocator = std.testing.allocator; |
| 136 | + var miner = OhemMiner.init(allocator, .{ .hard_ratio = 0.5, .min_keep = 1 }); |
| 137 | + |
| 138 | + const losses = [_]f32{ 0.1, 0.9, 0.2, 0.8, 0.3 }; |
| 139 | + var result = try miner.mine(&losses); |
| 140 | + defer miner.deinitResult(&result); |
| 141 | + |
| 142 | + try std.testing.expectEqual(@as(usize, 2), result.hard_count); |
| 143 | + try std.testing.expect(result.selected_losses[0] >= result.selected_losses[1]); |
| 144 | + try std.testing.expect(result.avg_loss > 0.5); |
| 145 | +} |
| 146 | + |
| 147 | +test "ohem respects min_keep" { |
| 148 | + const allocator = std.testing.allocator; |
| 149 | + var miner = OhemMiner.init(allocator, .{ .hard_ratio = 0.01, .min_keep = 3 }); |
| 150 | + |
| 151 | + const losses = [_]f32{ 0.1, 0.2, 0.3, 0.4, 0.5 }; |
| 152 | + var result = try miner.mine(&losses); |
| 153 | + defer miner.deinitResult(&result); |
| 154 | + |
| 155 | + try std.testing.expectEqual(@as(usize, 3), result.hard_count); |
| 156 | +} |
| 157 | + |
| 158 | +test "ohem handles empty losses" { |
| 159 | + const allocator = std.testing.allocator; |
| 160 | + var miner = OhemMiner.init(allocator, .{}); |
| 161 | + var result = try miner.mine(&[_]f32{}); |
| 162 | + try std.testing.expectEqual(@as(usize, 0), result.hard_count); |
| 163 | +} |
| 164 | + |
| 165 | +test "ohem all examples equally hard" { |
| 166 | + const allocator = std.testing.allocator; |
| 167 | + var miner = OhemMiner.init(allocator, .{ .hard_ratio = 0.5 }); |
| 168 | + |
| 169 | + const losses = [_]f32{ 0.5, 0.5, 0.5, 0.5 }; |
| 170 | + var result = try miner.mine(&losses); |
| 171 | + defer miner.deinitResult(&result); |
| 172 | + |
| 173 | + try std.testing.expectEqual(@as(usize, 2), result.hard_count); |
| 174 | + try std.testing.expectApproxEqAbs(@as(f32, 0.5), result.avg_loss, 1e-6); |
| 175 | +} |
| 176 | + |
| 177 | +test "loss stats compute" { |
| 178 | + const losses = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0 }; |
| 179 | + const stats = LossStats.compute(&losses); |
| 180 | + |
| 181 | + try std.testing.expectApproxEqAbs(@as(f32, 3.0), stats.mean, 1e-6); |
| 182 | + try std.testing.expectEqual(@as(f32, 5.0), stats.max); |
| 183 | + try std.testing.expectEqual(@as(f32, 1.0), stats.min); |
| 184 | + try std.testing.expect(stats.variance > 0); |
| 185 | + try std.testing.expectEqual(@as(usize, 2), stats.above_mean_count); |
| 186 | +} |
0 commit comments