Skip to content

Commit 68e79e9

Browse files
aron-muonclaude
andcommitted
Fix transparent proxy for remote MCP servers behind redirects
Three issues prevented MCPRemoteProxy from connecting to third-party upstream MCP servers that use HTTP redirects: 1. X-Forwarded-Host leaked the proxy's hostname to the upstream. The upstream used it to construct 307 redirect URLs pointing back to the proxy, creating a redirect loop. Fix: skip SetXForwarded() for remote upstreams (isRemote == true). 2. Go's http.Transport.RoundTrip does not follow redirects, but httputil.ReverseProxy uses Transport directly. Upstream 307/308 redirects (e.g. HTTPS→HTTP scheme changes, path canonicalization) were returned to the MCP client which cannot follow them through the proxy. Fix: add forwardFollowingRedirects that transparently follows up to 10 redirects, preserving method and body for 307/308 (RFC 7538). 3. When disableUpstreamTokenInjection is true, the client's ToolHive JWT was still forwarded to the upstream in the Authorization header. Fix: add strip-auth middleware that removes the Authorization header before forwarding. Also adds debug logging for outbound request headers and upstream response status codes to aid diagnosis of remote proxy issues. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bc36676 commit 68e79e9

3 files changed

Lines changed: 190 additions & 28 deletions

File tree

pkg/runner/middleware.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package runner
55

66
import (
77
"fmt"
8+
"net/http"
89

910
"github.com/stacklok/toolhive/pkg/audit"
1011
"github.com/stacklok/toolhive/pkg/auth"
@@ -37,6 +38,7 @@ func GetSupportedMiddlewareFactories() map[string]types.MiddlewareFactory {
3738
audit.MiddlewareType: audit.CreateMiddleware,
3839
recovery.MiddlewareType: recovery.CreateMiddleware,
3940
headerfwd.HeaderForwardMiddlewareName: headerfwd.CreateMiddleware,
41+
stripAuthMiddlewareType: createStripAuthMiddleware,
4042
}
4143
}
4244

@@ -263,9 +265,10 @@ func addUpstreamSwapMiddleware(
263265
return middlewares, nil
264266
}
265267

266-
// Skip upstream token injection if explicitly disabled
268+
// When upstream token injection is disabled, strip the Authorization header
269+
// so the client's ToolHive JWT doesn't leak to the upstream server.
267270
if config.EmbeddedAuthServerConfig.DisableUpstreamTokenInjection {
268-
return middlewares, nil
271+
return addAuthHeaderStripMiddleware(middlewares)
269272
}
270273

271274
// Use provided config or defaults
@@ -287,6 +290,45 @@ func addUpstreamSwapMiddleware(
287290
return append(middlewares, *upstreamSwapMwConfig), nil
288291
}
289292

293+
// stripAuthMiddlewareType is the type identifier for the auth header stripping middleware.
294+
const stripAuthMiddlewareType = "strip-auth"
295+
296+
// addAuthHeaderStripMiddleware adds a middleware that removes the Authorization header
297+
// before forwarding to the upstream. This prevents the client's ToolHive JWT from
298+
// leaking to upstream servers that don't expect it.
299+
func addAuthHeaderStripMiddleware(
300+
middlewares []types.MiddlewareConfig,
301+
) ([]types.MiddlewareConfig, error) {
302+
mwConfig, err := types.NewMiddlewareConfig(stripAuthMiddlewareType, struct{}{})
303+
if err != nil {
304+
return nil, fmt.Errorf("failed to create strip-auth middleware config: %w", err)
305+
}
306+
return append(middlewares, *mwConfig), nil
307+
}
308+
309+
// createStripAuthMiddleware is the factory function for the auth header stripping middleware.
310+
func createStripAuthMiddleware(_ *types.MiddlewareConfig, runner types.MiddlewareRunner) error {
311+
mw := &stripAuthMiddleware{}
312+
runner.AddMiddleware(stripAuthMiddlewareType, mw)
313+
return nil
314+
}
315+
316+
// stripAuthMiddleware removes the Authorization header from requests.
317+
type stripAuthMiddleware struct{}
318+
319+
// Handler returns the middleware function.
320+
func (*stripAuthMiddleware) Handler() types.MiddlewareFunction {
321+
return func(next http.Handler) http.Handler {
322+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
323+
r.Header.Del("Authorization")
324+
next.ServeHTTP(w, r)
325+
})
326+
}
327+
}
328+
329+
// Close cleans up resources.
330+
func (*stripAuthMiddleware) Close() error { return nil }
331+
290332
// addAWSStsMiddleware adds AWS STS middleware if configured.
291333
// Returns an error if AWSStsConfig is set but RemoteURL is empty, because
292334
// SigV4 signing is only meaningful for remote MCP servers.

