Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions integration/test/tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type Agent struct {
openaiClient openai.Client
mcpSession *mcp.ClientSession
mcpClient *mcp.Client
mcpTransport *mcp.StreamableClientTransport
tools []*mcp.Tool
model string
}
Expand Down Expand Up @@ -118,6 +119,7 @@ func NewAgent(llmUserName, llmToken, llmBaseURL, openaiModel, mcpServerURL strin
openaiClient: openaiClient,
mcpSession: session,
mcpClient: mcpClient,
mcpTransport: mcpTransport,
tools: toolsResult.Tools,
model: openaiModel,
}, nil
Expand All @@ -135,11 +137,35 @@ func (a *Agent) convertMCPToolsToOpenAI() []openai.ChatCompletionToolUnionParam
return tools
}

func (a *Agent) reconnect(ctx context.Context) error {
if a.mcpSession != nil {
_ = a.mcpSession.Close()
a.mcpSession = nil
}
session, err := a.mcpClient.Connect(ctx, a.mcpTransport, nil)
if err != nil {
return fmt.Errorf("failed to reconnect to MCP server: %w", err)
}
a.mcpSession = session
slog.Info("Reconnected to MCP server")
return nil
}

func (a *Agent) callMCPTool(ctx context.Context, toolName string, arguments map[string]any) (string, error) {
Comment thread
rahulguptajss marked this conversation as resolved.
result, err := a.mcpSession.CallTool(ctx, &mcp.CallToolParams{
Name: toolName,
Arguments: arguments,
})
if err != nil && errors.Is(err, mcp.ErrConnectionClosed) {
slog.Warn("MCP session dropped, reconnecting", slog.Any("error", err))
if reconnErr := a.reconnect(ctx); reconnErr != nil {
return "", fmt.Errorf("failed to call tool: %w", errors.Join(err, fmt.Errorf("reconnect failed: %w", reconnErr)))
}
result, err = a.mcpSession.CallTool(ctx, &mcp.CallToolParams{
Name: toolName,
Arguments: arguments,
})
}
Comment thread
rahulguptajss marked this conversation as resolved.
if err != nil {
return "", fmt.Errorf("failed to call tool: %w", err)
}
Expand All @@ -166,6 +192,12 @@ func (a *Agent) callMCPTool(ctx context.Context, toolName string, arguments map[

func (a *Agent) ChatWithResponse(ctx context.Context, t *testing.T, userMessage string, expectedOntapErrorStr string) (string, error) {
messages := []openai.ChatCompletionMessageParamUnion{
openai.SystemMessage("When a tool call fails due to a validation error, check two things: " +
"(1) Did you misassign a value to the wrong parameter (e.g. passed the cluster name as the svm name)? " +
"(2) Did you misread or partially extract a value from the user's message (e.g. dropped a path prefix like /vol/vol1/)? " +
"In either case, correct the mistake using only values already present in the conversation — " +
"never truncate, invent, or change values beyond what is needed to fix the extraction or mapping error. " +
"If the error cannot be fixed this way, report it to the user and stop."),
openai.UserMessage(userMessage),
}

Expand Down
Loading