Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
if reader := runner.GetUpstreamTokenReader(); reader != nil {
opts = append(opts, WithUpstreamTokenReader(reader))
}
if provider := runner.GetKeyProvider(); provider != nil {
opts = append(opts, WithKeyProvider(provider))
}

middleware, authInfoHandler, err := GetAuthenticationMiddleware(context.Background(), params.OIDCConfig, opts...)
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions pkg/auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ func TestCreateMiddleware_WithoutOIDCConfig(t *testing.T) {
// Create mock runner
mockRunner := mocks.NewMockMiddlewareRunner(ctrl)

// Expect GetUpstreamTokenReader to be called (returns nil = no auth server)
// Expect GetUpstreamTokenReader and GetKeyProvider to be called (returns nil = no auth server)
mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil)
mockRunner.EXPECT().GetKeyProvider().Return(nil)

// Expect AddMiddleware to be called with a middleware instance
mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(name string, mw types.Middleware) {
Expand Down Expand Up @@ -213,8 +214,9 @@ func TestCreateMiddleware_EmptyParameters(t *testing.T) {

mockRunner := mocks.NewMockMiddlewareRunner(ctrl)

// Expect GetUpstreamTokenReader to be called (returns nil = no auth server)
// Expect GetUpstreamTokenReader and GetKeyProvider to be called (returns nil = no auth server)
mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil)
mockRunner.EXPECT().GetKeyProvider().Return(nil)

// Expect AddMiddleware to be called
mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any())
Expand Down
123 changes: 101 additions & 22 deletions pkg/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/stacklok/toolhive-core/env"
"github.com/stacklok/toolhive/pkg/auth/oauth"
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
"github.com/stacklok/toolhive/pkg/networking"
oauthproto "github.com/stacklok/toolhive/pkg/oauth"
)
Expand Down Expand Up @@ -372,6 +373,12 @@ type TokenValidator struct {
// nil means no enrichment (no embedded auth server).
upstreamTokenReader upstreamtoken.TokenReader

// keyProvider provides in-process JWKS key lookups from the embedded auth
// server's key provider. When set, getKeyFromJWKS resolves keys locally
// before falling back to HTTP. Eliminates self-referential HTTP calls.
// nil when no embedded auth server is configured.
keyProvider keys.PublicKeyProvider

// Lazy JWKS registration
jwksRegistered bool
jwksRegistrationMu sync.Mutex
Expand Down Expand Up @@ -547,6 +554,7 @@ func registerIntrospectionProviders(config TokenValidatorConfig, clientSecret st
type tokenValidatorOptions struct {
envReader env.Reader
upstreamTokenReader upstreamtoken.TokenReader
keyProvider keys.PublicKeyProvider
}

// TokenValidatorOption is a functional option for NewTokenValidator.
Expand All @@ -570,6 +578,31 @@ func WithUpstreamTokenReader(reader upstreamtoken.TokenReader) TokenValidatorOpt
}
}

// WithKeyProvider configures the token validator to use an in-process key
// provider for JWKS lookups instead of fetching keys over HTTP. This is used
// when the embedded auth server's key provider is available in the same process,
// eliminating self-referential HTTP calls and the need for insecureAllowHTTP
// and jwksAllowPrivateIP flags.
//
// Only PublicKeyProvider is required — the validator never signs tokens.
func WithKeyProvider(provider keys.PublicKeyProvider) TokenValidatorOption {
return func(o *tokenValidatorOptions) {
o.keyProvider = provider
}
}

// resolveClientSecret returns the client secret from the config, falling back
// to the TOOLHIVE_OIDC_CLIENT_SECRET environment variable if not set.
func resolveClientSecret(configSecret string, envReader env.Reader) string {
if configSecret != "" {
return configSecret
}
if envSecret := envReader.Getenv("TOOLHIVE_OIDC_CLIENT_SECRET"); envSecret != "" {
return envSecret
}
return ""
}

// NewTokenValidator creates a new token validator.
func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ...TokenValidatorOption) (*TokenValidator, error) {
// Apply functional options
Expand Down Expand Up @@ -611,8 +644,9 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ..
slog.Debug("OIDC discovery deferred - will discover on first validation request", "issuer", config.Issuer)
}

