diff --git a/pkg/transport/middleware/token_injection.go b/pkg/transport/middleware/token_injection.go index cfd46fc1f8..49fce30fbe 100644 --- a/pkg/transport/middleware/token_injection.go +++ b/pkg/transport/middleware/token_injection.go @@ -14,8 +14,15 @@ import ( "github.com/stacklok/toolhive/pkg/transport/types" ) +// retryAfterSecs tells MCP clients how long to wait before retrying. +// Matches the initial MonitoredTokenSource backoff interval so that clients +// retry around the same time the next token refresh attempt happens. +const retryAfterSecs = "10" + // CreateTokenInjectionMiddleware returns a middleware that injects a Bearer token -// from the provided oauth2.TokenSource. It returns 401 when the workload is unauthenticated. +// from the provided oauth2.TokenSource. It returns 503 Service Unavailable with a +// Retry-After header when the token cannot be retrieved, so that MCP clients treat +// the failure as transient rather than initiating an OAuth discovery flow. func CreateTokenInjectionMiddleware(tokenSource oauth2.TokenSource) types.MiddlewareFunction { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -24,8 +31,11 @@ func CreateTokenInjectionMiddleware(tokenSource oauth2.TokenSource) types.Middle if err != nil { slog.Warn("unable to retrieve OAuth token", "error", err) // The token source (AuthenticatedTokenSource) handles marking - // the workload as unauthenticated in its Token() method - http.Error(w, "Authentication required", http.StatusUnauthorized) + // the workload as unauthenticated in its Token() method. + // Return 503 instead of 401 so MCP clients do not mistake this + // for a server that requires client-side OAuth authentication. + w.Header().Set("Retry-After", retryAfterSecs) + http.Error(w, "Token temporarily unavailable", http.StatusServiceUnavailable) return } diff --git a/pkg/transport/middleware/token_injection_test.go b/pkg/transport/middleware/token_injection_test.go new file mode 100644 index 0000000000..c8f08f7d7e --- /dev/null +++ b/pkg/transport/middleware/token_injection_test.go @@ -0,0 +1,108 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +// stubTokenSource implements oauth2.TokenSource for testing. +type stubTokenSource struct { + token *oauth2.Token + err error +} + +func (s *stubTokenSource) Token() (*oauth2.Token, error) { + return s.token, s.err +} + +func TestCreateTokenInjectionMiddleware(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tokenSource oauth2.TokenSource + wantStatus int + wantNextCalled bool + wantAuthHeader string + wantRetryAfter string + wantBodyContain string + }{ + { + name: "token source error returns 503 with Retry-After", + tokenSource: &stubTokenSource{ + err: errors.New("token expired"), + }, + wantStatus: http.StatusServiceUnavailable, + wantNextCalled: false, + wantRetryAfter: retryAfterSecs, + wantBodyContain: "Token temporarily unavailable", + }, + { + name: "token source succeeds injects Bearer token", + tokenSource: &stubTokenSource{ + token: &oauth2.Token{AccessToken: "test-access-token"}, + }, + wantStatus: http.StatusOK, + wantNextCalled: true, + wantAuthHeader: "Bearer test-access-token", + }, + { + name: "nil token source passes request through", + tokenSource: nil, + wantStatus: http.StatusOK, + wantNextCalled: true, + wantAuthHeader: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + nextCalled := false + var capturedReq *http.Request + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + capturedReq = r + w.WriteHeader(http.StatusOK) + }) + + mw := CreateTokenInjectionMiddleware(tt.tokenSource) + handler := mw(next) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, tt.wantStatus, rec.Code) + assert.Equal(t, tt.wantNextCalled, nextCalled) + + if tt.wantRetryAfter != "" { + assert.Equal(t, tt.wantRetryAfter, rec.Header().Get("Retry-After")) + } + + if tt.wantBodyContain != "" { + assert.Contains(t, rec.Body.String(), tt.wantBodyContain) + } + + if tt.wantNextCalled { + require.NotNil(t, capturedReq) + if tt.wantAuthHeader != "" { + assert.Equal(t, tt.wantAuthHeader, capturedReq.Header.Get("Authorization")) + } else { + assert.Empty(t, capturedReq.Header.Get("Authorization")) + } + } + }) + } +}