Skip to content

Commit 4411d22

Browse files
committed
fix: serialize structured tool output before membrane pointerizing
1 parent eedae7a commit 4411d22

2 files changed

Lines changed: 150 additions & 1 deletion

File tree

Sources/Swarm/Agents/Agent.swift

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1885,7 +1885,7 @@ public struct Agent: AgentRuntime, Sendable {
18851885
}
18861886

18871887
if outcome.result.isSuccess {
1888-
var toolOutputText = outcome.result.output.stringValue ?? outcome.result.output.description
1888+
var toolOutputText = Self.toolOutputText(for: outcome.result.output)
18891889
if let membraneAdapter {
18901890
do {
18911891
let currentToolOutput = toolOutputText
@@ -1948,6 +1948,25 @@ public struct Agent: AgentRuntime, Sendable {
19481948
}
19491949
}
19501950

1951+
private static func toolOutputText(for output: SendableValue) -> String {
1952+
if let string = output.stringValue {
1953+
return string
1954+
}
1955+
1956+
do {
1957+
let object = output.toJSONObject()
1958+
let data = try JSONSerialization.data(
1959+
withJSONObject: object,
1960+
options: [.sortedKeys, .fragmentsAllowed]
1961+
)
1962+
if let text = String(data: data, encoding: .utf8) {
1963+
return text
1964+
}
1965+
} catch {}
1966+
1967+
return output.description
1968+
}
1969+
19511970
// MARK: - Handoff Tool Schema Integration
19521971

19531972
/// Builds tool schemas including handoff tool schemas.

Tests/SwarmTests/MembraneIntegrationTests.swift

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,41 @@ struct MembraneIntegrationTests {
126126
#expect(result.metadata["membrane.fallback.error"]?.stringValue?.contains("forced membrane failure") == true)
127127
}
128128

129+
@Test("Pointerized structured tool output resolves as JSON")
130+
func pointerizedStructuredToolOutputResolvesAsJSON() async throws {
131+
let provider = PointerResolvingInferenceProvider()
132+
let agent = try Agent(
133+
tools: [StructuredPayloadTool()],
134+
instructions: "Call the structured payload tool, resolve the pointer, then finish.",
135+
configuration: AgentConfiguration(
136+
name: "membrane-structured-tool-output",
137+
maxIterations: 4,
138+
defaultTracingEnabled: false
139+
),
140+
inferenceProvider: provider
141+
).environment(
142+
\.membrane,
143+
MembraneEnvironment(
144+
isEnabled: true,
145+
configuration: MembraneFeatureConfiguration(
146+
jitMinToolCount: 12,
147+
defaultJITLoadCount: 2,
148+
pointerThresholdBytes: 32,
149+
pointerSummaryMaxChars: 80
150+
)
151+
)
152+
)
153+
154+
let result = try await agent.run("exercise pointerized structured output")
155+
156+
let resolvedPointerOutput = try #require(result.toolResults.last?.output.stringValue)
157+
let data = Data(resolvedPointerOutput.utf8)
158+
let object = try JSONSerialization.jsonObject(with: data) as? [String: Any]
159+
let dictionary = try #require(object)
160+
#expect(dictionary["message"] as? String == "quoted \"value\" with newline\nsecond line")
161+
#expect(dictionary["count"] as? Int == 42)
162+
}
163+
129164
@Test("Default adapter checkpoint state roundtrips loaded tools")
130165
func defaultAdapterCheckpointRoundtrip() async throws {
131166
let adapter = DefaultMembraneAgentAdapter(
@@ -276,6 +311,101 @@ private struct MembraneTestTool: AnyJSONTool, Sendable {
276311
}
277312
}
278313

314+
private struct StructuredPayloadTool: AnyJSONTool, Sendable {
315+
let name = "structured_payload"
316+
let description = "Returns structured JSON-compatible data large enough to pointerize."
317+
let parameters: [ToolParameter] = []
318+
319+
func execute(arguments _: [String: SendableValue]) async throws -> SendableValue {
320+
.dictionary([
321+
"message": .string("quoted \"value\" with newline\nsecond line"),
322+
"count": .int(42),
323+
"items": .array((0 ..< 20).map { .string("item-\($0)") })
324+
])
325+
}
326+
}
327+
328+
private actor PointerResolvingInferenceProvider: InferenceProvider, ConversationInferenceProvider {
329+
private var turn = 0
330+
331+
func generate(prompt _: String, options _: InferenceOptions) async throws -> String {
332+
"done"
333+
}
334+
335+
nonisolated func stream(prompt _: String, options _: InferenceOptions) -> AsyncThrowingStream<String, Error> {
336+
AsyncThrowingStream { continuation in
337+
continuation.yield("done")
338+
continuation.finish()
339+
}
340+
}
341+
342+
func generateWithToolCalls(
343+
prompt _: String,
344+
tools _: [ToolSchema],
345+
options _: InferenceOptions
346+
) async throws -> InferenceResponse {
347+
try nextResponse(from: "")
348+
}
349+
350+
func generate(messages: [InferenceMessage], options _: InferenceOptions) async throws -> String {
351+
InferenceMessage.flattenPrompt(messages)
352+
}
353+
354+
nonisolated func stream(
355+
messages: [InferenceMessage],
356+
options _: InferenceOptions
357+
) -> AsyncThrowingStream<String, Error> {
358+
let text = InferenceMessage.flattenPrompt(messages)
359+
return AsyncThrowingStream { continuation in
360+
continuation.yield(text)
361+
continuation.finish()
362+
}
363+
}
364+
365+
func generateWithToolCalls(
366+
messages: [InferenceMessage],
367+
tools _: [ToolSchema],
368+
options _: InferenceOptions
369+
) async throws -> InferenceResponse {
370+
try nextResponse(from: InferenceMessage.flattenPrompt(messages))
371+
}
372+
373+
private func nextResponse(from prompt: String) throws -> InferenceResponse {
374+
defer { turn += 1 }
375+
switch turn {
376+
case 0:
377+
return InferenceResponse(
378+
toolCalls: [
379+
.init(id: "call-structured", name: "structured_payload", arguments: [:])
380+
],
381+
finishReason: .toolCall
382+
)
383+
case 1:
384+
let pointerID = try Self.pointerID(from: prompt)
385+
return InferenceResponse(
386+
toolCalls: [
387+
.init(
388+
id: "call-resolve",
389+
name: MembraneInternalToolName.resolvePointer,
390+
arguments: ["pointer_id": .string(pointerID)]
391+
)
392+
],
393+
finishReason: .toolCall
394+
)
395+
default:
396+
return InferenceResponse(content: "done", finishReason: .completed)
397+
}
398+
}
399+
400+
private static func pointerID(from prompt: String) throws -> String {
401+
guard let range = prompt.range(of: #"ptr_[0-9a-f]{12}"#, options: .regularExpression) else {
402+
struct MissingPointer: Error {}
403+
throw MissingPointer()
404+
}
405+
return String(prompt[range])
406+
}
407+
}
408+
279409
private actor ThrowingMembraneAdapter: MembraneAgentAdapter {
280410
func plan(
281411
prompt _: String,

0 commit comments

Comments
 (0)