Skip to content

Commit 8890752

Browse files
committed
refactor(agent): centralize tool execution helpers
1 parent ac13556 commit 8890752

11 files changed

Lines changed: 63 additions & 71 deletions

File tree

Sources/AgentRunKit/Core/AgentLoop/Agent+ContextBudget.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ extension Agent {
150150
}
151151

152152
func toolResultCharacterLimit(for toolName: String) -> Int? {
153-
tool(named: toolName)?.maxResultCharacters ?? configuration.maxToolResultCharacters
153+
firstTool(named: toolName, in: tools)?.maxResultCharacters ?? configuration.maxToolResultCharacters
154154
}
155155

156156
func truncatedToolResult(_ result: ToolResult, toolName: String) -> ToolResult {

Sources/AgentRunKit/Core/AgentLoop/Agent+Run.swift

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,7 @@ extension Agent {
9797
iterations: Int,
9898
history: [ChatMessage]
9999
) throws -> AgentResult {
100-
let decoded: FinishArguments
101-
do {
102-
decoded = try JSONDecoder().decode(FinishArguments.self, from: call.argumentsData)
103-
} catch {
104-
throw AgentError.finishDecodingFailed(message: String(describing: error))
105-
}
100+
let decoded = try decodeFinishArguments(from: call.argumentsData)
106101
return try AgentResult(
107102
finishReason: FinishReason(decoded.reason ?? "completed"),
108103
content: decoded.content,

Sources/AgentRunKit/Core/AgentLoop/Agent+ToolApproval.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ extension Agent {
2626
var denied: [IndexedToolResult] = []
2727

2828
for indexed in calls {
29-
guard let tool = tool(named: indexed.call.name) else {
29+
guard let tool = firstTool(named: indexed.call.name, in: tools) else {
3030
approved.append(indexed)
3131
continue
3232
}

Sources/AgentRunKit/Core/AgentLoop/Agent+ToolExecution.swift

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,21 @@
11
import Foundation
22

33
extension Agent {
4-
func tool(named name: String) -> (any AnyTool<C>)? {
5-
tools.first(where: { $0.name == name })
6-
}
7-
84
func resolveTimeout(for call: ToolCall) -> Duration {
9-
guard let tool = tool(named: call.name) else {
5+
guard let tool = firstTool(named: call.name, in: tools) else {
106
return configuration.toolTimeout
117
}
128
return resolvedToolTimeout(for: tool, default: configuration.toolTimeout)
139
}
1410

15-
func withTimeout<T: Sendable>(
16-
_ timeout: Duration,
17-
toolName: String,
18-
operation: @Sendable @escaping () async throws -> T
19-
) async throws -> T {
20-
try await withThrowingTaskGroup(of: T.self) { group in
21-
group.addTask { try await operation() }
22-
group.addTask {
23-
try await Task.sleep(for: timeout)
24-
throw AgentError.toolTimeout(tool: toolName)
25-
}
26-
guard let result = try await group.next() else {
27-
preconditionFailure("ThrowingTaskGroup with two tasks must yield a result")
28-
}
29-
group.cancelAll()
30-
return result
31-
}
32-
}
33-
3411
func executeWithTimeout(
3512
_ call: ToolCall, context: C, approvalHandler: ToolApprovalHandler? = nil
3613
) async throws -> ToolResult {
3714
do {
38-
return try await withTimeout(resolveTimeout(for: call), toolName: call.name) {
15+
return try await withToolTimeout(resolveTimeout(for: call), toolName: call.name) {
3916
if let handler = approvalHandler,
40-
let approvalAware = self.tool(named: call.name) as? any ApprovalAwareSubAgentTool<C> {
17+
let tool = firstTool(named: call.name, in: self.tools),
18+
let approvalAware = tool as? any ApprovalAwareSubAgentTool<C> {
4119
return try await approvalAware.executeWithApproval(
4220
arguments: call.argumentsData, context: context, approvalHandler: handler
4321
)
@@ -73,7 +51,7 @@ extension Agent {
7351

7452
let result: ToolResult
7553
do {
76-
result = try await withTimeout(resolveTimeout(for: call), toolName: call.name) {
54+
result = try await withToolTimeout(resolveTimeout(for: call), toolName: call.name) {
7755
try await tool.executeStreaming(
7856
toolCallId: call.id, arguments: call.argumentsData,
7957
context: context, parentSessionID: eventFactory.sessionID,
@@ -108,7 +86,7 @@ extension Agent {
10886
))
10987
} else {
11088
let call = wave.calls[0]
111-
let result: ToolResult = if let streamableTool = tool(named: call.name)
89+
let result: ToolResult = if let streamableTool = firstTool(named: call.name, in: tools)
11290
as? any StreamableSubAgentTool<C> {
11391
try await executeStreamableWithTimeout(
11492
call, tool: streamableTool, context: context,
@@ -150,7 +128,7 @@ extension Agent {
150128
}
151129

152130
func executeTool(_ call: ToolCall, context: C) async throws -> ToolResult {
153-
guard let tool = tool(named: call.name) else {
131+
guard let tool = firstTool(named: call.name, in: tools) else {
154132
throw AgentError.toolNotFound(name: call.name)
155133
}
156134
return try await tool.execute(arguments: call.argumentsData, context: context)
@@ -161,7 +139,7 @@ extension Agent {
161139
var waves: [ExecutionWave] = []
162140
var safeBatch: [ToolCall] = []
163141
for call in calls {
164-
if tool(named: call.name)?.isConcurrencySafe ?? false {
142+
if firstTool(named: call.name, in: tools)?.isConcurrencySafe ?? false {
165143
safeBatch.append(call)
166144
} else {
167145
if !safeBatch.isEmpty {
@@ -210,7 +188,7 @@ extension Agent {
210188
return try await withThrowingTaskGroup(of: (Int, ToolCall, ToolResult).self) { group in
211189
for (index, call) in calls.enumerated() {
212190
group.addTask {
213-
let result: ToolResult = if let streamableTool = self.tool(named: call.name)
191+
let result: ToolResult = if let streamableTool = firstTool(named: call.name, in: self.tools)
214192
as? any StreamableSubAgentTool<C> {
215193
try await self.executeStreamableWithTimeout(
216194
call, tool: streamableTool, context: context,

Sources/AgentRunKit/Core/AgentLoop/Chat.swift

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ private extension Chat {
317317
}
318318

319319
func toolResultCharacterLimit(for toolName: String) -> Int? {
320-
tool(named: toolName)?.maxResultCharacters ?? maxToolResultCharacters
320+
firstTool(named: toolName, in: tools)?.maxResultCharacters ?? maxToolResultCharacters
321321
}
322322

323323
func truncatedToolResult(_ result: ToolResult, toolName: String) -> ToolResult {
@@ -335,7 +335,7 @@ private extension Chat {
335335
allowlist: inout Set<String>,
336336
continuation: AsyncThrowingStream<StreamEvent, Error>.Continuation
337337
) async throws -> ToolResult {
338-
guard let tool = tool(named: call.name) else {
338+
guard let tool = firstTool(named: call.name, in: tools) else {
339339
return .error(AgentError.toolNotFound(name: call.name).feedbackMessage)
340340
}
341341

@@ -400,25 +400,13 @@ private extension Chat {
400400
approvalHandler: ToolApprovalHandler? = nil
401401
) async throws -> ToolResult {
402402
do {
403-
return try await withThrowingTaskGroup(of: ToolResult.self) { group in
404-
group.addTask {
405-
try await executeTool(
406-
call,
407-
with: resolvedTool,
408-
context: context,
409-
approvalHandler: approvalHandler
410-
)
411-
}
412-
group.addTask {
413-
try await Task.sleep(for: self.resolveTimeout(for: resolvedTool))
414-
throw AgentError.toolTimeout(tool: call.name)
415-
}
416-
417-
guard let result = try await group.next() else {
418-
preconditionFailure("ThrowingTaskGroup with two tasks must yield a result")
419-
}
420-
group.cancelAll()
421-
return result
403+
return try await withToolTimeout(resolveTimeout(for: resolvedTool), toolName: call.name) {
404+
try await self.executeTool(
405+
call,
406+
with: resolvedTool,
407+
context: context,
408+
approvalHandler: approvalHandler
409+
)
422410
}
423411
} catch is CancellationError {
424412
throw CancellationError()
@@ -445,8 +433,4 @@ private extension Chat {
445433
}
446434
return try await tool.execute(arguments: call.argumentsData, context: context)
447435
}
448-
449-
func tool(named name: String) -> (any AnyTool<C>)? {
450-
tools.first(where: { $0.name == name })
451-
}
452436
}

Sources/AgentRunKit/Core/AgentLoop/FinishArguments.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ public struct FinishArguments: Codable, Sendable {
1010
}
1111
}
1212

13+
func decodeFinishArguments(from arguments: Data) throws -> FinishArguments {
14+
do {
15+
return try JSONDecoder().decode(FinishArguments.self, from: arguments)
16+
} catch {
17+
throw AgentError.finishDecodingFailed(message: String(describing: error))
18+
}
19+
}
20+
1321
package let reservedFinishToolDefinition = ToolDefinition(
1422
name: "finish",
1523
description: """
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import Foundation
2+
3+
func firstTool<C: ToolContext>(
4+
named name: String,
5+
in tools: [any AnyTool<C>]
6+
) -> (any AnyTool<C>)? {
7+
tools.first(where: { $0.name == name })
8+
}
9+
10+
func withToolTimeout<T: Sendable>(
11+
_ timeout: Duration,
12+
toolName: String,
13+
operation: @Sendable @escaping () async throws -> T
14+
) async throws -> T {
15+
try await withThrowingTaskGroup(of: T.self) { group in
16+
group.addTask { try await operation() }
17+
group.addTask {
18+
try await Task.sleep(for: timeout)
19+
throw AgentError.toolTimeout(tool: toolName)
20+
}
21+
guard let result = try await group.next() else {
22+
preconditionFailure("ThrowingTaskGroup with two tasks must yield a result")
23+
}
24+
group.cancelAll()
25+
return result
26+
}
27+
}

Sources/AgentRunKit/Core/Streaming/Agent+StreamEvents.swift

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@ extension Agent {
77
history: [ChatMessage],
88
eventFactory: StreamEventFactory
99
) throws -> StreamEvent {
10-
let decoded: FinishArguments
11-
do {
12-
decoded = try JSONDecoder().decode(FinishArguments.self, from: finishCall.argumentsData)
13-
} catch {
14-
throw AgentError.finishDecodingFailed(message: String(describing: error))
15-
}
10+
let decoded = try decodeFinishArguments(from: finishCall.argumentsData)
1611
return try makeFinishedEvent(
1712
tokenUsage: tokenUsage,
1813
content: decoded.content,

Tests/AgentRunKitTests/Core/AgentLoop/AgentTests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ struct AgentTests {
194194
}.last
195195
#expect(toolMessage?.0 == "slow")
196196
#expect(toolMessage?.1.contains("timed out") == true)
197+
#expect(toolMessage?.1.contains("'slow'") == true)
197198
}
198199

199200
@Test

Tests/AgentRunKitTests/Core/AgentLoop/ChatTimeoutTests.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ struct ChatTimeoutTests {
4949

5050
let toolCompletedEvent = events.first { event in
5151
if case let .toolCallCompleted(_, name, result) = event.kind {
52-
return name == "slow" && result.isError && result.content.contains("timed out")
52+
return name == "slow"
53+
&& result.isError
54+
&& result.content.contains("timed out")
55+
&& result.content.contains("'slow'")
5356
}
5457
return false
5558
}

0 commit comments

Comments
 (0)