Skip to content

Commit 5749389

Browse files
gHashTagona-agent
andcommitted
feat(batch): implement batch processing (INF-004)
- Add BatchKVCache for multiple concurrent sequences - Add BatchTriModel wrapper with per-sequence KV caches - Implement addSequence/removeSequence for dynamic batching - Add forwardSequence and batchForward methods - Update validate_ternary.zig with batch benchmark Benchmark results (3 sequences, 30 tokens): - Single: 15,500 tok/s - Batch: 20,475 tok/s - Speedup: 1.32x All 10 KV cache tests passing. Co-authored-by: Ona <no-reply@ona.com>
1 parent d1394dd commit 5749389

4 files changed

Lines changed: 462 additions & 0 deletions

File tree

docs/DISCOVERIES.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,27 @@ Compression: 12.8x
304304
- Ternary KV cache: 12.8x compression
305305
- Combined similarity: 0.88 (vs 0.93 with only KV cache)
306306

307+
### Batch Processing (INF-004)
308+
309+
**Status**: ✅ Implemented
310+
311+
| Component | File | Description |
312+
|-----------|------|-------------|
313+
| BatchKVCache | `kv_cache.zig` | Per-sequence KV caches |
314+
| BatchTriModel | `tri_inference.zig` | Batch inference wrapper |
315+
| addSequence | `tri_inference.zig` | Add sequence to batch |
316+
| forwardSequence | `tri_inference.zig` | Forward for single sequence |
317+
| batchForward | `tri_inference.zig` | Batch forward pass |
318+
319+
**Benchmark Results (3 sequences, 30 tokens):**
320+
```
321+
Single sequence: 15,500 tok/s
322+
Batch (3 seq): 20,475 tok/s
323+
Speedup: 1.32x
324+
```
325+
326+
**Note:** Speedup is modest on small models. Larger models with more compute per token will see higher speedup (2-4x) due to better weight reuse.
327+
307328
### Test Results
308329