pkg/runner/middleware_test.go

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,10 @@ func TestAddUpstreamSwapMiddleware(t *testing.T) {
253253
t.Parallel()
254254

255255
tests := []struct {
256-
name string
257-
config *RunConfig
258-
wantAppended bool
256+
name string
257+
config *RunConfig
258+
wantAppended bool
259+
wantType string // expected middleware type when appended
259260
}{
260261
{
261262
name: "nil EmbeddedAuthServerConfig returns input unchanged",
@@ -269,15 +270,17 @@ func TestAddUpstreamSwapMiddleware(t *testing.T) {
269270
UpstreamSwapConfig: nil,
270271
},
271272
wantAppended: true,
273+
wantType: upstreamswap.MiddlewareType,
272274
},
273275
{
274-
name: "EmbeddedAuthServerConfig with DisableUpstreamTokenInjection skips middleware",
276+
name: "DisableUpstreamTokenInjection adds strip-auth middleware instead",
275277
config: func() *RunConfig {
276278
cfg := createMinimalAuthServerConfig()
277279
cfg.DisableUpstreamTokenInjection = true
278280
return &RunConfig{EmbeddedAuthServerConfig: cfg}
279281
}(),
280-
wantAppended: false,
282+
wantAppended: true,
283+
wantType: stripAuthMiddlewareType,
281284
},
282285
{
283286
name: "EmbeddedAuthServerConfig set with explicit UpstreamSwapConfig uses provided config",
@@ -288,6 +291,7 @@ func TestAddUpstreamSwapMiddleware(t *testing.T) {
288291
},
289292
},
290293
wantAppended: true,
294+
wantType: upstreamswap.MiddlewareType,
291295
},
292296
{
293297
name: "EmbeddedAuthServerConfig with custom header strategy config",
@@ -299,6 +303,7 @@ func TestAddUpstreamSwapMiddleware(t *testing.T) {
299303
},
300304
},
301305
wantAppended: true,
306+
wantType: upstreamswap.MiddlewareType,
302307
},
303308
}
304309

