Skip to content

Commit 2d43f46

Browse files
committed
feat: forward client request headers to upstream providers in bridge routes
1 parent 250e790 commit 2d43f46

18 files changed

Lines changed: 583 additions & 75 deletions

intercept/chatcompletions/base.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/coder/aibridge/config"
1111
aibcontext "github.com/coder/aibridge/context"
12+
"github.com/coder/aibridge/intercept"
1213
"github.com/coder/aibridge/intercept/apidump"
1314
"github.com/coder/aibridge/mcp"
1415
"github.com/coder/aibridge/recorder"
@@ -29,6 +30,10 @@ type interceptionBase struct {
2930
req *ChatCompletionNewParamsWrapper
3031
cfg config.OpenAI
3132

33+
// clientHeaders are the original HTTP headers from the client request.
34+
clientHeaders http.Header
35+
authHeaderName string
36+
3237
logger slog.Logger
3338
tracer trace.Tracer
3439

@@ -41,10 +46,21 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService
4146

4247
// Add extra headers if configured.
4348
// Some providers require additional headers that are not added by the SDK.
49+
// TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192
4450
for key, value := range i.cfg.ExtraHeaders {
4551
opts = append(opts, option.WithHeader(key, value))
4652
}
4753

54+
// Forward client headers to upstream. This middleware runs after the SDK
55+
// has built the request, and replaces the outgoing headers with the sanitized
56+
// client headers plus provider auth.
57+
if i.clientHeaders != nil {
58+
opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
59+
req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName)
60+
return next(req)
61+
}))
62+
}
63+
4864
// Add API dump middleware if configured
4965
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
5066
opts = append(opts, option.WithMiddleware(mw))

intercept/chatcompletions/blocking.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,21 @@ type BlockingInterception struct {
2828
interceptionBase
2929
}
3030

31-
func NewBlockingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *BlockingInterception {
31+
func NewBlockingInterceptor(
32+
id uuid.UUID,
33+
req *ChatCompletionNewParamsWrapper,
34+
cfg config.OpenAI,
35+
clientHeaders http.Header,
36+
authHeaderName string,
37+
tracer trace.Tracer,
38+
) *BlockingInterception {
3239
return &BlockingInterception{interceptionBase: interceptionBase{
33-
id: id,
34-
req: req,
35-
cfg: cfg,
36-
tracer: tracer,
40+
id: id,
41+
req: req,
42+
cfg: cfg,
43+
clientHeaders: clientHeaders,
44+
authHeaderName: authHeaderName,
45+
tracer: tracer,
3746
}}
3847
}
3948

@@ -78,6 +87,9 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
7887

7988
var opts []option.RequestOption
8089
opts = append(opts, option.WithRequestTimeout(time.Second*600))
90+
91+
// TODO(ssncferreira): inject actor headers directly in the client-header
92+
// middleware instead of using SDK options.
8193
if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders {
8294
opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...)
8395
}

intercept/chatcompletions/streaming.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,21 @@ type StreamingInterception struct {
3333
interceptionBase
3434
}
3535

36-
func NewStreamingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *StreamingInterception {
36+
func NewStreamingInterceptor(
37+
id uuid.UUID,
38+
req *ChatCompletionNewParamsWrapper,
39+
cfg config.OpenAI,
40+
clientHeaders http.Header,
41+
authHeaderName string,
42+
tracer trace.Tracer,
43+
) *StreamingInterception {
3744
return &StreamingInterception{interceptionBase: interceptionBase{
38-
id: id,
39-
req: req,
40-
cfg: cfg,
41-
tracer: tracer,
45+
id: id,
46+
req: req,
47+
cfg: cfg,
48+
clientHeaders: clientHeaders,
49+
authHeaderName: authHeaderName,
50+
tracer: tracer,
4251
}}
4352
}
4453

@@ -121,6 +130,9 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
121130
for {
122131
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)
123132
var opts []option.RequestOption
133+
134+
// TODO(ssncferreira): inject actor headers directly in the client-header
135+
// middleware instead of using SDK options.
124136
if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders {
125137
opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...)
126138
}

intercept/chatcompletions/streaming_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,16 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
8181
Stream: true,
8282
}
8383

84+
// Create test request
85+
w := httptest.NewRecorder()
86+
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)
87+
8488
tracer := otel.Tracer("test")
85-
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, tracer)
89+
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, httpReq.Header, "Authorization", tracer)
8690

8791
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
8892
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
8993

90-
// Create test request
91-
w := httptest.NewRecorder()
92-
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)
93-
9494
// Process the request
9595
err := interceptor.ProcessRequest(w, httpReq)
9696

intercept/client_headers.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package intercept
2+
3+
import "net/http"
4+
5+
// hopByHopHeaders are connection-level headers specific to the connection
6+
// between client and AI Bridge, not meant for the upstream.
7+
// See https://www.rfc-editor.org/rfc/rfc2616#section-13.5.1
8+
var hopByHopHeaders = []string{
9+
"Connection",
10+
"Keep-Alive",
11+
"Proxy-Authenticate",
12+
"Proxy-Authorization",
13+
"Te",
14+
"Trailer",
15+
"Transfer-Encoding",
16+
"Upgrade",
17+
}
18+
19+
// nonForwardedHeaders are transport-level headers managed by aibridge or
20+
// Go's HTTP transport that must not be forwarded to the upstream provider.
21+
var nonForwardedHeaders = []string{
22+
"Host",
23+
"Accept-Encoding",
24+
"Content-Length",
25+
}
26+
27+
// authHeaders are headers that carry authentication credentials from the
28+
// client. These are stripped because the SDK re-injects the correct
29+
// provider credentials (API key or per-user token).
30+
var authHeaders = []string{
31+
"Authorization",
32+
"X-Api-Key",
33+
}
34+
35+
// SanitizeClientHeaders returns a copy of the client headers with hop-by-hop,
36+
// transport, and auth headers removed.
37+
func SanitizeClientHeaders(clientHeaders http.Header) http.Header {
38+
sanitized := clientHeaders.Clone()
39+
for _, h := range hopByHopHeaders {
40+
sanitized.Del(h)
41+
}
42+
for _, h := range nonForwardedHeaders {
43+
sanitized.Del(h)
44+
}
45+
for _, h := range authHeaders {
46+
sanitized.Del(h)
47+
}
48+
return sanitized
49+
}
50+
51+
// BuildUpstreamHeaders produces the header set for an upstream SDK request.
52+
// It starts from the sanitized client headers, then preserves specific
53+
// headers from the SDK-built request that must not be overwritten.
54+
func BuildUpstreamHeaders(sdkHeader http.Header, clientHeaders http.Header, authHeaderName string) http.Header {
55+
headers := SanitizeClientHeaders(clientHeaders)
56+
57+
// Preserve the auth header set by the SDK from the provider configuration.
58+
if v := sdkHeader.Get(authHeaderName); v != "" {
59+
headers.Set(authHeaderName, v)
60+
}
61+
62+
// Preserve actor headers injected by aibridge as per-request SDK options.
63+
for name, values := range sdkHeader {
64+
if IsActorHeader(name) {
65+
headers[name] = values
66+
}
67+
}
68+
69+
return headers
70+
}

0 commit comments

Comments
 (0)