Skip to content

Commit bfdd23b

Browse files
tgrunnagleMatteoManzoni
authored andcommitted
Wire in-process JWKS key resolution for vMCP embedded auth server (stacklok#4526)
When the embedded auth server is active in vMCP (VirtualMCPServer), token validation was failing silently because the OIDC middleware fetched JWKS keys over HTTP from the proxy's own endpoint — a self-referential HTTP round-trip that required operators to set `insecureAllowHTTP` and/or `jwksAllowPrivateIP` just to make token validation work. These are insecure workarounds, and the failures were difficult to diagnose. This PR extends the fix for the runner and proxy runner to vMCP. The embedded auth server's `KeyProvider` is now extracted in `runServe` and passed through to the OIDC middleware factory, where it is wired into the `TokenValidator` for in-process key resolution. HTTP JWKS fetch is retained as a fallback for key-ID misses and external OIDC providers.
1 parent 2b45920 commit bfdd23b

6 files changed

Lines changed: 237 additions & 14 deletions

File tree

cmd/vmcp/app/commands.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
2525
authserverconfig "github.com/stacklok/toolhive/pkg/authserver"
2626
authserverrunner "github.com/stacklok/toolhive/pkg/authserver/runner"
27+
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
2728
"github.com/stacklok/toolhive/pkg/container/runtime"
2829
"github.com/stacklok/toolhive/pkg/groups"
2930
"github.com/stacklok/toolhive/pkg/telemetry"
@@ -589,18 +590,20 @@ func runServe(cmd *cobra.Command, _ []string) error {
589590
}
590591
}
591592

592-
// Create an upstream token reader from the embedded auth server so that
593-
// the OIDC middleware can enrich Identity with upstream provider tokens.
594-
// This is required for the upstream_inject outgoing auth strategy.
593+
// Extract dependencies from the embedded auth server so the OIDC middleware
594+
// can (a) resolve JWKS keys in-process instead of self-referential HTTP
595+
// calls, and (b) enrich Identity with upstream provider tokens.
595596
var upstreamReader upstreamtoken.TokenReader
597+
var keyProvider keys.PublicKeyProvider
596598
if embeddedAuthServer != nil {
597599
stor := embeddedAuthServer.IDPTokenStorage()
598600
refresher := embeddedAuthServer.UpstreamTokenRefresher()
599601
upstreamReader = upstreamtoken.NewInProcessService(stor, refresher)
602+
keyProvider = embeddedAuthServer.KeyProvider()
600603
}
601604

602605
authMiddleware, authzMiddleware, authInfoHandler, err :=
603-
factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth, passThroughTools, upstreamReader)
606+
factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth, passThroughTools, upstreamReader, keyProvider)
604607
if err != nil {
605608
return fmt.Errorf("failed to create authentication middleware: %w", err)
606609
}

pkg/vmcp/auth/factory/authz_not_wired_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func TestNewIncomingAuthMiddleware_AuthzEnforced(t *testing.T) {
4747
},
4848
}
4949

50-
authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil)
50+
authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil, nil)
5151
require.NoError(t, err, "middleware creation should succeed")
5252
require.NotNil(t, authMw, "auth middleware should not be nil")
5353
require.NotNil(t, authzMw, "authz middleware should not be nil")
@@ -105,7 +105,7 @@ func TestNewIncomingAuthMiddleware_AuthzEnforced(t *testing.T) {
105105
},
106106
}
107107

108-
authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil)
108+
authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil, nil)
109109
require.NoError(t, err, "middleware creation should succeed")
110110
require.NotNil(t, authMw, "auth middleware should not be nil")
111111
require.NotNil(t, authzMw, "authz middleware should not be nil")
@@ -163,7 +163,7 @@ func TestNewIncomingAuthMiddleware_AuthzApproveAndBlock(t *testing.T) {
163163
},
164164
}
165165

166-
authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil)
166+
authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil, nil)
167167
require.NoError(t, err, "middleware creation should succeed")
168168
require.NotNil(t, authMw, "auth middleware should not be nil")
169169
require.NotNil(t, authzMw, "authz middleware should not be nil")

