Skip to content

Commit 557915f

Browse files
committed
add(checkpointer): persistent checkpoints, agent.resume, and AgentStream observer rehydration
1 parent fe43d03 commit 557915f

42 files changed

Lines changed: 3287 additions & 328 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import Foundation
2+
3+
extension Agent {
4+
@discardableResult
5+
func checkpointIfConfigured(
6+
iterationNumber: Int,
7+
state: StreamingLoopState,
8+
totalUsage: TokenUsage,
9+
iterationUsage: TokenUsage?,
10+
eventFactory: StreamEventFactory,
11+
checkpointer: (any AgentCheckpointer)?
12+
) async throws -> CheckpointID? {
13+
guard let checkpointer,
14+
let sessionID = eventFactory.sessionID,
15+
let runID = eventFactory.runID
16+
else { return nil }
17+
let checkpoint = AgentCheckpoint(
18+
messages: state.messages,
19+
iteration: iterationNumber,
20+
tokenUsage: totalUsage,
21+
iterationUsage: iterationUsage,
22+
contextBudgetState: state.budgetPhase?.checkpointState,
23+
historyWasRewrittenLocally: state.historyWasRewrittenLocally,
24+
sessionAllowlist: state.sessionAllowlist,
25+
sessionID: sessionID,
26+
runID: runID,
27+
mcpToolBindings: mcpToolBindings(in: state.messages)
28+
)
29+
try await checkpointer.save(checkpoint)
30+
return checkpoint.checkpointID
31+
}
32+
33+
func mcpToolBindings(in messages: [ChatMessage]) -> Set<MCPToolBinding> {
34+
let participatingNames = Set(messages.flatMap { message -> [String] in
35+
switch message {
36+
case let .assistant(assistant): return assistant.toolCalls.map(\.name)
37+
case let .tool(_, name, _): return [name]
38+
default: return []
39+
}
40+
})
41+
return Set(tools.compactMap { tool in
42+
guard participatingNames.contains(tool.name),
43+
let mcpTool = tool as? MCPTool<C>
44+
else { return nil }
45+
return mcpTool.checkpointBinding
46+
})
47+
}
48+
}

Sources/AgentRunKit/Core/Agent+ContextBudget.swift

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,22 @@ extension Agent {
1717
_ budgetPhase: inout ContextBudgetPhase?,
1818
usage: TokenUsage,
1919
messages: inout [ChatMessage],
20-
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation? = nil
20+
emit: StreamEmitter? = nil
2121
) {
2222
guard var phase = budgetPhase else { return }
2323
let result = phase.afterResponse(usage: usage, messages: &messages)
2424
budgetPhase = phase
25-
continuation?.yield(.make(.budgetUpdated(budget: result.budget)))
25+
emit?.yield(.budgetUpdated(budget: result.budget))
2626
if result.advisoryEmitted {
27-
continuation?.yield(.make(.budgetAdvisory(budget: result.budget)))
27+
emit?.yield(.budgetAdvisory(budget: result.budget))
2828
}
2929
}
3030

