Skip to content

Commit d42f238

Browse files
committed
feat: implement session-based API tool management and enhance tool message handling
1 parent 8e29785 commit d42f238

1 file changed

Lines changed: 79 additions & 3 deletions

File tree

agent/middleware/apiBasedTools.ts

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,70 @@ function getEnabledApiToolNames(messages: unknown[]) {
4747
return enabledToolNames;
4848
}
4949

50+
const loadedApiToolNamesBySession = new Map<string, Set<string>>();
51+
52+
function getSessionLoadedApiToolNames(sessionId: string) {
53+
let toolNames = loadedApiToolNamesBySession.get(sessionId);
54+
55+
if (!toolNames) {
56+
toolNames = new Set<string>();
57+
loadedApiToolNamesBySession.set(sessionId, toolNames);
58+
}
59+
60+
return toolNames;
61+
}
62+
63+
function getEnabledApiToolNamesForSession(messages: unknown[], sessionId?: string) {
64+
const enabledToolNames = getEnabledApiToolNames(messages);
65+
66+
if (!sessionId) {
67+
return enabledToolNames;
68+
}
69+
70+
for (const toolName of getSessionLoadedApiToolNames(sessionId)) {
71+
enabledToolNames.add(toolName);
72+
}
73+
74+
return enabledToolNames;
75+
}
76+
77+
function getToolMessageContent(message: unknown) {
78+
if (!ToolMessage.isInstance(message)) {
79+
return "";
80+
}
81+
82+
return typeof message.content === "string"
83+
? message.content
84+
: Array.isArray(message.content)
85+
? message.content
86+
.map((block) =>
87+
typeof block === "string"
88+
? block
89+
: "text" in block
90+
? block.text
91+
: "",
92+
)
93+
.join("")
94+
: "";
95+
}
96+
97+
function rememberLoadedToolFromFetchResult(sessionId: string, result: unknown) {
98+
if (!ToolMessage.isInstance(result) || result.name !== "fetch_tool_schema") {
99+
return;
100+
}
101+
102+
try {
103+
const parsed = JSON.parse(getToolMessageContent(result)) as {
104+
status?: number;
105+
name?: string;
106+
};
107+
108+
if (parsed.status === 200 && parsed.name) {
109+
getSessionLoadedApiToolNames(sessionId).add(parsed.name);
110+
}
111+
} catch {}
112+
}
113+
50114
export function createApiBasedToolsMiddleware(
51115
apiBasedTools: Record<string, ApiBasedTool>,
52116
adminforth: IAdminForth,
@@ -62,7 +126,11 @@ export function createApiBasedToolsMiddleware(
62126
return createMiddleware({
63127
name: "ApiBasedToolsMiddleware",
64128
async wrapModelCall(request, handler) {
65-
const enabledApiToolNames = getEnabledApiToolNames(request.state.messages);
129+
const { sessionId } = request.runtime.context as { sessionId?: string };
130+
const enabledApiToolNames = getEnabledApiToolNamesForSession(
131+
request.state.messages,
132+
sessionId,
133+
);
66134
const tools = [...enabledApiToolNames]
67135
.filter((toolName) => !alwaysAvailableApiToolNames.has(toolName))
68136
.map((toolName) => dynamicTools[toolName]);
@@ -80,9 +148,10 @@ export function createApiBasedToolsMiddleware(
80148
async wrapToolCall(request, handler) {
81149
const startedAt = Date.now();
82150
const toolInput = JSON.stringify(request.toolCall.args ?? {});
83-
const { adminUser, emitToolCallEvent, userTimeZone } = request.runtime.context as {
151+
const { adminUser, emitToolCallEvent, sessionId, userTimeZone } = request.runtime.context as {
84152
adminUser: AdminUser;
85153
emitToolCallEvent: ToolCallEventSink;
154+
sessionId: string;
86155
userTimeZone: string;
87156
};
88157
const toolArgs = (request.toolCall.args ?? {}) as Record<string, unknown>;
@@ -120,7 +189,10 @@ export function createApiBasedToolsMiddleware(
120189
if (request.tool) {
121190
result = await handler(request);
122191
} else {
123-
const enabledApiToolNames = getEnabledApiToolNames(request.state.messages);
192+
const enabledApiToolNames = getEnabledApiToolNamesForSession(
193+
request.state.messages,
194+
sessionId,
195+
);
124196

125197
if (enabledApiToolNames.has(request.toolCall.name)) {
126198
result = await handler({
@@ -137,6 +209,10 @@ export function createApiBasedToolsMiddleware(
137209
}
138210
}
139211

212+
if (sessionId) {
213+
rememberLoadedToolFromFetchResult(sessionId, result);
214+
}
215+
140216
toolCallTracker.finishSuccess(result);
141217
return result;
142218
} catch (error) {

0 commit comments

Comments
 (0)