@@ -318,20 +323,20 @@ func TestAddUpstreamSwapMiddleware(t *testing.T) {
318323
// Should have one additional entry.
319324
require.Len(t, got, len(initial)+1)
320325
added := got[len(got)-1]
321-
assert.Equal(t, upstreamswap.MiddlewareType, added.Type)
322-
323-
// Verify serialized params contain the expected config.
324-
var params upstreamswap.MiddlewareParams
325-
require.NoError(t, json.Unmarshal(added.Parameters, &params))
326+
assert.Equal(t, tt.wantType, added.Type)
326327

327-
if tt.config.UpstreamSwapConfig != nil {
328-
// Should use the provided config
329-
require.NotNil(t, params.Config)
330-
assert.Equal(t, tt.config.UpstreamSwapConfig.HeaderStrategy, params.Config.HeaderStrategy)
331-
assert.Equal(t, tt.config.UpstreamSwapConfig.CustomHeaderName, params.Config.CustomHeaderName)
332-
} else {
333-
// Should use defaults (empty config is valid)
334-
require.NotNil(t, params.Config)
328+
// For upstreamswap type, verify serialized params
329+
if tt.wantType == upstreamswap.MiddlewareType {
330+
var params upstreamswap.MiddlewareParams
331+
require.NoError(t, json.Unmarshal(added.Parameters, &params))
332+
333+
if tt.config.UpstreamSwapConfig != nil {
334+
require.NotNil(t, params.Config)
335+
assert.Equal(t, tt.config.UpstreamSwapConfig.HeaderStrategy, params.Config.HeaderStrategy)
336+
assert.Equal(t, tt.config.UpstreamSwapConfig.CustomHeaderName, params.Config.CustomHeaderName)
337+
} else {
338+
require.NotNil(t, params.Config)
339+
}
335340
}
336341
})
337342
}
@@ -344,6 +349,7 @@ func TestPopulateMiddlewareConfigs_UpstreamSwap(t *testing.T) {
344349
name string
345350
config *RunConfig
346351
wantUpstreamSwap bool
352+
wantStripAuth bool
347353
wantHeaderStrategy string
348354
}{
349355
{
@@ -357,13 +363,14 @@ func TestPopulateMiddlewareConfigs_UpstreamSwap(t *testing.T) {
357363
wantUpstreamSwap: false,
358364
},
359365
{
360-
name: "DisableUpstreamTokenInjection omits upstream-swap",
366+
name: "DisableUpstreamTokenInjection adds strip-auth instead of upstream-swap",
361367
config: func() *RunConfig {
362368
cfg := createMinimalAuthServerConfig()
363369
cfg.DisableUpstreamTokenInjection = true
364370
return &RunConfig{EmbeddedAuthServerConfig: cfg}
365371
}(),
366372
wantUpstreamSwap: false,
373+
wantStripAuth: true,
367374
},
368375
{
369376
name: "explicit UpstreamSwapConfig is used",
@@ -385,20 +392,25 @@ func TestPopulateMiddlewareConfigs_UpstreamSwap(t *testing.T) {
385392
err := PopulateMiddlewareConfigs(tt.config)
386393
require.NoError(t, err)
387394

388-
var found bool
395+
var foundSwap bool
396+
var foundStrip bool
389397
var foundConfig *types.MiddlewareConfig
390398
for i, mw := range tt.config.MiddlewareConfigs {
391399
if mw.Type == upstreamswap.MiddlewareType {
392-
found = true
400+
foundSwap = true
393401
foundConfig = &tt.config.MiddlewareConfigs[i]
394-
break
402+
}
403+
if mw.Type == stripAuthMiddlewareType {
404+
foundStrip = true
395405
}
396406
}
397-
assert.Equal(t, tt.wantUpstreamSwap, found,
407+
assert.Equal(t, tt.wantUpstreamSwap, foundSwap,
398408
"upstream-swap middleware presence mismatch")
409+
assert.Equal(t, tt.wantStripAuth, foundStrip,
410+
"strip-auth middleware presence mismatch")
399411

400412
// Verify config values if we expect the middleware and have specific expectations
401-
if found && tt.wantHeaderStrategy != "" {
413+
if foundSwap && tt.wantHeaderStrategy != "" {
402414
var params upstreamswap.MiddlewareParams
403415
require.NoError(t, json.Unmarshal(foundConfig.Parameters, &params))
404416
require.NotNil(t, params.Config)

pkg/transport/proxy/transparent/transparent_proxy.go

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,10 @@ func (p *TransparentProxy) serverInitialized() bool {
352352
return p.isServerInitialized.Load()
353353
}
354354

355+
// maxRedirects is the maximum number of HTTP redirects the transport follows
356+
// before returning an error. Matches http.Client's default of 10.
357+
const maxRedirects = 10
358+
355359
func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) {
356360
tr := t.base
357361
if tr == nil {
@@ -360,6 +364,87 @@ func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) {
360364
return tr.RoundTrip(req)
361365
}
362366

367+
// forwardFollowingRedirects sends req and transparently follows 3xx redirects.
368+
//
369+
// Go's http.Transport.RoundTrip does not follow redirects (that is
370+
// http.Client's job), but httputil.ReverseProxy uses Transport directly.
371+
// Without this, upstream redirects (e.g. HTTPS→HTTP scheme changes, path
372+
// canonicalization, or API gateway routing) are returned to the MCP client
373+
// which cannot follow them through the proxy.
374+
//
375+
// 307/308 preserve method and body (RFC 7538). 301/302/303 change to GET.
376+
func (t *tracingTransport) forwardFollowingRedirects(req *http.Request, body []byte) (*http.Response, error) {
377+
for range maxRedirects {
378+
resp, err := t.forward(req)
379+
if err != nil {
380+
return nil, err
381+
}
382+
383+
if !isRedirectStatus(resp.StatusCode) {
384+
return resp, nil
385+
}
386+
387+
location := resp.Header.Get("Location")
388+
if location == "" {
389+
return resp, nil
390+
}
391+
392+
redirectURL, err := req.URL.Parse(location)
393+
if err != nil {
394+
slog.Warn("unparsable redirect Location, returning raw redirect response",
395+
"location", location, "error", err)
396+
return resp, nil
397+
}
398+
399+
slog.Warn("following upstream redirect",
400+
"status", resp.StatusCode,
401+
"from", req.URL.String(),
402+
"to", redirectURL.String(),
403+
)
404+
405+
// Drain and close the body so the TCP connection can be reused.
406+
_, _ = io.Copy(io.Discard, resp.Body)
407+
resp.Body.Close()
408+
409+
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, redirectURL.String(), nil)
410+
if err != nil {
411+
return nil, fmt.Errorf("failed to build redirect request: %w", err)
412+
}
413+
newReq.Header = req.Header.Clone()
414+
newReq.Host = redirectURL.Host
415+
416+
switch resp.StatusCode {
417+
case http.StatusMovedPermanently, http.StatusFound, http.StatusSeeOther:
418+
// 301/302/303 → change to GET, drop body (RFC 7231)
419+
newReq.Method = http.MethodGet
420+
newReq.ContentLength = 0
421+
default:
422+
// 307/308 → preserve method and body (RFC 7538)
423+
if len(body) > 0 {
424+
newReq.Body = io.NopCloser(bytes.NewReader(body))
425+
newReq.ContentLength = int64(len(body))
426+
}
427+
}
428+
429+
req = newReq
430+
}
431+
432+
return nil, fmt.Errorf("upstream exceeded %d redirects", maxRedirects)
433+
}
434+
435+
// isRedirectStatus returns true for HTTP 3xx codes that carry a Location header.
436+
func isRedirectStatus(code int) bool {
437+
switch code {
438+
case http.StatusMovedPermanently, // 301
439+
http.StatusFound, // 302
440+
http.StatusSeeOther, // 303
441+
http.StatusTemporaryRedirect, // 307
442+
http.StatusPermanentRedirect: // 308
443+
return true
444+
}
445+
return false
446+
}
447+
363448
// nolint:gocyclo // This function handles multiple request types and is complex by design
364449
func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
365450
// Always rewrite Host header to match the target URL to avoid "Invalid Host" errors
@@ -371,6 +456,14 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error)
371456
req.Host = req.URL.Host
372457
}
373458

459+
slog.Debug("outbound request to upstream",
460+
"method", req.Method,
461+
"url", req.URL.String(),
462+
"host", req.Host,
463+
"accept", req.Header.Get("Accept"),
464+
"content_type", req.Header.Get("Content-Type"),
465+
)
466+
374467
reqBody := readRequestBody(req)
375468

376469
// thv proxy does not provide the transport type, so we need to detect it from the request
@@ -385,7 +478,7 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error)
385478
sawInitialize = t.detectInitialize(reqBody)
386479
}
387480

388-
resp, err := t.forward(req)
481+
resp, err := t.forwardFollowingRedirects(req, reqBody)
389482
if err != nil {
390483
if errors.Is(err, context.Canceled) {
391484
// Expected during shutdown or client disconnect—silently ignore
@@ -395,6 +488,13 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error)
395488
return nil, err
396489
}
397490

491+
slog.Debug("upstream response received",
492+
"status", resp.StatusCode,
493+
"url", req.URL.String(),
494+
"content_type", resp.Header.Get("Content-Type"),
495+
"mcp_session_id", resp.Header.Get("Mcp-Session-Id"),
496+
)
497+
398498
// Check for 401 Unauthorized response (bearer token authentication failure)
399499
if resp.StatusCode == http.StatusUnauthorized {
400500
//nolint:gosec // G706: logging target URI from config
@@ -504,7 +604,15 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
504604
FlushInterval: -1,
505605
Rewrite: func(pr *httputil.ProxyRequest) {
506606
pr.SetURL(targetURL)
507-
pr.SetXForwarded()
607+
608+
// Only set X-Forwarded-* headers for local backends.
609+
// For remote upstreams, these headers leak the proxy's hostname
610+
// (X-Forwarded-Host) to third-party servers, which can cause
611+
// 307 redirect loops when the upstream uses that header to
612+
// construct redirect URLs pointing back to the proxy.
613+
if !p.isRemote {
614+
pr.SetXForwarded()
615+
}
508616

509617
// Stash the original inbound request in the outbound request's
510618
// context so that ModifyResponse (SSE response processor) can

0 commit comments

Comments
 (0)