Skip to content

Commit fc636ee

Browse files
committed
fix(compaction): harden summarization and truncation behavior
1 parent 870e167 commit fc636ee

3 files changed

Lines changed: 220 additions & 8 deletions

File tree

Sources/AgentRunKit/Core/Agent.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ extension Agent {
144144
var totalUsage = TokenUsage()
145145
var lastTotalTokens: Int?
146146
var sessionAllowlist: Set<String> = []
147-
let compactor = ContextCompactor(
147+
var compactor = ContextCompactor(
148148
client: client, toolDefinitions: toolDefinitions, configuration: configuration
149149
)
150150
var budgetPhase = try makeBudgetPhase()
@@ -266,7 +266,7 @@ private extension Agent {
266266
var sessionAllowlist: Set<String> = []
267267
let policy = StreamPolicy.agent
268268
let processor = StreamProcessor(client: client, toolDefinitions: toolDefinitions, policy: policy)
269-
let compactor = ContextCompactor(
269+
var compactor = ContextCompactor(
270270
client: client, toolDefinitions: toolDefinitions, configuration: configuration
271271
)
272272
var budgetPhase = try makeBudgetPhase()

Sources/AgentRunKit/Core/ContextCompactor.swift

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,20 @@ struct ContextCompactor {
55
let toolDefinitions: [ToolDefinition]
66
let configuration: AgentConfiguration
77

8+
init(client: any LLMClient, toolDefinitions: [ToolDefinition], configuration: AgentConfiguration) {
9+
self.client = client
10+
self.toolDefinitions = toolDefinitions
11+
self.configuration = configuration
12+
}
13+
814
private static let minimumPruningReduction = 0.2
915
private static let pruningPreviewLength = 80
16+
private static let maxConsecutiveSummarizationFailures = 3
17+
18+
private var consecutiveSummarizationFailures = 0
1019

1120
@discardableResult
12-
func compactOrTruncateIfNeeded(
21+
mutating func compactOrTruncateIfNeeded(
1322
_ messages: inout [ChatMessage],
1423
lastTotalTokens: Int?,
1524
totalUsage: inout TokenUsage
@@ -26,15 +35,23 @@ struct ContextCompactor {
2635
let (pruned, reductionRatio) = pruneObservations(messages)
2736
if reductionRatio > Self.minimumPruningReduction {
2837
messages = pruned
38+
consecutiveSummarizationFailures = 0
2939
return true
3040
}
3141

42+
guard consecutiveSummarizationFailures < Self.maxConsecutiveSummarizationFailures else {
43+
truncateIfNeeded(&messages)
44+
return false
45+
}
46+
3247
do {
3348
let (compacted, compactionUsage) = try await summarize(pruned)
3449
messages = compacted
3550
totalUsage += compactionUsage
51+
consecutiveSummarizationFailures = 0
3652
return true
3753
} catch {
54+
consecutiveSummarizationFailures += 1
3855
truncateIfNeeded(&messages)
3956
return false
4057
}
@@ -82,7 +99,8 @@ struct ContextCompactor {
8299
let taskContext = extractTaskContext(messages)
83100
let recentContext = extractRecentContext(messages)
84101

85-
let summaryRequest = messages + [.user(configuration.compactionPrompt ?? Self.summarizationPrompt)]
102+
let prompt = configuration.compactionPrompt ?? Self.summarizationPrompt
103+
let summaryRequest = Self.stripMedia(messages) + [.user(prompt)]
86104
let response = try await client.generate(
87105
messages: summaryRequest, tools: toolDefinitions, responseFormat: nil, requestContext: nil
88106
)
@@ -104,8 +122,9 @@ struct ContextCompactor {
104122
content.count > max else { return content }
105123
let marker = "\n\n...[truncated]...\n\n"
106124
let contentBudget = Swift.max(max - marker.count, 0)
107-
let half = contentBudget / 2
108-
return "\(content.prefix(half))\(marker)\(content.suffix(half))"
125+
let headBudget = contentBudget * 3 / 5
126+
let tailBudget = contentBudget - headBudget
127+
return "\(content.prefix(headBudget))\(marker)\(content.suffix(tailBudget))"
109128
}
110129

111130
private func truncateIfNeeded(_ messages: inout [ChatMessage]) {
@@ -115,6 +134,22 @@ struct ContextCompactor {
115134
}
116135

117136
private extension ContextCompactor {
137+
static func stripMedia(_ messages: [ChatMessage]) -> [ChatMessage] {
138+
messages.map { message in
139+
guard case let .userMultimodal(parts) = message else { return message }
140+
let stripped = parts.map { part -> ContentPart in
141+
switch part {
142+
case .text: return part
143+
case .imageURL, .imageBase64: return .text("[image]")
144+
case .videoBase64: return .text("[video]")
145+
case .pdfBase64: return .text("[PDF]")
146+
case .audioBase64: return .text("[audio]")
147+
}
148+
}
149+
return .userMultimodal(stripped)
150+
}
151+
}
152+
118153
func extractTaskContext(_ messages: [ChatMessage]) -> [ChatMessage] {
119154
var context: [ChatMessage] = []
120155
for message in messages {

Tests/AgentRunKitTests/ContextCompactionTests.swift

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ private actor CompactionMockLLMClient: LLMClient {
99
private let responses: [AssistantMessage]
1010
private var callIndex: Int = 0
1111
private(set) var allCapturedMessages: [[ChatMessage]] = []
12+
private(set) var generateCallCount: Int = 0
1213
private let failSummarization: Bool
1314

1415
init(
@@ -25,6 +26,7 @@ private actor CompactionMockLLMClient: LLMClient {
2526
messages: [ChatMessage], tools _: [ToolDefinition],
2627
responseFormat _: ResponseFormat?, requestContext _: RequestContext?
2728
) async throws -> AssistantMessage {
29+
generateCallCount += 1
2830
if failSummarization, case let .user(text) = messages.last,
2931
text.contains("CONTEXT CHECKPOINT") {
3032
throw AgentError.llmError(.other("Summarization failed"))
@@ -269,6 +271,120 @@ struct CompactionTriggerTests {
269271
// MARK: - Compaction Fallback Tests
270272

271273
struct CompactionFallbackTests {
274+
@Test
275+
func circuitBreakerSkipsSummarizationAfterConsecutiveFailures() async {
276+
let client = CompactionMockLLMClient(
277+
responses: [], contextWindowSize: 1000, failSummarization: true
278+
)
279+
var compactor = ContextCompactor(
280+
client: client,
281+
toolDefinitions: [],
282+
configuration: AgentConfiguration(maxMessages: 20, compactionThreshold: 0.5)
283+
)
284+
var messages: [ChatMessage] = [
285+
.user("Hello"),
286+
.assistant(AssistantMessage(content: "", toolCalls: [
287+
ToolCall(id: "call_1", name: "search", arguments: "{}"),
288+
])),
289+
.tool(id: "call_1", name: "search", content: String(repeating: "x", count: 10)),
290+
.assistant(AssistantMessage(content: "Done")),
291+
]
292+
var usage = TokenUsage()
293+
294+
for _ in 0 ..< 3 {
295+
await compactor.compactOrTruncateIfNeeded(
296+
&messages, lastTotalTokens: 900, totalUsage: &usage
297+
)
298+
}
299+
let callsAfterTripping = await client.generateCallCount
300+
#expect(callsAfterTripping == 3)
301+
302+
await compactor.compactOrTruncateIfNeeded(
303+
&messages, lastTotalTokens: 900, totalUsage: &usage
304+
)
305+
let callsAfterSkip = await client.generateCallCount
306+
#expect(callsAfterSkip == 3)
307+
}
308+
309+
@Test
310+
func circuitBreakerResetsOnSuccess() async {
311+
var compactor = ContextCompactor(
312+
client: CompactionMockLLMClient(
313+
responses: [
314+
AssistantMessage(content: "Summary.", tokenUsage: TokenUsage(input: 50, output: 100)),
315+
],
316+
contextWindowSize: 1000, failSummarization: false
317+
),
318+
toolDefinitions: [],
319+
configuration: AgentConfiguration(compactionThreshold: 0.5)
320+
)
321+
var messages: [ChatMessage] = [
322+
.user("Hello"),
323+
.assistant(AssistantMessage(content: "Done")),
324+
]
325+
var usage = TokenUsage()
326+
327+
let result = await compactor.compactOrTruncateIfNeeded(
328+
&messages, lastTotalTokens: 900, totalUsage: &usage
329+
)
330+
#expect(result)
331+
#expect(hasBridge(messages))
332+
}
333+
334+
@Test
335+
func circuitBreakerResetsAfterPruningSuccess() async {
336+
let client = CompactionMockLLMClient(
337+
responses: [], contextWindowSize: 1000, failSummarization: true
338+
)
339+
var compactor = ContextCompactor(
340+
client: client,
341+
toolDefinitions: [],
342+
configuration: AgentConfiguration(maxMessages: 20, compactionThreshold: 0.5)
343+
)
344+
var summarizationMessages: [ChatMessage] = [
345+
.user("Hello"),
346+
.assistant(AssistantMessage(content: "", toolCalls: [
347+
ToolCall(id: "call_1", name: "search", arguments: "{}"),
348+
])),
349+
.tool(id: "call_1", name: "search", content: String(repeating: "x", count: 10)),
350+
.assistant(AssistantMessage(content: "Done")),
351+
]
352+
var pruningMessages: [ChatMessage] = [
353+
.user("Hello"),
354+
.assistant(AssistantMessage(content: "", toolCalls: [
355+
ToolCall(id: "call_2", name: "read_file", arguments: "{}"),
356+
])),
357+
.tool(id: "call_2", name: "read_file", content: String(repeating: "x", count: 5000)),
358+
.assistant(AssistantMessage(content: "Done")),
359+
]
360+
var usage = TokenUsage()
361+
362+
for _ in 0 ..< 2 {
363+
await compactor.compactOrTruncateIfNeeded(
364+
&summarizationMessages, lastTotalTokens: 900, totalUsage: &usage
365+
)
366+
}
367+
#expect(await client.generateCallCount == 2)
368+
369+
let pruned = await compactor.compactOrTruncateIfNeeded(
370+
&pruningMessages, lastTotalTokens: 900, totalUsage: &usage
371+
)
372+
#expect(pruned)
373+
#expect(await client.generateCallCount == 2)
374+
375+
for _ in 0 ..< 3 {
376+
await compactor.compactOrTruncateIfNeeded(
377+
&summarizationMessages, lastTotalTokens: 900, totalUsage: &usage
378+
)
379+
}
380+
#expect(await client.generateCallCount == 5)
381+
382+
await compactor.compactOrTruncateIfNeeded(
383+
&summarizationMessages, lastTotalTokens: 900, totalUsage: &usage
384+
)
385+
#expect(await client.generateCallCount == 5)
386+
}
387+
272388
@Test
273389
func compactionFallsBackToTruncationOnError() async throws {
274390
let client = CompactionMockLLMClient(
@@ -527,6 +643,67 @@ struct ObservationPruningTests {
527643
}
528644
}
529645

646+
// MARK: - Media Stripping Tests
647+
648+
struct MediaStrippingTests {
649+
@Test
650+
func summarizationStripsMediaFromMultimodalMessages() async throws {
651+
let client = CompactionMockLLMClient(
652+
responses: [
653+
AssistantMessage(content: "Summary.", tokenUsage: TokenUsage(input: 50, output: 100)),
654+
]
655+
)
656+
let compactor = ContextCompactor(
657+
client: client, toolDefinitions: [], configuration: AgentConfiguration()
658+
)
659+
let messages: [ChatMessage] = [
660+
.user([
661+
.text("Describe this"),
662+
.image(data: Data(repeating: 0xFF, count: 1000), mimeType: "image/png"),
663+
.audio(data: Data(repeating: 0xAA, count: 500), format: .mp3),
664+
.video(data: Data(repeating: 0xBB, count: 500), mimeType: "video/mp4"),
665+
.pdf(data: Data(repeating: 0xCC, count: 500)),
666+
]),
667+
.assistant(AssistantMessage(content: "I see an image.")),
668+
]
669+
_ = try await compactor.summarize(messages)
670+
671+
let captured = await client.allCapturedMessages
672+
guard case let .userMultimodal(parts) = captured[0][0] else {
673+
Issue.record("Expected userMultimodal"); return
674+
}
675+
#expect(parts.count == 5)
676+
#expect(parts.allSatisfy { if case .text = $0 { true } else { false } })
677+
#expect(parts.contains { if case let .text(text) = $0 { text == "[image]" } else { false } })
678+
#expect(parts.contains { if case let .text(text) = $0 { text == "[audio]" } else { false } })
679+
#expect(parts.contains { if case let .text(text) = $0 { text == "[video]" } else { false } })
680+
#expect(parts.contains { if case let .text(text) = $0 { text == "[PDF]" } else { false } })
681+
}
682+
683+
@Test
684+
func summarizationPreservesTextOnlyMessages() async throws {
685+
let client = CompactionMockLLMClient(
686+
responses: [
687+
AssistantMessage(content: "Summary.", tokenUsage: TokenUsage(input: 50, output: 100)),
688+
]
689+
)
690+
let compactor = ContextCompactor(
691+
client: client, toolDefinitions: [], configuration: AgentConfiguration()
692+
)
693+
let messages: [ChatMessage] = [
694+
.user("Plain text message"),
695+
.assistant(AssistantMessage(content: "Response")),
696+
]
697+
_ = try await compactor.summarize(messages)
698+
699+
let captured = await client.allCapturedMessages
700+
guard case let .user(text) = captured[0][0] else {
701+
Issue.record("Expected user string message"); return
702+
}
703+
#expect(text == "Plain text message")
704+
}
705+
}
706+
530707
// MARK: - Tool Result Truncation Tests
531708

532709
struct ToolResultTruncationTests {
@@ -578,8 +755,8 @@ struct ToolResultTruncationTests {
578755
+ String(repeating: "Z", count: 100)
579756
let config = AgentConfiguration(maxToolResultCharacters: 60)
580757
let truncated = ContextCompactor.truncateToolResult(content, configuration: config)
581-
#expect(truncated.hasPrefix(String(repeating: "A", count: 19)))
582-
#expect(truncated.hasSuffix(String(repeating: "Z", count: 19)))
758+
#expect(truncated.hasPrefix(String(repeating: "A", count: 22)))
759+
#expect(truncated.hasSuffix(String(repeating: "Z", count: 16)))
583760
#expect(truncated.count <= 60)
584761
#expect(truncated.contains("truncated"))
585762
}

0 commit comments

Comments
 (0)