diff --git a/src/routes/chat.ts b/src/routes/chat.ts index 04623a7..05b7825 100644 --- a/src/routes/chat.ts +++ b/src/routes/chat.ts @@ -4,6 +4,50 @@ import { getSystemPrompt } from '../domains'; const LITELLM_SERVER_URL = process.env.LITELLM_SERVER_URL || 'http://localhost:4000'; +// Tool definitions for AI function calling +const TOOLS = [ + { + type: 'function', + function: { + name: 'get_user', + description: 'Get information about the current authenticated user including their ID, role, and token details', + parameters: { + type: 'object', + properties: {}, + required: [] + } + } + } +]; + +// Execute tool calls and return results +function executeToolCall(toolName: string, args: any, req: Request): string { + switch (toolName) { + case 'get_user': + if (req.user) { + return JSON.stringify({ + userId: req.user.sub, + role: req.user.role || 'unknown', + issuer: req.user.iss, + audience: req.user.aud, + issuedAt: req.user.iat ? new Date(req.user.iat * 1000).toISOString() : null, + expiresAt: req.user.exp ? new Date(req.user.exp * 1000).toISOString() : null, + scope: req.user.scope || null + }); + } else { + return JSON.stringify({ + error: 'No authenticated user', + message: 'This endpoint requires authentication to retrieve user information' + }); + } + default: + return JSON.stringify({ + error: 'Unknown tool', + message: `Tool '${toolName}' is not available` + }); + } +} + // Map fish names to internal security levels const FISH_TO_LEVEL: Record = { minnow: 'insecure', @@ -116,7 +160,9 @@ export async function chatHandler(req: Request, res: Response): Promise { // Prepare LiteLLM request const litellmRequest: any = { - messages: messages + messages: messages, + tools: TOOLS, + tool_choice: 'auto' }; // Add model if provided @@ -124,26 +170,78 @@ export async function chatHandler(req: Request, res: Response): Promise { litellmRequest.model = model; } - // Forward request to LiteLLM server - const response = await fetch(`${LITELLM_SERVER_URL}/v1/chat/completions`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(litellmRequest), - }); + // Tool call loop - continue until we get a final response + const MAX_TOOL_ITERATIONS = 10; + let iteration = 0; - if (!response.ok) { - const errorText = await response.text(); - res.status(response.status).json({ - error: 'LiteLLM server error', - message: errorText + while (iteration < MAX_TOOL_ITERATIONS) { + iteration++; + + // Forward request to LiteLLM server + const response = await fetch(`${LITELLM_SERVER_URL}/v1/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(litellmRequest), }); + + if (!response.ok) { + const errorText = await response.text(); + res.status(response.status).json({ + error: 'LiteLLM server error', + message: errorText + }); + return; + } + + const data = await response.json(); + const assistantMessage = data.choices?.[0]?.message; + + // Check if the model wants to call tools + if (assistantMessage?.tool_calls && assistantMessage.tool_calls.length > 0) { + // Add the assistant's message with tool calls to the conversation + litellmRequest.messages.push({ + role: 'assistant', + content: assistantMessage.content || null, + tool_calls: assistantMessage.tool_calls + }); + + // Execute each tool call and add results + for (const toolCall of assistantMessage.tool_calls) { + const toolName = toolCall.function.name; + let toolArgs = {}; + + try { + toolArgs = JSON.parse(toolCall.function.arguments || '{}'); + } catch { + // If parsing fails, use empty args + } + + const toolResult = executeToolCall(toolName, toolArgs, req); + + // Add tool result to messages + litellmRequest.messages.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: toolResult + }); + } + + // Continue the loop to get the model's response to tool results + continue; + } + + // No tool calls - return the final response + res.json(data); return; } - const data = await response.json(); - res.json(data); + // If we hit max iterations, return an error + res.status(500).json({ + error: 'Tool call limit exceeded', + message: `Maximum tool call iterations (${MAX_TOOL_ITERATIONS}) reached` + }); } catch (error) { console.error('Error in chat handler:', error); res.status(500).json({