Skip to content

Commit 729d11f

Browse files
gkatz2claude
andauthored
Return 503 for expired proxy tokens (#4722)
* Return 503 for expired proxy tokens When the proxy's OAuth token source fails, the token injection middleware returned 401 Unauthorized. MCP clients interpret 401 as a signal to begin OAuth authentication, but the proxy manages OAuth internally and has no client-facing OAuth metadata. The failed discovery is cached, blocking reconnection even after the token refreshes. Return 503 Service Unavailable with Retry-After instead, which clients treat as a transient connection error. Fixes #4721 Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: Greg Katz <gkatz@indeed.com> * Extract Retry-After value into named constant Address review nit: extract the retry delay into a package-level constant with a comment linking it to the MonitoredTokenSource backoff interval. Signed-off-by: Greg Katz <gkatz@indeed.com> --------- Signed-off-by: Greg Katz <gkatz@indeed.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent fae9f0d commit 729d11f

2 files changed

Lines changed: 121 additions & 3 deletions

File tree

pkg/transport/middleware/token_injection.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,15 @@ import (
1414
"github.com/stacklok/toolhive/pkg/transport/types"
1515
)
1616

17+
// retryAfterSecs tells MCP clients how long to wait before retrying.
18+
// Matches the initial MonitoredTokenSource backoff interval so that clients
19+
// retry around the same time the next token refresh attempt happens.
20+
const retryAfterSecs = "10"
21+
1722
// CreateTokenInjectionMiddleware returns a middleware that injects a Bearer token
18-
// from the provided oauth2.TokenSource. It returns 401 when the workload is unauthenticated.
23+
// from the provided oauth2.TokenSource. It returns 503 Service Unavailable with a
24+
// Retry-After header when the token cannot be retrieved, so that MCP clients treat
25+
// the failure as transient rather than initiating an OAuth discovery flow.
1926
func CreateTokenInjectionMiddleware(tokenSource oauth2.TokenSource) types.MiddlewareFunction {
2027
return func(next http.Handler) http.Handler {
2128
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -24,8 +31,11 @@ func CreateTokenInjectionMiddleware(tokenSource oauth2.TokenSource) types.Middle
2431
if err != nil {
2532
slog.Warn("unable to retrieve OAuth token", "error", err)
2633
// The token source (AuthenticatedTokenSource) handles marking
27-
// the workload as unauthenticated in its Token() method
28-
http.Error(w, "Authentication required", http.StatusUnauthorized)
34+
// the workload as unauthenticated in its Token() method.
35+
// Return 503 instead of 401 so MCP clients do not mistake this
36+
// for a server that requires client-side OAuth authentication.
37+
w.Header().Set("Retry-After", retryAfterSecs)
38+
http.Error(w, "Token temporarily unavailable", http.StatusServiceUnavailable)
2939
return
3040
}
3141

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package middleware
5+
6+
import (
7+
"errors"
8+
"net/http"
9+
"net/http/httptest"
10+
"testing"
11+
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
14+
"golang.org/x/oauth2"
15+
)
16+
17+
// stubTokenSource implements oauth2.TokenSource for testing.
18+
type stubTokenSource struct {
19+
token *oauth2.Token
20+
err error
21+
}
22+
23+
func (s *stubTokenSource) Token() (*oauth2.Token, error) {
24+
return s.token, s.err
25+
}
26+
27+
func TestCreateTokenInjectionMiddleware(t *testing.T) {
28+
t.Parallel()
29+
30+
tests := []struct {
31+
name string
32+
tokenSource oauth2.TokenSource
33+
wantStatus int
34+
wantNextCalled bool
35+
wantAuthHeader string
36+
wantRetryAfter string
37+
wantBodyContain string
38+
}{
39+
{
40+
name: "token source error returns 503 with Retry-After",
41+
tokenSource: &stubTokenSource{
42+
err: errors.New("token expired"),
43+
},
44+
wantStatus: http.StatusServiceUnavailable,
45+
wantNextCalled: false,
46+
wantRetryAfter: retryAfterSecs,
47+
wantBodyContain: "Token temporarily unavailable",
48+
},
49+
{
50+
name: "token source succeeds injects Bearer token",
51+
tokenSource: &stubTokenSource{
52+
token: &oauth2.Token{AccessToken: "test-access-token"},
53+
},
54+
wantStatus: http.StatusOK,
55+
wantNextCalled: true,
56+
wantAuthHeader: "Bearer test-access-token",
57+
},
58+
{
59+
name: "nil token source passes request through",
60+
tokenSource: nil,
61+
wantStatus: http.StatusOK,
62+
wantNextCalled: true,
63+
wantAuthHeader: "",
64+
},
65+
}
66+
67+
for _, tt := range tests {
68+
t.Run(tt.name, func(t *testing.T) {
69+
t.Parallel()
70+
71+
nextCalled := false
72+
var capturedReq *http.Request
73+
74+
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
75+
nextCalled = true
76+
capturedReq = r
77+
w.WriteHeader(http.StatusOK)
78+
})
79+
80+
mw := CreateTokenInjectionMiddleware(tt.tokenSource)
81+
handler := mw(next)
82+
83+
req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
84+
rec := httptest.NewRecorder()
85+
handler.ServeHTTP(rec, req)
86+
87+
assert.Equal(t, tt.wantStatus, rec.Code)
88+
assert.Equal(t, tt.wantNextCalled, nextCalled)
89+
90+
if tt.wantRetryAfter != "" {
91+
assert.Equal(t, tt.wantRetryAfter, rec.Header().Get("Retry-After"))
92+
}
93+
94+
if tt.wantBodyContain != "" {
95+
assert.Contains(t, rec.Body.String(), tt.wantBodyContain)
96+
}
97+
98+
if tt.wantNextCalled {
99+
require.NotNil(t, capturedReq)
100+
if tt.wantAuthHeader != "" {
101+
assert.Equal(t, tt.wantAuthHeader, capturedReq.Header.Get("Authorization"))
102+
} else {
103+
assert.Empty(t, capturedReq.Header.Get("Authorization"))
104+
}
105+
}
106+
})
107+
}
108+
}

0 commit comments

Comments
 (0)