|
| 1 | +// context.zig — Context window management for tri-api |
| 2 | +// Token estimation, auto-compaction (truncate tool outputs → summarize). |
| 3 | +// Issue #67: Phase 8 Context Management |
| 4 | +const std = @import("std"); |
| 5 | +const proto = @import("tool_protocol.zig"); |
| 6 | + |
| 7 | +// ─── Config ────────────────────────────────────────────────────────────────── |
| 8 | + |
| 9 | +pub const ContextConfig = struct { |
| 10 | + max_tokens: u32 = 180_000, // Claude Sonnet context window |
| 11 | + compact_threshold: u32 = 144_000, // 80% — trigger compaction |
| 12 | + keep_turns: u32 = 3, // preserve last N turns during truncation |
| 13 | +}; |
| 14 | + |
| 15 | +// ─── Token estimation ──────────────────────────────────────────────────────── |
| 16 | + |
| 17 | +/// Estimate token count from byte length (~4 bytes/token average for cl100k_base). |
| 18 | +pub fn estimateTokens(text: []const u8) u32 { |
| 19 | + return @intCast(@max(1, text.len / 4)); |
| 20 | +} |
| 21 | + |
| 22 | +// ─── Context Manager ───────────────────────────────────────────────────────── |
| 23 | + |
| 24 | +pub const ContextManager = struct { |
| 25 | + allocator: std.mem.Allocator, |
| 26 | + config: ContextConfig, |
| 27 | + api_input_tokens: u32, // accurate count from API responses |
| 28 | + api_output_tokens: u32, |
| 29 | + |
| 30 | + pub fn init(allocator: std.mem.Allocator) ContextManager { |
| 31 | + return .{ |
| 32 | + .allocator = allocator, |
| 33 | + .config = .{}, |
| 34 | + .api_input_tokens = 0, |
| 35 | + .api_output_tokens = 0, |
| 36 | + }; |
| 37 | + } |
| 38 | + |
| 39 | + /// Track tokens reported by the API (more accurate than estimation). |
| 40 | + pub fn trackApiUsage(self: *ContextManager, input_tokens: u32, output_tokens: u32) void { |
| 41 | + self.api_input_tokens += input_tokens; |
| 42 | + self.api_output_tokens += output_tokens; |
| 43 | + } |
| 44 | + |
| 45 | + /// Quick check: are we near the limit? |
| 46 | + pub fn isNearLimit(self: *ContextManager, messages: *const std.ArrayList(u8)) bool { |
| 47 | + return estimateTokens(messages.items) >= self.config.compact_threshold; |
| 48 | + } |
| 49 | + |
| 50 | + /// Format context usage for TUI display. Caller owns memory. |
| 51 | + pub fn formatUsage(self: *ContextManager, messages: *const std.ArrayList(u8)) [64]u8 { |
| 52 | + const est = estimateTokens(messages.items); |
| 53 | + var buf: [64]u8 = undefined; |
| 54 | + _ = std.fmt.bufPrint(&buf, "[~{d}K/{d}K tokens]", .{ est / 1000, self.config.max_tokens / 1000 }) catch { |
| 55 | + @memcpy(buf[0..12], "[ctx ?/?K] "); |
| 56 | + }; |
| 57 | + return buf; |
| 58 | + } |
| 59 | + |
| 60 | + /// Phase 1: Truncate old tool_result content, keeping last N turns. |
| 61 | + /// Returns true if any truncation happened. |
| 62 | + pub fn truncateOldToolOutputs(self: *ContextManager, messages: *std.ArrayList(u8)) bool { |
| 63 | + // Count turns from the end to find the cutoff point. |
| 64 | + // A "turn" = one assistant message. Count "role":"assistant" from the end. |
| 65 | + const data = messages.items; |
| 66 | + const assistant_marker = "\"role\":\"assistant\""; |
| 67 | + var turn_count: u32 = 0; |
| 68 | + var cutoff_pos: usize = data.len; |
| 69 | + |
| 70 | + // Walk backwards to find position of the Nth assistant message from end |
| 71 | + var search_end = data.len; |
| 72 | + while (search_end > assistant_marker.len) { |
| 73 | + // Search backwards for the marker |
| 74 | + var pos = search_end - 1; |
| 75 | + var found = false; |
| 76 | + while (pos >= assistant_marker.len) : (pos -= 1) { |
| 77 | + if (std.mem.startsWith(u8, data[pos - assistant_marker.len + 1 ..], assistant_marker)) { |
| 78 | + turn_count += 1; |
| 79 | + if (turn_count == self.config.keep_turns) { |
| 80 | + // Everything before this position is eligible for truncation |
| 81 | + cutoff_pos = pos - assistant_marker.len + 1; |
| 82 | + found = true; |
| 83 | + break; |
| 84 | + } |
| 85 | + search_end = pos - assistant_marker.len + 1; |
| 86 | + found = true; |
| 87 | + break; |
| 88 | + } |
| 89 | + if (pos == 0) break; |
| 90 | + } |
| 91 | + if (!found) break; |
| 92 | + if (cutoff_pos < data.len) break; |
| 93 | + } |
| 94 | + |
| 95 | + if (cutoff_pos >= data.len) return false; // Not enough turns to truncate |
| 96 | + |
| 97 | + // Now scan the region [0..cutoff_pos] for "type":"tool_result" blocks |
| 98 | + // and replace their "content":"..." with truncation marker |
| 99 | + const tool_marker = "\"type\":\"tool_result\""; |
| 100 | + const content_marker = "\"content\":\""; |
| 101 | + var modified = false; |
| 102 | + |
| 103 | + var result = std.ArrayList(u8).empty; |
| 104 | + var i: usize = 0; |
| 105 | + |
| 106 | + while (i < data.len) { |
| 107 | + if (i < cutoff_pos) { |
| 108 | + // In the truncatable region — look for tool_result content |
| 109 | + if (i + tool_marker.len <= data.len and |
| 110 | + std.mem.eql(u8, data[i .. i + tool_marker.len], tool_marker)) |
| 111 | + { |
| 112 | + // Found a tool_result. Copy up to and including tool_result marker |
| 113 | + result.appendSlice(self.allocator, data[i .. i + tool_marker.len]) catch return false; |
| 114 | + const j = i + tool_marker.len; |
| 115 | + |
| 116 | + // Find the "content":" field after it |
| 117 | + if (std.mem.indexOfPos(u8, data, j, content_marker)) |ci| { |
| 118 | + if (ci < cutoff_pos and ci - j < 200) { |
| 119 | + // Copy everything between tool_result marker and content value |
| 120 | + result.appendSlice(self.allocator, data[j..ci]) catch return false; |
| 121 | + result.appendSlice(self.allocator, content_marker) catch return false; |
| 122 | + |
| 123 | + // Find the end of the content string value |
| 124 | + const val_start = ci + content_marker.len; |
| 125 | + var val_end = val_start; |
| 126 | + while (val_end < data.len) : (val_end += 1) { |
| 127 | + if (data[val_end] == '"' and (val_end == val_start or data[val_end - 1] != '\\')) break; |
| 128 | + } |
| 129 | + const original_len = val_end - val_start; |
| 130 | + |
| 131 | + if (original_len > 200) { |
| 132 | + // Replace with truncation marker |
| 133 | + var trunc_buf: [64]u8 = undefined; |
| 134 | + const trunc_msg = std.fmt.bufPrint(&trunc_buf, "[truncated {d} bytes]", .{original_len}) catch "[truncated]"; |
| 135 | + result.appendSlice(self.allocator, trunc_msg) catch return false; |
| 136 | + modified = true; |
| 137 | + } else { |
| 138 | + // Short content — keep it |
| 139 | + result.appendSlice(self.allocator, data[val_start..val_end]) catch return false; |
| 140 | + } |
| 141 | + |
| 142 | + i = val_end; |
| 143 | + continue; |
| 144 | + } |
| 145 | + } |
| 146 | + i = j; |
| 147 | + continue; |
| 148 | + } |
| 149 | + } |
| 150 | + result.append(self.allocator, data[i]) catch return false; |
| 151 | + i += 1; |
| 152 | + } |
| 153 | + |
| 154 | + if (modified) { |
| 155 | + // Replace messages content |
| 156 | + messages.clearRetainingCapacity(); |
| 157 | + messages.appendSlice(self.allocator, result.items) catch {}; |
| 158 | + } |
| 159 | + result.deinit(self.allocator); |
| 160 | + |
| 161 | + return modified; |
| 162 | + } |
| 163 | + |
| 164 | + /// Build a compaction request body for API summarization. |
| 165 | + /// Returns the JSON request body (caller owns memory), or null if not needed. |
| 166 | + /// The caller should POST this to the API, parse the text response, |
| 167 | + /// then call applySummary() with the result. |
| 168 | + pub fn buildCompactionRequest(self: *ContextManager, messages: *const std.ArrayList(u8), model: []const u8) ?[]const u8 { |
| 169 | + if (!self.isNearLimit(messages)) return null; |
| 170 | + |
| 171 | + var body = std.ArrayList(u8).empty; |
| 172 | + body.appendSlice(self.allocator, "{\"model\":\"") catch return null; |
| 173 | + body.appendSlice(self.allocator, model) catch return null; |
| 174 | + body.appendSlice(self.allocator, "\",\"max_tokens\":2048,\"messages\":[{\"role\":\"user\",\"content\":\"") catch return null; |
| 175 | + |
| 176 | + // Inject summary prompt + conversation excerpt |
| 177 | + const prompt_prefix = "Summarize the following conversation concisely in 2-3 paragraphs. Preserve: all file paths mentioned, key decisions made, current task state, and any errors encountered. Conversation:\\n\\n"; |
| 178 | + body.appendSlice(self.allocator, prompt_prefix) catch return null; |
| 179 | + |
| 180 | + // Include first portion of messages (up to ~100K chars) |
| 181 | + const max_excerpt = @min(messages.items.len, 400_000); |
| 182 | + proto.writeJsonEscaped(body.writer(self.allocator), messages.items[0..max_excerpt]) catch return null; |
| 183 | + |
| 184 | + body.appendSlice(self.allocator, "\"}]}") catch return null; |
| 185 | + |
| 186 | + return body.toOwnedSlice(self.allocator) catch null; |
| 187 | + } |
| 188 | + |
| 189 | + /// Apply a summary: replace old messages with summary + keep recent turns. |
| 190 | + pub fn applySummary(self: *ContextManager, messages: *std.ArrayList(u8), summary: []const u8) void { |
| 191 | + // Find the start of the last N turns |
| 192 | + const data = messages.items; |
| 193 | + const assistant_marker = "\"role\":\"assistant\""; |
| 194 | + var turn_count: u32 = 0; |
| 195 | + var keep_from: usize = data.len; |
| 196 | + |
| 197 | + var search_pos = data.len; |
| 198 | + while (search_pos > 0) { |
| 199 | + const region = data[0..search_pos]; |
| 200 | + if (std.mem.lastIndexOf(u8, region, assistant_marker)) |pos| { |
| 201 | + turn_count += 1; |
| 202 | + if (turn_count == self.config.keep_turns) { |
| 203 | + // Walk back to the start of this message object |
| 204 | + var msg_start = pos; |
| 205 | + while (msg_start > 0 and data[msg_start] != ',') : (msg_start -= 1) {} |
| 206 | + keep_from = if (data[msg_start] == ',') msg_start else msg_start; |
| 207 | + break; |
| 208 | + } |
| 209 | + search_pos = pos; |
| 210 | + } else break; |
| 211 | + } |
| 212 | + |
| 213 | + // Build new messages: [{"role":"user","content":"[Context Summary]\n{summary}"},...recent turns...] |
| 214 | + var new_msgs = std.ArrayList(u8).empty; |
| 215 | + new_msgs.appendSlice(self.allocator, "[{\"role\":\"user\",\"content\":\"[Previous context summary]\\n") catch return; |
| 216 | + proto.writeJsonEscaped(new_msgs.writer(self.allocator), summary) catch return; |
| 217 | + new_msgs.appendSlice(self.allocator, "\"}") catch return; |
| 218 | + |
| 219 | + // Append recent turns |
| 220 | + if (keep_from < data.len) { |
| 221 | + new_msgs.appendSlice(self.allocator, data[keep_from..]) catch return; |
| 222 | + } |
| 223 | + |
| 224 | + // Replace |
| 225 | + messages.clearRetainingCapacity(); |
| 226 | + messages.appendSlice(self.allocator, new_msgs.items) catch {}; |
| 227 | + new_msgs.deinit(self.allocator); |
| 228 | + } |
| 229 | +}; |
| 230 | + |
| 231 | +// ─── Tests ─────────────────────────────────────────────────────────────────── |
| 232 | + |
| 233 | +test "estimateTokens" { |
| 234 | + try std.testing.expect(estimateTokens("hello world") > 0); |
| 235 | + try std.testing.expectEqual(@as(u32, 2), estimateTokens("12345678")); |
| 236 | + try std.testing.expectEqual(@as(u32, 1), estimateTokens("hi")); |
| 237 | +} |
| 238 | + |
| 239 | +test "ContextManager isNearLimit" { |
| 240 | + var ctx = ContextManager.init(std.testing.allocator); |
| 241 | + ctx.config.compact_threshold = 10; // 10 tokens = 40 bytes |
| 242 | + |
| 243 | + // 20 bytes = 5 tokens < 10 threshold |
| 244 | + var small = std.ArrayList(u8).empty; |
| 245 | + defer small.deinit(std.testing.allocator); |
| 246 | + try small.appendSlice(std.testing.allocator, "12345678901234567890"); |
| 247 | + try std.testing.expect(!ctx.isNearLimit(&small)); |
| 248 | + |
| 249 | + // 100 bytes = 25 tokens > 10 threshold |
| 250 | + var large = std.ArrayList(u8).empty; |
| 251 | + defer large.deinit(std.testing.allocator); |
| 252 | + try large.appendSlice(std.testing.allocator, "a" ** 100); |
| 253 | + try std.testing.expect(ctx.isNearLimit(&large)); |
| 254 | +} |
| 255 | + |
| 256 | +test "truncateOldToolOutputs" { |
| 257 | + const allocator = std.testing.allocator; |
| 258 | + var ctx = ContextManager.init(allocator); |
| 259 | + ctx.config.keep_turns = 1; |
| 260 | + |
| 261 | + // Build a messages array with 2 assistant turns and a long tool_result in the first |
| 262 | + const msgs_data = |
| 263 | + \\[{"role":"user","content":"do something"},{"role":"assistant","content":"ok"},{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}]},{"role":"assistant","content":"done"},{"role":"user","content":"thanks"}] |
| 264 | + ; |
| 265 | + |
| 266 | + var messages = std.ArrayList(u8).empty; |
| 267 | + defer messages.deinit(allocator); |
| 268 | + try messages.appendSlice(allocator, msgs_data); |
| 269 | + |
| 270 | + const modified = ctx.truncateOldToolOutputs(&messages); |
| 271 | + try std.testing.expect(modified); |
| 272 | + // The truncated version should be shorter |
| 273 | + try std.testing.expect(messages.items.len < msgs_data.len); |
| 274 | + // Should contain truncation marker |
| 275 | + try std.testing.expect(std.mem.indexOf(u8, messages.items, "[truncated") != null); |
| 276 | +} |
0 commit comments