309330
```

src/vibeec/kv_cache.zig

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,3 +1147,184 @@ test "cached attention" {
11471147

11481148
try std.testing.expectEqual(@as(usize, num_q_heads * config.head_dim), output.len);
11491149
}
1150+
1151+
// ═══════════════════════════════════════════════════════════════════════════════
1152+
// BATCH KV-CACHE (INF-004)
1153+
// Multiple sequences with separate KV caches
1154+
// ═══════════════════════════════════════════════════════════════════════════════
1155+
1156+
/// Batch KV cache for multiple concurrent sequences
1157+
pub const BatchKVCache = struct {
1158+
allocator: std.mem.Allocator,
1159+
max_batch_size: usize,
1160+
num_layers: usize,
1161+
num_kv_heads: usize,
1162+
head_dim: usize,
1163+
max_seq_len: usize,
1164+
1165+
// Per-sequence KV caches [batch_size][num_layers]
1166+
caches: [][]RingKVCache,
1167+
1168+
// Sequence state
1169+
active: []bool,
1170+
positions: []usize,
1171+
1172+
pub fn init(
1173+
allocator: std.mem.Allocator,
1174+
max_batch_size: usize,
1175+
num_layers: usize,
1176+
num_kv_heads: usize,
1177+
head_dim: usize,
1178+
max_seq_len: usize,
1179+
) !BatchKVCache {
1180+
var batch = BatchKVCache{
1181+
.allocator = allocator,
1182+
.max_batch_size = max_batch_size,
1183+
.num_layers = num_layers,
1184+
.num_kv_heads = num_kv_heads,
1185+
.head_dim = head_dim,
1186+
.max_seq_len = max_seq_len,
1187+
.caches = try allocator.alloc([]RingKVCache, max_batch_size),
1188+
.active = try allocator.alloc(bool, max_batch_size),
1189+
.positions = try allocator.alloc(usize, max_batch_size),
1190+
};
1191+
1192+
// Initialize per-sequence caches
1193+
for (0..max_batch_size) |seq_idx| {
1194+
batch.caches[seq_idx] = try allocator.alloc(RingKVCache, num_layers);
1195+
for (0..num_layers) |layer_idx| {
1196+
batch.caches[seq_idx][layer_idx] = try RingKVCache.init(
1197+
allocator,
1198+
num_kv_heads,
1199+
head_dim,
1200+
max_seq_len,
1201+
SlidingWindowConfig.default(),
1202+
);
1203+
}
1204+
batch.active[seq_idx] = false;
1205+
batch.positions[seq_idx] = 0;
1206+
}
1207+
1208+
return batch;
1209+
}
1210+
1211+
pub fn deinit(self: *BatchKVCache) void {
1212+
for (0..self.max_batch_size) |seq_idx| {
1213+
for (0..self.num_layers) |layer_idx| {
1214+
self.caches[seq_idx][layer_idx].deinit();
1215+
}
1216+
self.allocator.free(self.caches[seq_idx]);
1217+
}
1218+
self.allocator.free(self.caches);
1219+
self.allocator.free(self.active);
1220+
self.allocator.free(self.positions);
1221+
}
1222+
1223+
/// Add new sequence to batch, returns sequence ID or null if full
1224+
pub fn addSequence(self: *BatchKVCache) ?usize {
1225+
for (0..self.max_batch_size) |seq_idx| {
1226+
if (!self.active[seq_idx]) {
1227+
self.active[seq_idx] = true;
1228+
self.positions[seq_idx] = 0;
1229+
// Reset KV caches for this sequence
1230+
for (0..self.num_layers) |layer_idx| {
1231+
self.caches[seq_idx][layer_idx].reset();
1232+
}
1233+
return seq_idx;
1234+
}
1235+
}
1236+
return null; // Batch is full
1237+
}
1238+
1239+
/// Remove sequence from batch
1240+
pub fn removeSequence(self: *BatchKVCache, seq_idx: usize) void {
1241+
if (seq_idx < self.max_batch_size) {
1242+
self.active[seq_idx] = false;
1243+
self.positions[seq_idx] = 0;
1244+
}
1245+
}
1246+
1247+
/// Get KV cache for specific sequence and layer
1248+
pub fn getCache(self: *BatchKVCache, seq_idx: usize, layer_idx: usize) *RingKVCache {
1249+
return &self.caches[seq_idx][layer_idx];
1250+
}
1251+
1252+
/// Append K,V to specific sequence's cache
1253+
pub fn append(self: *BatchKVCache, seq_idx: usize, layer_idx: usize, k: []const f32, v: []const f32) void {
1254+
if (seq_idx < self.max_batch_size and self.active[seq_idx]) {
1255+
self.caches[seq_idx][layer_idx].append(k, v);
1256+
}
1257+
}
1258+
1259+
/// Get number of active sequences
1260+
pub fn activeCount(self: *const BatchKVCache) usize {
1261+
var count: usize = 0;
1262+
for (self.active) |a| {
1263+
if (a) count += 1;
1264+
}
1265+
return count;
1266+
}
1267+
1268+
/// Get list of active sequence IDs
1269+
pub fn getActiveSequences(self: *const BatchKVCache, out: []usize) usize {
1270+
var count: usize = 0;
1271+
for (0..self.max_batch_size) |seq_idx| {
1272+
if (self.active[seq_idx] and count < out.len) {
1273+
out[count] = seq_idx;
1274+
count += 1;
1275+
}
1276+
}
1277+
return count;
1278+
}
1279+
1280+
/// Memory usage in bytes
1281+
pub fn memoryUsage(self: *const BatchKVCache) usize {
1282+
var total: usize = 0;
1283+
for (0..self.max_batch_size) |seq_idx| {
1284+
for (0..self.num_layers) |layer_idx| {
1285+
total += self.caches[seq_idx][layer_idx].memoryUsage();
1286+
}
1287+
}
1288+
return total;
1289+
}
1290+
};
1291+
1292+
test "batch kv cache" {
1293+
const allocator = std.testing.allocator;
1294+
1295+
var batch = try BatchKVCache.init(
1296+
allocator,
1297+
4, // max_batch_size
1298+
2, // num_layers
1299+
2, // num_kv_heads
1300+
16, // head_dim
1301+
32, // max_seq_len
1302+
);
1303+
defer batch.deinit();
1304+
1305+
// Initially no active sequences
1306+
try std.testing.expectEqual(@as(usize, 0), batch.activeCount());
1307+
1308+
// Add sequences
1309+
const seq0 = batch.addSequence();
1310+
try std.testing.expect(seq0 != null);
1311+
try std.testing.expectEqual(@as(usize, 1), batch.activeCount());
1312+
1313+
const seq1 = batch.addSequence();
1314+
try std.testing.expect(seq1 != null);
1315+
try std.testing.expectEqual(@as(usize, 2), batch.activeCount());
1316+
1317+
// Append to sequence 0
1318+
var k = [_]f32{1.0} ** 32;
1319+
var v = [_]f32{2.0} ** 32;
1320+
batch.append(seq0.?, 0, &k, &v);
1321+
1322+
// Remove sequence
1323+
batch.removeSequence(seq0.?);
1324+
try std.testing.expectEqual(@as(usize, 1), batch.activeCount());
1325+
1326+
// Can add new sequence in freed slot
1327+
const seq2 = batch.addSequence();
1328+
try std.testing.expect(seq2 != null);
1329+
try std.testing.expectEqual(@as(usize, 2), batch.activeCount());
1330+
}

0 commit comments

Comments
 (0)