// Ensure we have either an explicit JWKS URL or an issuer to discover from
if jwksURL == "" && config.Issuer == "" {
// Ensure we have either an explicit JWKS URL, an issuer to discover from,
// or a local key provider (embedded auth server).
if jwksURL == "" && config.Issuer == "" && o.keyProvider == nil {
return nil, ErrMissingIssuerAndJWKSURL
}

Expand All @@ -638,14 +672,8 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ..

// Skip synchronous JWKS registration - will be done lazily on first use

// Load client secret from environment variable if not provided in config
// This allows secrets to be injected via Kubernetes Secret references
clientSecret := config.ClientSecret
if clientSecret == "" {
if envSecret := o.envReader.Getenv("TOOLHIVE_OIDC_CLIENT_SECRET"); envSecret != "" {
clientSecret = envSecret
}
}
// Resolve client secret from config or environment variable
clientSecret := resolveClientSecret(config.ClientSecret, o.envReader)

// Register introspection providers
registry, err := registerIntrospectionProviders(config, clientSecret)
Expand All @@ -667,6 +695,7 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ..
registry: registry,
insecureAllowHTTP: config.InsecureAllowHTTP,
upstreamTokenReader: o.upstreamTokenReader,
keyProvider: o.keyProvider,
}

return validator, nil
Expand Down Expand Up @@ -802,8 +831,67 @@ func (v *TokenValidator) ensureOIDCDiscovered(ctx context.Context) error {
return nil
}

