Skip to content

Commit ec569df

Browse files
committed
add(core): human-in-the-loop tool approval with sub-agent propagation
1 parent 59b0d38 commit ec569df

32 files changed

Lines changed: 2185 additions & 124 deletions

Sources/AgentRunKit/Core/Agent+ContextBudget.swift

Lines changed: 123 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,27 +59,139 @@ extension Agent {
5959
}
6060

6161
func executeAndAppendResults(
62-
_ calls: [ToolCall], context: C, messages: inout [ChatMessage]
62+
_ calls: [ToolCall], context: C, messages: inout [ChatMessage],
63+
approvalHandler: ToolApprovalHandler? = nil, allowlist: inout Set<String>
6364
) async throws {
6465
guard !calls.isEmpty else { return }
65-
let results = try await executeToolsInParallel(calls, context: context.withParentHistory(messages))
66-
for (call, result) in results {
67-
let content = ContextCompactor.truncateToolResult(result.content, configuration: configuration)
68-
messages.append(.tool(id: call.id, name: call.name, content: content))
66+
let executionContext = context.withParentHistory(messages)
67+
68+
guard let handler = approvalHandler, configuration.approvalPolicy != .none else {
69+
let results = try await executeToolsInParallel(
70+
calls,
71+
context: executionContext,
72+
approvalHandler: approvalHandler
73+
)
74+
for (call, result) in results {
75+
let content = ContextCompactor.truncateToolResult(result.content, configuration: configuration)
76+
messages.append(.tool(id: call.id, name: call.name, content: content))
77+
}
78+
return
79+
}
80+
81+
var autoExecute: [IndexedToolCall] = []
82+
var needsApproval: [IndexedToolCall] = []
83+
for (offset, call) in calls.enumerated() {
84+
let indexed = IndexedToolCall(index: offset, call: call)
85+
if requiresApproval(call, allowlist: allowlist) {
86+
needsApproval.append(indexed)
87+
} else {
88+
autoExecute.append(indexed)
89+
}
90+
}
91+
92+
var allResults: [IndexedToolResult] = []
93+
94+
if !autoExecute.isEmpty {
95+
let results = try await executeToolsInParallel(
96+
autoExecute.map(\.call), context: executionContext, approvalHandler: handler
97+
)
98+
for (position, (call, result)) in results.enumerated() {
99+
allResults.append(IndexedToolResult(index: autoExecute[position].index, call: call, result: result))
100+
}
101+
}
102+
103+
let (approved, denied) = try await resolveApprovals(
104+
needsApproval, handler: handler, allowlist: &allowlist, continuation: nil
105+
)
106+
try Task.checkCancellation()
107+
108+
allResults.append(contentsOf: denied)
109+
110+
if !approved.isEmpty {
111+
let results = try await executeToolsInParallel(
112+
approved.map(\.call), context: executionContext, approvalHandler: handler
113+
)
114+
for (position, (call, result)) in results.enumerated() {
115+
allResults.append(IndexedToolResult(index: approved[position].index, call: call, result: result))
116+
}
117+
}
118+
119+
allResults.sort { $0.index < $1.index }
120+
for entry in allResults {
121+
let content = ContextCompactor.truncateToolResult(entry.result.content, configuration: configuration)
122+
messages.append(.tool(id: entry.call.id, name: entry.call.name, content: content))
69123
}
70124
}
71125

72126
func executeStreamingAndAppendResults(
73127
_ calls: [ToolCall], context: C, messages: inout [ChatMessage],
74-
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation
128+
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation,
129+
approvalHandler: ToolApprovalHandler? = nil, allowlist: inout Set<String>
75130
) async throws {
76131
guard !calls.isEmpty else { return }
77-
let results = try await executeToolsStreaming(
78-
calls, context: context.withParentHistory(messages), continuation: continuation
132+
let executionContext = context.withParentHistory(messages)
133+
134+
guard let handler = approvalHandler, configuration.approvalPolicy != .none else {
135+
let results = try await executeToolsStreaming(
136+
calls,
137+
context: executionContext,
138+
continuation: continuation,
139+
approvalHandler: approvalHandler
140+
)
141+
for (call, result) in results {
142+
let content = ContextCompactor.truncateToolResult(result.content, configuration: configuration)
143+
messages.append(.tool(id: call.id, name: call.name, content: content))
144+
}
145+
return
146+
}
147+
148+
var autoExecute: [IndexedToolCall] = []
149+
var needsApproval: [IndexedToolCall] = []
150+
for (offset, call) in calls.enumerated() {
151+
let indexed = IndexedToolCall(index: offset, call: call)
152+
if requiresApproval(call, allowlist: allowlist) {
153+
needsApproval.append(indexed)
154+
} else {
155+
autoExecute.append(indexed)
156+
}
157+
}
158+
159+
var allResults: [IndexedToolResult] = []
160+
161+
if !autoExecute.isEmpty {
162+
let results = try await executeToolsStreaming(
163+
autoExecute.map(\.call), context: executionContext,
164+
continuation: continuation, approvalHandler: handler
165+
)
166+
for (position, (call, result)) in results.enumerated() {
167+
allResults.append(IndexedToolResult(index: autoExecute[position].index, call: call, result: result))
168+
}
169+
}
170+
171+
let (approved, denied) = try await resolveApprovals(
172+
needsApproval, handler: handler, allowlist: &allowlist, continuation: continuation
79173
)
80-
for (call, result) in results {
81-
let content = ContextCompactor.truncateToolResult(result.content, configuration: configuration)
82-
messages.append(.tool(id: call.id, name: call.name, content: content))
174+
try Task.checkCancellation()
175+
176+
for entry in denied {
177+
continuation.yield(.toolCallCompleted(id: entry.call.id, name: entry.call.name, result: entry.result))
178+
allResults.append(entry)
179+
}
180+
181+
if !approved.isEmpty {
182+
let results = try await executeToolsStreaming(
183+
approved.map(\.call), context: executionContext,
184+
continuation: continuation, approvalHandler: handler
185+
)
186+
for (position, (call, result)) in results.enumerated() {
187+
allResults.append(IndexedToolResult(index: approved[position].index, call: call, result: result))
188+
}
189+
}
190+
191+
allResults.sort { $0.index < $1.index }
192+
for entry in allResults {
193+
let content = ContextCompactor.truncateToolResult(entry.result.content, configuration: configuration)
194+
messages.append(.tool(id: entry.call.id, name: entry.call.name, content: content))
83195
}
84196
}
85197
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import Foundation
2+
3+
struct InvocationOptions {
4+
let tokenBudget: Int?
5+
let requestContext: RequestContext?
6+
let systemPromptOverride: String?
7+
let approvalHandler: ToolApprovalHandler?
8+
}
9+
10+
struct IndexedToolCall {
11+
let index: Int
12+
let call: ToolCall
13+
}
14+
15+
struct IndexedToolResult {
16+
let index: Int
17+
let call: ToolCall
18+
let result: ToolResult
19+
}
20+
21+
extension Agent {
22+
func requiresApproval(_ call: ToolCall, allowlist: Set<String>) -> Bool {
23+
configuration.approvalPolicy.requiresApproval(toolName: call.name, allowlist: allowlist)
24+
}
25+
26+
func resolveApprovals(
27+
_ calls: [IndexedToolCall],
28+
handler: @escaping ToolApprovalHandler,
29+
allowlist: inout Set<String>,
30+
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation?
31+
) async throws -> (approved: [IndexedToolCall], denied: [IndexedToolResult]) {
32+
var approved: [IndexedToolCall] = []
33+
var denied: [IndexedToolResult] = []
34+
35+
for indexed in calls {
36+
guard let tool = tool(named: indexed.call.name) else {
37+
approved.append(indexed)
38+
continue
39+
}
40+
41+
if allowlist.contains(indexed.call.name) {
42+
approved.append(indexed)
43+
continue
44+
}
45+
46+
let request = ToolApprovalRequest(
47+
toolCallId: indexed.call.id,
48+
toolName: indexed.call.name,
49+
arguments: indexed.call.arguments,
50+
toolDescription: tool.description
51+
)
52+
continuation?.yield(.toolApprovalRequested(request))
53+
let decision = try await awaitApprovalDecision(for: request, using: handler)
54+
continuation?.yield(.toolApprovalResolved(toolCallId: indexed.call.id, decision: decision))
55+
56+
switch decision {
57+
case .approve:
58+
approved.append(indexed)
59+
case .approveAlways:
60+
allowlist.insert(indexed.call.name)
61+
approved.append(indexed)
62+
case let .approveWithModifiedArguments(newArgs):
63+
let modified = ToolCall(id: indexed.call.id, name: indexed.call.name, arguments: newArgs)
64+
approved.append(IndexedToolCall(index: indexed.index, call: modified))
65+
case let .deny(reason):
66+
let result = ToolResult.error(reason ?? "Tool call was denied.")
67+
denied.append(IndexedToolResult(index: indexed.index, call: indexed.call, result: result))
68+
}
69+
}
70+
71+
return (approved: approved, denied: denied)
72+
}
73+
}

Sources/AgentRunKit/Core/Agent+ToolExecution.swift

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import Foundation
22

33
extension Agent {
4+
func tool(named name: String) -> (any AnyTool<C>)? {
5+
tools.first(where: { $0.name == name })
6+
}
7+
48
func resolveTimeout(for call: ToolCall) -> Duration? {
5-
guard let tool = tools.first(where: { $0.name == call.name }) else {
9+
guard let tool = tool(named: call.name) else {
610
return configuration.toolTimeout
711
}
812
if let overriding = tool as? any TimeoutOverriding {
@@ -31,10 +35,18 @@ extension Agent {
3135
}
3236
}
3337

34-
func executeWithTimeout(_ call: ToolCall, context: C) async throws -> ToolResult {
38+
func executeWithTimeout(
39+
_ call: ToolCall, context: C, approvalHandler: ToolApprovalHandler? = nil
40+
) async throws -> ToolResult {
3541
do {
3642
return try await withTimeout(resolveTimeout(for: call), toolName: call.name) {
37-
try await self.executeTool(call, context: context)
43+
if let handler = approvalHandler,
44+
let approvalAware = self.tool(named: call.name) as? any ApprovalAwareSubAgentTool<C> {
45+
return try await approvalAware.executeWithApproval(
46+
arguments: call.argumentsData, context: context, approvalHandler: handler
47+
)
48+
}
49+
return try await self.executeTool(call, context: context)
3850
}
3951
} catch is CancellationError {
4052
throw CancellationError()
@@ -49,7 +61,8 @@ extension Agent {
4961
_ call: ToolCall,
5062
tool: any StreamableSubAgentTool<C>,
5163
context: C,
52-
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation
64+
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation,
65+
approvalHandler: ToolApprovalHandler? = nil
5366
) async throws -> ToolResult {
5467
continuation.yield(.subAgentStarted(toolCallId: call.id, toolName: call.name))
5568

@@ -66,7 +79,8 @@ extension Agent {
6679
result = try await withTimeout(resolveTimeout(for: call), toolName: call.name) {
6780
try await tool.executeStreaming(
6881
toolCallId: call.id, arguments: call.argumentsData,
69-
context: context, eventHandler: eventHandler
82+
context: context, eventHandler: eventHandler,
83+
approvalHandler: approvalHandler
7084
)
7185
}
7286
} catch is CancellationError {
@@ -83,18 +97,20 @@ extension Agent {
8397
func executeToolsStreaming(
8498
_ calls: [ToolCall],
8599
context: C,
86-
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation
100+
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation,
101+
approvalHandler: ToolApprovalHandler? = nil
87102
) async throws -> [(call: ToolCall, result: ToolResult)] {
88103
try await withThrowingTaskGroup(of: (Int, ToolCall, ToolResult).self) { group in
89104
for (index, call) in calls.enumerated() {
90105
group.addTask {
91-
let result: ToolResult = if let streamableTool = self.tools.first(where: { $0.name == call.name })
106+
let result: ToolResult = if let streamableTool = self.tool(named: call.name)
92107
as? any StreamableSubAgentTool<C> {
93108
try await self.executeStreamableWithTimeout(
94-
call, tool: streamableTool, context: context, continuation: continuation
109+
call, tool: streamableTool, context: context,
110+
continuation: continuation, approvalHandler: approvalHandler
95111
)
96112
} else {
97-
try await self.executeWithTimeout(call, context: context)
113+
try await self.executeWithTimeout(call, context: context, approvalHandler: approvalHandler)
98114
}
99115
return (index, call, result)
100116
}
@@ -111,12 +127,15 @@ extension Agent {
111127

112128
func executeToolsInParallel(
113129
_ calls: [ToolCall],
114-
context: C
130+
context: C,
131+
approvalHandler: ToolApprovalHandler? = nil
115132
) async throws -> [(call: ToolCall, result: ToolResult)] {
116133
try await withThrowingTaskGroup(of: (Int, ToolCall, ToolResult).self) { group in
117134
for (index, call) in calls.enumerated() {
118135
group.addTask {
119-
let result = try await self.executeWithTimeout(call, context: context)
136+
let result = try await self.executeWithTimeout(
137+
call, context: context, approvalHandler: approvalHandler
138+
)
120139
return (index, call, result)
121140
}
122141
}
@@ -130,7 +149,7 @@ extension Agent {
130149
}
131150

132151
func executeTool(_ call: ToolCall, context: C) async throws -> ToolResult {
133-
guard let tool = tools.first(where: { $0.name == call.name }) else {
152+
guard let tool = tool(named: call.name) else {
134153
throw AgentError.toolNotFound(name: call.name)
135154
}
136155
return try await tool.execute(arguments: call.argumentsData, context: context)

0 commit comments

Comments
 (0)