Skip to content

Commit dc0ad44

Browse files
committed
fix openAI tool arguments consistency
1 parent 454cd8b commit dc0ad44

6 files changed

Lines changed: 29 additions & 14 deletions

bridge_integration_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,9 @@ func TestOpenAIChatCompletions(t *testing.T) {
442442

443443
require.Len(t, recorderClient.toolUsages, 1)
444444
assert.Equal(t, "read_file", recorderClient.toolUsages[0].Tool)
445-
require.IsType(t, "", recorderClient.toolUsages[0].Args)
445+
require.IsType(t, map[string]any{}, recorderClient.toolUsages[0].Args)
446446
require.Contains(t, recorderClient.toolUsages[0].Args, "path")
447-
assert.Equal(t, "README.md", gjson.Get(recorderClient.toolUsages[0].Args.(string), "path").Str)
447+
assert.Equal(t, "README.md", recorderClient.toolUsages[0].Args.(map[string]any)["path"])
448448

449449
require.Len(t, recorderClient.userPrompts, 1)
450450
assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt)
@@ -852,7 +852,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
852852
// Ensure expected tool was invoked with expected input.
853853
require.Len(t, recorderClient.toolUsages, 1)
854854
require.Equal(t, mockToolName, recorderClient.toolUsages[0].Tool)
855-
expected := "{\"owner\":\"admin\"}"
855+
expected := map[string]any{"owner": "admin"}
856856
require.EqualValues(t, expected, recorderClient.toolUsages[0].Args)
857857

858858
var (

intercept_openai_chat_base.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"net/http"
7+
"strings"
78

89
"github.com/coder/aibridge/mcp"
910
"github.com/coder/aibridge/tracing"
@@ -104,6 +105,18 @@ func (i *OpenAIChatInterceptionBase) injectTools() {
104105
}
105106
}
106107

108+
func (i *OpenAIChatInterceptionBase) unmarshalArgs(in string) (args ToolArgs) {
109+
if len(strings.TrimSpace(in)) == 0 {
110+
return args // An empty string will fail JSON unmarshaling.
111+
}
112+
113+
if err := json.Unmarshal([]byte(in), &args); err != nil {
114+
i.logger.Warn(context.Background(), "failed to unmarshal tool args", slog.Error(err))
115+
}
116+
117+
return args
118+
}
119+
107120
// writeUpstreamError marshals and writes a given error.
108121
func (i *OpenAIChatInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *OpenAIErrorResponse) {
109122
if oaiErr == nil {

intercept_openai_chat_blocking.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
119119
InterceptionID: i.ID().String(),
120120
MsgID: completion.ID,
121121
Tool: toolCall.Function.Name,
122-
Args: toolCall.Function.Arguments,
122+
Args: i.unmarshalArgs(toolCall.Function.Arguments),
123123
Injected: false,
124124
})
125125
}
@@ -150,13 +150,14 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
150150
appendedPrevMsg = true
151151
}
152152

153-
res, err := tool.Call(ctx, i.tracer, tc.Function.Arguments)
153+
args := i.unmarshalArgs(tc.Function.Arguments)
154+
res, err := tool.Call(ctx, i.tracer, args)
154155
_ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{
155156
InterceptionID: i.ID().String(),
156157
MsgID: completion.ID,
157158
ServerURL: &tool.ServerURL,
158159
Tool: tool.Name,
159-
Args: tc.Function.Arguments,
160+
Args: args,
160161
Injected: true,
161162
InvocationError: err,
162163
})

intercept_openai_chat_streaming.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
154154
InterceptionID: i.ID().String(),
155155
MsgID: processor.getMsgID(),
156156
Tool: toolCall.Name,
157-
Args: toolCall.Arguments,
157+
Args: i.unmarshalArgs(toolCall.Arguments),
158158
Injected: false,
159159
})
160160
toolCall = nil
@@ -241,13 +241,14 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
241241
i.req.Messages = append(i.req.Messages, processor.getLastCompletion().ToParam())
242242

243243
id := toolCall.ID
244-
toolRes, toolErr := tool.Call(streamCtx, i.tracer, toolCall.Arguments)
244+
args := i.unmarshalArgs(toolCall.Arguments)
245+
toolRes, toolErr := tool.Call(streamCtx, i.tracer, args)
245246
_ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{
246247
InterceptionID: i.ID().String(),
247248
MsgID: processor.getMsgID(),
248249
ServerURL: &tool.ServerURL,
249250
Tool: tool.Name,
250-
Args: toolCall.Arguments,
251+
Args: args,
251252
Injected: true,
252253
InvocationError: toolErr,
253254
})

mcp/tool.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp
5454
attribute.String(tracing.MCPServerName, t.ServerName),
5555
attribute.String(tracing.MCPServerURL, t.ServerURL),
5656
)
57+
ctx, span := tracer.Start(ctx, "Intercept.ProcessRequest.ToolCall", trace.WithAttributes(spanAttrs...))
58+
defer tracing.EndSpanErr(span, &outErr)
59+
5760
inputJson, err := json.Marshal(input)
5861
if err != nil {
5962
t.Logger.Warn(ctx, "failed to marshal tool input, will be omitted from span attrs: %v", err)
@@ -62,12 +65,9 @@ func (t *Tool) Call(ctx context.Context, tracer trace.Tracer, input any) (_ *mcp
6265
if len(strJson) > maxSpanInputAttrLen {
6366
strJson = strJson[:100]
6467
}
65-
spanAttrs = append(spanAttrs, attribute.String(tracing.MCPInput, strJson))
68+
span.SetAttributes(attribute.String(tracing.MCPInput, strJson))
6669
}
6770

68-
ctx, span := tracer.Start(ctx, "Intercept.ProcessRequest.ToolCall", trace.WithAttributes(spanAttrs...))
69-
defer tracing.EndSpanErr(span, &outErr)
70-
7171
return t.Client.CallTool(ctx, mcp.CallToolRequest{
7272
Params: mcp.CallToolParams{
7373
Name: t.Name,

trace_integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) {
638638
attribute.String(tracing.Provider, aibridge.ProviderOpenAI),
639639
attribute.String(tracing.Model, gjson.Get(reqBody, "model").Str),
640640
attribute.String(tracing.InitiatorID, userID),
641-
attribute.String(tracing.MCPInput, "\"{\\\"owner\\\":\\\"admin\\\"}\""),
641+
attribute.String(tracing.MCPInput, "{\"owner\":\"admin\"}"),
642642
attribute.String(tracing.MCPToolName, "coder_list_workspaces"),
643643
attribute.String(tracing.MCPServerName, tool.ServerName),
644644
attribute.String(tracing.MCPServerURL, tool.ServerURL),

0 commit comments

Comments
 (0)