Skip to content

Commit aed7dd9

Browse files
fix: preserve the stream property for chat/completions calls (#164)
* fix: preserve the stream property for chat/completions calls * test: add request body validation to mock server * document aibridge stream mashalling behaviour for chat completions --------- Co-authored-by: Susana Cardoso Ferreira <susana@coder.com>
1 parent 9e2857a commit aed7dd9

4 files changed

Lines changed: 109 additions & 4 deletions

File tree

bridge_integration_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,6 +2076,35 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, requ
20762076
defer r.Body.Close()
20772077
require.NoError(t, err)
20782078

2079+
// Validate request body based on endpoint.
2080+
var validationErr error
2081+
if strings.Contains(r.URL.Path, "/chat/completions") {
2082+
validationErr = validateOpenAIChatCompletionRequest(body)
2083+
} else if strings.Contains(r.URL.Path, "/responses") {
2084+
validationErr = validateOpenAIResponsesRequest(body)
2085+
} else if strings.Contains(r.URL.Path, "/messages") {
2086+
validationErr = validateAnthropicMessagesRequest(body)
2087+
}
2088+
2089+
// If validation failed, return error response
2090+
if validationErr != nil {
2091+
// Return HTTP error response
2092+
w.Header().Set("Content-Type", "application/json")
2093+
w.WriteHeader(http.StatusBadRequest)
2094+
errResp := map[string]any{
2095+
"error": map[string]any{
2096+
"message": fmt.Sprintf("Request #%d validation failed: %v", ms.callCount.Load(), validationErr),
2097+
"type": "invalid_request_error",
2098+
},
2099+
}
2100+
json.NewEncoder(w).Encode(errResp)
2101+
2102+
// Mark test as failed with detailed message
2103+
t.Errorf("Request #%d validation failed: %v\n\nRequest body:\n%s",
2104+
ms.callCount.Load(), validationErr, string(body))
2105+
return
2106+
}
2107+
20792108
type msg struct {
20802109
Stream bool `json:"stream"`
20812110
}
@@ -2135,6 +2164,72 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, requ
21352164
return ms
21362165
}
21372166

2167+
// validateOpenAIChatCompletionRequest validates that an OpenAI chat completion request
2168+
// has all required fields. Returns an error if validation fails.
2169+
func validateOpenAIChatCompletionRequest(body []byte) error {
2170+
var req openai.ChatCompletionNewParams
2171+
if err := json.Unmarshal(body, &req); err != nil {
2172+
return fmt.Errorf("request should unmarshal into ChatCompletionNewParams: %w", err)
2173+
}
2174+
2175+
// Collect all validation errors
2176+
var errs []string
2177+
if req.Model == "" {
2178+
errs = append(errs, "model field is required but empty")
2179+
}
2180+
if len(req.Messages) == 0 {
2181+
errs = append(errs, "messages field is required but empty")
2182+
}
2183+
2184+
if len(errs) > 0 {
2185+
return fmt.Errorf("validation failed: %s", strings.Join(errs, "; "))
2186+
}
2187+
return nil
2188+
}
2189+
2190+
// validateOpenAIResponsesRequest validates that an OpenAI responses request
2191+
// has all required fields. Returns an error if validation fails.
2192+
func validateOpenAIResponsesRequest(body []byte) error {
2193+
var reqBody map[string]any
2194+
if err := json.Unmarshal(body, &reqBody); err != nil {
2195+
return fmt.Errorf("request should be valid JSON: %w", err)
2196+
}
2197+
2198+
// Verify required fields for OpenAI responses
2199+
// Note: Using map here since there's no specific SDK type for responses endpoint
2200+
model, ok := reqBody["model"]
2201+
if !ok || model == "" {
2202+
return fmt.Errorf("model field is required but missing or empty")
2203+
}
2204+
return nil
2205+
}
2206+
2207+
// validateAnthropicMessagesRequest validates that an Anthropic messages request
2208+
// has all required fields. Returns an error if validation fails.
2209+
func validateAnthropicMessagesRequest(body []byte) error {
2210+
var req anthropic.MessageNewParams
2211+
if err := json.Unmarshal(body, &req); err != nil {
2212+
return fmt.Errorf("request should unmarshal into MessageNewParams: %w", err)
2213+
}
2214+
2215+
// Collect all validation errors
2216+
var errs []string
2217+
if req.Model == "" {
2218+
errs = append(errs, "model field is required but empty")
2219+
}
2220+
if len(req.Messages) == 0 {
2221+
errs = append(errs, "messages field is required but empty")
2222+
}
2223+
if req.MaxTokens == 0 {
2224+
errs = append(errs, "max_tokens field is required but zero")
2225+
}
2226+
2227+
if len(errs) > 0 {
2228+
return fmt.Errorf("validation failed: %s", strings.Join(errs, "; "))
2229+
}
2230+
return nil
2231+
}
2232+
21382233
const mockToolName = "coder_list_workspaces"
21392234

