@@ -1147,3 +1147,184 @@ test "cached attention" {
11471147
11481148 try std .testing .expectEqual (@as (usize , num_q_heads * config .head_dim ), output .len );
11491149}
1150+
1151+ // ═══════════════════════════════════════════════════════════════════════════════
1152+ // BATCH KV-CACHE (INF-004)
1153+ // Multiple sequences with separate KV caches
1154+ // ═══════════════════════════════════════════════════════════════════════════════
1155+
1156+ /// Batch KV cache for multiple concurrent sequences
1157+ pub const BatchKVCache = struct {
1158+ allocator : std.mem.Allocator ,
1159+ max_batch_size : usize ,
1160+ num_layers : usize ,
1161+ num_kv_heads : usize ,
1162+ head_dim : usize ,
1163+ max_seq_len : usize ,
1164+
1165+ // Per-sequence KV caches [batch_size][num_layers]
1166+ caches : [][]RingKVCache ,
1167+
1168+ // Sequence state
1169+ active : []bool ,
1170+ positions : []usize ,
1171+
1172+ pub fn init (
1173+ allocator : std.mem.Allocator ,
1174+ max_batch_size : usize ,
1175+ num_layers : usize ,
1176+ num_kv_heads : usize ,
1177+ head_dim : usize ,
1178+ max_seq_len : usize ,
1179+ ) ! BatchKVCache {
1180+ var batch = BatchKVCache {
1181+ .allocator = allocator ,
1182+ .max_batch_size = max_batch_size ,
1183+ .num_layers = num_layers ,
1184+ .num_kv_heads = num_kv_heads ,
1185+ .head_dim = head_dim ,
1186+ .max_seq_len = max_seq_len ,
1187+ .caches = try allocator .alloc ([]RingKVCache , max_batch_size ),
1188+ .active = try allocator .alloc (bool , max_batch_size ),
1189+ .positions = try allocator .alloc (usize , max_batch_size ),
1190+ };
1191+
1192+ // Initialize per-sequence caches
1193+ for (0.. max_batch_size ) | seq_idx | {
1194+ batch .caches [seq_idx ] = try allocator .alloc (RingKVCache , num_layers );
1195+ for (0.. num_layers ) | layer_idx | {
1196+ batch .caches [seq_idx ][layer_idx ] = try RingKVCache .init (
1197+ allocator ,
1198+ num_kv_heads ,
1199+ head_dim ,
1200+ max_seq_len ,
1201+ SlidingWindowConfig .default (),
1202+ );
1203+ }
1204+ batch .active [seq_idx ] = false ;
1205+ batch .positions [seq_idx ] = 0 ;
1206+ }
1207+
1208+ return batch ;
1209+ }
1210+
1211+ pub fn deinit (self : * BatchKVCache ) void {
1212+ for (0.. self .max_batch_size ) | seq_idx | {
1213+ for (0.. self .num_layers ) | layer_idx | {
1214+ self .caches [seq_idx ][layer_idx ].deinit ();
1215+ }
1216+ self .allocator .free (self .caches [seq_idx ]);
1217+ }
1218+ self .allocator .free (self .caches );
1219+ self .allocator .free (self .active );
1220+ self .allocator .free (self .positions );
1221+ }
1222+
1223+ /// Add new sequence to batch, returns sequence ID or null if full
1224+ pub fn addSequence (self : * BatchKVCache ) ? usize {
1225+ for (0.. self .max_batch_size ) | seq_idx | {
1226+ if (! self .active [seq_idx ]) {
1227+ self .active [seq_idx ] = true ;
1228+ self .positions [seq_idx ] = 0 ;
1229+ // Reset KV caches for this sequence
1230+ for (0.. self .num_layers ) | layer_idx | {
1231+ self .caches [seq_idx ][layer_idx ].reset ();
1232+ }
1233+ return seq_idx ;
1234+ }
1235+ }
1236+ return null ; // Batch is full
1237+ }
1238+
1239+ /// Remove sequence from batch
1240+ pub fn removeSequence (self : * BatchKVCache , seq_idx : usize ) void {
1241+ if (seq_idx < self .max_batch_size ) {
1242+ self .active [seq_idx ] = false ;
1243+ self .positions [seq_idx ] = 0 ;
1244+ }
1245+ }
1246+
1247+ /// Get KV cache for specific sequence and layer
1248+ pub fn getCache (self : * BatchKVCache , seq_idx : usize , layer_idx : usize ) * RingKVCache {
1249+ return & self .caches [seq_idx ][layer_idx ];
1250+ }
1251+
1252+ /// Append K,V to specific sequence's cache
1253+ pub fn append (self : * BatchKVCache , seq_idx : usize , layer_idx : usize , k : []const f32 , v : []const f32 ) void {
1254+ if (seq_idx < self .max_batch_size and self .active [seq_idx ]) {
1255+ self .caches [seq_idx ][layer_idx ].append (k , v );
1256+ }
1257+ }
1258+
1259+ /// Get number of active sequences
1260+ pub fn activeCount (self : * const BatchKVCache ) usize {
1261+ var count : usize = 0 ;
1262+ for (self .active ) | a | {
1263+ if (a ) count += 1 ;
1264+ }
1265+ return count ;
1266+ }
1267+
1268+ /// Get list of active sequence IDs
1269+ pub fn getActiveSequences (self : * const BatchKVCache , out : []usize ) usize {
1270+ var count : usize = 0 ;
1271+ for (0.. self .max_batch_size ) | seq_idx | {
1272+ if (self .active [seq_idx ] and count < out .len ) {
1273+ out [count ] = seq_idx ;
1274+ count += 1 ;
1275+ }
1276+ }
1277+ return count ;
1278+ }
1279+
1280+ /// Memory usage in bytes
1281+ pub fn memoryUsage (self : * const BatchKVCache ) usize {
1282+ var total : usize = 0 ;
1283+ for (0.. self .max_batch_size ) | seq_idx | {
1284+ for (0.. self .num_layers ) | layer_idx | {
1285+ total += self .caches [seq_idx ][layer_idx ].memoryUsage ();
1286+ }
1287+ }
1288+ return total ;
1289+ }
1290+ };
1291+
1292+ test "batch kv cache" {
1293+ const allocator = std .testing .allocator ;
1294+
1295+ var batch = try BatchKVCache .init (
1296+ allocator ,
1297+ 4 , // max_batch_size
1298+ 2 , // num_layers
1299+ 2 , // num_kv_heads
1300+ 16 , // head_dim
1301+ 32 , // max_seq_len
1302+ );
1303+ defer batch .deinit ();
1304+
1305+ // Initially no active sequences
1306+ try std .testing .expectEqual (@as (usize , 0 ), batch .activeCount ());
1307+
1308+ // Add sequences
1309+ const seq0 = batch .addSequence ();
1310+ try std .testing .expect (seq0 != null );
1311+ try std .testing .expectEqual (@as (usize , 1 ), batch .activeCount ());
1312+
1313+ const seq1 = batch .addSequence ();
1314+ try std .testing .expect (seq1 != null );
1315+ try std .testing .expectEqual (@as (usize , 2 ), batch .activeCount ());
1316+
1317+ // Append to sequence 0
1318+ var k = [_ ]f32 {1.0 } ** 32 ;
1319+ var v = [_ ]f32 {2.0 } ** 32 ;
1320+ batch .append (seq0 .? , 0 , & k , & v );
1321+
1322+ // Remove sequence
1323+ batch .removeSequence (seq0 .? );
1324+ try std .testing .expectEqual (@as (usize , 1 ), batch .activeCount ());
1325+
1326+ // Can add new sequence in freed slot
1327+ const seq2 = batch .addSequence ();
1328+ try std .testing .expect (seq2 != null );
1329+ try std .testing .expectEqual (@as (usize , 2 ), batch .activeCount ());
1330+ }
0 commit comments