pkg/vmcp/auth/factory/incoming.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/stacklok/toolhive/pkg/auth"
1313
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
14+
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
1415
"github.com/stacklok/toolhive/pkg/authz"
1516
"github.com/stacklok/toolhive/pkg/authz/authorizers"
1617
"github.com/stacklok/toolhive/pkg/authz/authorizers/cedar"
@@ -51,6 +52,7 @@ func NewIncomingAuthMiddleware(
5152
cfg *config.IncomingAuthConfig,
5253
passThroughTools map[string]struct{},
5354
upstreamReader upstreamtoken.TokenReader,
55+
keyProvider keys.PublicKeyProvider,
5456
) (
5557
authMw func(http.Handler) http.Handler,
5658
authzMw func(http.Handler) http.Handler,
@@ -65,7 +67,7 @@ func NewIncomingAuthMiddleware(
6567

6668
switch cfg.Type {
6769
case "oidc":
68-
authMiddleware, authInfoHandler, err = newOIDCAuthMiddleware(ctx, cfg.OIDC, upstreamReader)
70+
authMiddleware, authInfoHandler, err = newOIDCAuthMiddleware(ctx, cfg.OIDC, upstreamReader, keyProvider)
6971
case "local":
7072
authMiddleware, authInfoHandler, err = newLocalAuthMiddleware(ctx)
7173
case "anonymous":
@@ -151,6 +153,7 @@ func newOIDCAuthMiddleware(
151153
ctx context.Context,
152154
oidcCfg *config.OIDCConfig,
153155
reader upstreamtoken.TokenReader,
156+
keyProvider keys.PublicKeyProvider,
154157
) (func(http.Handler) http.Handler, http.Handler, error) {
155158
if oidcCfg == nil {
156159
return nil, nil, fmt.Errorf("OIDC configuration required when Type='oidc'")
@@ -175,9 +178,13 @@ func newOIDCAuthMiddleware(
175178
Scopes: oidcCfg.Scopes,
176179
}
177180

178-
// Wire the upstream token reader so the JWT validator can enrich Identity
179-
// with upstream provider tokens (needed for upstream_inject auth strategy).
181+
// Wire optional dependencies from the embedded auth server so the JWT
182+
// validator can (a) resolve JWKS keys in-process instead of self-referential
183+
// HTTP calls, and (b) enrich Identity with upstream provider tokens.
180184
var opts []auth.TokenValidatorOption
185+
if keyProvider != nil {
186+
opts = append(opts, auth.WithKeyProvider(keyProvider))
187+
}
181188
if reader != nil {
182189
opts = append(opts, auth.WithUpstreamTokenReader(reader))
183190
}
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package factory
5+
6+
import (
7+
"crypto/ecdsa"
8+
"crypto/elliptic"
9+
"crypto/rand"
10+
"net/http"
11+
"net/http/httptest"
12+
"testing"
13+
"time"
14+
15+
"github.com/golang-jwt/jwt/v5"
16+
"github.com/stretchr/testify/assert"
17+
"github.com/stretchr/testify/require"
18+
"go.uber.org/mock/gomock"
19+
20+
pkgauth "github.com/stacklok/toolhive/pkg/auth"
21+
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
22+
keysmocks "github.com/stacklok/toolhive/pkg/authserver/server/keys/mocks"
23+
"github.com/stacklok/toolhive/pkg/vmcp/config"
24+
)
25+
26+
// TestNewOIDCAuthMiddleware_KeyProvider_LocalResolution verifies that when a
27+
// PublicKeyProvider is wired in, key resolution happens in-process via the
28+
// local provider rather than through an HTTP JWKS fetch.
29+
func TestNewOIDCAuthMiddleware_KeyProvider_LocalResolution(t *testing.T) {
30+
t.Parallel()
31+
32+
// Generate an ECDSA P-256 key pair (matching the embedded auth server's
33+
// default GeneratingProvider algorithm).
34+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
35+
require.NoError(t, err)
36+
37+
const ecdsaKeyID = "test-ecdsa-key-1"
38+
39+
// Stand up a minimal OIDC discovery server so issuer validation passes.
40+
// The JWKS endpoint returns an empty key set — all key resolution should
41+
// happen through the local provider, not HTTP.
42+
server, _ := newTestOIDCServer(t)
43+
t.Cleanup(server.Close)
44+
45+
issuer := server.URL
46+
47+
oidcCfg := &config.OIDCConfig{
48+
Issuer: issuer,
49+
ClientID: "test-client",
50+
Audience: "test-audience",
51+
InsecureAllowHTTP: true,
52+
JwksAllowPrivateIP: true,
53+
}
54+
55+
ctrl := gomock.NewController(t)
56+
mockProvider := keysmocks.NewMockPublicKeyProvider(ctrl)
57+
mockProvider.EXPECT().
58+
PublicKeys(gomock.Any()).
59+
Return([]*keys.PublicKeyData{{
60+
KeyID: ecdsaKeyID,
61+
Algorithm: "ES256",
62+
PublicKey: &privateKey.PublicKey,
63+
CreatedAt: time.Now(),
64+
}}, nil).
65+
AnyTimes()
66+
67+
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, mockProvider)
68+
require.NoError(t, err, "middleware creation should succeed with key provider")
69+
require.NotNil(t, authMw)
70+
71+
var capturedIdentity *pkgauth.Identity
72+
handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
73+
capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context())
74+
}))
75+
76+
// Sign a JWT with the ECDSA private key — only the local provider
77+
// holds the matching public key.
78+
tok := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
79+
"iss": issuer,
80+
"aud": "test-audience",
81+
"sub": "test-user",
82+
"exp": time.Now().Add(time.Hour).Unix(),
83+
})
84+
tok.Header["kid"] = ecdsaKeyID
85+
tokenString, err := tok.SignedString(privateKey)
86+
require.NoError(t, err)
87+
88+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
89+
req.Header.Set("Authorization", "Bearer "+tokenString)
90+
rr := httptest.NewRecorder()
91+
92+
handler.ServeHTTP(rr, req)
93+
94+
require.Equal(t, http.StatusOK, rr.Code, "request should succeed via local key provider")
95+
require.NotNil(t, capturedIdentity, "identity should be present in context")
96+
assert.Equal(t, "test-user", capturedIdentity.Subject)
97+
}
98+
99+
// TestNewOIDCAuthMiddleware_KeyProvider_HTTPFallback verifies that when the
100+
// key provider is nil, key resolution falls back to an HTTP JWKS fetch.
101+
func TestNewOIDCAuthMiddleware_KeyProvider_HTTPFallback(t *testing.T) {
102+
t.Parallel()
103+
104+
// Use the RSA key from the test OIDC server (served via HTTP JWKS).
105+
server, rsaPrivateKey := newTestOIDCServer(t)
106+
t.Cleanup(server.Close)
107+
108+
issuer := server.URL
109+
oidcCfg := &config.OIDCConfig{
110+
Issuer: issuer,
111+
ClientID: "test-client",
112+
Audience: "test-audience",
113+
InsecureAllowHTTP: true,
114+
JwksAllowPrivateIP: true,
115+
}
116+
117+
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, nil)
118+
require.NoError(t, err)
119+
require.NotNil(t, authMw)
120+
121+
var capturedIdentity *pkgauth.Identity
122+
handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
123+
capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context())
124+
}))
125+
126+
token := signJWT(t, rsaPrivateKey, jwt.MapClaims{
127+
"iss": issuer,
128+
"aud": "test-audience",
129+
"sub": "test-user",
130+
"exp": time.Now().Add(time.Hour).Unix(),
131+
})
132+
133+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
134+
req.Header.Set("Authorization", "Bearer "+token)
135+
rr := httptest.NewRecorder()
136+
137+
handler.ServeHTTP(rr, req)
138+
139+
require.Equal(t, http.StatusOK, rr.Code, "request should succeed via HTTP JWKS fallback")
140+
require.NotNil(t, capturedIdentity, "identity should be present in context")
141+
assert.Equal(t, "test-user", capturedIdentity.Subject)
142+
}
143+
144+
// TestNewOIDCAuthMiddleware_KeyProvider_KidMissFallback verifies that when the
145+
// local PublicKeyProvider does not hold a key matching the JWT's kid, the
146+
// validator falls back to HTTP JWKS and the request still succeeds. This
147+
// confirms the end-to-end wiring for the kid-miss path at the factory level.
148+
func TestNewOIDCAuthMiddleware_KeyProvider_KidMissFallback(t *testing.T) {
149+
t.Parallel()
150+
151+
// Stand up a real OIDC server that serves the RSA key via HTTP JWKS.
152+
server, rsaPrivateKey := newTestOIDCServer(t)
153+
t.Cleanup(server.Close)
154+
155+
issuer := server.URL
156+
oidcCfg := &config.OIDCConfig{
157+
Issuer: issuer,
158+
ClientID: "test-client",
159+
Audience: "test-audience",
160+
InsecureAllowHTTP: true,
161+
JwksAllowPrivateIP: true,
162+
}
163+
164+
// Wire a mock provider that returns a key with a *different* kid than the
165+
// one in the JWT. The validator should call the local provider first, get a
166+
// kid-miss (nil key returned), and then fall back to HTTP JWKS.
167+
ctrl := gomock.NewController(t)
168+
mockProvider := keysmocks.NewMockPublicKeyProvider(ctrl)
169+
170+
// Generate a throwaway ECDSA key so the mock returns a non-nil key list
171+
// with a different kid.
172+
throwawayKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
173+
require.NoError(t, err)
174+
175+
mockProvider.EXPECT().
176+
PublicKeys(gomock.Any()).
177+
Return([]*keys.PublicKeyData{{
178+
KeyID: "unrelated-key-id", // does NOT match testKeyID used by signJWT
179+
Algorithm: "ES256",
180+
PublicKey: &throwawayKey.PublicKey,
181+
CreatedAt: time.Now(),
182+
}}, nil).
183+
AnyTimes()
184+
185+
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, mockProvider)
186+
require.NoError(t, err)
187+
require.NotNil(t, authMw)
188+
189+
var capturedIdentity *pkgauth.Identity
190+
handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
191+
capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context())
192+
}))
193+
194+
// Sign the JWT with the RSA key from the test server (kid = testKeyID).
195+
// The mock provider holds a key with a different kid, so the validator must
196+
// fall back to HTTP JWKS to find the matching key.
197+
token := signJWT(t, rsaPrivateKey, jwt.MapClaims{
198+
"iss": issuer,
199+
"aud": "test-audience",
200+
"sub": "test-user",
201+
"exp": time.Now().Add(time.Hour).Unix(),
202+
})
203+
204+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
205+
req.Header.Set("Authorization", "Bearer "+token)
206+
rr := httptest.NewRecorder()
207+
208+
handler.ServeHTTP(rr, req)
209+
210+
require.Equal(t, http.StatusOK, rr.Code, "request should succeed via HTTP JWKS fallback on kid-miss")
211+
require.NotNil(t, capturedIdentity, "identity should be present in context")
212+
assert.Equal(t, "test-user", capturedIdentity.Subject)
213+
}