21402235
// callAccumulator tracks all tool invocations by name and each instance's arguments.

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,4 @@ require (
9494
replace github.com/anthropics/anthropic-sdk-go v1.13.0 => github.com/dannykopping/anthropic-sdk-go v0.0.0-20251230111224-88a4315810bd
9595

9696
// https://github.com/openai/openai-go/pull/602
97-
replace github.com/openai/openai-go/v3 => github.com/SasSwart/openai-go/v3 v3.0.0-20260202093810-72af3b857f95
97+
replace github.com/openai/openai-go/v3 => github.com/SasSwart/openai-go/v3 v3.0.0-20260204134041-fb987b42a728

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ cloud.google.com/go/logging v1.8.1 h1:26skQWPeYhvIasWKm48+Eq7oUqdcdbwsCVwz5Ys0Fv
77
cloud.google.com/go/logging v1.8.1/go.mod h1:TJjR+SimHwuC8MZ9cjByQulAMgni+RkXeI3wwctHJEI=
88
cloud.google.com/go/longrunning v0.5.1 h1:Fr7TXftcqTudoyRJa113hyaqlGdiBQkp0Gq7tErFDWI=
99
cloud.google.com/go/longrunning v0.5.1/go.mod h1:spvimkwdz6SPWKEt/XBij79E9fiTkHSQl/fRUUQJYJc=
10-
github.com/SasSwart/openai-go/v3 v3.0.0-20260202093810-72af3b857f95 h1:HVJp3FanNaeFAlwg0/lkdkSnwFemHnwwjXBM8KRj540=
11-
github.com/SasSwart/openai-go/v3 v3.0.0-20260202093810-72af3b857f95/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
10+
github.com/SasSwart/openai-go/v3 v3.0.0-20260204134041-fb987b42a728 h1:FOjd3xOH+arcrtz1e5P6WZ/VtRD5KQHHRg4kc4BZers=
11+
github.com/SasSwart/openai-go/v3 v3.0.0-20260204134041-fb987b42a728/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
1212
github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY=
1313
github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc=
1414
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 h1:tW1/Rkad38LA15X4UQtjXZXNKsCgkshC3EbmcUmghTg=

intercept/chatcompletions/streaming.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,16 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
125125
opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...)
126126
}
127127

128+
// We take control of request body here and pass it to the SDK as a raw byte slice.
129+
// This is because the SDK's serialization applies hidden request options that result in
130+
// unexpected, breaking behaviour. See https://github.com/coder/aibridge/pull/164
131+
body, err := json.Marshal(i.req.ChatCompletionNewParams)
132+
if err != nil {
133+
return fmt.Errorf("marshal request body: %w", err)
134+
}
135+
opts = append(opts, option.WithRequestBody("application/json", body))
136+
opts = append(opts, option.WithJSONSet("stream", true))
137+
128138
stream = i.newStream(streamCtx, svc, opts)
129139
processor := newStreamProcessor(streamCtx, i.logger.Named("stream-processor"), i.getInjectedToolByName)
130140

@@ -380,7 +390,7 @@ func (i *StreamingInterception) newStream(ctx context.Context, svc openai.ChatCo
380390
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
381391
defer span.End()
382392

383-
return svc.NewStreaming(ctx, i.req.ChatCompletionNewParams, opts...)
393+
return svc.NewStreaming(ctx, openai.ChatCompletionNewParams{}, opts...)
384394
}
385395

386396
type streamProcessor struct {

0 commit comments

Comments
 (0)