From b3652de1540635fe55061a0a4faae02de47c2bf3 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Mon, 30 Mar 2026 18:26:21 +0300 Subject: [PATCH] Refactor OAuth token persistence and fix Resource/Audience conflation Extract TokenPersistenceManager to pkg/auth/remote to eliminate the repeated nil-check + fetch-cached-token + create-token-source pattern shared by three callers. Generalize RegistryOAuthConfig to OAuthConfig in pkg/config, adding a Resource field (RFC 8707) and an injectable configUpdater callback so callers can supply their own persistence logic. Fix a bug where Audience (provider-specific, e.g. Auth0) was passed where Resource (RFC 8707 resource indicator) was expected: Resource now flows to CreateOAuthConfigFromOIDC and Audience is routed into OAuthParams["audience"] for authorization URL parameters. Add field-level doc comments to OAuthConfig clarifying the distinction between Audience and Resource. Fix %w error wrapping in tryRestoreFromCache and tryRestoreFromCachedTokens. Convert configUpdaterFunc from a type alias to a named type. Add unit tests covering: FetchRefreshToken direct paths, the Resource-vs-Audience split in buildOAuthFlowConfig (regression guard), configUpdater callback invocation, endpoint-override logic, wrapWithPersistence persistence callbacks, and resolveClientCredentials priority logic. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- pkg/auth/remote/doc.go | 22 +- pkg/auth/remote/handler.go | 17 +- pkg/auth/remote/handler_test.go | 320 ++++++++++++++++++ pkg/auth/remote/token_persistence_manager.go | 74 ++++ .../remote/token_persistence_manager_test.go | 239 +++++++++++++ pkg/config/config.go | 25 +- pkg/registry/auth/auth.go | 32 +- pkg/registry/auth/auth_test.go | 300 +++++++++++++--- pkg/registry/auth/login.go | 4 +- pkg/registry/auth/login_test.go | 4 +- pkg/registry/auth/oauth_token_source.go | 60 ++-- pkg/registry/auth_manager_test.go | 16 +- pkg/registry/factory.go | 5 +- 13 files changed, 1006 insertions(+), 112 deletions(-) create mode 100644 pkg/auth/remote/token_persistence_manager.go create mode 100644 pkg/auth/remote/token_persistence_manager_test.go diff --git a/pkg/auth/remote/doc.go b/pkg/auth/remote/doc.go index b543d71865..91aa14f61b 100644 --- a/pkg/auth/remote/doc.go +++ b/pkg/auth/remote/doc.go @@ -1,18 +1,28 @@ // SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -// Package remote provides authentication handling for remote MCP servers. +// Package remote provides authentication handling for remote MCP servers, +// as well as general-purpose OAuth token source utilities used across the codebase. // -// This package implements OAuth/OIDC-based authentication with automatic -// discovery support for remote MCP servers. It handles: +// # Remote MCP server authentication +// +// Handler.Authenticate() is the main entry point: it takes a remote URL +// and performs all necessary discovery and authentication steps, including: // - OAuth issuer discovery (RFC 8414) // - Protected resource metadata (RFC 9728) // - OAuth flow execution (PKCE-based) // - Token source creation for HTTP transports // -// The main entry point is Handler.Authenticate() which takes a remote URL -// and performs all necessary discovery and authentication steps. -// // Configuration is defined in pkg/runner.RemoteAuthConfig as part of the // runner's RunConfig structure. +// +// # General-purpose token source utilities +// +// These types and functions are also used outside of remote MCP auth (e.g. registry auth): +// - PersistingTokenSource / NewPersistingTokenSource — wraps an oauth2.TokenSource +// and invokes a TokenPersister callback whenever tokens are refreshed. +// - CreateTokenSourceFromCached — restores an oauth2.TokenSource from a cached +// refresh token without requiring a new interactive flow. +// - TokenPersistenceManager / NewTokenPersistenceManager — retrieves a cached +// refresh token from a secrets provider and creates a token source from it. package remote diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index a798698831..41d6dda7d3 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -237,14 +237,9 @@ func (h *Handler) tryRestoreFromCachedTokens( scopes []string, authServerInfo *discovery.AuthServerInfo, ) (oauth2.TokenSource, error) { - // Resolve the refresh token from the secret manager - if h.secretProvider == nil { - return nil, fmt.Errorf("secret provider not configured, cannot restore cached tokens") - } - - refreshToken, err := h.secretProvider.GetSecret(ctx, h.config.CachedRefreshTokenRef) + mgr, err := NewTokenPersistenceManager(h.secretProvider) if err != nil { - return nil, fmt.Errorf("failed to retrieve cached refresh token: %w", err) + return nil, fmt.Errorf("secret provider not configured, cannot restore cached tokens: %w", err) } // Resolve client credentials - prefer cached DCR credentials over config @@ -284,12 +279,16 @@ func (h *Handler) tryRestoreFromCachedTokens( // Create token source from cached refresh token. // Passes resource for RFC 8707 compliance when configured. - baseSource := CreateTokenSourceFromCached( + baseSource, err := mgr.RestoreFromCache( + ctx, + h.config.CachedRefreshTokenRef, oauth2Config, - refreshToken, h.config.CachedTokenExpiry, h.config.Resource, ) + if err != nil { + return nil, err + } // Try to get a token to verify the cached tokens are valid // This will trigger a refresh since we don't have an access token diff --git a/pkg/auth/remote/handler_test.go b/pkg/auth/remote/handler_test.go index d68b465ed7..42b8c1e8b9 100644 --- a/pkg/auth/remote/handler_test.go +++ b/pkg/auth/remote/handler_test.go @@ -6,6 +6,7 @@ package remote import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" @@ -14,8 +15,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/oauth2" "github.com/stacklok/toolhive/pkg/auth/discovery" + "github.com/stacklok/toolhive/pkg/secrets/mocks" ) const ( @@ -853,3 +857,319 @@ func TestAuthenticate_BearerTokenDiscovery(t *testing.T) { assert.Equal(t, "Bearer", token.TokenType) }) } + +// stubTokenSource is a minimal oauth2.TokenSource used in wrapWithPersistence tests. +type stubTokenSource struct{} + +func (*stubTokenSource) Token() (*oauth2.Token, error) { return &oauth2.Token{}, nil } + +func TestBuildOAuthFlowConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *Config + scopes []string + authServerInfo *discovery.AuthServerInfo + wantConfig *discovery.OAuthFlowConfig + }{ + { + name: "nil authServerInfo — config fields copied as-is", + config: &Config{ + ClientID: "client-id", + ClientSecret: "client-secret", + AuthorizeURL: "https://auth.example.com/authorize", + TokenURL: "https://auth.example.com/token", + CallbackPort: 8080, + SkipBrowser: true, + Resource: "https://api.example.com", + OAuthParams: map[string]string{"audience": "myapi"}, + }, + scopes: []string{"openid", "profile"}, + authServerInfo: nil, + wantConfig: &discovery.OAuthFlowConfig{ + ClientID: "client-id", + ClientSecret: "client-secret", + AuthorizeURL: "https://auth.example.com/authorize", + TokenURL: "https://auth.example.com/token", + Scopes: []string{"openid", "profile"}, + CallbackPort: 8080, + SkipBrowser: true, + Resource: "https://api.example.com", + OAuthParams: map[string]string{"audience": "myapi"}, + }, + }, + { + name: "authServerInfo used when config URLs are empty", + config: &Config{ + ClientID: "client-id", + }, + scopes: []string{"openid"}, + authServerInfo: &discovery.AuthServerInfo{ + AuthorizationURL: "https://discovered.example.com/authorize", + TokenURL: "https://discovered.example.com/token", + RegistrationEndpoint: "https://discovered.example.com/register", + }, + wantConfig: &discovery.OAuthFlowConfig{ + ClientID: "client-id", + AuthorizeURL: "https://discovered.example.com/authorize", + TokenURL: "https://discovered.example.com/token", + RegistrationEndpoint: "https://discovered.example.com/register", + Scopes: []string{"openid"}, + }, + }, + { + name: "config AuthorizeURL preserved when set", + config: &Config{ + AuthorizeURL: "https://static.example.com/authorize", + }, + scopes: nil, + authServerInfo: &discovery.AuthServerInfo{ + AuthorizationURL: "https://discovered.example.com/authorize", + TokenURL: "https://discovered.example.com/token", + }, + wantConfig: &discovery.OAuthFlowConfig{ + // AuthorizeURL set → authServerInfo is NOT used (TokenURL also not overwritten) + AuthorizeURL: "https://static.example.com/authorize", + TokenURL: "", + }, + }, + { + name: "config TokenURL preserved when set", + config: &Config{ + TokenURL: "https://static.example.com/token", + }, + scopes: nil, + authServerInfo: &discovery.AuthServerInfo{ + AuthorizationURL: "https://discovered.example.com/authorize", + TokenURL: "https://discovered.example.com/token", + }, + wantConfig: &discovery.OAuthFlowConfig{ + // TokenURL set → authServerInfo is NOT used (AuthorizeURL also not overwritten) + AuthorizeURL: "", + TokenURL: "https://static.example.com/token", + }, + }, + { + name: "Resource and OAuthParams passed through unchanged", + config: &Config{ + Resource: "https://api.example.com/resource", + OAuthParams: map[string]string{"key": "value"}, + }, + scopes: nil, + authServerInfo: nil, + wantConfig: &discovery.OAuthFlowConfig{ + Resource: "https://api.example.com/resource", + OAuthParams: map[string]string{"key": "value"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + handler := &Handler{config: tt.config} + got := handler.buildOAuthFlowConfig(tt.scopes, tt.authServerInfo) + + assert.Equal(t, tt.wantConfig.ClientID, got.ClientID, "ClientID") + assert.Equal(t, tt.wantConfig.ClientSecret, got.ClientSecret, "ClientSecret") + assert.Equal(t, tt.wantConfig.AuthorizeURL, got.AuthorizeURL, "AuthorizeURL") + assert.Equal(t, tt.wantConfig.TokenURL, got.TokenURL, "TokenURL") + assert.Equal(t, tt.wantConfig.RegistrationEndpoint, got.RegistrationEndpoint, "RegistrationEndpoint") + assert.Equal(t, tt.wantConfig.Scopes, got.Scopes, "Scopes") + assert.Equal(t, tt.wantConfig.Resource, got.Resource, "Resource") + assert.Equal(t, tt.wantConfig.OAuthParams, got.OAuthParams, "OAuthParams") + assert.Equal(t, tt.wantConfig.CallbackPort, got.CallbackPort, "CallbackPort") + assert.Equal(t, tt.wantConfig.SkipBrowser, got.SkipBrowser, "SkipBrowser") + }) + } +} + +func TestWrapWithPersistence(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tokenPersister TokenPersister + clientCredentialsPersister ClientCredentialsPersister + result *discovery.OAuthFlowResult + wantPersistingSource bool // true if returned source should be a *PersistingTokenSource + }{ + { + name: "nil persisters — returns original token source unwrapped", + tokenPersister: nil, + result: &discovery.OAuthFlowResult{TokenSource: &stubTokenSource{}, RefreshToken: "rt"}, + wantPersistingSource: false, + }, + { + name: "token persister called when refresh token present", + tokenPersister: func(_ string, _ time.Time) error { + return nil + }, + result: &discovery.OAuthFlowResult{TokenSource: &stubTokenSource{}, RefreshToken: "rt"}, + wantPersistingSource: true, + }, + { + name: "token persister NOT called when refresh token empty", + tokenPersister: func(_ string, _ time.Time) error { + // This should NOT be called; if it is, returning an error makes the test meaningful + return errors.New("persister should not have been called") + }, + result: &discovery.OAuthFlowResult{ + TokenSource: &stubTokenSource{}, + RefreshToken: "", // empty — persister must not be invoked + }, + // tokenPersister is set so source is still wrapped + wantPersistingSource: true, + }, + { + name: "token persister error is non-fatal", + tokenPersister: func(_ string, _ time.Time) error { + return errors.New("persist failed") + }, + result: &discovery.OAuthFlowResult{TokenSource: &stubTokenSource{}, RefreshToken: "rt"}, + wantPersistingSource: true, + }, + { + name: "client credentials persister called when clientID present", + clientCredentialsPersister: func(clientID, clientSecret string) error { + assert.Equal(t, "my-client-id", clientID) + assert.Equal(t, "my-client-secret", clientSecret) + return nil + }, + result: &discovery.OAuthFlowResult{ + TokenSource: &stubTokenSource{}, + ClientID: "my-client-id", + ClientSecret: "my-client-secret", + }, + wantPersistingSource: false, // no tokenPersister set + }, + { + name: "client credentials persister NOT called when clientID empty", + clientCredentialsPersister: func(_, _ string) error { + return errors.New("persister should not have been called") + }, + result: &discovery.OAuthFlowResult{ + TokenSource: &stubTokenSource{}, + ClientID: "", // empty — persister must not be invoked + }, + wantPersistingSource: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + handler := &Handler{ + config: &Config{}, + tokenPersister: tt.tokenPersister, + clientCredentialsPersister: tt.clientCredentialsPersister, + } + + got := handler.wrapWithPersistence(tt.result) + + require.NotNil(t, got) + if tt.wantPersistingSource { + _, ok := got.(*PersistingTokenSource) + assert.True(t, ok, "expected *PersistingTokenSource, got %T", got) + } else { + assert.Equal(t, tt.result.TokenSource, got) + } + }) + } +} + +func TestResolveClientCredentials(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *Config + setupMock func(provider *mocks.MockProvider) + wantClientID string + wantClientSecret string + }{ + { + name: "no cached credentials — static config used", + config: &Config{ + ClientID: "static-id", + ClientSecret: "static-secret", + }, + setupMock: nil, + wantClientID: "static-id", + wantClientSecret: "static-secret", + }, + { + name: "cached client ID overrides static", + config: &Config{ + ClientID: "static-id", + ClientSecret: "static-secret", + CachedClientID: "cached-id", + // CachedClientSecretRef empty → no secret fetch; static secret kept + }, + setupMock: nil, + wantClientID: "cached-id", + wantClientSecret: "static-secret", // static secret preserved when no ref to override it + }, + { + name: "cached client ID with secret ref — secret fetched", + config: &Config{ + CachedClientID: "cached-id", + CachedClientSecretRef: "secret-ref", + }, + setupMock: func(provider *mocks.MockProvider) { + provider.EXPECT(). + GetSecret(gomock.Any(), "secret-ref"). + Return("cached-secret", nil) + }, + wantClientID: "cached-id", + wantClientSecret: "cached-secret", + }, + { + name: "cached secret ref with provider error — falls back to empty secret", + config: &Config{ + CachedClientID: "cached-id", + CachedClientSecretRef: "secret-ref", + }, + setupMock: func(provider *mocks.MockProvider) { + provider.EXPECT(). + GetSecret(gomock.Any(), "secret-ref"). + Return("", errors.New("storage error")) + }, + wantClientID: "cached-id", + wantClientSecret: "", + }, + { + name: "nil secret provider — empty secret used even if ref set", + config: &Config{ + CachedClientID: "cached-id", + CachedClientSecretRef: "secret-ref", + }, + setupMock: nil, // secretProvider stays nil + wantClientID: "cached-id", + wantClientSecret: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + handler := &Handler{config: tt.config} + + if tt.setupMock != nil { + ctrl := gomock.NewController(t) + mockProvider := mocks.NewMockProvider(ctrl) + tt.setupMock(mockProvider) + handler.secretProvider = mockProvider + } + + gotID, gotSecret := handler.resolveClientCredentials(context.Background()) + + assert.Equal(t, tt.wantClientID, gotID, "clientID") + assert.Equal(t, tt.wantClientSecret, gotSecret, "clientSecret") + }) + } +} diff --git a/pkg/auth/remote/token_persistence_manager.go b/pkg/auth/remote/token_persistence_manager.go new file mode 100644 index 0000000000..81775c1be2 --- /dev/null +++ b/pkg/auth/remote/token_persistence_manager.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package remote + +import ( + "context" + "fmt" + "time" + + "golang.org/x/oauth2" + + "github.com/stacklok/toolhive/pkg/secrets" +) + +// TokenPersistenceManager retrieves cached refresh tokens from a secrets provider +// and creates an oauth2.TokenSource from them. +// The oauth2.Config is intentionally not held here — callers build it themselves +// because endpoint discovery and credential resolution differ between use sites. +type TokenPersistenceManager struct { + secretsProvider secrets.Provider +} + +// NewTokenPersistenceManager creates a TokenPersistenceManager backed by the +// given secrets provider. Returns an error if provider is nil. +func NewTokenPersistenceManager(provider secrets.Provider) (*TokenPersistenceManager, error) { + if provider == nil { + return nil, fmt.Errorf("secrets provider is required") + } + return &TokenPersistenceManager{secretsProvider: provider}, nil +} + +// FetchRefreshToken retrieves the refresh token stored under tokenKey from the +// secrets provider. Returns an error if the provider returns an error or the +// token is empty (not yet cached). +// +// Use this when you need to verify a cached token exists before performing +// expensive operations (e.g. network-based OIDC discovery), then create the +// token source separately via CreateTokenSourceFromCached. +func (m *TokenPersistenceManager) FetchRefreshToken(ctx context.Context, tokenKey string) (string, error) { + token, err := m.secretsProvider.GetSecret(ctx, tokenKey) + if err != nil { + return "", fmt.Errorf("failed to retrieve cached refresh token: %w", err) + } + if token == "" { + return "", fmt.Errorf("no cached refresh token found") + } + return token, nil +} + +// RestoreFromCache retrieves the refresh token stored under tokenKey from the +// secrets provider and creates a token source using the supplied oauth2.Config, +// expiry, and resource indicator. +// +// Use this when the oauth2.Config is already built before calling (e.g. config +// comes from static values or already-completed discovery). If building the +// config requires expensive operations like OIDC discovery, use FetchRefreshToken +// first to confirm a cached token exists before incurring that cost. +// +// It does NOT call Token() to verify, and does NOT wrap with a TokenPersister. +// Those are caller responsibilities. +func (m *TokenPersistenceManager) RestoreFromCache( + ctx context.Context, + tokenKey string, + oauth2Cfg *oauth2.Config, + expiry time.Time, + resource string, +) (oauth2.TokenSource, error) { + refreshToken, err := m.FetchRefreshToken(ctx, tokenKey) + if err != nil { + return nil, err + } + return CreateTokenSourceFromCached(oauth2Cfg, refreshToken, expiry, resource), nil +} diff --git a/pkg/auth/remote/token_persistence_manager_test.go b/pkg/auth/remote/token_persistence_manager_test.go new file mode 100644 index 0000000000..b3057c16c6 --- /dev/null +++ b/pkg/auth/remote/token_persistence_manager_test.go @@ -0,0 +1,239 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package remote + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/oauth2" + + "github.com/stacklok/toolhive/pkg/secrets/mocks" +) + +func TestNewTokenPersistenceManager(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider func(ctrl *gomock.Controller) *mocks.MockProvider + wantErr bool + wantManager bool + }{ + { + name: "nil provider returns error and nil manager", + provider: nil, + wantErr: true, + wantManager: false, + }, + { + name: "non-nil provider returns non-nil manager and nil error", + provider: func(ctrl *gomock.Controller) *mocks.MockProvider { + return mocks.NewMockProvider(ctrl) + }, + wantErr: false, + wantManager: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + var manager *TokenPersistenceManager + var err error + if tt.provider == nil { + manager, err = NewTokenPersistenceManager(nil) + } else { + manager, err = NewTokenPersistenceManager(tt.provider(ctrl)) + } + + if tt.wantErr { + require.Error(t, err) + assert.Nil(t, manager) + } else { + require.NoError(t, err) + assert.NotNil(t, manager) + } + }) + } +} + +func TestTokenPersistenceManager_RestoreFromCache(t *testing.T) { + t.Parallel() + + minimalOAuth2Config := &oauth2.Config{ + ClientID: "test-client", + Endpoint: oauth2.Endpoint{ + TokenURL: "https://example.com/token", + }, + } + + tests := []struct { + name string + tokenKey string + resource string + setupMock func(provider *mocks.MockProvider) + wantErr bool + wantErrContain string + wantSource bool + }{ + { + name: "GetSecret returns error", + tokenKey: "my-token-key", + setupMock: func(provider *mocks.MockProvider) { + provider.EXPECT(). + GetSecret(gomock.Any(), "my-token-key"). + Return("", errors.New("storage unavailable")) + }, + wantErr: true, + wantErrContain: "failed to retrieve cached refresh token", + }, + { + name: "GetSecret returns empty string", + tokenKey: "my-token-key", + setupMock: func(provider *mocks.MockProvider) { + provider.EXPECT(). + GetSecret(gomock.Any(), "my-token-key"). + Return("", nil) + }, + wantErr: true, + wantErrContain: "no cached refresh token found", + }, + { + name: "valid refresh token with no resource returns non-nil TokenSource", + tokenKey: "my-token-key", + resource: "", + setupMock: func(provider *mocks.MockProvider) { + provider.EXPECT(). + GetSecret(gomock.Any(), "my-token-key"). + Return("valid-refresh-token", nil) + }, + wantErr: false, + wantSource: true, + }, + { + name: "valid refresh token with non-empty resource returns non-nil TokenSource", + tokenKey: "my-token-key", + resource: "https://api.example.com/resource", + setupMock: func(provider *mocks.MockProvider) { + provider.EXPECT(). + GetSecret(gomock.Any(), "my-token-key"). + Return("valid-refresh-token", nil) + }, + wantErr: false, + wantSource: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockProvider := mocks.NewMockProvider(ctrl) + tt.setupMock(mockProvider) + + manager, err := NewTokenPersistenceManager(mockProvider) + require.NoError(t, err) + require.NotNil(t, manager) + + source, err := manager.RestoreFromCache( + context.Background(), + tt.tokenKey, + minimalOAuth2Config, + time.Now().Add(time.Hour), + tt.resource, + ) + + if tt.wantErr { + require.Error(t, err) + if tt.wantErrContain != "" { + assert.Contains(t, err.Error(), tt.wantErrContain) + } + assert.Nil(t, source) + } else { + require.NoError(t, err) + assert.NotNil(t, source) + } + }) + } +} + +func TestTokenPersistenceManager_FetchRefreshToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupMock func(provider *mocks.MockProvider) + wantErr bool + wantErrContain string + wantToken string + }{ + { + name: "GetSecret returns an error", + setupMock: func(provider *mocks.MockProvider) { + provider.EXPECT(). + GetSecret(gomock.Any(), "my-token-key"). + Return("", errors.New("storage unavailable")) + }, + wantErr: true, + wantErrContain: "failed to retrieve cached refresh token", + }, + { + name: "GetSecret returns empty string", + setupMock: func(provider *mocks.MockProvider) { + provider.EXPECT(). + GetSecret(gomock.Any(), "my-token-key"). + Return("", nil) + }, + wantErr: true, + wantErrContain: "no cached refresh token found", + }, + { + name: "GetSecret returns a non-empty token", + setupMock: func(provider *mocks.MockProvider) { + provider.EXPECT(). + GetSecret(gomock.Any(), "my-token-key"). + Return("valid-refresh-token", nil) + }, + wantErr: false, + wantToken: "valid-refresh-token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockProvider := mocks.NewMockProvider(ctrl) + tt.setupMock(mockProvider) + + manager, err := NewTokenPersistenceManager(mockProvider) + require.NoError(t, err) + require.NotNil(t, manager) + + token, err := manager.FetchRefreshToken(context.Background(), "my-token-key") + + if tt.wantErr { + require.Error(t, err) + if tt.wantErrContain != "" { + assert.Contains(t, err.Error(), tt.wantErrContain) + } + assert.Empty(t, token) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantToken, token) + } + }) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 561cb3c944..6bb62b62c7 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -57,17 +57,26 @@ type RegistryAuth struct { Type string `yaml:"type,omitempty"` // OAuth holds OAuth/OIDC authentication configuration. - OAuth *RegistryOAuthConfig `yaml:"oauth,omitempty"` + OAuth *OAuthConfig `yaml:"oauth,omitempty"` } -// RegistryOAuthConfig holds OAuth/OIDC configuration for registry authentication. +// OAuthConfig holds OAuth/OIDC configuration for browser-based authentication flows. // PKCE (S256) is always enforced per OAuth 2.1 requirements for public clients. -type RegistryOAuthConfig struct { - Issuer string `yaml:"issuer"` - ClientID string `yaml:"client_id"` - Scopes []string `yaml:"scopes,omitempty"` - Audience string `yaml:"audience,omitempty"` - CallbackPort int `yaml:"callback_port,omitempty"` +// Used for both registry auth and other service auth (e.g. enterprise config server). +type OAuthConfig struct { + Issuer string `yaml:"issuer"` + ClientID string `yaml:"client_id"` + Scopes []string `yaml:"scopes,omitempty"` + // Audience is a provider-specific request parameter (e.g. Auth0) sent as an + // extra authorization URL parameter. It is distinct from the RFC 8707 Resource + // field — use Audience for providers that require an "audience" URL param, and + // Resource for standard RFC 8707 resource indicators. + Audience string `yaml:"audience,omitempty"` + // Resource is the RFC 8707 resource indicator sent to the token endpoint. + // It is distinct from Audience — use this for standard OAuth 2.0 resource + // indicators, not for provider-specific audience parameters. + Resource string `yaml:"resource,omitempty"` + CallbackPort int `yaml:"callback_port,omitempty"` // Cached token references for session restoration across CLI invocations. CachedRefreshTokenRef string `yaml:"cached_refresh_token_ref,omitempty"` diff --git a/pkg/registry/auth/auth.go b/pkg/registry/auth/auth.go index ff0d8c0c24..d87d21fcf6 100644 --- a/pkg/registry/auth/auth.go +++ b/pkg/registry/auth/auth.go @@ -9,6 +9,8 @@ import ( "crypto/sha256" "encoding/hex" "errors" + "log/slog" + "time" "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/secrets" @@ -25,16 +27,19 @@ type TokenSource interface { Token(ctx context.Context) (string, error) } -// NewTokenSource creates a TokenSource from registry OAuth configuration. +// NewTokenSource creates a TokenSource from OAuth configuration. // Returns nil, nil if oauth config is nil (no auth required). -// The registryURL is used to derive a unique secret key for token storage. +// The serviceURL is used to derive a unique secret key for token storage. // The secrets provider may be nil if secret storage is not available. // The interactive flag controls whether browser-based OAuth flows are allowed. +// configUpdater is called whenever a token ref or expiry needs to be persisted +// back to the caller's config store; pass nil to skip config persistence. func NewTokenSource( - cfg *config.RegistryOAuthConfig, - registryURL string, + cfg *config.OAuthConfig, + serviceURL string, secretsProvider secrets.Provider, interactive bool, + configUpdater func(tokenRef string, expiry time.Time), ) (TokenSource, error) { if cfg == nil { return nil, nil @@ -42,12 +47,29 @@ func NewTokenSource( return &oauthTokenSource{ oauthCfg: cfg, - registryURL: registryURL, + serviceURL: serviceURL, secretsProvider: secretsProvider, interactive: interactive, + configUpdater: configUpdater, }, nil } +// RegistryConfigUpdater returns a configUpdater callback that persists OAuth +// token references back to the toolhive on-disk config under RegistryAuth.OAuth. +// Pass this to NewTokenSource when using it for registry authentication. +func RegistryConfigUpdater() func(tokenRef string, expiry time.Time) { + return func(tokenRef string, expiry time.Time) { + if err := config.UpdateConfig(func(cfg *config.Config) { + if cfg.RegistryAuth.OAuth != nil { + cfg.RegistryAuth.OAuth.CachedRefreshTokenRef = tokenRef + cfg.RegistryAuth.OAuth.CachedTokenExpiry = expiry + } + }); err != nil { + slog.Warn("Failed to update config with token reference", "error", err) + } + } +} + // DeriveSecretKey computes the secret key for storing a registry's refresh token. // The key follows the formula: REGISTRY_OAUTH_<8 hex chars> // where the hex is derived from sha256(registryURL + "\x00" + issuer)[:4]. diff --git a/pkg/registry/auth/auth_test.go b/pkg/registry/auth/auth_test.go index eee3833f5d..b3f68a4824 100644 --- a/pkg/registry/auth/auth_test.go +++ b/pkg/registry/auth/auth_test.go @@ -28,34 +28,34 @@ func TestDeriveSecretKey(t *testing.T) { t.Parallel() tests := []struct { - name string - registryURL string - issuer string + name string + serviceURL string + issuer string }{ { - name: "typical registry and issuer", - registryURL: "https://registry.example.com", - issuer: "https://auth.example.com", + name: "typical registry and issuer", + serviceURL: "https://registry.example.com", + issuer: "https://auth.example.com", }, { - name: "empty strings", - registryURL: "", - issuer: "", + name: "empty strings", + serviceURL: "", + issuer: "", }, { - name: "empty issuer", - registryURL: "https://registry.example.com", - issuer: "", + name: "empty issuer", + serviceURL: "https://registry.example.com", + issuer: "", }, { - name: "empty registry URL", - registryURL: "", - issuer: "https://auth.example.com", + name: "empty registry URL", + serviceURL: "", + issuer: "https://auth.example.com", }, { - name: "localhost registry", - registryURL: "http://localhost:5000", - issuer: "http://localhost:8080", + name: "localhost registry", + serviceURL: "http://localhost:5000", + issuer: "http://localhost:8080", }, } @@ -63,7 +63,7 @@ func TestDeriveSecretKey(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - key := DeriveSecretKey(tt.registryURL, tt.issuer) + key := DeriveSecretKey(tt.serviceURL, tt.issuer) // Must start with the correct prefix require.True(t, len(key) > len("REGISTRY_OAUTH_"), "key too short") @@ -82,7 +82,7 @@ func TestDeriveSecretKey(t *testing.T) { } // Verify the derivation formula: sha256(registryURL + "\x00" + issuer)[:4] - h := sha256.Sum256([]byte(tt.registryURL + "\x00" + tt.issuer)) + h := sha256.Sum256([]byte(tt.serviceURL + "\x00" + tt.issuer)) expected := "REGISTRY_OAUTH_" + hex.EncodeToString(h[:4]) require.Equal(t, expected, key) }) @@ -105,8 +105,8 @@ func TestDeriveSecretKey_UniquePerInputCombination(t *testing.T) { t.Parallel() combinations := []struct { - registryURL string - issuer string + serviceURL string + issuer string }{ {"https://registry-a.example.com", "https://auth.example.com"}, {"https://registry-b.example.com", "https://auth.example.com"}, @@ -116,11 +116,11 @@ func TestDeriveSecretKey_UniquePerInputCombination(t *testing.T) { keys := make(map[string]struct{}, len(combinations)) for _, combo := range combinations { - key := DeriveSecretKey(combo.registryURL, combo.issuer) + key := DeriveSecretKey(combo.serviceURL, combo.issuer) _, alreadySeen := keys[key] require.False(t, alreadySeen, "DeriveSecretKey produced a duplicate key for registryURL=%q issuer=%q: %q", - combo.registryURL, combo.issuer, key, + combo.serviceURL, combo.issuer, key, ) keys[key] = struct{}{} } @@ -144,7 +144,7 @@ func TestNewTokenSource(t *testing.T) { tests := []struct { name string - cfg *config.RegistryOAuthConfig + cfg *config.OAuthConfig wantNil bool wantErrNil bool }{ @@ -156,7 +156,7 @@ func TestNewTokenSource(t *testing.T) { }, { name: "non-nil config returns non-nil source", - cfg: &config.RegistryOAuthConfig{ + cfg: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "my-client-id", }, @@ -165,7 +165,7 @@ func TestNewTokenSource(t *testing.T) { }, { name: "config with scopes and audience returns non-nil source", - cfg: &config.RegistryOAuthConfig{ + cfg: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "my-client-id", Scopes: []string{"openid", "profile"}, @@ -174,13 +174,23 @@ func TestNewTokenSource(t *testing.T) { wantNil: false, wantErrNil: true, }, + { + name: "config with Resource field returns non-nil source", + cfg: &config.OAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client-id", + Resource: "https://api.example.com/resource", + }, + wantNil: false, + wantErrNil: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - src, err := NewTokenSource(tt.cfg, "https://registry.example.com", nil, false) + src, err := NewTokenSource(tt.cfg, "https://registry.example.com", nil, false, nil) if tt.wantErrNil { require.NoError(t, err) @@ -231,11 +241,11 @@ func TestOAuthTokenSource_Token_NonInteractiveNoCache(t *testing.T) { } src := &oauthTokenSource{ - oauthCfg: &config.RegistryOAuthConfig{ + oauthCfg: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "test-client", }, - registryURL: "https://registry.example.com", + serviceURL: "https://registry.example.com", secretsProvider: provider, interactive: false, } @@ -277,12 +287,12 @@ func TestOAuthTokenSource_RefreshTokenKey(t *testing.T) { t.Parallel() src := &oauthTokenSource{ - oauthCfg: &config.RegistryOAuthConfig{ + oauthCfg: &config.OAuthConfig{ Issuer: issuer, ClientID: "test-client", CachedRefreshTokenRef: tt.cachedRefreshTokenRef, }, - registryURL: registryURL, + serviceURL: registryURL, } got := src.refreshTokenKey() @@ -346,11 +356,11 @@ func TestOAuthTokenSource_Token_InMemoryCacheHit(t *testing.T) { } src := &oauthTokenSource{ - oauthCfg: &config.RegistryOAuthConfig{ + oauthCfg: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "test-client", }, - registryURL: "https://registry.example.com", + serviceURL: "https://registry.example.com", secretsProvider: nil, // should never be called interactive: false, tokenSource: &mockOAuth2TokenSource{token: validToken}, @@ -375,11 +385,11 @@ func TestOAuthTokenSource_Token_InMemoryCacheExpiredFallsThrough(t *testing.T) { } src := &oauthTokenSource{ - oauthCfg: &config.RegistryOAuthConfig{ + oauthCfg: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "test-client", }, - registryURL: "https://registry.example.com", + serviceURL: "https://registry.example.com", secretsProvider: nil, interactive: false, tokenSource: &mockOAuth2TokenSource{token: expiredToken}, @@ -400,11 +410,11 @@ func TestOAuthTokenSource_Token_InMemoryCacheErrorFallsThrough(t *testing.T) { t.Parallel() src := &oauthTokenSource{ - oauthCfg: &config.RegistryOAuthConfig{ + oauthCfg: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "test-client", }, - registryURL: "https://registry.example.com", + serviceURL: "https://registry.example.com", secretsProvider: nil, interactive: false, tokenSource: &mockOAuth2TokenSource{err: errors.New("token refresh failed")}, @@ -424,11 +434,11 @@ func TestOAuthTokenSource_TryRestoreFromCache_NilProvider(t *testing.T) { t.Parallel() src := &oauthTokenSource{ - oauthCfg: &config.RegistryOAuthConfig{ + oauthCfg: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "test-client", }, - registryURL: "https://registry.example.com", + serviceURL: "https://registry.example.com", secretsProvider: nil, // genuine nil interface — triggers the nil guard in tryRestoreFromCache } @@ -456,7 +466,7 @@ func TestOAuthTokenSource_TryRestoreFromCache(t *testing.T) { Return("", errors.New("vault unavailable")) return mock }, - wantErrContains: "failed to get cached refresh token", + wantErrContains: "failed to retrieve cached refresh token", }, { name: "GetSecret returns empty string", @@ -479,11 +489,11 @@ func TestOAuthTokenSource_TryRestoreFromCache(t *testing.T) { provider := tt.buildProvider(ctrl) src := &oauthTokenSource{ - oauthCfg: &config.RegistryOAuthConfig{ + oauthCfg: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "test-client", }, - registryURL: "https://registry.example.com", + serviceURL: "https://registry.example.com", secretsProvider: provider, } @@ -509,11 +519,11 @@ func TestOAuthTokenSource_TryRestoreFromCache_WithOIDCServer(t *testing.T) { Return("my-refresh-token", nil) src := &oauthTokenSource{ - oauthCfg: &config.RegistryOAuthConfig{ + oauthCfg: &config.OAuthConfig{ Issuer: srv.URL, ClientID: "test-client", }, - registryURL: "https://registry.example.com", + serviceURL: "https://registry.example.com", secretsProvider: mockProvider, } @@ -568,11 +578,11 @@ func TestOAuthTokenSource_CreateTokenPersister(t *testing.T) { tt.setupMock(mockProvider) src := &oauthTokenSource{ - oauthCfg: &config.RegistryOAuthConfig{ + oauthCfg: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "test-client", }, - registryURL: "https://registry.example.com", + serviceURL: "https://registry.example.com", secretsProvider: mockProvider, } @@ -590,3 +600,199 @@ func TestOAuthTokenSource_CreateTokenPersister(t *testing.T) { }) } } + +// TestBuildOAuthFlowConfig_ResourceVsAudience is the regression guard for the bug fix +// that separates RFC 8707 Resource from the provider-specific Audience parameter. +// Resource is passed directly to CreateOAuthConfigFromOIDC (and used at the token +// endpoint), whereas Audience is injected into OAuthParams (authorization URL only). +func TestBuildOAuthFlowConfig_ResourceVsAudience(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + audience string + resource string + wantAudienceParam string // expected value of OAuthParams["audience"], "" means absent + wantResourceParam bool // true if we expect "resource" key in OAuthParams (should always be false) + }{ + { + name: "Audience set — appears in OAuthParams[audience]", + audience: "https://api.auth0.com/", + resource: "", + wantAudienceParam: "https://api.auth0.com/", + wantResourceParam: false, + }, + { + name: "Resource set, Audience empty — OAuthParams not polluted", + audience: "", + resource: "https://api.example.com/", + wantAudienceParam: "", + wantResourceParam: false, + }, + { + name: "Both Resource and Audience set — only Audience in OAuthParams", + audience: "my-audience", + resource: "https://api.example.com/", + wantAudienceParam: "my-audience", + wantResourceParam: false, + }, + { + name: "Neither set — OAuthParams is nil", + audience: "", + resource: "", + wantAudienceParam: "", + wantResourceParam: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + srv := newOIDCTestServer(t) + + src := &oauthTokenSource{ + oauthCfg: &config.OAuthConfig{ + Issuer: srv.URL, + ClientID: "test-client", + Audience: tt.audience, + Resource: tt.resource, + }, + serviceURL: "https://registry.example.com", + } + + cfg, err := src.buildOAuthFlowConfig(context.Background()) + require.NoError(t, err) + require.NotNil(t, cfg) + + if tt.wantAudienceParam != "" { + require.NotNil(t, cfg.OAuthParams, "OAuthParams must be non-nil when audience is set") + require.Equal(t, tt.wantAudienceParam, cfg.OAuthParams["audience"], + "OAuthParams[audience] must match the configured Audience value") + } else { + // Either nil map or absent key — neither is acceptable as a spurious value. + if cfg.OAuthParams != nil { + _, hasAudience := cfg.OAuthParams["audience"] + require.False(t, hasAudience, + "OAuthParams must NOT contain 'audience' key when Audience is empty") + } + } + + // Resource must never appear in OAuthParams; it travels through a separate + // channel (passed to CreateOAuthConfigFromOIDC and the token endpoint). + if cfg.OAuthParams != nil { + _, hasResource := cfg.OAuthParams["resource"] + require.False(t, hasResource, + "Resource must NOT be placed in OAuthParams; it is handled by CreateOAuthConfigFromOIDC") + } + }) + } +} + +// TestOAuthTokenSource_ConfigUpdater covers the injectable configUpdater callback and +// the updateConfigTokenRef method that dispatches to it. +func TestOAuthTokenSource_ConfigUpdater(t *testing.T) { + t.Parallel() + + expiry := time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC) + + t.Run("updateConfigTokenRef calls configUpdater when set", func(t *testing.T) { + t.Parallel() + + var gotRef string + var gotExpiry time.Time + + src := &oauthTokenSource{ + oauthCfg: &config.OAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client", + }, + serviceURL: "https://registry.example.com", + configUpdater: func(tokenRef string, exp time.Time) { + gotRef = tokenRef + gotExpiry = exp + }, + } + + src.updateConfigTokenRef("my-token-ref", expiry) + + require.Equal(t, "my-token-ref", gotRef) + require.Equal(t, expiry, gotExpiry) + }) + + t.Run("updateConfigTokenRef is a no-op when configUpdater is nil", func(t *testing.T) { + t.Parallel() + + src := &oauthTokenSource{ + oauthCfg: &config.OAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client", + }, + serviceURL: "https://registry.example.com", + configUpdater: nil, + } + + // Must not panic. + require.NotPanics(t, func() { + src.updateConfigTokenRef("my-token-ref", expiry) + }) + }) + + t.Run("NewTokenSource stores non-nil configUpdater", func(t *testing.T) { + t.Parallel() + + updater := func(_ string, _ time.Time) {} + + ts, err := NewTokenSource( + &config.OAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client-id", + }, + "https://registry.example.com", + nil, + false, + updater, + ) + require.NoError(t, err) + require.NotNil(t, ts) + + src, ok := ts.(*oauthTokenSource) + require.True(t, ok, "returned TokenSource must be *oauthTokenSource") + require.NotNil(t, src.configUpdater) + }) + + t.Run("createTokenPersister with non-nil configUpdater calls it", func(t *testing.T) { + t.Parallel() + + var gotRef string + var gotExpiry time.Time + + ctrl := gomock.NewController(t) + mockProvider := secretsmocks.NewMockProvider(ctrl) + mockProvider.EXPECT(). + SetSecret(gomock.Any(), "my-key", "refresh-token"). + Return(nil) + + src := &oauthTokenSource{ + oauthCfg: &config.OAuthConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client", + }, + serviceURL: "https://registry.example.com", + secretsProvider: mockProvider, + configUpdater: func(tokenRef string, exp time.Time) { + gotRef = tokenRef + gotExpiry = exp + }, + } + + persister := src.createTokenPersister("my-key") + require.NotNil(t, persister) + + err := persister("refresh-token", expiry) + require.NoError(t, err) + + require.Equal(t, "my-key", gotRef) + require.Equal(t, expiry, gotExpiry) + }) +} diff --git a/pkg/registry/auth/login.go b/pkg/registry/auth/login.go index 09ff49c8c5..1367a99664 100644 --- a/pkg/registry/auth/login.go +++ b/pkg/registry/auth/login.go @@ -82,7 +82,7 @@ func Login( registryURL := registryURLFromConfig(cfg) - ts, err := NewTokenSource(cfg.RegistryAuth.OAuth, registryURL, secretsProvider, true) + ts, err := NewTokenSource(cfg.RegistryAuth.OAuth, registryURL, secretsProvider, true, RegistryConfigUpdater()) if err != nil { return fmt.Errorf("creating token source: %w", err) } @@ -284,7 +284,7 @@ func ConfigureOAuth( return func(c *config.Config) { c.RegistryAuth = config.RegistryAuth{ Type: config.RegistryAuthTypeOAuth, - OAuth: &config.RegistryOAuthConfig{ + OAuth: &config.OAuthConfig{ Issuer: issuer, ClientID: clientID, Scopes: resolvedScopes, diff --git a/pkg/registry/auth/login_test.go b/pkg/registry/auth/login_test.go index 67830401b9..5aa807a0b1 100644 --- a/pkg/registry/auth/login_test.go +++ b/pkg/registry/auth/login_test.go @@ -22,8 +22,8 @@ import ( // --- helpers --- // oauthConfig returns a minimal valid OAuth config for tests. -func oauthConfig() *config.RegistryOAuthConfig { - return &config.RegistryOAuthConfig{ +func oauthConfig() *config.OAuthConfig { + return &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "test-client", } diff --git a/pkg/registry/auth/oauth_token_source.go b/pkg/registry/auth/oauth_token_source.go index a1301f3071..c0423766a3 100644 --- a/pkg/registry/auth/oauth_token_source.go +++ b/pkg/registry/auth/oauth_token_source.go @@ -18,12 +18,17 @@ import ( "github.com/stacklok/toolhive/pkg/secrets" ) +// configUpdaterFunc is the signature for a callback that persists a token +// reference and expiry back to the caller's config store. +type configUpdaterFunc func(tokenRef string, expiry time.Time) + // oauthTokenSource implements TokenSource using an OIDC browser-based flow. type oauthTokenSource struct { - oauthCfg *config.RegistryOAuthConfig - registryURL string + oauthCfg *config.OAuthConfig + serviceURL string secretsProvider secrets.Provider interactive bool + configUpdater configUpdaterFunc mu sync.Mutex tokenSource oauth2.TokenSource } @@ -72,18 +77,15 @@ func (o *oauthTokenSource) Token(ctx context.Context) (string, error) { // tryRestoreFromCache attempts to restore token source from cached refresh token. func (o *oauthTokenSource) tryRestoreFromCache(ctx context.Context) error { - if o.secretsProvider == nil { - return fmt.Errorf("no secrets provider available") + mgr, err := remote.NewTokenPersistenceManager(o.secretsProvider) + if err != nil { + return fmt.Errorf("no secrets provider available: %w", err) } - refreshTokenKey := o.refreshTokenKey() - - refreshToken, err := o.secretsProvider.GetSecret(ctx, refreshTokenKey) + // Check the secret exists before OIDC discovery to avoid unnecessary network calls. + refreshToken, err := mgr.FetchRefreshToken(ctx, o.refreshTokenKey()) if err != nil { - return fmt.Errorf("failed to get cached refresh token: %w", err) - } - if refreshToken == "" { - return fmt.Errorf("no cached refresh token found") + return err } oauth2Cfg, err := o.buildOAuth2Config(ctx) @@ -91,7 +93,7 @@ func (o *oauthTokenSource) tryRestoreFromCache(ctx context.Context) error { return fmt.Errorf("failed to create oauth2 config: %w", err) } - o.tokenSource = remote.CreateTokenSourceFromCached(oauth2Cfg, refreshToken, o.oauthCfg.CachedTokenExpiry, "") + o.tokenSource = remote.CreateTokenSourceFromCached(oauth2Cfg, refreshToken, o.oauthCfg.CachedTokenExpiry, o.oauthCfg.Resource) return nil } @@ -153,7 +155,7 @@ func (o *oauthTokenSource) buildOAuthFlowConfig(ctx context.Context) (*oauth.Con scopes := ensureOfflineAccess(o.oauthCfg.Scopes) - return oauth.CreateOAuthConfigFromOIDC( + cfg, err := oauth.CreateOAuthConfigFromOIDC( ctx, o.oauthCfg.Issuer, o.oauthCfg.ClientID, @@ -161,8 +163,23 @@ func (o *oauthTokenSource) buildOAuthFlowConfig(ctx context.Context) (*oauth.Con scopes, true, // Always use PKCE (S256) callbackPort, - o.oauthCfg.Audience, + o.oauthCfg.Resource, ) + if err != nil { + return nil, err + } + + // Audience is a provider-specific request parameter (e.g. Auth0) distinct from + // the RFC 8707 resource indicator. Pass it as an extra auth URL parameter so + // providers that require it receive it correctly. + if o.oauthCfg.Audience != "" { + if cfg.OAuthParams == nil { + cfg.OAuthParams = make(map[string]string) + } + cfg.OAuthParams["audience"] = o.oauthCfg.Audience + } + + return cfg, nil } // ensureOfflineAccess returns scopes with "offline_access" included. @@ -210,15 +227,10 @@ func (o *oauthTokenSource) createTokenPersister(refreshTokenKey string) remote.T } } -// updateConfigTokenRef updates the config with the refresh token reference and expiry. -func (*oauthTokenSource) updateConfigTokenRef(refreshTokenKey string, expiry time.Time) { - if err := config.UpdateConfig(func(cfg *config.Config) { - if cfg.RegistryAuth.OAuth != nil { - cfg.RegistryAuth.OAuth.CachedRefreshTokenRef = refreshTokenKey - cfg.RegistryAuth.OAuth.CachedTokenExpiry = expiry - } - }); err != nil { - slog.Warn("Failed to update config with token reference", "error", err) +// updateConfigTokenRef delegates to the injected configUpdater if one was provided. +func (o *oauthTokenSource) updateConfigTokenRef(refreshTokenKey string, expiry time.Time) { + if o.configUpdater != nil { + o.configUpdater(refreshTokenKey, expiry) } } @@ -228,5 +240,5 @@ func (o *oauthTokenSource) refreshTokenKey() string { if o.oauthCfg.CachedRefreshTokenRef != "" { return o.oauthCfg.CachedRefreshTokenRef } - return DeriveSecretKey(o.registryURL, o.oauthCfg.Issuer) + return DeriveSecretKey(o.serviceURL, o.oauthCfg.Issuer) } diff --git a/pkg/registry/auth_manager_test.go b/pkg/registry/auth_manager_test.go index 5d00daa388..b778ed454e 100644 --- a/pkg/registry/auth_manager_test.go +++ b/pkg/registry/auth_manager_test.go @@ -50,7 +50,7 @@ func TestDefaultAuthManager_UnsetAuth(t *testing.T) { cfg := &config.Config{ RegistryAuth: config.RegistryAuth{ Type: config.RegistryAuthTypeOAuth, - OAuth: &config.RegistryOAuthConfig{ + OAuth: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "my-client", }, @@ -93,7 +93,7 @@ func TestDefaultAuthManager_GetAuthInfo(t *testing.T) { name: "returns oauth type without cached tokens when OAuth section has no ref", registryAuth: config.RegistryAuth{ Type: config.RegistryAuthTypeOAuth, - OAuth: &config.RegistryOAuthConfig{ + OAuth: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "my-client", }, @@ -105,7 +105,7 @@ func TestDefaultAuthManager_GetAuthInfo(t *testing.T) { name: "returns oauth type with cached tokens when CachedRefreshTokenRef is set", registryAuth: config.RegistryAuth{ Type: config.RegistryAuthTypeOAuth, - OAuth: &config.RegistryOAuthConfig{ + OAuth: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "my-client", CachedRefreshTokenRef: "REGISTRY_OAUTH_aabbccdd", @@ -164,7 +164,7 @@ func TestDefaultAuthManager_GetAuthStatus(t *testing.T) { name: "returns configured when OAuth set but no cached tokens", registryAuth: config.RegistryAuth{ Type: config.RegistryAuthTypeOAuth, - OAuth: &config.RegistryOAuthConfig{ + OAuth: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "my-client", }, @@ -176,7 +176,7 @@ func TestDefaultAuthManager_GetAuthStatus(t *testing.T) { name: "returns authenticated when OAuth set with cached tokens", registryAuth: config.RegistryAuth{ Type: config.RegistryAuthTypeOAuth, - OAuth: &config.RegistryOAuthConfig{ + OAuth: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "my-client", CachedRefreshTokenRef: "REGISTRY_OAUTH_aabbccdd", @@ -241,7 +241,7 @@ func TestDefaultAuthManager_GetOAuthPublicConfig(t *testing.T) { name: "returns config with all fields populated", registryAuth: config.RegistryAuth{ Type: config.RegistryAuthTypeOAuth, - OAuth: &config.RegistryOAuthConfig{ + OAuth: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "my-client", Audience: "api://toolhive", @@ -259,7 +259,7 @@ func TestDefaultAuthManager_GetOAuthPublicConfig(t *testing.T) { name: "returns config without optional fields", registryAuth: config.RegistryAuth{ Type: config.RegistryAuthTypeOAuth, - OAuth: &config.RegistryOAuthConfig{ + OAuth: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "my-client", }, @@ -273,7 +273,7 @@ func TestDefaultAuthManager_GetOAuthPublicConfig(t *testing.T) { name: "excludes cached token fields", registryAuth: config.RegistryAuth{ Type: config.RegistryAuthTypeOAuth, - OAuth: &config.RegistryOAuthConfig{ + OAuth: &config.OAuthConfig{ Issuer: "https://auth.example.com", ClientID: "my-client", CachedRefreshTokenRef: "REGISTRY_OAUTH_aabbccdd", diff --git a/pkg/registry/factory.go b/pkg/registry/factory.go index c15c5d5fcc..1bfa1912b8 100644 --- a/pkg/registry/factory.go +++ b/pkg/registry/factory.go @@ -141,7 +141,10 @@ func resolveTokenSource(cfg *config.Config, interactive bool) auth.TokenSource { } } - tokenSource, err := auth.NewTokenSource(cfg.RegistryAuth.OAuth, cfg.RegistryApiUrl, secretsProvider, interactive) + tokenSource, err := auth.NewTokenSource( + cfg.RegistryAuth.OAuth, cfg.RegistryApiUrl, secretsProvider, + interactive, auth.RegistryConfigUpdater(), + ) if err != nil { slog.Warn("Failed to create registry auth token source", "error", err) return nil