3131
@discardableResult
3232
func executePruneCalls(
3333
_ calls: [IndexedToolCall],
3434
messages: inout [ChatMessage],
35-
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation? = nil
35+
emit: StreamEmitter? = nil
3636
) -> (historyWasRewritten: Bool, results: [IndexedToolResult]) {
3737
let pruneEnabled = configuration.contextBudget?.enablePruneTool == true
3838
var historyWasRewritten = false
@@ -54,11 +54,9 @@ extension Agent {
5454
}
5555
}
5656
results.append(IndexedToolResult(index: indexed.index, call: indexed.call, result: result))
57-
continuation?.yield(.make(.toolCallCompleted(
58-
id: indexed.call.id,
59-
name: indexed.call.name,
60-
result: result
61-
)))
57+
emit?.yield(.toolCallCompleted(
58+
id: indexed.call.id, name: indexed.call.name, result: result
59+
))
6260
}
6361
return (historyWasRewritten: historyWasRewritten, results: results)
6462
}
@@ -82,7 +80,7 @@ extension Agent {
8280
var allResults = try await executeIndexedCalls(autoExecute, context: executionContext, approvalHandler: handler)
8381

8482
let (approved, denied) = try await resolveApprovals(
85-
needsApproval, handler: handler, allowlist: &allowlist, continuation: nil
83+
needsApproval, handler: handler, allowlist: &allowlist
8684
)
8785
try Task.checkCancellation()
8886

@@ -97,47 +95,50 @@ extension Agent {
9795

9896
func executeStreamingResults(
9997
_ calls: [IndexedToolCall], context: C, messages: [ChatMessage],
98+
options: InvocationOptions,
10099
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation,
101-
approvalHandler: ToolApprovalHandler? = nil, allowlist: inout Set<String>
100+
allowlist: inout Set<String>
102101
) async throws -> [IndexedToolResult] {
103102
guard !calls.isEmpty else { return [] }
104103
let executionContext = context.withParentHistory(messages.resolvedPrefixForInheritance())
104+
let approvalHandler = options.approvalHandler
105+
let emit = StreamEmitter(factory: options.eventFactory, continuation: continuation)
105106

106107
guard let handler = approvalHandler, configuration.approvalPolicy != .none else {
107108
return try await executeIndexedStreamingCalls(
108109
calls,
109110
context: executionContext,
110-
continuation: continuation,
111-
approvalHandler: approvalHandler
111+
options: options,
112+
continuation: continuation
112113
)
113114
}
114115

115116
let (autoExecute, needsApproval) = partitionCallsRequiringApproval(calls, allowlist: allowlist)
116117
var allResults = try await executeIndexedStreamingCalls(
117118
autoExecute,
118119
context: executionContext,
119-
continuation: continuation,
120-
approvalHandler: handler
120+
options: options,
121+
continuation: continuation
121122
)
122123

123124
let (approved, denied) = try await resolveApprovals(
124-
needsApproval, handler: handler, allowlist: &allowlist, continuation: continuation
125+
needsApproval, handler: handler, emit: emit, allowlist: &allowlist
125126
)
126127
try Task.checkCancellation()
127128

128129
for entry in denied {
129130
let truncatedEntry = truncatedIndexedToolResult(entry)
130-
continuation.yield(.make(.toolCallCompleted(
131+
emit.yield(.toolCallCompleted(
131132
id: truncatedEntry.call.id, name: truncatedEntry.call.name, result: truncatedEntry.result
132-
)))
133+
))
133134
allResults.append(truncatedEntry)
134135
}
135136

136137
try await allResults.append(contentsOf: executeIndexedStreamingCalls(
137138
approved,
138139
context: executionContext,
139-
continuation: continuation,
140-
approvalHandler: handler
140+
options: options,
141+
continuation: continuation
141142
))
142143
return allResults.sorted { $0.index < $1.index }
143144
}
@@ -194,14 +195,14 @@ extension Agent {
194195
private func executeIndexedStreamingCalls(
195196
_ calls: [IndexedToolCall],
196197
context: C,
197-
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation,
198-
approvalHandler: ToolApprovalHandler?
198+
options: InvocationOptions,
199+
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation
199200
) async throws -> [IndexedToolResult] {
200201
let results = try await executeToolsStreaming(
201202
calls.map(\.call),
202203
context: context,
203-
continuation: continuation,
204-
approvalHandler: approvalHandler
204+
options: options,
205+
continuation: continuation
205206
)
206207
return zip(calls, results).map { indexed, entry in
207208
IndexedToolResult(

Sources/AgentRunKit/Core/Agent+IterationHistory.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ extension Agent {
66
iterationNumber: Int,
77
messages: [ChatMessage],
88
context: C,
9+
eventFactory: StreamEventFactory,
910
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation
1011
) {
1112
guard let usage = iteration.usage else { return }
12-
continuation.yield(.make(.iterationCompleted(
13+
continuation.yield(eventFactory.make(.iterationCompleted(
1314
usage: usage,
1415
iteration: iterationNumber,
1516
history: emittedIterationHistory(messages: messages, context: context)
@@ -34,14 +35,16 @@ extension Agent {
3435
case let .iterationCompleted(usage, iteration, history) where depth > limit && !history.isEmpty:
3536
return StreamEvent(
3637
id: event.id, timestamp: event.timestamp,
37-
sessionID: event.sessionID, runID: event.runID, parentEventID: event.parentEventID,
38+
sessionID: event.sessionID, runID: event.runID,
39+
parentEventID: event.parentEventID, origin: event.origin,
3840
kind: .iterationCompleted(usage: usage, iteration: iteration, history: [])
3941
)
4042
case let .subAgentEvent(toolCallId, toolName, nested):
4143
let rewritten = rewritingHistoryEmission(in: nested, depth: depth + 1, limit: limit)
4244
return StreamEvent(
4345
id: event.id, timestamp: event.timestamp,
44-
sessionID: event.sessionID, runID: event.runID, parentEventID: event.parentEventID,
46+
sessionID: event.sessionID, runID: event.runID,
47+
parentEventID: event.parentEventID, origin: event.origin,
4548
kind: .subAgentEvent(toolCallId: toolCallId, toolName: toolName, event: rewritten)
4649
)
4750
default:
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import Foundation
2+
3+
extension Agent {
4+
/// Resumes a previously checkpointed run, replaying iteration history before continuing live.
5+
public func resume(
6+
from checkpointID: CheckpointID,
7+
checkpointer: any AgentCheckpointer,
8+
context: C,
9+
tokenBudget: Int? = nil,
10+
requestContext: RequestContext? = nil,
11+
approvalHandler: ToolApprovalHandler? = nil
12+
) async throws -> AsyncThrowingStream<StreamEvent, Error> {
13+
let target = try await checkpointer.load(checkpointID)
14+
return try await resume(
15+
target: target, checkpointer: checkpointer, context: context,
16+
tokenBudget: tokenBudget, requestContext: requestContext,
17+
approvalHandler: approvalHandler
18+
)
19+
}
20+
21+
func resume(
22+
target: AgentCheckpoint,
23+
checkpointer: any AgentCheckpointer,
24+
context: C,
25+
tokenBudget: Int? = nil,
26+
requestContext: RequestContext? = nil,
27+
approvalHandler: ToolApprovalHandler? = nil
28+
) async throws -> AsyncThrowingStream<StreamEvent, Error> {
29+
let options = InvocationOptions(
30+
tokenBudget: tokenBudget, requestContext: requestContext,
31+
systemPromptOverride: nil, approvalHandler: approvalHandler,
32+
sessionID: target.sessionID, runID: RunID(),
33+
checkpointer: checkpointer
34+
)
35+
validateInvocation(options)
36+
try target.messages.validateForAgentHistory()
37+
try validateMCPBindings(target.mcpToolBindings)
38+
return AsyncThrowingStream { continuation in
39+
let task = Task { [self] in
40+
do {
41+
try await replayAndContinueResume(
42+
target: target, context: context,
43+
options: options, continuation: continuation
44+
)
45+
} catch {
46+
continuation.finish(throwing: error)
47+
}
48+
}
49+
continuation.onTermination = { _ in task.cancel() }
50+
}
51+
}
52+
53+
private func validateMCPBindings(_ checkpointed: Set<MCPToolBinding>) throws {
54+
guard !checkpointed.isEmpty else { return }
55+
let liveBindings: Set<MCPToolBinding> = Set(
56+
tools.compactMap { ($0 as? MCPTool<C>)?.checkpointBinding }
57+
)
58+
let missing = checkpointed.subtracting(liveBindings)
59+
guard missing.isEmpty else {
60+
throw AgentCheckpointError.mcpBindingMismatch(Array(missing))
61+
}
62+
}
63+
64+
private func replayAndContinueResume(
65+
target: AgentCheckpoint,
66+
context: C,
67+
options: InvocationOptions,
68+
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation
69+
) async throws {
70+
let replayFactory = StreamEventFactory(
71+
sessionID: target.sessionID,
72+
runID: target.runID,
73+
origin: .replayed(from: target.checkpointID)
74+
)
75+
continuation.yield(replayFactory.make(.iterationCompleted(
76+
usage: target.iterationUsage ?? TokenUsage(),
77+
iteration: target.iteration,
78+
history: target.messages
79+
)))
80+
var state = StreamingLoopState(
81+
messages: target.messages,
82+
historyWasRewrittenLocally: true,
83+
budgetPhase: target.contextBudgetState.map(ContextBudgetPhase.init(checkpointState:)),
84+
sessionAllowlist: target.sessionAllowlist
85+
)
86+
if let earlyFinish = earlyFinishEvent(target: target, options: options) {
87+
finishStreaming(continuation: continuation, event: earlyFinish)
88+
return
89+
}
90+
try await performStreamLoop(
91+
state: &state, startIteration: target.iteration + 1,
92+
totalUsage: target.tokenUsage, lastTotalTokens: target.iterationUsage?.total,
93+
context: context, options: options,
94+
continuation: continuation
95+
)
96+
}
97+
98+
private func earlyFinishEvent(target: AgentCheckpoint, options: InvocationOptions) -> StreamEvent? {
99+
if let tokenBudget = options.tokenBudget, target.tokenUsage.total > tokenBudget {
100+
return makeFinishedEvent(
101+
tokenUsage: target.tokenUsage, content: nil,
102+
reason: .tokenBudgetExceeded(budget: tokenBudget, used: target.tokenUsage.total),
103+
history: target.messages, eventFactory: options.eventFactory
104+
)
105+
}
106+
if target.iteration >= configuration.maxIterations {
107+
return makeFinishedEvent(
108+
tokenUsage: target.tokenUsage, content: nil,
109+
reason: .maxIterationsReached(limit: configuration.maxIterations),
110+
history: target.messages, eventFactory: options.eventFactory
111+
)
112+
}
113+
return nil
114+
}
115+
}

0 commit comments

Comments
 (0)