// getKeyFromLocalProvider attempts to find a verification key from the local
// key provider (embedded auth server). Returns (key, nil) on success,
// (nil, nil) to signal fallback to HTTP, or (nil, error) for hard failures.
// validateTokenHeader checks the signing method is supported (RSA or ECDSA) and
// extracts the key ID from the token header. Returns an error for unsupported
// methods or a missing kid claim.
func validateTokenHeader(token *jwt.Token) (string, error) {
switch token.Method.(type) {
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
// Supported signing methods
default:
return "", fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

kid, ok := token.Header["kid"].(string)
if !ok {
return "", fmt.Errorf("token header missing kid")
}
return kid, nil
}

func (v *TokenValidator) getKeyFromLocalProvider(ctx context.Context, token *jwt.Token) (interface{}, error) {
if v.keyProvider == nil {
return nil, nil
}

kid, err := validateTokenHeader(token)
if err != nil {
return nil, err
}

pubKeys, err := v.keyProvider.PublicKeys(ctx)
if err != nil {
slog.Debug("local JWKS provider failed, falling back to HTTP", "error", err)
return nil, nil
}

for _, k := range pubKeys {
if k.KeyID == kid {
slog.Debug("resolved JWKS key from embedded auth server", "kid", kid)
return k.PublicKey, nil
}
}

// Key not found locally — fall back to HTTP JWKS
slog.Debug("key not found in local JWKS provider, falling back to HTTP", "kid", kid)
return nil, nil
}

// getKeyFromJWKS gets the key from the JWKS.
func (v *TokenValidator) getKeyFromJWKS(ctx context.Context, token *jwt.Token) (interface{}, error) {
// Try local key provider first (embedded auth server in-process keys).
// This avoids self-referential HTTP calls when the auth server and
// token validator run in the same process.
if key, err := v.getKeyFromLocalProvider(ctx, token); err != nil {
return nil, err
} else if key != nil {
return key, nil
}

// Fall through to HTTP-based JWKS lookup.
// Defensive check: JWKS URL must be set before calling this function.
// This invariant is normally guaranteed by ValidateToken calling ensureOIDCDiscovered first.
if v.jwksURL == "" {
Expand All @@ -815,18 +903,9 @@ func (v *TokenValidator) getKeyFromJWKS(ctx context.Context, token *jwt.Token) (
return nil, fmt.Errorf("JWKS registration failed: %w", err)
}

// Validate the signing method
switch token.Method.(type) {
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
// Supported RSA signing methods
default:
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

// Get the key ID from the token header
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("token header missing kid")
kid, err := validateTokenHeader(token)
if err != nil {
return nil, err
}

// Get the key set from the JWKS
Expand Down
133 changes: 133 additions & 0 deletions pkg/auth/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
envmocks "github.com/stacklok/toolhive-core/env/mocks"
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
upstreamtokenmocks "github.com/stacklok/toolhive/pkg/auth/upstreamtoken/mocks"
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
keysmocks "github.com/stacklok/toolhive/pkg/authserver/server/keys/mocks"
"github.com/stacklok/toolhive/pkg/networking"
oauthproto "github.com/stacklok/toolhive/pkg/oauth"
)
Expand Down Expand Up @@ -2466,3 +2468,134 @@ func TestMiddleware_UpstreamTokenEnrichment(t *testing.T) {
require.Nil(t, captured.UpstreamTokens)
})
}

func TestWithKeyProvider(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockPublicKeyProvider(ctrl)
opt := WithKeyProvider(provider)

o := &tokenValidatorOptions{}
opt(o)

require.Equal(t, provider, o.keyProvider)
}

func TestGetKeyFromLocalProvider(t *testing.T) {
t.Parallel()

// Generate a test RSA key pair for verification
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

t.Run("returns nil when no provider configured", func(t *testing.T) {
t.Parallel()

v := &TokenValidator{} // no keyProvider
token := &jwt.Token{
Method: jwt.SigningMethodRS256,
Header: map[string]interface{}{"kid": "test-kid"},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.NoError(t, err)
require.Nil(t, key)
})

t.Run("returns key when kid matches", func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockPublicKeyProvider(ctrl)
provider.EXPECT().PublicKeys(gomock.Any()).Return([]*keys.PublicKeyData{
{KeyID: "other-kid", PublicKey: &privateKey.PublicKey},
{KeyID: "target-kid", PublicKey: &privateKey.PublicKey},
}, nil)

v := &TokenValidator{keyProvider: provider}
token := &jwt.Token{
Method: jwt.SigningMethodRS256,
Header: map[string]interface{}{"kid": "target-kid"},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.NoError(t, err)
require.NotNil(t, key)
require.Equal(t, &privateKey.PublicKey, key)
})

t.Run("falls back when kid not found", func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockPublicKeyProvider(ctrl)
provider.EXPECT().PublicKeys(gomock.Any()).Return([]*keys.PublicKeyData{
{KeyID: "other-kid", PublicKey: &privateKey.PublicKey},
}, nil)

v := &TokenValidator{keyProvider: provider}
token := &jwt.Token{
Method: jwt.SigningMethodRS256,
Header: map[string]interface{}{"kid": "missing-kid"},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.NoError(t, err)
require.Nil(t, key, "should return nil to signal HTTP fallback")
})

t.Run("falls back when provider returns error", func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockPublicKeyProvider(ctrl)
provider.EXPECT().PublicKeys(gomock.Any()).Return(nil, errors.New("key unavailable"))

v := &TokenValidator{keyProvider: provider}
token := &jwt.Token{
Method: jwt.SigningMethodRS256,
Header: map[string]interface{}{"kid": "test-kid"},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.NoError(t, err, "provider errors should trigger fallback, not hard failure")
require.Nil(t, key)
})

t.Run("rejects unsupported signing method", func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockPublicKeyProvider(ctrl)

v := &TokenValidator{keyProvider: provider}
token := &jwt.Token{
Method: jwt.SigningMethodHS256,
Header: map[string]interface{}{"alg": "HS256", "kid": "test-kid"},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.Error(t, err)
require.Contains(t, err.Error(), "unexpected signing method")
require.Nil(t, key)
})

t.Run("rejects missing kid", func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockPublicKeyProvider(ctrl)

v := &TokenValidator{keyProvider: provider}
token := &jwt.Token{
Method: jwt.SigningMethodRS256,
Header: map[string]interface{}{},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.Error(t, err)
require.Contains(t, err.Error(), "token header missing kid")
require.Nil(t, key)
})
}
Loading
Loading