pkg/vmcp/auth/factory/incoming_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func TestNewIncomingAuthMiddleware(t *testing.T) {
135135
t.Run(tt.name, func(t *testing.T) {
136136
t.Parallel()
137137

138-
authMw, authzMw, authInfo, err := NewIncomingAuthMiddleware(t.Context(), tt.cfg, nil, nil)
138+
authMw, authzMw, authInfo, err := NewIncomingAuthMiddleware(t.Context(), tt.cfg, nil, nil, nil)
139139

140140
if tt.wantErr {
141141
require.Error(t, err)

pkg/vmcp/auth/factory/incoming_upstream_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func TestNewOIDCAuthMiddleware_UpstreamTokenReaderWiring(t *testing.T) {
113113
GetAllValidTokens(gomock.Any(), "session-abc").
114114
Return(map[string]string{"google": "gcp-access-token"}, nil)
115115

116-
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader)
116+
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader, nil)
117117
require.NoError(t, err, "middleware creation should succeed with non-nil reader")
118118
require.NotNil(t, authMw)
119119

@@ -145,7 +145,7 @@ func TestNewOIDCAuthMiddleware_UpstreamTokenReaderWiring(t *testing.T) {
145145
t.Run("upstream tokens nil when reader is nil", func(t *testing.T) {
146146
t.Parallel()
147147

148-
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil)
148+
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, nil)
149149
require.NoError(t, err)
150150
require.NotNil(t, authMw)
151151

@@ -181,7 +181,7 @@ func TestNewOIDCAuthMiddleware_UpstreamTokenReaderWiring(t *testing.T) {
181181
reader := upstreamtokenmocks.NewMockTokenReader(ctrl)
182182
// No EXPECT -- reader should not be called when tsid is absent.
183183

184-
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader)
184+
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader, nil)
185185
require.NoError(t, err)
186186
require.NotNil(t, authMw)
187187

0 commit comments

Comments
 (0)