Skip to content

Commit cb50cf8

Browse files
committed
fix(agent): recover from prompt-too-long runs
1 parent 474e0bd commit cb50cf8

19 files changed

Lines changed: 2307 additions & 506 deletions

Sources/AgentRunKit/Core/Agent+ContextBudget.swift

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,31 @@ extension Agent {
3636
}
3737
}
3838

39+
@discardableResult
3940
func executePruneCalls(
4041
_ calls: [ToolCall],
4142
messages: inout [ChatMessage],
4243
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation? = nil
43-
) {
44+
) -> Bool {
4445
let pruneEnabled = configuration.contextBudget?.enablePruneTool == true
46+
var historyWasRewritten = false
4547
for call in calls {
4648
let result: ToolResult
4749
if !pruneEnabled {
4850
result = .error("Tool not available: prune_context is disabled.")
4951
} else {
5052
do {
51-
result = try executePruneContext(arguments: call.argumentsData, messages: &messages)
53+
let pruneResult = try executePruneContext(arguments: call.argumentsData, messages: &messages)
54+
result = pruneResult.toolResult
55+
historyWasRewritten = historyWasRewritten || pruneResult.historyWasRewritten
5256
} catch {
5357
result = .error("prune_context failed: \(error)")
5458
}
5559
}
5660
messages.append(.tool(id: call.id, name: call.name, content: result.content))
5761
continuation?.yield(.make(.toolCallCompleted(id: call.id, name: call.name, result: result)))
5862
}
63+
return historyWasRewritten
5964
}
6065

