Skip to content

Commit b3cd38b

Browse files
authored
feat: forward client request headers to upstream providers in bridge routes (#214)
* feat: forward client request headers to upstream providers in bridge routes * chore: address comments * test: update openai tests to table driven format
1 parent 099670f commit b3cd38b

18 files changed

Lines changed: 571 additions & 75 deletions

intercept/chatcompletions/base.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/coder/aibridge/config"
1111
aibcontext "github.com/coder/aibridge/context"
12+
"github.com/coder/aibridge/intercept"
1213
"github.com/coder/aibridge/intercept/apidump"
1314
"github.com/coder/aibridge/mcp"
1415
"github.com/coder/aibridge/recorder"
@@ -29,6 +30,10 @@ type interceptionBase struct {
2930
req *ChatCompletionNewParamsWrapper
3031
cfg config.OpenAI
3132

33+
// clientHeaders are the original HTTP headers from the client request.
34+
clientHeaders http.Header
35+
authHeaderName string
36+
3237
logger slog.Logger
3338
tracer trace.Tracer
3439

@@ -41,10 +46,21 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService
4146

4247
// Add extra headers if configured.
4348
// Some providers require additional headers that are not added by the SDK.
49+
// TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192
4450
for key, value := range i.cfg.ExtraHeaders {
4551
opts = append(opts, option.WithHeader(key, value))
4652
}
4753

54+
// Forward client headers to upstream. This middleware runs after the SDK
55+
// has built the request, and replaces the outgoing headers with the sanitized
56+
// client headers plus provider auth.
57+
if i.clientHeaders != nil {
58+
opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
59+
req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName)
60+
return next(req)
61+
}))
62+
}
63+
4864
// Add API dump middleware if configured
4965
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
5066
opts = append(opts, option.WithMiddleware(mw))

intercept/chatcompletions/blocking.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,21 @@ type BlockingInterception struct {
2828
interceptionBase
2929
}
3030

31-
func NewBlockingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *BlockingInterception {
31+
func NewBlockingInterceptor(
32+
id uuid.UUID,
33+
req *ChatCompletionNewParamsWrapper,
34+
cfg config.OpenAI,
35+
clientHeaders http.Header,
36+
authHeaderName string,
37+
tracer trace.Tracer,
38+
) *BlockingInterception {
3239
return &BlockingInterception{interceptionBase: interceptionBase{
33-
id: id,
34-
req: req,
35-
cfg: cfg,
36-
tracer: tracer,
40+
id: id,
41+
req: req,
42+
cfg: cfg,
43+
clientHeaders: clientHeaders,
44+
authHeaderName: authHeaderName,
45+
tracer: tracer,
3746
}}
3847
}
3948

@@ -78,6 +87,9 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
7887

7988
var opts []option.RequestOption
8089
opts = append(opts, option.WithRequestTimeout(time.Second*600))
90+
91+
// TODO(ssncferreira): inject actor headers directly in the client-header
92+
// middleware instead of using SDK options.
8193
if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders {
8294
opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...)
8395
}

intercept/chatcompletions/streaming.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,21 @@ type StreamingInterception struct {
3333
interceptionBase
3434
}
3535

36-
func NewStreamingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *StreamingInterception {
36+
func NewStreamingInterceptor(
37+
id uuid.UUID,
38+
req *ChatCompletionNewParamsWrapper,
39+
cfg config.OpenAI,
40+
clientHeaders http.Header,
41+
authHeaderName string,
42+
tracer trace.Tracer,
43+
) *StreamingInterception {
3744
return &StreamingInterception{interceptionBase: interceptionBase{
38-
id: id,
39-
req: req,
40-
cfg: cfg,
41-
tracer: tracer,
45+
id: id,
46+
req: req,
47+
cfg: cfg,
48+
clientHeaders: clientHeaders,
49+
authHeaderName: authHeaderName,
50+
tracer: tracer,
4251
}}
4352
}
4453

@@ -115,6 +124,9 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
115124
for {
116125
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)
117126
var opts []option.RequestOption
127+
128+
// TODO(ssncferreira): inject actor headers directly in the client-header
129+
// middleware instead of using SDK options.
118130
if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders {
119131
opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...)
120132
}

intercept/chatcompletions/streaming_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,16 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
8181
Stream: true,
8282
}
8383

84+
// Create test request
85+
w := httptest.NewRecorder()
86+
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)
87+
8488
tracer := otel.Tracer("test")
85-
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, tracer)
89+
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, httpReq.Header, "Authorization", tracer)
8690

8791
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
8892
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
8993

90-
// Create test request
91-
w := httptest.NewRecorder()
92-
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)
93-
9494
// Process the request
9595
err := interceptor.ProcessRequest(w, httpReq)
9696

intercept/client_headers.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package intercept
2+
3+
import "net/http"
4+
5+
// hopByHopHeaders are connection-level headers specific to the connection
6+
// between client and AI Bridge, not meant for the upstream.
7+
// See https://www.rfc-editor.org/rfc/rfc2616#section-13.5.1
8+
var hopByHopHeaders = []string{
9+
"Connection",
10+
"Keep-Alive",
11+
"Proxy-Authenticate",
12+
"Proxy-Authorization",
13+
"Te",
14+
"Trailer",
15+
"Transfer-Encoding",
16+
"Upgrade",
17+
}
18+
19+
// nonForwardedHeaders are transport-level headers managed by aibridge or
20+
// Go's HTTP transport that must not be forwarded to the upstream provider.
21+
var nonForwardedHeaders = []string{
22+
"Host",
23+
"Accept-Encoding",
24+
"Content-Length",
25+
}
26+
27+
// authHeaders are headers that carry authentication credentials from the
28+
// client. The upstream request is built by the SDK, which sets the correct
29+
// provider credentials via option.WithAPIKey. Client auth headers are
30+
// stripped here and the provider credentials are re-injected by
31+
// BuildUpstreamHeaders from the SDK-built request.
32+
var authHeaders = []string{
33+
"Authorization",
34+
"X-Api-Key",
35+
}
36+
37+
// PrepareClientHeaders returns a copy of the client headers with hop-by-hop,
38+
// transport, and auth headers removed.
39+
func PrepareClientHeaders(clientHeaders http.Header) http.Header {
40+
prepared := clientHeaders.Clone()
41+
for _, h := range hopByHopHeaders {
42+
prepared.Del(h)
43+
}
44+
for _, h := range nonForwardedHeaders {
45+
prepared.Del(h)
46+
}
47+
for _, h := range authHeaders {
48+
prepared.Del(h)
49+
}
50+
return prepared
51+
}
52+
53+
// BuildUpstreamHeaders produces the header set for an upstream SDK request.
54+
// It starts from the prepared client headers, then preserves specific
55+
// headers from the SDK-built request that must not be overwritten.
56+
func BuildUpstreamHeaders(sdkHeader http.Header, clientHeaders http.Header, authHeaderName string) http.Header {
57+
headers := PrepareClientHeaders(clientHeaders)
58+
59+
// Preserve the auth header set by the SDK from the provider configuration.
60+
if v := sdkHeader.Get(authHeaderName); v != "" {
61+
headers.Set(authHeaderName, v)
62+
}
63+
64+
// Preserve actor headers injected by aibridge as per-request SDK options.
65+
for name, values := range sdkHeader {
66+
if IsActorHeader(name) {
67+
headers[name] = values
68+
}
69+
}
70+
71+
return headers
72+
}

0 commit comments

Comments
 (0)