Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fixtures/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ var (
OaiResponsesBlockingConversation []byte

//go:embed openai/responses/blocking/http_error.txtar
OaiResponsesBlockingHttpErr []byte
OaiResponsesBlockingHTTPErr []byte

//go:embed openai/responses/blocking/prev_response_id.txtar
OaiResponsesBlockingPrevResponseID []byte
Expand Down Expand Up @@ -139,7 +139,7 @@ var (
OaiResponsesStreamingConversation []byte

//go:embed openai/responses/streaming/http_error.txtar
OaiResponsesStreamingHttpErr []byte
OaiResponsesStreamingHTTPErr []byte

//go:embed openai/responses/streaming/prev_response_id.txtar
OaiResponsesStreamingPrevResponseID []byte
Expand Down
22 changes: 11 additions & 11 deletions intercept/apidump/apidump.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,21 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
return xerrors.Errorf("write response header terminator: %w", err)
}

// Wrap the response body to capture it as it streams
if resp.Body != nil {
resp.Body = &streamingBodyDumper{
body: resp.Body,
dumpPath: dumpPath,
headerData: headerBuf.Bytes(),
logger: func(err error) {
d.logger.Named("apidump").Warn(context.Background(), "failed to initialize response dump", slog.Error(err))
},
}
} else {
if resp.Body == nil {
// No body, just write headers
return os.WriteFile(dumpPath, headerBuf.Bytes(), 0o644)
}

// Wrap the response body to capture it as it streams
resp.Body = &streamingBodyDumper{
body: resp.Body,
dumpPath: dumpPath,
headerData: headerBuf.Bytes(),
logger: func(err error) {
d.logger.Named("apidump").Warn(context.Background(), "failed to initialize response dump", slog.Error(err))
},
}

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (i *interceptionBase) Model() string {
return "coder-aibridge-unknown"
}

return string(i.req.Model)
return i.req.Model
}

func (i *interceptionBase) newErrorResponse(err error) map[string]any {
Expand Down
2 changes: 1 addition & 1 deletion intercept/chatcompletions/paramswrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func generatePayload(messageCount int) []byte {
}
// Use realistic message content size
content := fmt.Sprintf("This is message number %d with some realistic content that might appear in a conversation.", i+1)
messages = append(messages, fmt.Sprintf(`{"role": "%s", "content": "%s"}`, role, content))
messages = append(messages, fmt.Sprintf(`{"role": %q, "content": %q}`, role, content))
}

return []byte(fmt.Sprintf(`{
Expand Down
70 changes: 34 additions & 36 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,14 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
})

toolCall = nil
} else {
} else if stream.Err() == nil {
// When the provider responds with only tool calls (no text content),
// no chunks are relayed to the client, so the stream is not yet
// initiated. Initiate it here so the SSE headers are sent and the
// ping ticker is started, preventing client timeout during tool invocation.
// Only initiate if no stream error, if there's an error, we'll return
// an HTTP error response instead of starting an SSE stream.
if stream.Err() == nil {
events.InitiateStream(w)
}
events.InitiateStream(w)
}
}

Expand Down Expand Up @@ -232,43 +230,43 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
})
}

if events.IsStreaming() {
// Check if the stream encountered any errors.
if streamErr := stream.Err(); streamErr != nil {
if eventstream.IsUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
// We can't reflect an error back if there's a connection error or the request context was canceled.
} else if oaiErr := getErrorResponse(streamErr); oaiErr != nil {
logger.Warn(ctx, "openai stream error", slog.Error(streamErr))
interceptionErr = oaiErr
} else {
logger.Warn(ctx, "unknown error", slog.Error(streamErr))
// Unfortunately, the OpenAI SDK does not support parsing errors received in the stream
// into known types (i.e. [shared.OverloadedError]).
// See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171
// All it does is wrap the payload in an error - which is all we can return, currently.
interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr))
}
} else if lastErr != nil {
// Otherwise check if any logical errors occurred during processing.
logger.Warn(ctx, "stream failed", slog.Error(lastErr))
interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr))
}

if interceptionErr != nil {
payload, err := i.marshalErr(interceptionErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr)))
} else if err := events.Send(streamCtx, payload); err != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
}
}
} else {
if !events.IsStreaming() {
// response/downstream Stream has not started yet; write error response and exit.
i.writeUpstreamError(w, getErrorResponse(stream.Err()))
return stream.Err()
}

// Check if the stream encountered any errors.
if streamErr := stream.Err(); streamErr != nil {
if eventstream.IsUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
// We can't reflect an error back if there's a connection error or the request context was canceled.
} else if oaiErr := getErrorResponse(streamErr); oaiErr != nil {
logger.Warn(ctx, "openai stream error", slog.Error(streamErr))
interceptionErr = oaiErr
} else {
logger.Warn(ctx, "unknown error", slog.Error(streamErr))
// Unfortunately, the OpenAI SDK does not support parsing errors received in the stream
// into known types (i.e. [shared.OverloadedError]).
// See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171
// All it does is wrap the payload in an error - which is all we can return, currently.
interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr))
}
} else if lastErr != nil {
// Otherwise check if any logical errors occurred during processing.
logger.Warn(ctx, "stream failed", slog.Error(lastErr))
interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr))
}

if interceptionErr != nil {
payload, err := i.marshalErr(interceptionErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr)))
} else if err := events.Send(streamCtx, payload); err != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
}
}

// No tool call, nothing more to do.
if toolCall == nil {
break
Expand Down
3 changes: 1 addition & 2 deletions intercept/eventstream/eventstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ func flush(w http.ResponseWriter) (err error) {
}

defer func() {
if r := recover(); r != nil {
// Likely a broken connection, don't spam the logs.
if r := recover(); r != nil { //nolint:revive // Intentionally swallowed; likely a broken connection.
}
}()

Expand Down
18 changes: 8 additions & 10 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,15 @@ func (i *interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recor

var thoughtRecords []*recorder.ModelThoughtRecord
for _, block := range msg.Content {
switch variant := block.AsAny().(type) {
case anthropic.ThinkingBlock:
if variant.Thinking == "" {
continue
}
thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{
Content: variant.Thinking,
Metadata: recorder.Metadata{"source": recorder.ThoughtSourceThinking},
})
}
// anthropic.RedactedThinkingBlock also exists, but there's nothing useful we can capture.
variant, ok := block.AsAny().(anthropic.ThinkingBlock)
if !ok || variant.Thinking == "" {
continue
}
thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{
Content: variant.Thinking,
Metadata: recorder.Metadata{"source": recorder.ThoughtSourceThinking},
})
}
return thoughtRecords
}
Expand Down
44 changes: 20 additions & 24 deletions intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ newStream:
// Tool-related handling.
switch event.Type {
case string(constant.ValueOf[constant.ContentBlockStart]()):
switch block := event.AsContentBlockStart().ContentBlock.AsAny().(type) {
case anthropic.ToolUseBlock:
if block, ok := event.AsContentBlockStart().ContentBlock.AsAny().(anthropic.ToolUseBlock); ok {
lastToolName = block.Name

if i.mcpProxy != nil && i.mcpProxy.GetTool(block.Name) != nil {
Expand Down Expand Up @@ -307,8 +306,7 @@ newStream:
foundTools int
)
for _, block := range message.Content {
switch variant := block.AsAny().(type) {
case anthropic.ToolUseBlock:
if variant, ok := block.AsAny().(anthropic.ToolUseBlock); ok {
foundTools++
if variant.Name == name {
input = variant.Input
Expand Down Expand Up @@ -431,24 +429,23 @@ newStream:
// Causes a new stream to be run with updated messages.
isFirst = false
continue newStream
} else {
// Find all the non-injected tools and track their uses.
for _, block := range message.Content {
switch variant := block.AsAny().(type) {
case anthropic.ToolUseBlock:
if i.mcpProxy != nil && i.mcpProxy.GetTool(variant.Name) != nil {
continue
}
}

_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: message.ID,
ToolCallID: variant.ID,
Tool: variant.Name,
Args: variant.Input,
Injected: false,
})
// Find all the non-injected tools and track their uses.
for _, block := range message.Content {
if variant, ok := block.AsAny().(anthropic.ToolUseBlock); ok {
if i.mcpProxy != nil && i.mcpProxy.GetTool(variant.Name) != nil {
continue
}

_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: message.ID,
ToolCallID: variant.ID,
Tool: variant.Name,
Args: variant.Input,
Injected: false,
})
}
}
}
Expand All @@ -464,11 +461,10 @@ newStream:
if eventstream.IsUnrecoverableError(err) {
logger.Debug(ctx, "processing terminated", slog.Error(err))
break // Stop processing if client disconnected or context canceled.
} else {
logger.Warn(ctx, "failed to relay event", slog.Error(err))
lastErr = xerrors.Errorf("relay event: %w", err)
break
}
logger.Warn(ctx, "failed to relay event", slog.Error(err))
lastErr = xerrors.Errorf("relay event: %w", err)
break
}
}

Expand Down
16 changes: 8 additions & 8 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,15 @@ func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Conte

func (i *responsesInterceptionBase) parseFunctionCallJSONArgs(ctx context.Context, raw string) recorder.ToolArgs {
trimmed := strings.TrimSpace(raw)
if trimmed != "" {
var args recorder.ToolArgs
if err := json.Unmarshal([]byte(trimmed), &args); err != nil {
i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err))
} else {
return args
}
if trimmed == "" {
return trimmed
}
var args recorder.ToolArgs
if err := json.Unmarshal([]byte(trimmed), &args); err != nil {
i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err))
return trimmed
}
return trimmed
return args
}

func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, response *responses.Response) {
Expand Down
6 changes: 3 additions & 3 deletions internal/integrationtest/circuit_breaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ const (
)

func anthropicSuccessResponse(model string) string {
return fmt.Sprintf(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"%s","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, model)
return fmt.Sprintf(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":%q,"stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, model)
}

func openAISuccessResponse(model string) string {
return fmt.Sprintf(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, model)
return fmt.Sprintf(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":%q,"choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, model)
}

// TestCircuitBreaker_FullRecoveryCycle tests the complete circuit breaker lifecycle:
Expand Down Expand Up @@ -555,7 +555,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
)

doRequest := func(model string) *http.Response {
body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model)
body := fmt.Sprintf(`{"model":%q,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, []byte(body), http.Header{
"x-api-key": {"test"},
"anthropic-version": {"2023-06-01"},
Expand Down
4 changes: 2 additions & 2 deletions internal/integrationtest/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestMetrics_Interception(t *testing.T) {
},
{
name: "oai_responses_blocking_error",
fixture: fixtures.OaiResponsesBlockingHttpErr,
fixture: fixtures.OaiResponsesBlockingHTTPErr,
path: pathOpenAIResponses,
headers: http.Header{"User-Agent": []string{"codex/1.0.0"}},
expectStatus: metrics.InterceptionCountStatusFailed,
Expand All @@ -127,7 +127,7 @@ func TestMetrics_Interception(t *testing.T) {
},
{
name: "oai_responses_streaming_error",
fixture: fixtures.OaiResponsesStreamingHttpErr,
fixture: fixtures.OaiResponsesStreamingHTTPErr,
path: pathOpenAIResponses,
headers: http.Header{"Originator": []string{"roo-code"}},
expectStatus: metrics.InterceptionCountStatusFailed,
Expand Down
4 changes: 2 additions & 2 deletions internal/integrationtest/trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ func TestTraceOpenAIErr(t *testing.T) {
},
{
name: "trace_openai_responses_streaming_http_error",
fixture: fixtures.OaiResponsesStreamingHttpErr,
fixture: fixtures.OaiResponsesStreamingHTTPErr,
streaming: true,
allowOverflow: true, // 429 error causes retries

Expand All @@ -664,7 +664,7 @@ func TestTraceOpenAIErr(t *testing.T) {
},
{
name: "trace_openai_responses_blocking_http_error",
fixture: fixtures.OaiResponsesBlockingHttpErr,
fixture: fixtures.OaiResponsesBlockingHTTPErr,
streaming: false,

path: pathOpenAIResponses,
Expand Down
8 changes: 4 additions & 4 deletions internal/testutil/mockprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ import (
)

type MockProvider struct {
Name_ string
NameStr string
URL string
Bridged []string
Passthrough []string
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
}

func (m *MockProvider) Type() string { return m.Name_ }
func (m *MockProvider) Name() string { return m.Name_ }
func (m *MockProvider) Type() string { return m.NameStr }
func (m *MockProvider) Name() string { return m.NameStr }
func (m *MockProvider) BaseURL() string { return m.URL }
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.Name_) }
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) }
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
func (m *MockProvider) AuthHeader() string { return "Authorization" }
Expand Down
Loading
Loading