6166
func executeAndAppendResults(
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import Foundation
2+
3+
extension Agent {
4+
func executeRunIteration(
5+
messages: inout [ChatMessage],
6+
totalUsage: inout TokenUsage,
7+
lastTotalTokens: inout Int?,
8+
compactor: inout ContextCompactor,
9+
historyWasRewrittenLocally: inout Bool,
10+
requestContext: RequestContext?
11+
) async throws -> AssistantMessage {
12+
let summaryRequestMode = requestMode(for: historyWasRewrittenLocally)
13+
let compactionOutcome = await compactor.compactOrTruncateIfNeeded(
14+
&messages,
15+
lastTotalTokens: lastTotalTokens,
16+
totalUsage: &totalUsage
17+
) { summaryRequest in
18+
try await self.client.generateForRun(
19+
messages: summaryRequest,
20+
tools: self.toolDefinitions,
21+
responseFormat: nil,
22+
requestContext: nil,
23+
requestMode: summaryRequestMode
24+
)
25+
}
26+
if compactionOutcome.didRewriteHistory {
27+
historyWasRewrittenLocally = true
28+
}
29+
30+
let response = try await generateRunResponse(
31+
messages: &messages,
32+
compactor: &compactor,
33+
historyWasRewrittenLocally: &historyWasRewrittenLocally,
34+
requestContext: requestContext
35+
)
36+
messages.append(.assistant(response))
37+
if let usage = response.tokenUsage {
38+
totalUsage += usage
39+
lastTotalTokens = usage.total
40+
}
41+
return response
42+
}
43+
44+
func generateRunResponse(
45+
messages: inout [ChatMessage],
46+
compactor: inout ContextCompactor,
47+
historyWasRewrittenLocally: inout Bool,
48+
requestContext: RequestContext?
49+
) async throws -> AssistantMessage {
50+
var attemptedReactiveRecovery = false
51+
52+
while true {
53+
do {
54+
let response = try await client.generateForRun(
55+
messages: messages,
56+
tools: toolDefinitions,
57+
responseFormat: nil,
58+
requestContext: requestContext,
59+
requestMode: requestMode(for: historyWasRewrittenLocally)
60+
)
61+
historyWasRewrittenLocally = false
62+
return response
63+
} catch let AgentError.llmError(transport) where transport.isPromptTooLong {
64+
guard !attemptedReactiveRecovery else {
65+
throw AgentError.llmError(transport)
66+
}
67+
attemptedReactiveRecovery = true
68+
let reactiveOutcome = compactor.reactiveCompact(&messages)
69+
guard reactiveOutcome.didRewriteHistory else {
70+
throw AgentError.llmError(transport)
71+
}
72+
historyWasRewrittenLocally = true
73+
}
74+
}
75+
}
76+
77+
func requestMode(for historyWasRewrittenLocally: Bool) -> RunRequestMode {
78+
historyWasRewrittenLocally ? .forceFullRequest : .auto
79+
}
80+
81+
func parseFinishResult(
82+
_ call: ToolCall,
83+
tokenUsage: TokenUsage,
84+
iterations: Int,
85+
history: [ChatMessage]
86+
) throws -> AgentResult {
87+
let decoded: FinishArguments
88+
do {
89+
decoded = try JSONDecoder().decode(FinishArguments.self, from: call.argumentsData)
90+
} catch {
91+
throw AgentError.finishDecodingFailed(message: String(describing: error))
92+
}
93+
return AgentResult(
94+
finishReason: FinishReason(decoded.reason ?? "completed"),
95+
content: decoded.content,
96+
totalTokenUsage: tokenUsage,
97+
iterations: iterations,
98+
history: history
99+
)
100+
}
101+
102+
func makeTerminalResult(
103+
reason: FinishReason,
104+
tokenUsage: TokenUsage,
105+
iterations: Int,
106+
history: [ChatMessage]
107+
) -> AgentResult {
108+
AgentResult(
109+
finishReason: reason,
110+
content: nil,
111+
totalTokenUsage: tokenUsage,
112+
iterations: iterations,
113+
history: history
114+
)
115+
}
116+
}

Sources/AgentRunKit/Core/Agent.swift

Lines changed: 20 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -148,24 +148,19 @@ extension Agent {
148148
client: client, toolDefinitions: toolDefinitions, configuration: configuration
149149
)
150150
var budgetPhase = try makeBudgetPhase()
151+
var historyWasRewrittenLocally = false
151152

152153
for iteration in 1 ... configuration.maxIterations {
153154
try Task.checkCancellation()
154155

155-
await compactor.compactOrTruncateIfNeeded(
156-
&messages, lastTotalTokens: lastTotalTokens, totalUsage: &totalUsage
157-
)
158-
let response = try await client.generate(
159-
messages: messages,
160-
tools: toolDefinitions,
161-
responseFormat: nil,
156+
let response = try await executeRunIteration(
157+
messages: &messages,
158+
totalUsage: &totalUsage,
159+
lastTotalTokens: &lastTotalTokens,
160+
compactor: &compactor,
161+
historyWasRewrittenLocally: &historyWasRewrittenLocally,
162162
requestContext: options.requestContext
163163
)
164-
messages.append(.assistant(response))
165-
if let usage = response.tokenUsage {
166-
totalUsage += usage
167-
lastTotalTokens = usage.total
168-
}
169164
let budgetUsage = try requireBudgetUsage(response.tokenUsage, budgetPhase: budgetPhase)
170165

171166
if let finishCall = response.toolCalls.first(where: { $0.name == "finish" }) {
@@ -189,7 +184,13 @@ extension Agent {
189184
let pruneCalls = response.toolCalls.filter { $0.name == "prune_context" }
190185
let regularCalls = response.toolCalls.filter { $0.name != "finish" && $0.name != "prune_context" }
191186

192-
executePruneCalls(pruneCalls, messages: &messages)
187+
let pruneRewroteHistory = executePruneCalls(
188+
pruneCalls,
189+
messages: &messages
190+
)
191+
if pruneRewroteHistory {
192+
historyWasRewrittenLocally = true
193+
}
193194
try await executeAndAppendResults(
194195
regularCalls, context: context, messages: &messages,
195196
approvalHandler: options.approvalHandler, allowlist: &sessionAllowlist
@@ -281,10 +282,14 @@ extension Agent {
281282
for iterationNumber in 1 ... configuration.maxIterations {
282283
try Task.checkCancellation()
283284

284-
let compacted = await compactor.compactOrTruncateIfNeeded(
285+
let compactionOutcome = await compactor.compactOrTruncateIfNeeded(
285286
&messages, lastTotalTokens: lastTotalTokens, totalUsage: &totalUsage
286287
)
287-
emitCompactionEventIfNeeded(compacted, lastTotalTokens: lastTotalTokens, continuation: continuation)
288+
emitCompactionEventIfNeeded(
289+
compactionOutcome.emitsCompactionEvent,
290+
lastTotalTokens: lastTotalTokens,
291+
continuation: continuation
292+
)
288293
let iteration = try await processor.process(
289294
messages: messages, totalUsage: &totalUsage, continuation: continuation,
290295
requestContext: options.requestContext
@@ -376,28 +381,6 @@ extension Agent {
376381
)
377382
}
378383

379-
func parseFinishResult(
380-
_ call: ToolCall,
381-
tokenUsage: TokenUsage,
382-
iterations: Int,
383-
history: [ChatMessage]
384-
) throws -> AgentResult {
385-
let data = call.argumentsData
386-
let decoded: FinishArguments
387-
do {
388-
decoded = try JSONDecoder().decode(FinishArguments.self, from: data)
389-
} catch {
390-
throw AgentError.finishDecodingFailed(message: String(describing: error))
391-
}
392-
return AgentResult(
393-
finishReason: FinishReason(decoded.reason ?? "completed"),
394-
content: decoded.content,
395-
totalTokenUsage: tokenUsage,
396-
iterations: iterations,
397-
history: history
398-
)
399-
}
400-
401384
private func makeFinishedEvent(
402385
tokenUsage: TokenUsage,
403386
content: String?,
@@ -412,21 +395,6 @@ extension Agent {
412395
))
413396
}
414397

415-
private func makeTerminalResult(
416-
reason: FinishReason,
417-
tokenUsage: TokenUsage,
418-
iterations: Int,
419-
history: [ChatMessage]
420-
) -> AgentResult {
421-
AgentResult(
422-
finishReason: reason,
423-
content: nil,
424-
totalTokenUsage: tokenUsage,
425-
iterations: iterations,
426-
history: history
427-
)
428-
}
429-
430398
private func emitCompactionEventIfNeeded(
431399
_ compacted: Bool,
432400
lastTotalTokens: Int?,

0 commit comments

Comments
 (0)