diff --git a/api.go b/api.go index a3943c5..f8fc60f 100644 --- a/api.go +++ b/api.go @@ -62,5 +62,5 @@ func NewMetrics(reg prometheus.Registerer) *metrics.Metrics { } func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) Recorder { - return recorder.NewRecorder(logger, tracer, clientFn) + return recorder.NewWrappedRecorder(logger, tracer, clientFn) } diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index 4f78da3..ab744cb 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -177,10 +177,11 @@ func TestExecute_OnStateChange(t *testing.T) { // Trip circuit w := httptest.NewRecorder() - cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error { + err := cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error { rw.WriteHeader(http.StatusTooManyRequests) return nil }) + assert.NoError(t, err) // Verify state change callback was called with correct parameters assert.Len(t, stateChanges, 1) diff --git a/intercept/apidump/apidump.go b/intercept/apidump/apidump.go index 949af2e..7d7e3df 100644 --- a/intercept/apidump/apidump.go +++ b/intercept/apidump/apidump.go @@ -107,8 +107,9 @@ func (d *dumper) dumpRequest(req *http.Request) error { if err != nil { return xerrors.Errorf("write request header terminator: %w", err) } - buf.Write(prettyBody) - buf.WriteByte('\n') + // bytes.Buffer writes to in-memory storage and never return errors. + _, _ = buf.Write(prettyBody) + _ = buf.WriteByte('\n') return os.WriteFile(dumpPath, buf.Bytes(), 0o600) } diff --git a/intercept/apidump/streaming.go b/intercept/apidump/streaming.go index e2db42a..ef9805d 100644 --- a/intercept/apidump/streaming.go +++ b/intercept/apidump/streaming.go @@ -37,7 +37,7 @@ func (s *streamingBodyDumper) init() { // Write headers first. if _, err := s.file.Write(s.headerData); err != nil { s.initErr = xerrors.Errorf("write headers: %w", err) - s.file.Close() + _ = s.file.Close() // best-effort cleanup on header write failure s.file = nil } }) diff --git a/intercept/apidump/streaming_test.go b/intercept/apidump/streaming_test.go index 2a39c1b..47c0492 100644 --- a/intercept/apidump/streaming_test.go +++ b/intercept/apidump/streaming_test.go @@ -42,10 +42,12 @@ func TestMiddleware_StreamingResponse(t *testing.T) { // Create a pipe to simulate streaming pr, pw := io.Pipe() go func() { + defer pw.Close() //nolint:revive // error handled via pipe read side for _, chunk := range chunks { - pw.Write([]byte(chunk)) + if _, err := pw.Write([]byte(chunk)); err != nil { + return + } } - pw.Close() }() resp, err := middleware(req, func(r *http.Request) (*http.Response, error) { @@ -65,7 +67,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) { for { n, err := resp.Body.Read(buf) if n > 0 { - receivedData.Write(buf[:n]) + _, _ = receivedData.Write(buf[:n]) // bytes.Buffer.Write never fails } if err == io.EOF { break diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index f8ca562..2e3a72e 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -390,10 +390,11 @@ func (i *StreamingInterception) marshalErr(err error) ([]byte, error) { } func (*StreamingInterception) encodeForStream(payload []byte) []byte { + // bytes.Buffer writes to in-memory storage and never return errors. var buf bytes.Buffer - buf.WriteString("data: ") - buf.Write(payload) - buf.WriteString("\n\n") + _, _ = buf.WriteString("data: ") + _, _ = buf.Write(payload) + _, _ = buf.WriteString("\n\n") return buf.Bytes() } diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 29d2fd5..327a170 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -67,7 +67,7 @@ var bedrockSupportedBetaFlags = map[string]bool{ type interceptionBase struct { id uuid.UUID providerName string - reqPayload MessagesRequestPayload + reqPayload RequestPayload cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index 1a93f21..ec82912 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -763,10 +763,10 @@ func TestAugmentRequestForBedrock_AdaptiveThinking(t *testing.T) { } } -func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload { +func mustMessagesPayload(t *testing.T, requestBody string) RequestPayload { t.Helper() - payload, err := NewMessagesRequestPayload([]byte(requestBody)) + payload, err := NewRequestPayload([]byte(requestBody)) require.NoError(t, err) return payload diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 9ac31f3..7fb3f56 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -32,7 +32,7 @@ type BlockingInterception struct { func NewBlockingInterceptor( id uuid.UUID, - reqPayload MessagesRequestPayload, + reqPayload RequestPayload, providerName string, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, diff --git a/intercept/messages/reqpayload.go b/intercept/messages/reqpayload.go index a5829f3..dfe52fc 100644 --- a/intercept/messages/reqpayload.go +++ b/intercept/messages/reqpayload.go @@ -82,12 +82,12 @@ var ( } ) -// MessagesRequestPayload is raw JSON bytes of an Anthropic Messages API request. +// RequestPayload is raw JSON bytes of an Anthropic Messages API request. // Methods provide package-specific reads and rewrites while preserving the // original body for upstream pass-through. -type MessagesRequestPayload []byte +type RequestPayload []byte -func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) { +func NewRequestPayload(raw []byte) (RequestPayload, error) { if len(bytes.TrimSpace(raw)) == 0 { return nil, xerrors.New("messages empty request body") } @@ -95,10 +95,10 @@ func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) { return nil, xerrors.New("messages invalid JSON request body") } - return MessagesRequestPayload(raw), nil + return RequestPayload(raw), nil } -func (p MessagesRequestPayload) Stream() bool { +func (p RequestPayload) Stream() bool { v := gjson.GetBytes(p, messagesReqPathStream) if !v.IsBool() { return false @@ -106,11 +106,11 @@ func (p MessagesRequestPayload) Stream() bool { return v.Bool() } -func (p MessagesRequestPayload) model() string { +func (p RequestPayload) model() string { return gjson.GetBytes(p, messagesReqPathModel).Str } -func (p MessagesRequestPayload) correlatingToolCallID() *string { +func (p RequestPayload) correlatingToolCallID() *string { messages := gjson.GetBytes(p, messagesReqPathMessages) if !messages.IsArray() { return nil @@ -147,7 +147,7 @@ func (p MessagesRequestPayload) correlatingToolCallID() *string { // lastUserPrompt returns the prompt text from the last user message. If no prompt // is found, it returns empty string, false, nil. Unexpected shapes are treated as // unsupported and do not fail the request path. -func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) { +func (p RequestPayload) lastUserPrompt() (string, bool, error) { messages := gjson.GetBytes(p, messagesReqPathMessages) if !messages.Exists() || messages.Type == gjson.Null { return "", false, nil @@ -195,7 +195,7 @@ func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) { return "", false, nil } -func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) (MessagesRequestPayload, error) { +func (p RequestPayload) injectTools(injected []anthropic.ToolUnionParam) (RequestPayload, error) { if len(injected) == 0 { return p, nil } @@ -221,7 +221,7 @@ func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) return p.set(messagesReqPathTools, allTools) } -func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPayload, error) { +func (p RequestPayload) disableParallelToolCalls() (RequestPayload, error) { toolChoice := gjson.GetBytes(p, messagesReqPathToolChoice) // If no tool_choice was defined, assume auto. @@ -258,7 +258,7 @@ func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPaylo } } -func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (MessagesRequestPayload, error) { +func (p RequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (RequestPayload, error) { if len(newMessages) == 0 { return p, nil } @@ -285,11 +285,11 @@ func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.Message return p.set(messagesReqPathMessages, allMessages) } -func (p MessagesRequestPayload) withModel(model string) (MessagesRequestPayload, error) { +func (p RequestPayload) withModel(model string) (RequestPayload, error) { return p.set(messagesReqPathModel, model) } -func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) { +func (p RequestPayload) messages() ([]json.RawMessage, error) { messages := gjson.GetBytes(p, messagesReqPathMessages) if !messages.Exists() || messages.Type == gjson.Null { return nil, nil @@ -301,7 +301,7 @@ func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) { return p.resultToRawMessage(messages.Array()), nil } -func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) { +func (p RequestPayload) tools() ([]json.RawMessage, error) { tools := gjson.GetBytes(p, messagesReqPathTools) if !tools.Exists() || tools.Type == gjson.Null { return nil, nil @@ -313,7 +313,7 @@ func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) { return p.resultToRawMessage(tools.Array()), nil } -func (MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage { +func (RequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage { // gjson.Result conversion to json.RawMessage is needed because // gjson.Result does not implement json.Marshaler — would // serialize its struct fields instead of the raw JSON it represents. @@ -326,7 +326,7 @@ func (MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.Ra // convertAdaptiveThinkingForBedrock converts thinking.type "adaptive" to "enabled" with a calculated budget_tokens // conversion is needed for Bedrock models that does not support the "adaptive" thinking.type -func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesRequestPayload, error) { +func (p RequestPayload) convertAdaptiveThinkingForBedrock() (RequestPayload, error) { thinkingType := gjson.GetBytes(p, messagesReqPathThinkingType) if thinkingType.String() != constAdaptive { return p, nil @@ -377,7 +377,7 @@ func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesReq // removed when the corresponding flag is absent from the Anthropic-Beta header. // Model-specific beta flags must already be filtered from the header before // calling this method (see filterBedrockBetaFlags). -func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Header) (MessagesRequestPayload, error) { +func (p RequestPayload) removeUnsupportedBedrockFields(headers http.Header) (RequestPayload, error) { var payloadMap map[string]any if err := json.Unmarshal(p, &payloadMap); err != nil { return p, xerrors.Errorf("failed to unmarshal request payload when removing unsupported Bedrock fields: %w", err) @@ -400,13 +400,13 @@ func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Head if err != nil { return p, xerrors.Errorf("failed to marshal request payload when removing unsupported Bedrock fields: %w", err) } - return MessagesRequestPayload(result), nil + return RequestPayload(result), nil } -func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) { +func (p RequestPayload) set(path string, value any) (RequestPayload, error) { out, err := sjson.SetBytes(p, path, value) if err != nil { return p, xerrors.Errorf("set %s: %w", path, err) } - return MessagesRequestPayload(out), nil + return RequestPayload(out), nil } diff --git a/intercept/messages/reqpayload_test.go b/intercept/messages/reqpayload_test.go index f16fa4f..d1b062f 100644 --- a/intercept/messages/reqpayload_test.go +++ b/intercept/messages/reqpayload_test.go @@ -11,7 +11,7 @@ import ( "github.com/coder/aibridge/utils" ) -func TestNewMessagesRequestPayload(t *testing.T) { +func TestNewRequestPayload(t *testing.T) { t.Parallel() testCases := []struct { @@ -42,7 +42,7 @@ func TestNewMessagesRequestPayload(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { t.Parallel() - payload, err := NewMessagesRequestPayload(testCase.requestBody) + payload, err := NewRequestPayload(testCase.requestBody) if testCase.expectError { require.Error(t, err) require.Nil(t, payload) @@ -50,12 +50,12 @@ func TestNewMessagesRequestPayload(t *testing.T) { } require.NoError(t, err) - require.Equal(t, MessagesRequestPayload(testCase.requestBody), payload) + require.Equal(t, RequestPayload(testCase.requestBody), payload) }) } } -func TestMessagesRequestPayloadStream(t *testing.T) { +func TestRequestPayloadStream(t *testing.T) { t.Parallel() testCases := []struct { @@ -97,7 +97,7 @@ func TestMessagesRequestPayloadStream(t *testing.T) { } } -func TestMessagesRequestPayloadModel(t *testing.T) { +func TestRequestPayloadModel(t *testing.T) { t.Parallel() testCases := []struct { @@ -132,7 +132,7 @@ func TestMessagesRequestPayloadModel(t *testing.T) { } } -func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) { +func TestRequestPayloadLastUserPrompt(t *testing.T) { t.Parallel() testCases := []struct { @@ -229,7 +229,7 @@ func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) { } } -func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) { +func TestRequestPayloadCorrelatingToolCallID(t *testing.T) { t.Parallel() testCases := []struct { @@ -266,7 +266,7 @@ func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) { } } -func TestMessagesRequestPayloadInjectTools(t *testing.T) { +func TestRequestPayloadInjectTools(t *testing.T) { t.Parallel() payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`) @@ -291,7 +291,7 @@ func TestMessagesRequestPayloadInjectTools(t *testing.T) { require.Equal(t, "ephemeral", toolItems[1].Get("cache_control.type").String()) } -func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) { +func TestRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) { t.Parallel() testCases := []struct { @@ -361,7 +361,7 @@ func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) { } } -func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) { +func TestRequestPayloadDisableParallelToolCalls(t *testing.T) { t.Parallel() testCases := []struct { @@ -451,7 +451,7 @@ func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) { } } -func TestMessagesRequestPayloadAppendedMessages(t *testing.T) { +func TestRequestPayloadAppendedMessages(t *testing.T) { t.Parallel() payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`) diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index d8a19f1..dfe6acc 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -37,7 +37,7 @@ type StreamingInterception struct { func NewStreamingInterceptor( id uuid.UUID, - reqPayload MessagesRequestPayload, + reqPayload RequestPayload, providerName string, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, @@ -573,13 +573,14 @@ func (i *StreamingInterception) pingPayload() []byte { } func (*StreamingInterception) encodeForStream(payload []byte, typ string) []byte { + // bytes.Buffer writes to in-memory storage and never return errors. var buf bytes.Buffer - buf.WriteString("event: ") - buf.WriteString(typ) - buf.WriteString("\n") - buf.WriteString("data: ") - buf.Write(payload) - buf.WriteString("\n\n") + _, _ = buf.WriteString("event: ") + _, _ = buf.WriteString(typ) + _, _ = buf.WriteString("\n") + _, _ = buf.WriteString("data: ") + _, _ = buf.Write(payload) + _, _ = buf.WriteString("\n\n") return buf.Bytes() } diff --git a/intercept/responses/base.go b/intercept/responses/base.go index baa1f44..38cb8f9 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -42,7 +42,7 @@ type responsesInterceptionBase struct { // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header authHeaderName string - reqPayload ResponsesRequestPayload + reqPayload RequestPayload cfg config.OpenAI recorder recorder.Recorder diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index e25f592..adf2322 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -363,9 +363,10 @@ func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) { respCopy := responseCopier{} body := "test_body" - respCopy.buff.Write([]byte(body)) + _, _ = respCopy.buff.Write([]byte(body)) // bytes.Buffer.Write never fails - respCopy.forwardResp(&mrw) + err := respCopy.forwardResp(&mrw) + require.NoError(t, err) require.False(t, mrw.headerCalled) require.False(t, mrw.writeCalled) require.False(t, mrw.writeHeaderCalled) @@ -373,7 +374,8 @@ func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) { // after response is received data is forwarded respCopy.responseReceived.Store(true) - respCopy.forwardResp(&mrw) + err = respCopy.forwardResp(&mrw) + require.NoError(t, err) require.True(t, mrw.headerCalled) require.True(t, mrw.writeCalled) require.True(t, mrw.writeHeaderCalled) diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index ed75556..d32a5f1 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -28,7 +28,7 @@ type BlockingResponsesInterceptor struct { func NewBlockingInterceptor( id uuid.UUID, - reqPayload ResponsesRequestPayload, + reqPayload RequestPayload, providerName string, cfg config.OpenAI, clientHeaders http.Header, diff --git a/intercept/responses/reqpayload.go b/intercept/responses/reqpayload.go index 0208635..600402d 100644 --- a/intercept/responses/reqpayload.go +++ b/intercept/responses/reqpayload.go @@ -37,13 +37,13 @@ var ( reqPathType = string(constant.ValueOf[constant.Type]()) ) -// ResponsesRequestPayload is raw JSON bytes of a Responses API request. +// RequestPayload is raw JSON bytes of a Responses API request. // Methods provide package-specific reads and rewrites while preserving the // original body for upstream pass-through. // Note: No changes are made on schema error. -type ResponsesRequestPayload []byte +type RequestPayload []byte -func NewResponsesRequestPayload(raw []byte) (ResponsesRequestPayload, error) { +func NewRequestPayload(raw []byte) (RequestPayload, error) { if len(bytes.TrimSpace(raw)) == 0 { return nil, xerrors.New("empty request body") } @@ -51,22 +51,22 @@ func NewResponsesRequestPayload(raw []byte) (ResponsesRequestPayload, error) { return nil, xerrors.New("invalid JSON payload") } - return ResponsesRequestPayload(raw), nil + return RequestPayload(raw), nil } -func (p ResponsesRequestPayload) Stream() bool { +func (p RequestPayload) Stream() bool { return gjson.GetBytes(p, reqPathStream).Bool() } -func (p ResponsesRequestPayload) model() string { +func (p RequestPayload) model() string { return gjson.GetBytes(p, reqPathModel).String() } -func (p ResponsesRequestPayload) background() bool { +func (p RequestPayload) background() bool { return gjson.GetBytes(p, reqPathBackground).Bool() } -func (p ResponsesRequestPayload) correlatingToolCallID() *string { +func (p RequestPayload) correlatingToolCallID() *string { items := gjson.GetBytes(p, reqPathInput) if !items.IsArray() { return nil @@ -94,7 +94,7 @@ func (p ResponsesRequestPayload) correlatingToolCallID() *string { // item, or the string input value if present. If no prompt is found, it returns // empty string, false, nil. Unexpected shapes are treated as unsupported and do // not fail the request path. -func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog.Logger) (string, bool, error) { +func (p RequestPayload) lastUserPrompt(ctx context.Context, logger slog.Logger) (string, bool, error) { inputItems := gjson.GetBytes(p, reqPathInput) if !inputItems.Exists() || inputItems.Type == gjson.Null { return "", false, nil @@ -155,10 +155,10 @@ func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog } if promptExists { - sb.WriteByte('\n') + _ = sb.WriteByte('\n') // strings.Builder.WriteByte never fails } promptExists = true - sb.WriteString(text.Str) + _, _ = sb.WriteString(text.Str) // strings.Builder.WriteString never fails } if !promptExists { @@ -168,7 +168,7 @@ func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog return sb.String(), true, nil } -func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam) (ResponsesRequestPayload, error) { +func (p RequestPayload) injectTools(injected []responses.ToolUnionParam) (RequestPayload, error) { if len(injected) == 0 { return p, nil } @@ -189,11 +189,11 @@ func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam return p.set(reqPathTools, allTools) } -func (p ResponsesRequestPayload) disableParallelToolCalls() (ResponsesRequestPayload, error) { +func (p RequestPayload) disableParallelToolCalls() (RequestPayload, error) { return p.set(reqPathParallelToolCalls, false) } -func (p ResponsesRequestPayload) appendInputItems(items []responses.ResponseInputItemUnionParam) (ResponsesRequestPayload, error) { +func (p RequestPayload) appendInputItems(items []responses.ResponseInputItemUnionParam) (RequestPayload, error) { if len(items) == 0 { return p, nil } @@ -212,7 +212,7 @@ func (p ResponsesRequestPayload) appendInputItems(items []responses.ResponseInpu return p.set(reqPathInput, allInput) } -func (p ResponsesRequestPayload) inputItems() ([]any, error) { +func (p RequestPayload) inputItems() ([]any, error) { input := gjson.GetBytes(p, reqPathInput) if !input.Exists() || input.Type == gjson.Null { return []any{}, nil @@ -235,7 +235,7 @@ func (p ResponsesRequestPayload) inputItems() ([]any, error) { return existing, nil } -func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) { +func (p RequestPayload) toolItems() ([]json.RawMessage, error) { tools := gjson.GetBytes(p, reqPathTools) if !tools.Exists() { return nil, nil @@ -253,7 +253,7 @@ func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) { return existing, nil } -func (p ResponsesRequestPayload) set(path string, value any) (ResponsesRequestPayload, error) { +func (p RequestPayload) set(path string, value any) (RequestPayload, error) { updated, err := sjson.SetBytes(p, path, value) if err != nil { return p, xerrors.Errorf("failed to set value at path %s: %w", path, err) diff --git a/intercept/responses/reqpayload_test.go b/intercept/responses/reqpayload_test.go index 09b7480..df99954 100644 --- a/intercept/responses/reqpayload_test.go +++ b/intercept/responses/reqpayload_test.go @@ -16,7 +16,7 @@ import ( "github.com/coder/aibridge/utils" ) -func TestNewResponsesRequestPayload(t *testing.T) { +func TestNewRequestPayload(t *testing.T) { t.Parallel() payloadWithWrongTypes := []byte(`{"model":123,"stream":"yes","input":42,"background":"nope"}`) @@ -42,7 +42,7 @@ func TestNewResponsesRequestPayload(t *testing.T) { err: "invalid JSON payload", }, { - // ResponsesRequestPayload just checks for JSON validity, + // RequestPayload just checks for JSON validity, // schema errors are not surfaced here and // the original body is preserved for upstream handling // similar to how reverse proxy would behave. @@ -59,7 +59,7 @@ func TestNewResponsesRequestPayload(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - payload, err := NewResponsesRequestPayload(tc.raw) + payload, err := NewRequestPayload(tc.raw) if tc.err != "" { require.ErrorContains(t, err, tc.err) @@ -518,10 +518,10 @@ func injectedFunctionTool(name string) responses.ToolUnionParam { } } -func mustPayload(t *testing.T, raw []byte) ResponsesRequestPayload { +func mustPayload(t *testing.T, raw []byte) RequestPayload { t.Helper() - payload, err := NewResponsesRequestPayload(raw) + payload, err := NewRequestPayload(raw) require.NoError(t, err) return payload } diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index f5a2346..c3ec3ba 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -35,7 +35,7 @@ type StreamingResponsesInterceptor struct { func NewStreamingInterceptor( id uuid.UUID, - reqPayload ResponsesRequestPayload, + reqPayload RequestPayload, providerName string, cfg config.OpenAI, clientHeaders http.Header, diff --git a/mcp/tool.go b/mcp/tool.go index 70845c0..b665743 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -106,12 +106,13 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp // - https://community.openai.com/t/function-call-description-max-length/529902 // - https://github.com/anthropics/claude-code/issues/2326 func EncodeToolID(server, tool string) string { + // strings.Builder writes to in-memory storage and never return errors. var sb strings.Builder - sb.WriteString(injectedToolPrefix) - sb.WriteString(injectedToolDelimiter) - sb.WriteString(server) - sb.WriteString(injectedToolDelimiter) - sb.WriteString(tool) + _, _ = sb.WriteString(injectedToolPrefix) + _, _ = sb.WriteString(injectedToolDelimiter) + _, _ = sb.WriteString(server) + _, _ = sb.WriteString(injectedToolDelimiter) + _, _ = sb.WriteString(tool) return sb.String() } diff --git a/provider/anthropic.go b/provider/anthropic.go index 7be9611..32de0d5 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -114,7 +114,7 @@ func (p *Anthropic) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tr return nil, xerrors.Errorf("read body: %w", err) } - reqPayload, err := messages.NewMessagesRequestPayload(payload) + reqPayload, err := messages.NewRequestPayload(payload) if err != nil { return nil, xerrors.Errorf("unmarshal request body: %w", err) } diff --git a/provider/copilot.go b/provider/copilot.go index 186dfdd..56caa6e 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -170,7 +170,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac if err != nil { return nil, xerrors.Errorf("read body: %w", err) } - reqPayload, err := responses.NewResponsesRequestPayload(payload) + reqPayload, err := responses.NewRequestPayload(payload) if err != nil { return nil, xerrors.Errorf("unmarshal request body: %w", err) } diff --git a/provider/openai.go b/provider/openai.go index 4b90823..6281aef 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -141,7 +141,7 @@ func (p *OpenAI) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trace if err != nil { return nil, xerrors.Errorf("read body: %w", err) } - reqPayload, err := responses.NewResponsesRequestPayload(payload) + reqPayload, err := responses.NewRequestPayload(payload) if err != nil { return nil, xerrors.Errorf("unmarshal request body: %w", err) } diff --git a/recorder/recorder.go b/recorder/recorder.go index 7e2b988..c1a4b59 100644 --- a/recorder/recorder.go +++ b/recorder/recorder.go @@ -16,19 +16,19 @@ import ( ) var ( - _ Recorder = &RecorderWrapper{} + _ Recorder = &WrappedRecorder{} _ Recorder = &AsyncRecorder{} ) -// RecorderWrapper is a convenience struct which implements RecorderClient and resolves a client before calling each method. +// WrappedRecorder is a convenience struct which implements RecorderClient and resolves a client before calling each method. // It also sets the start/creation time of each record. -type RecorderWrapper struct { +type WrappedRecorder struct { logger slog.Logger tracer trace.Tracer clientFn func() (Recorder, error) } -func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) { +func (r *WrappedRecorder) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -46,7 +46,7 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept return err } -func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) { +func (r *WrappedRecorder) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -64,7 +64,7 @@ func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *Inte return err } -func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) { +func (r *WrappedRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -82,7 +82,7 @@ func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsag return err } -func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) { +func (r *WrappedRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -100,7 +100,7 @@ func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageR return err } -func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) { +func (r *WrappedRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -118,7 +118,7 @@ func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRec return err } -func (r *RecorderWrapper) RecordModelThought(ctx context.Context, req *ModelThoughtRecord) (outErr error) { +func (r *WrappedRecorder) RecordModelThought(ctx context.Context, req *ModelThoughtRecord) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordModelThought", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -136,8 +136,8 @@ func (r *RecorderWrapper) RecordModelThought(ctx context.Context, req *ModelThou return err } -func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *RecorderWrapper { - return &RecorderWrapper{ +func NewWrappedRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *WrappedRecorder { + return &WrappedRecorder{ logger: logger, tracer: tracer, clientFn: clientFn,