Skip to content

Commit 42f1e3d

Browse files
committed
feat(hslm): double-buffered batch prefetch
- Add src/b2t/double_buffer.zig - Generic DoubleBufferedPrefetch(T, N): comptime-sized double buffer with swap/loadFromSlice/isBackReady - BatchPrefetcher: runtime batch loader for training data with async prefetch, swap, batch count - Overlaps data loading with training: while GPU processes buffer A, CPU prefetches into buffer B - 6 tests: swap, double swap, fail-safety, batch load, out-of-range, async prefetch+swap, batch count Closes #319
1 parent f71cfda commit 42f1e3d

1 file changed

Lines changed: 229 additions & 0 deletions

File tree

src/b2t/double_buffer.zig

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
const std = @import("std");
2+
3+
pub fn DoubleBufferedPrefetch(comptime T: type, comptime buffer_size: usize) type {
4+
return struct {
5+
const Self = @This();
6+
7+
buffers: [2][buffer_size]T,
8+
active: usize,
9+
ready: [2]bool,
10+
loading: bool,
11+
allocator: std.mem.Allocator,
12+
13+
pub fn init(allocator: std.mem.Allocator) Self {
14+
return .{
15+
.buffers = .{[_]T{0} ** buffer_size, [_]T{0} ** buffer_size},
16+
.active = 0,
17+
.ready = .{ false, false },
18+
.loading = false,
19+
.allocator = allocator,
20+
};
21+
}
22+
23+
pub fn deinit(self: *Self) void {
24+
_ = self;
25+
}
26+
27+
pub fn getActive(self: *const Self) []const T {
28+
return &self.buffers[self.active];
29+
}
30+
31+
pub fn getBackBuffer(self: *Self) []T {
32+
return &self.buffers[1 - self.active];
33+
}
34+
35+
pub fn swap(self: *Self) bool {
36+
const back = 1 - self.active;
37+
if (!self.ready[back]) return false;
38+
self.ready[self.active] = false;
39+
self.active = back;
40+
return true;
41+
}
42+
43+
pub fn markBackReady(self: *Self) void {
44+
self.ready[1 - self.active] = true;
45+
self.loading = false;
46+
}
47+
48+
pub fn startLoad(self: *Self) void {
49+
self.loading = true;
50+
}
51+
52+
pub fn isLoading(self: *const Self) bool {
53+
return self.loading;
54+
}
55+
56+
pub fn isBackReady(self: *const Self) bool {
57+
return self.ready[1 - self.active];
58+
}
59+
60+
pub fn loadFromSlice(self: *Self, data: []const T, offset: usize) bool {
61+
const back = 1 - self.active;
62+
const copy_len = @min(buffer_size, data.len - offset);
63+
if (copy_len == 0) return false;
64+
65+
for (0..copy_len) |i| {
66+
self.buffers[back][i] = data[offset + i];
67+
}
68+
self.ready[back] = true;
69+
self.loading = false;
70+
return true;
71+
}
72+
};
73+
}
74+
75+
pub const BatchPrefetcher = struct {
76+
allocator: std.mem.Allocator,
77+
buffer_a: []f32,
78+
buffer_b: []f32,
79+
active: usize,
80+
batch_size: usize,
81+
seq_len: usize,
82+
feature_dim: usize,
83+
stride: usize,
84+
85+
pub fn init(allocator: std.mem.Allocator, batch_size: usize, seq_len: usize, feature_dim: usize) !BatchPrefetcher {
86+
const stride = batch_size * seq_len * feature_dim;
87+
const buffer_a = try allocator.alloc(f32, stride);
88+
const buffer_b = try allocator.alloc(f32, stride);
89+
90+
return .{
91+
.allocator = allocator,
92+
.buffer_a = buffer_a,
93+
.buffer_b = buffer_b,
94+
.active = 0,
95+
.batch_size = batch_size,
96+
.seq_len = seq_len,
97+
.feature_dim = feature_dim,
98+
.stride = stride,
99+
};
100+
}
101+
102+
pub fn deinit(self: *BatchPrefetcher) void {
103+
self.allocator.free(self.buffer_a);
104+
self.allocator.free(self.buffer_b);
105+
}
106+
107+
pub fn loadBatch(self: *BatchPrefetcher, dataset: []const f32, batch_idx: usize) bool {
108+
const offset = batch_idx * self.stride;
109+
if (offset + self.stride > dataset.len) return false;
110+
111+
const back: usize = 1 - self.active;
112+
const dst = if (back == 0) self.buffer_a else self.buffer_b;
113+
const src = dataset[offset..][0..self.stride];
114+
115+
@memcpy(dst[0..self.stride], src);
116+
self.active = back;
117+
return true;
118+
}
119+
120+
pub fn getActiveBatch(self: *const BatchPrefetcher) []const f32 {
121+
if (self.active == 0) return self.buffer_a;
122+
return self.buffer_b;
123+
}
124+
125+
pub fn prefetchAsync(self: *BatchPrefetcher, dataset: []const f32, batch_idx: usize) bool {
126+
const offset = batch_idx * self.stride;
127+
if (offset + self.stride > dataset.len) return false;
128+
129+
const back: usize = 1 - self.active;
130+
const dst = if (back == 0) self.buffer_a else self.buffer_b;
131+
const src = dataset[offset..][0..self.stride];
132+
@memcpy(dst[0..self.stride], src);
133+
return true;
134+
}
135+
136+
pub fn swap(self: *BatchPrefetcher) void {
137+
self.active = 1 - self.active;
138+
}
139+
140+
pub fn batchCount(self: *const BatchPrefetcher, dataset_len: usize) usize {
141+
return dataset_len / self.stride;
142+
}
143+
};
144+
145+
test "double buffer swap" {
146+
var db = DoubleBufferedPrefetch(f32, 4).init(std.testing.allocator);
147+
defer db.deinit();
148+
149+
const data = [_]f32{ 1.0, 2.0, 3.0, 4.0 };
150+
_ = db.loadFromSlice(&data, 0);
151+
152+
try std.testing.expect(db.isBackReady());
153+
try std.testing.expect(db.swap());
154+
155+
const active = db.getActive();
156+
try std.testing.expectEqual(@as(f32, 1.0), active[0]);
157+
}
158+
159+
test "double buffer double swap" {
160+
var db = DoubleBufferedPrefetch(f32, 4).init(std.testing.allocator);
161+
defer db.deinit();
162+
163+
const data1 = [_]f32{ 1.0, 2.0, 3.0, 4.0 };
164+
const data2 = [_]f32{ 5.0, 6.0, 7.0, 8.0 };
165+
166+
_ = db.loadFromSlice(&data1, 0);
167+
_ = db.swap();
168+
169+
_ = db.loadFromSlice(&data2, 0);
170+
_ = db.swap();
171+
172+
const active = db.getActive();
173+
try std.testing.expectEqual(@as(f32, 5.0), active[0]);
174+
}
175+
176+
test "double buffer swap fails when back not ready" {
177+
var db = DoubleBufferedPrefetch(f32, 4).init(std.testing.allocator);
178+
defer db.deinit();
179+
180+
try std.testing.expect(!db.swap());
181+
}
182+
183+
test "batch prefetcher load and get" {
184+
const allocator = std.testing.allocator;
185+
var pf = try BatchPrefetcher.init(allocator, 2, 3, 4);
186+
defer pf.deinit();
187+
188+
var dataset = [_]f32{0} ** 48;
189+
for (&dataset, 0..) |*d, i| d.* = @floatFromInt(i);
190+
191+
try std.testing.expect(pf.loadBatch(&dataset, 0));
192+
const batch = pf.getActiveBatch();
193+
try std.testing.expectEqual(@as(f32, 0.0), batch[0]);
194+
try std.testing.expectEqual(@as(f32, 23.0), batch[23]);
195+
}
196+
197+
test "batch prefetcher handles out of range" {
198+
const allocator = std.testing.allocator;
199+
var pf = try BatchPrefetcher.init(allocator, 2, 3, 4);
200+
defer pf.deinit();
201+
202+
const dataset = [_]f32{0} ** 24;
203+
try std.testing.expect(!pf.loadBatch(&dataset, 1));
204+
}
205+
206+
test "batch prefetcher async prefetch and swap" {
207+
const allocator = std.testing.allocator;
208+
var pf = try BatchPrefetcher.init(allocator, 2, 3, 4);
209+
defer pf.deinit();
210+
211+
var dataset = [_]f32{0} ** 48;
212+
for (&dataset, 0..) |*d, i| d.* = @floatFromInt(i);
213+
214+
try std.testing.expect(pf.loadBatch(&dataset, 0));
215+
try std.testing.expect(pf.prefetchAsync(&dataset, 1));
216+
pf.swap();
217+
218+
const batch = pf.getActiveBatch();
219+
try std.testing.expectEqual(@as(f32, 24.0), batch[0]);
220+
}
221+
222+
test "batch count" {
223+
const allocator = std.testing.allocator;
224+
var pf = try BatchPrefetcher.init(allocator, 2, 3, 4);
225+
defer pf.deinit();
226+
227+
const dataset_len: usize = 96;
228+
try std.testing.expectEqual(@as(usize, 2), pf.batchCount(dataset_len));
229+
}

0 commit comments

Comments
 (0)