Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions pkg/transport/middleware/token_injection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@ import (
"github.com/stacklok/toolhive/pkg/transport/types"
)

// retryAfterSecs tells MCP clients how long to wait before retrying.
// Matches the initial MonitoredTokenSource backoff interval so that clients
// retry around the same time the next token refresh attempt happens.
const retryAfterSecs = "10"

// CreateTokenInjectionMiddleware returns a middleware that injects a Bearer token
// from the provided oauth2.TokenSource. It returns 401 when the workload is unauthenticated.
// from the provided oauth2.TokenSource. It returns 503 Service Unavailable with a
// Retry-After header when the token cannot be retrieved, so that MCP clients treat
// the failure as transient rather than initiating an OAuth discovery flow.
func CreateTokenInjectionMiddleware(tokenSource oauth2.TokenSource) types.MiddlewareFunction {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -24,8 +31,11 @@ func CreateTokenInjectionMiddleware(tokenSource oauth2.TokenSource) types.Middle
if err != nil {
slog.Warn("unable to retrieve OAuth token", "error", err)
// The token source (AuthenticatedTokenSource) handles marking
// the workload as unauthenticated in its Token() method
http.Error(w, "Authentication required", http.StatusUnauthorized)
// the workload as unauthenticated in its Token() method.
// Return 503 instead of 401 so MCP clients do not mistake this
// for a server that requires client-side OAuth authentication.
w.Header().Set("Retry-After", retryAfterSecs)
http.Error(w, "Token temporarily unavailable", http.StatusServiceUnavailable)
return
}

Expand Down
108 changes: 108 additions & 0 deletions pkg/transport/middleware/token_injection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package middleware

import (
"errors"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

// stubTokenSource implements oauth2.TokenSource for testing.
type stubTokenSource struct {
token *oauth2.Token
err error
}

func (s *stubTokenSource) Token() (*oauth2.Token, error) {
return s.token, s.err
}

func TestCreateTokenInjectionMiddleware(t *testing.T) {
t.Parallel()

tests := []struct {
name string
tokenSource oauth2.TokenSource
wantStatus int
wantNextCalled bool
wantAuthHeader string
wantRetryAfter string
wantBodyContain string
}{
{
name: "token source error returns 503 with Retry-After",
tokenSource: &stubTokenSource{
err: errors.New("token expired"),
},
wantStatus: http.StatusServiceUnavailable,
wantNextCalled: false,
wantRetryAfter: retryAfterSecs,
wantBodyContain: "Token temporarily unavailable",
},
{
name: "token source succeeds injects Bearer token",
tokenSource: &stubTokenSource{
token: &oauth2.Token{AccessToken: "test-access-token"},
},
wantStatus: http.StatusOK,
wantNextCalled: true,
wantAuthHeader: "Bearer test-access-token",
},
{
name: "nil token source passes request through",
tokenSource: nil,
wantStatus: http.StatusOK,
wantNextCalled: true,
wantAuthHeader: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

nextCalled := false
var capturedReq *http.Request

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
capturedReq = r
w.WriteHeader(http.StatusOK)
})

mw := CreateTokenInjectionMiddleware(tt.tokenSource)
handler := mw(next)

req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)

assert.Equal(t, tt.wantStatus, rec.Code)
assert.Equal(t, tt.wantNextCalled, nextCalled)

if tt.wantRetryAfter != "" {
assert.Equal(t, tt.wantRetryAfter, rec.Header().Get("Retry-After"))
}

if tt.wantBodyContain != "" {
assert.Contains(t, rec.Body.String(), tt.wantBodyContain)
}

if tt.wantNextCalled {
require.NotNil(t, capturedReq)
if tt.wantAuthHeader != "" {
assert.Equal(t, tt.wantAuthHeader, capturedReq.Header.Get("Authorization"))
} else {
assert.Empty(t, capturedReq.Header.Get("Authorization"))
}
}
})
}
}
Loading