Skip to content

Commit 1937653

Browse files
authored
chore: replace ResponsesNewParamsWrapper with ResponsesRequestPayload in responses interceptor (#213)
Refactored the responses interceptor replacing `ResponsesNewParamsWrapper` with a new `ResponsesRequestPayload` type that wraps raw JSON bytes and consolidates operations made on request payload.
1 parent 5c071a7 commit 1937653

11 files changed

Lines changed: 615 additions & 234 deletions

File tree

intercept/responses/base.go

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,19 @@ const (
3838
)
3939

4040
type responsesInterceptionBase struct {
41-
id uuid.UUID
42-
req *ResponsesNewParamsWrapper
43-
reqPayload []byte
44-
cfg config.OpenAI
45-
model string
46-
41+
id uuid.UUID
4742
// clientHeaders are the original HTTP headers from the client request.
4843
clientHeaders http.Header
4944
authHeaderName string
45+
reqPayload ResponsesRequestPayload
5046

47+
cfg config.OpenAI
5148
recorder recorder.Recorder
5249
mcpProxy mcp.ServerProxier
53-
logger slog.Logger
54-
metrics metrics.Metrics
55-
tracer trace.Tracer
50+
51+
logger slog.Logger
52+
metrics metrics.Metrics
53+
tracer trace.Tracer
5654
}
5755

5856
func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService {
@@ -88,26 +86,37 @@ func (i *responsesInterceptionBase) ID() uuid.UUID {
8886
}
8987

9088
func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
91-
i.logger = logger.With(slog.F("model", i.model))
89+
i.logger = logger.With(slog.F("model", i.Model()))
9290
i.recorder = recorder
9391
i.mcpProxy = mcpProxy
9492
}
9593

9694
func (i *responsesInterceptionBase) Model() string {
97-
return i.model
95+
return i.reqPayload.model()
9896
}
9997

10098
func (i *responsesInterceptionBase) CorrelatingToolCallID() *string {
101-
if len(i.req.Input.OfInputItemList) == 0 {
99+
items := gjson.GetBytes(i.reqPayload, "input")
100+
if !items.IsArray() {
101+
return nil
102+
}
103+
104+
arr := items.Array()
105+
if len(arr) == 0 {
106+
return nil
107+
}
108+
109+
last := arr[len(arr)-1]
110+
if last.Get(string(constant.ValueOf[constant.Type]())).String() != string(constant.ValueOf[constant.FunctionCallOutput]()) {
102111
return nil
103112
}
104113

105-
// The tool result should be the last input message.
106-
item := i.req.Input.OfInputItemList[len(i.req.Input.OfInputItemList)-1]
107-
if item.OfFunctionCallOutput == nil {
114+
callID := last.Get("call_id").String()
115+
if callID == "" {
108116
return nil
109117
}
110-
return &item.OfFunctionCallOutput.CallID
118+
119+
return &callID
111120
}
112121

113122
func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
@@ -122,13 +131,7 @@ func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streami
122131
}
123132

124133
func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.ResponseWriter) error {
125-
if i.req == nil {
126-
err := errors.New("developer error: req is nil")
127-
i.sendCustomErr(ctx, w, http.StatusInternalServerError, err)
128-
return err
129-
}
130-
131-
if i.req.Background.Value {
134+
if i.reqPayload.background() {
132135
err := fmt.Errorf("background requests are currently not supported by AI Bridge")
133136
i.sendCustomErr(ctx, w, http.StatusNotImplemented, err)
134137
return err
@@ -161,15 +164,15 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o
161164
// eg. Codex CLI produces requests without ID set in reasoning items: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-item-reasoning-id
162165
// when re-encoded, ID field is set to empty string which results
163166
// in bad request while not sending ID field at all somehow works.
164-
option.WithRequestBody("application/json", i.reqPayload),
167+
option.WithRequestBody("application/json", []byte(i.reqPayload)),
165168

166169
// copyMiddleware copies body of original response body to the buffer in responseCopier,
167170
// also reference to headers and status code is kept responseCopier.
168171
// responseCopier is used by interceptors to forward response as it was received,
169172
// eliminating any possibility of JSON re-encoding issues.
170173
option.WithMiddleware(respCopy.copyMiddleware),
171174
}
172-
if !i.req.Stream {
175+
if !i.reqPayload.Stream() {
173176
opts = append(opts, option.WithRequestTimeout(requestTimeout))
174177
}
175178
return opts
@@ -182,77 +185,80 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (string,
182185
if i == nil {
183186
return "", false, errors.New("cannot get last user prompt: nil struct")
184187
}
185-
if i.req == nil {
188+
if i.reqPayload == nil {
186189
return "", false, errors.New("cannot get last user prompt: nil request struct")
187190
}
188191

189-
// 'input' field can be a string or array of objects:
192+
// 'input' can be either a string or an array of input items:
190193
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input
191-
192-
// Check string variant
193-
if i.req.Input.OfString.Valid() {
194-
return i.req.Input.OfString.Value, true, nil
194+
inputItems := gjson.GetBytes(i.reqPayload, "input")
195+
if !inputItems.Exists() || inputItems.Type == gjson.Null {
196+
return "", false, nil
195197
}
196198

197-
// Fallback to parsing original bytes since golang SDK doesn't properly decode 'Input' field.
198-
// If 'type' field of input item is not set it will be omitted from 'Input.OfInputItemList'
199-
// It is an optional field according to API: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message
200-
// example: fixtures/openai/responses/blocking/builtin_tool.txtar
201-
inputItems := gjson.GetBytes(i.reqPayload, "input")
199+
// String variant: treat the whole input as the user prompt.
200+
if inputItems.Type == gjson.String {
201+
return inputItems.String(), true, nil
202+
}
202203

204+
// Array variant: checking only the last input item
203205
if !inputItems.IsArray() {
204-
if inputItems.Type == gjson.Null {
205-
return "", false, nil
206-
}
207-
return "", false, fmt.Errorf("unexpected input type: %v", inputItems.Type.String())
206+
return "", false, fmt.Errorf("unexpected input type: %s", inputItems.Type)
208207
}
209208

210209
inputItemsArr := inputItems.Array()
211210
if len(inputItemsArr) == 0 {
212211
return "", false, nil
213212
}
214-
lastItem := inputItemsArr[len(inputItemsArr)-1]
215213

216-
// Request was likely not human-initiated.
214+
lastItem := inputItemsArr[len(inputItemsArr)-1]
217215
if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) {
216+
// Request was likely not initiated by a prompt but is an iteration of agentic loop.
218217
return "", false, nil
219218
}
220219

221-
// content can be a string or array of objects:
220+
// Message content can be either a string or an array of typed content items:
222221
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content
223222
content := lastItem.Get(string(constant.ValueOf[constant.Content]()))
223+
if !content.Exists() || content.Type == gjson.Null {
224+
return "", false, nil
225+
}
226+
227+
// String variant: use it directly as the prompt.
228+
if content.Type == gjson.String {
229+
return content.Str, true, nil
230+
}
224231

225-
// non array case, should be string
226232
if !content.IsArray() {
227-
if content.Type == gjson.String {
228-
return content.Str, true, nil
229-
}
230-
return "", false, fmt.Errorf("unexpected input content type: %v", content.Type.String())
233+
return "", false, fmt.Errorf("unexpected input content type: %s", content.Type)
231234
}
232235

233236
var sb strings.Builder
234237
promptExists := false
235238
for _, c := range content.Array() {
236-
// ignore inputs of not `input_text` type
239+
// Ignore non-text content blocks such as images or files.
237240
if c.Get(string(constant.ValueOf[constant.Type]())).Str != string(constant.ValueOf[constant.InputText]()) {
238241
continue
239242
}
240243

241244
text := c.Get(string(constant.ValueOf[constant.Text]()))
242-
if text.Type == gjson.String {
243-
promptExists = true
244-
sb.WriteString(text.Str + "\n")
245-
} else {
245+
if text.Type != gjson.String {
246246
i.logger.Warn(ctx, fmt.Sprintf("unexpected input content array element text type: %v", text.Type))
247+
continue
248+
}
249+
250+
if promptExists {
251+
sb.WriteByte('\n')
247252
}
253+
promptExists = true
254+
sb.WriteString(text.Str)
248255
}
249256

250257
if !promptExists {
251258
return "", false, nil
252259
}
253260

254-
prompt := strings.TrimSuffix(sb.String(), "\n")
255-
return prompt, true, nil
261+
return sb.String(), true, nil
256262
}
257263

258264
func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) {

intercept/responses/base_test.go

Lines changed: 32 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/coder/aibridge/utils"
1313
"github.com/google/uuid"
1414
oairesponses "github.com/openai/openai-go/v3/responses"
15+
"github.com/stretchr/testify/assert"
1516
"github.com/stretchr/testify/require"
1617
)
1718

@@ -20,95 +21,53 @@ func TestScanForCorrelatingToolCallID(t *testing.T) {
2021

2122
tests := []struct {
2223
name string
23-
input []oairesponses.ResponseInputItemUnionParam
24-
expected *string
24+
payload []byte
25+
wantCall *string
2526
}{
2627
{
27-
name: "no input items",
28-
input: nil,
29-
expected: nil,
28+
name: "no input",
29+
payload: []byte(`{"model":"gpt-4o"}`),
3030
},
3131
{
32-
name: "no function_call_output items",
33-
input: []oairesponses.ResponseInputItemUnionParam{
34-
{
35-
OfMessage: &oairesponses.EasyInputMessageParam{
36-
Role: "user",
37-
},
38-
},
39-
},
40-
expected: nil,
32+
name: "empty input array",
33+
payload: []byte(`{"model":"gpt-4o","input":[]}`),
4134
},
4235
{
43-
name: "single function_call_output",
44-
input: []oairesponses.ResponseInputItemUnionParam{
45-
{
46-
OfMessage: &oairesponses.EasyInputMessageParam{
47-
Role: "user",
48-
},
49-
},
50-
{
51-
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
52-
CallID: "call_abc",
53-
},
54-
},
55-
},
56-
expected: utils.PtrTo("call_abc"),
36+
name: "no function_call_output items",
37+
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"}]}`),
5738
},
5839
{
59-
name: "multiple function_call_outputs returns last",
60-
input: []oairesponses.ResponseInputItemUnionParam{
61-
{
62-
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
63-
CallID: "call_first",
64-
},
65-
},
66-
{
67-
OfMessage: &oairesponses.EasyInputMessageParam{
68-
Role: "user",
69-
},
70-
},
71-
{
72-
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
73-
CallID: "call_second",
74-
},
75-
},
76-
},
77-
expected: utils.PtrTo("call_second"),
40+
name: "single function_call_output",
41+
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_abc","output":"result"}]}`),
42+
wantCall: utils.PtrTo("call_abc"),
7843
},
7944
{
80-
name: "last input is not a tool result",
81-
input: []oairesponses.ResponseInputItemUnionParam{
82-
{
83-
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
84-
CallID: "call_first",
85-
},
86-
},
87-
{
88-
OfMessage: &oairesponses.EasyInputMessageParam{
89-
Role: "user",
90-
},
91-
},
92-
},
93-
expected: nil,
45+
name: "multiple function_call_outputs returns last",
46+
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_second","output":"r2"}]}`),
47+
wantCall: utils.PtrTo("call_second"),
48+
},
49+
{
50+
name: "last input is not a tool result",
51+
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"}]}`),
52+
},
53+
{
54+
name: "missing call id",
55+
payload: []byte(`{"input":[{"type":"function_call_output","output":"ok"}]}`),
9456
},
9557
}
9658

9759
for _, tc := range tests {
9860
t.Run(tc.name, func(t *testing.T) {
9961
t.Parallel()
10062

63+
rp, err := NewResponsesRequestPayload(tc.payload)
64+
require.NoError(t, err)
10165
base := &responsesInterceptionBase{
102-
req: &ResponsesNewParamsWrapper{
103-
ResponseNewParams: oairesponses.ResponseNewParams{
104-
Input: oairesponses.ResponseNewParamsInputUnion{
105-
OfInputItemList: tc.input,
106-
},
107-
},
108-
},
66+
reqPayload: rp,
10967
}
11068

111-
require.Equal(t, tc.expected, base.CorrelatingToolCallID())
69+
callID := base.CorrelatingToolCallID()
70+
assert.Equal(t, tc.wantCall, callID)
11271
})
11372
}
11473
}
@@ -161,13 +120,10 @@ func TestLastUserPrompt(t *testing.T) {
161120
t.Run(tc.name, func(t *testing.T) {
162121
t.Parallel()
163122

164-
req := &ResponsesNewParamsWrapper{}
165-
err := req.UnmarshalJSON(tc.reqPayload)
123+
rp, err := NewResponsesRequestPayload(tc.reqPayload)
166124
require.NoError(t, err)
167-
168125
base := &responsesInterceptionBase{
169-
req: req,
170-
reqPayload: tc.reqPayload,
126+
reqPayload: rp,
171127
}
172128

173129
prompt, promptFound, err := base.lastUserPrompt(t.Context())
@@ -253,13 +209,11 @@ func TestLastUserPromptNotFound(t *testing.T) {
253209
t.Run(tc.name, func(t *testing.T) {
254210
t.Parallel()
255211

256-
req := &ResponsesNewParamsWrapper{}
257-
err := req.UnmarshalJSON(tc.reqPayload)
212+
rp, err := NewResponsesRequestPayload(tc.reqPayload)
258213
require.NoError(t, err)
259214

260215
base := &responsesInterceptionBase{
261-
req: req,
262-
reqPayload: tc.reqPayload,
216+
reqPayload: rp,
263217
}
264218

265219
prompt, promptFound, err := base.lastUserPrompt(t.Context())

0 commit comments

Comments
 (0)