diff --git a/integration/test/tools_test.go b/integration/test/tools_test.go index 89abd32..4355547 100644 --- a/integration/test/tools_test.go +++ b/integration/test/tools_test.go @@ -79,6 +79,7 @@ type Agent struct { openaiClient openai.Client mcpSession *mcp.ClientSession mcpClient *mcp.Client + mcpTransport *mcp.StreamableClientTransport tools []*mcp.Tool model string } @@ -118,6 +119,7 @@ func NewAgent(llmUserName, llmToken, llmBaseURL, openaiModel, mcpServerURL strin openaiClient: openaiClient, mcpSession: session, mcpClient: mcpClient, + mcpTransport: mcpTransport, tools: toolsResult.Tools, model: openaiModel, }, nil @@ -135,11 +137,35 @@ func (a *Agent) convertMCPToolsToOpenAI() []openai.ChatCompletionToolUnionParam return tools } +func (a *Agent) reconnect(ctx context.Context) error { + if a.mcpSession != nil { + _ = a.mcpSession.Close() + a.mcpSession = nil + } + session, err := a.mcpClient.Connect(ctx, a.mcpTransport, nil) + if err != nil { + return fmt.Errorf("failed to reconnect to MCP server: %w", err) + } + a.mcpSession = session + slog.Info("Reconnected to MCP server") + return nil +} + func (a *Agent) callMCPTool(ctx context.Context, toolName string, arguments map[string]any) (string, error) { result, err := a.mcpSession.CallTool(ctx, &mcp.CallToolParams{ Name: toolName, Arguments: arguments, }) + if err != nil && errors.Is(err, mcp.ErrConnectionClosed) { + slog.Warn("MCP session dropped, reconnecting", slog.Any("error", err)) + if reconnErr := a.reconnect(ctx); reconnErr != nil { + return "", fmt.Errorf("failed to call tool: %w", errors.Join(err, fmt.Errorf("reconnect failed: %w", reconnErr))) + } + result, err = a.mcpSession.CallTool(ctx, &mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + }) + } if err != nil { return "", fmt.Errorf("failed to call tool: %w", err) } @@ -166,6 +192,12 @@ func (a *Agent) callMCPTool(ctx context.Context, toolName string, arguments map[ func (a *Agent) ChatWithResponse(ctx context.Context, t *testing.T, userMessage string, expectedOntapErrorStr string) (string, error) { messages := []openai.ChatCompletionMessageParamUnion{ + openai.SystemMessage("When a tool call fails due to a validation error, check two things: " + + "(1) Did you misassign a value to the wrong parameter (e.g. passed the cluster name as the svm name)? " + + "(2) Did you misread or partially extract a value from the user's message (e.g. dropped a path prefix like /vol/vol1/)? " + + "In either case, correct the mistake using only values already present in the conversation — " + + "never truncate, invent, or change values beyond what is needed to fix the extraction or mapping error. " + + "If the error cannot be fixed this way, report it to the user and stop."), openai.UserMessage(userMessage), }