diff --git a/pkg/authserver/server/handlers/authorize.go b/pkg/authserver/server/handlers/authorize.go index dece4979ad..ba767e88c9 100644 --- a/pkg/authserver/server/handlers/authorize.go +++ b/pkg/authserver/server/handlers/authorize.go @@ -66,7 +66,9 @@ func (h *Handler) AuthorizeHandler(w http.ResponseWriter, req *http.Request) { return } - slog.Debug("parsed client-requested scopes", //nolint:gosec // G706: scope count is an integer + slog.Debug("authorize request received", + "client_id", clientID, + "redirect_uri", redirectURI, "scope_count", len(scopes), ) diff --git a/pkg/authserver/server/handlers/dcr.go b/pkg/authserver/server/handlers/dcr.go index bd9f8ef4df..f7dabed355 100644 --- a/pkg/authserver/server/handlers/dcr.go +++ b/pkg/authserver/server/handlers/dcr.go @@ -77,7 +77,7 @@ func (h *Handler) RegisterClientHandler(w http.ResponseWriter, req *http.Request // offline_access) — every DCR-registered client gains the ability to request // these scopes at /oauth/authorize regardless of what they registered with. if len(h.config.BaselineClientScopes) > 0 { - effective := unionScopes(scopes, h.config.BaselineClientScopes) + effective := registration.UnionScopes(scopes, h.config.BaselineClientScopes) if !slices.Equal(effective, scopes) { // Baseline-driven expansion is the intended behavior whenever // baseline_client_scopes is configured, so per-registration diff --git a/pkg/authserver/server/handlers/scopes.go b/pkg/authserver/server/handlers/scopes.go deleted file mode 100644 index 6f4705c25b..0000000000 --- a/pkg/authserver/server/handlers/scopes.go +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package handlers - -// unionScopes returns the union of requested and baseline scopes, preserving -// the order of requested first, then appending any baseline scopes not already -// present. Duplicates are removed. Returns nil when the result is empty. -// -// This is used by the DCR registration handler to inject an -// operator-configured scope baseline into the registered client's scope set: -// if a client narrows the scope field at /oauth/register, the baseline scopes -// are still part of the client's registered set so that the client can -// request them later at /oauth/authorize without invalid_scope rejection. -// -// Both inputs must already be validated by the caller (e.g. via ValidateScopes -// for client-supplied scopes). unionScopes does not filter empty strings or -// validate scope syntax — it only deduplicates and merges in stable order. -func unionScopes(requested, baseline []string) []string { - seen := make(map[string]bool, len(requested)+len(baseline)) - out := make([]string, 0, len(requested)+len(baseline)) - for _, s := range requested { - if !seen[s] { - seen[s] = true - out = append(out, s) - } - } - for _, s := range baseline { - if !seen[s] { - seen[s] = true - out = append(out, s) - } - } - if len(out) == 0 { - return nil - } - return out -} diff --git a/pkg/authserver/server/handlers/scopes_test.go b/pkg/authserver/server/handlers/scopes_test.go deleted file mode 100644 index 93ebf32dce..0000000000 --- a/pkg/authserver/server/handlers/scopes_test.go +++ /dev/null @@ -1,115 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package handlers - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestUnionScopes(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - req []string - baseline []string - want []string - }{ - { - name: "both nil returns nil", - req: nil, - baseline: nil, - want: nil, - }, - { - name: "both empty returns nil", - req: []string{}, - baseline: []string{}, - want: nil, - }, - { - name: "nil requested empty baseline returns nil", - req: nil, - baseline: []string{}, - want: nil, - }, - { - name: "requested only preserved unchanged", - req: []string{"openid", "profile"}, - baseline: nil, - want: []string{"openid", "profile"}, - }, - { - name: "baseline only returned when no requested", - req: nil, - baseline: []string{"openid", "email"}, - want: []string{"openid", "email"}, - }, - { - name: "requested subset of baseline: requested first then non-overlapping baseline", - req: []string{"openid"}, - baseline: []string{"openid", "profile", "email"}, - want: []string{"openid", "profile", "email"}, - }, - { - name: "disjoint sets: requested first then baseline", - req: []string{"openid", "profile"}, - baseline: []string{"email", "offline_access"}, - want: []string{"openid", "profile", "email", "offline_access"}, - }, - { - name: "exact match produces no duplicates", - req: []string{"openid", "profile"}, - baseline: []string{"openid", "profile"}, - want: []string{"openid", "profile"}, - }, - { - name: "duplicates in requested are deduplicated requested-first order preserved", - req: []string{"openid", "openid", "profile"}, - baseline: nil, - want: []string{"openid", "profile"}, - }, - { - name: "duplicates in baseline are deduplicated", - req: nil, - baseline: []string{"openid", "profile", "openid"}, - want: []string{"openid", "profile"}, - }, - { - name: "duplicates in both are deduplicated with requested-first order", - req: []string{"openid", "openid", "profile"}, - baseline: []string{"profile", "email", "email"}, - want: []string{"openid", "profile", "email"}, - }, - { - name: "no expansion when baseline is subset of requested (WARN not triggered in handler)", - req: []string{"openid", "profile", "email"}, - baseline: []string{"openid", "profile"}, - want: []string{"openid", "profile", "email"}, - }, - { - name: "multi-element requested with strict-superset baseline preserves requested order then appends", - req: []string{"openid", "profile"}, - baseline: []string{"openid", "profile", "email", "offline_access"}, - want: []string{"openid", "profile", "email", "offline_access"}, - }, - { - name: "empty-string entry in requested is passed through unchanged (precondition: caller must validate)", - req: []string{"", "openid"}, - baseline: nil, - want: []string{"", "openid"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - got := unionScopes(tt.req, tt.baseline) - assert.Equal(t, tt.want, got) - }) - } -} diff --git a/pkg/authserver/server/registration/dcr.go b/pkg/authserver/server/registration/dcr.go index 3197c1e807..cc290d1acb 100644 --- a/pkg/authserver/server/registration/dcr.go +++ b/pkg/authserver/server/registration/dcr.go @@ -327,3 +327,47 @@ func ValidateScopes(requestedScopes, allowedScopes []string) ([]string, *DCRErro return scopes, nil } + +// UnionScopes returns the union of requested and baseline scopes, preserving +// the order of requested first, then appending any baseline scopes not already +// present. Duplicates are removed. Returns nil when the result is empty. +// +// Both inputs must already be validated by the caller. UnionScopes does not +// filter empty strings or validate scope syntax — it only deduplicates and +// merges in stable order. +func UnionScopes(requested, baseline []string) []string { + seen := make(map[string]bool, len(requested)+len(baseline)) + out := make([]string, 0, len(requested)+len(baseline)) + for _, s := range requested { + if !seen[s] { + seen[s] = true + out = append(out, s) + } + } + for _, s := range baseline { + if !seen[s] { + seen[s] = true + out = append(out, s) + } + } + if len(out) == 0 { + return nil + } + return out +} + +// ValidatePublicGrantTypes validates the grant_types for a public OAuth client, +// applying the same rules as DCR: authorization_code must be present, and all +// declared values must be in the allowed set. Returns the validated slice (with +// defaults applied when nil/empty) or a *DCRError on violation. +func ValidatePublicGrantTypes(grantTypes []string) ([]string, *DCRError) { + return validateGrantTypes(grantTypes) +} + +// ValidatePublicResponseTypes validates the response_types for a public OAuth +// client, applying the same rules as DCR: code must be present and all declared +// values must be in the allowed set. Returns the validated slice (with defaults +// applied when nil/empty) or a *DCRError on violation. +func ValidatePublicResponseTypes(responseTypes []string) ([]string, *DCRError) { + return validateResponseTypes(responseTypes) +} diff --git a/pkg/authserver/server/registration/dcr_test.go b/pkg/authserver/server/registration/dcr_test.go index 07efa63f86..657c1ab9fa 100644 --- a/pkg/authserver/server/registration/dcr_test.go +++ b/pkg/authserver/server/registration/dcr_test.go @@ -694,3 +694,31 @@ func TestValidateScopeSubset(t *testing.T) { }) } } + +func TestUnionScopes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + req []string + baseline []string + want []string + }{ + {name: "both nil returns nil", req: nil, baseline: nil, want: nil}, + {name: "both empty returns nil", req: []string{}, baseline: []string{}, want: nil}, + {name: "requested only preserved unchanged", req: []string{"openid", "profile"}, baseline: nil, want: []string{"openid", "profile"}}, + {name: "baseline only returned when no requested", req: nil, baseline: []string{"openid", "email"}, want: []string{"openid", "email"}}, + {name: "requested subset of baseline expands correctly", req: []string{"openid"}, baseline: []string{"openid", "profile", "email"}, want: []string{"openid", "profile", "email"}}, + {name: "disjoint sets: requested first then baseline", req: []string{"openid", "profile"}, baseline: []string{"email", "offline_access"}, want: []string{"openid", "profile", "email", "offline_access"}}, + {name: "exact match produces no duplicates", req: []string{"openid", "profile"}, baseline: []string{"openid", "profile"}, want: []string{"openid", "profile"}}, + {name: "duplicates in requested are deduplicated", req: []string{"openid", "openid", "profile"}, baseline: nil, want: []string{"openid", "profile"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := UnionScopes(tt.req, tt.baseline) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/authserver/server_impl.go b/pkg/authserver/server_impl.go index 39a22468ca..250f0a68c0 100644 --- a/pkg/authserver/server_impl.go +++ b/pkg/authserver/server_impl.go @@ -177,7 +177,19 @@ func newServer(ctx context.Context, cfg Config, stor storage.Storage, opts ...se // so that GetClient calls for HTTPS client_id values are intercepted at the // fosite level (not just the handler level). if cfg.CIMDEnabled { - stor, err = storage.NewCIMDStorageDecorator(stor, true, cfg.CIMDCacheMaxSize, cfg.CIMDCacheFallbackTTL) + if len(cfg.BaselineClientScopes) > 0 { + slog.Warn("CIMD is enabled with baseline_client_scopes configured; "+ + "any third-party client resolved via CIMD will also receive these scopes — "+ + "ensure they are scopes you would grant by default to any unknown client", + "baseline_client_scopes", cfg.BaselineClientScopes) + } + stor, err = storage.NewCIMDStorageDecorator(stor, storage.CIMDDecoratorConfig{ + Enabled: true, + CacheMaxSize: cfg.CIMDCacheMaxSize, + FallbackTTL: cfg.CIMDCacheFallbackTTL, + ScopesSupported: cfg.ScopesSupported, + BaselineClientScopes: cfg.BaselineClientScopes, + }) if err != nil { return nil, fmt.Errorf("failed to initialize CIMD storage decorator: %w", err) } diff --git a/pkg/authserver/storage/cimd_decorator.go b/pkg/authserver/storage/cimd_decorator.go index fbead1ef50..d8a8c5c508 100644 --- a/pkg/authserver/storage/cimd_decorator.go +++ b/pkg/authserver/storage/cimd_decorator.go @@ -28,10 +28,12 @@ import ( // Only GetClient is overridden. DCR clients (opaque IDs) continue to work // exactly as before. type CIMDStorageDecorator struct { - Storage // embed full interface — all methods delegate - sf singleflight.Group // deduplicates concurrent fetches for the same URL - cache *lru.Cache[string, *cimdCacheEntry] - ttl time.Duration + Storage // embed full interface — all methods delegate + sf singleflight.Group // deduplicates concurrent fetches for the same URL + cache *lru.Cache[string, *cimdCacheEntry] + ttl time.Duration + scopesSupported []string // AS-configured scopes; nil means accept any + baselineClientScopes []string // unioned into every client's scope set, same as DCR } type cimdCacheEntry struct { @@ -39,33 +41,45 @@ type cimdCacheEntry struct { expires time.Time } -// NewCIMDStorageDecorator wraps base with CIMD client lookup. When enabled=false -// it returns base unchanged (no allocation). cacheMaxSize must be >= 1; -// fallbackTTL is the fixed TTL applied to every cache entry (Cache-Control -// header parsing is not yet implemented; all entries use this value). -func NewCIMDStorageDecorator( - base Storage, - enabled bool, - cacheMaxSize int, - fallbackTTL time.Duration, -) (Storage, error) { - if !enabled { +// CIMDDecoratorConfig holds the configuration for NewCIMDStorageDecorator. +// Using a struct prevents silent swaps of the two adjacent []string fields. +type CIMDDecoratorConfig struct { + // Enabled returns base unchanged when false, avoiding an allocation. + Enabled bool + // CacheMaxSize is the maximum number of documents in the LRU cache (must be >= 1). + CacheMaxSize int + // FallbackTTL is the fixed TTL applied to every cache entry. + FallbackTTL time.Duration + // ScopesSupported is the AS scope allowlist; see pkg/authserver/config.go + // applyDefaults for production guarantees. Pass nil in tests only. + ScopesSupported []string + // BaselineClientScopes is unioned into every CIMD client's scope set, + // matching DCR handler behaviour. + BaselineClientScopes []string +} + +// NewCIMDStorageDecorator wraps base with CIMD client lookup. +// When cfg.Enabled is false it returns base unchanged (no allocation). +func NewCIMDStorageDecorator(base Storage, cfg CIMDDecoratorConfig) (Storage, error) { + if !cfg.Enabled { return base, nil } - if cacheMaxSize < 1 { - return nil, fmt.Errorf("CIMD storage decorator cacheMaxSize must be >= 1, got %d", cacheMaxSize) + if cfg.CacheMaxSize < 1 { + return nil, fmt.Errorf("CIMD storage decorator cacheMaxSize must be >= 1, got %d", cfg.CacheMaxSize) } - c, err := lru.New[string, *cimdCacheEntry](cacheMaxSize) + c, err := lru.New[string, *cimdCacheEntry](cfg.CacheMaxSize) if err != nil { return nil, fmt.Errorf("failed to create CIMD LRU cache: %w", err) } return &CIMDStorageDecorator{ - Storage: base, - cache: c, - ttl: fallbackTTL, + Storage: base, + cache: c, + ttl: cfg.FallbackTTL, + scopesSupported: slices.Clone(cfg.ScopesSupported), + baselineClientScopes: slices.Clone(cfg.BaselineClientScopes), }, nil } @@ -119,17 +133,69 @@ func (d *CIMDStorageDecorator) fetch(ctx context.Context, id string) (fosite.Cli } // Reject documents that declare an auth method this AS does not support. - // The embedded AS only advertises "none"; accepting a doc that says - // "private_key_jwt" and then silently treating the client as public would - // mislead operators and break clients that actually try to use JWT assertions. + // ErrInvalidClient: the document was fetched successfully but its declared + // metadata violates AS policy (distinct from ErrNotFound which means the + // document could not be fetched at all). if m := doc.TokenEndpointAuthMethod; m != "" && m != defaultCIMDTokenEndpointAuthMethod { return nil, fmt.Errorf("%w: CIMD document at %s claims token_endpoint_auth_method %q "+ "but this server only supports %q", - fosite.ErrNotFound.WithHint("unsupported token_endpoint_auth_method"), + fosite.ErrInvalidClient.WithHint("unsupported token_endpoint_auth_method"), id, m, defaultCIMDTokenEndpointAuthMethod) } - client := buildFositeClient(doc) + // Reject documents that declare grant_types or response_types the embedded AS + // does not support for public clients. Uses the same validators as DCR so the + // error messages and allowed sets are identical on both registration paths. + if _, dcrErr := registration.ValidatePublicGrantTypes(doc.GrantTypes); dcrErr != nil { + return nil, fmt.Errorf("%w: CIMD document at %s: %s", + fosite.ErrInvalidClient.WithHint(dcrErr.ErrorDescription), id, dcrErr.ErrorDescription) + } + if _, dcrErr := registration.ValidatePublicResponseTypes(doc.ResponseTypes); dcrErr != nil { + return nil, fmt.Errorf("%w: CIMD document at %s: %s", + fosite.ErrInvalidClient.WithHint(dcrErr.ErrorDescription), id, dcrErr.ErrorDescription) + } + + // Compute and validate the client scope list consistent with DCR. + // When ScopesSupported is configured: + // - Declared scopes are validated via registration.ValidateScopes (same + // function as the DCR handler). + // - Omitted scope uses ValidateScopes(nil, scopesSupported) which returns + // DefaultScopes when DefaultScopes ⊆ ScopesSupported, matching DCR. + // If DefaultScopes ⊄ ScopesSupported the document must declare scope + // explicitly to avoid ambiguous privilege grant. + // When ScopesSupported is not configured: no AS-level validation; declared + // scopes are used directly, or nil to let buildFositeClient apply DefaultScopes. + // In both cases BaselineClientScopes is unioned in after validation, + // matching the DCR handler's behaviour. + var resolvedScopes []string + if len(d.scopesSupported) > 0 { + if doc.Scope != "" { + computed, dcrErr := registration.ValidateScopes(strings.Fields(doc.Scope), d.scopesSupported) + if dcrErr != nil { + return nil, fmt.Errorf("%w: CIMD document at %s: %s", + fosite.ErrInvalidClient.WithHint(dcrErr.ErrorDescription), id, dcrErr.ErrorDescription) + } + resolvedScopes = computed + } else { + // Omitted scope: match DCR — give DefaultScopes when they fit, else require explicit scope. + computed, dcrErr := registration.ValidateScopes(nil, d.scopesSupported) + if dcrErr != nil { + return nil, fmt.Errorf("%w: CIMD document at %s omits scope but "+ + "DefaultScopes are not a subset of this server's scopes_supported — "+ + "the document must explicitly declare its required scopes", + fosite.ErrInvalidClient.WithHint("scope field required"), + id) + } + resolvedScopes = computed + } + } else if doc.Scope != "" { + resolvedScopes = strings.Fields(doc.Scope) + } + if len(d.baselineClientScopes) > 0 { + resolvedScopes = registration.UnionScopes(resolvedScopes, d.baselineClientScopes) + } + + client := buildFositeClient(doc, resolvedScopes) d.cache.Add(id, &cimdCacheEntry{ client: client, @@ -157,7 +223,10 @@ const defaultCIMDTokenEndpointAuthMethod = "none" // buildFositeClient converts a ClientMetadataDocument into a fosite.Client. // Redirect URIs containing http://localhost are wrapped in a LoopbackClient // so that RFC 8252 §7.3 dynamic port matching applies. -func buildFositeClient(doc *cimd.ClientMetadataDocument) fosite.Client { +// resolvedScopes is the already-validated scope list computed by fetch() via +// registration.ValidateScopes; when empty, DefaultScopes is used — this occurs when +// the decorator has no ScopesSupported restriction (unconstrained AS). +func buildFositeClient(doc *cimd.ClientMetadataDocument, resolvedScopes []string) fosite.Client { grantTypes := doc.GrantTypes if len(grantTypes) == 0 { grantTypes = defaultCIMDGrantTypes @@ -173,13 +242,12 @@ func buildFositeClient(doc *cimd.ClientMetadataDocument) fosite.Client { tokenEndpointAuthMethod = defaultCIMDTokenEndpointAuthMethod } - // When the document omits the scope field, apply the same defaults as DCR - // registration so CIMD clients can request openid/profile/email/offline_access - // without needing to enumerate them explicitly in the metadata document. - // Clone to avoid aliasing the package-level DefaultScopes slice. - scopes := slices.Clone(registration.DefaultScopes) - if doc.Scope != "" { - scopes = strings.Fields(doc.Scope) + // Scopes were computed and validated by fetch() via registration.ValidateScopes, + // consistent with the DCR handler. Fall back to DefaultScopes only when the + // decorator has no ScopesSupported restriction (unconstrained AS). + scopes := resolvedScopes + if len(scopes) == 0 { + scopes = slices.Clone(registration.DefaultScopes) } defaultClient := &fosite.DefaultClient{ diff --git a/pkg/authserver/storage/cimd_decorator_test.go b/pkg/authserver/storage/cimd_decorator_test.go index 99c3d66dc8..1666cc81de 100644 --- a/pkg/authserver/storage/cimd_decorator_test.go +++ b/pkg/authserver/storage/cimd_decorator_test.go @@ -8,6 +8,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "sync" "sync/atomic" "testing" @@ -62,7 +63,7 @@ func newTestBase(t *testing.T) *MemoryStorage { // newEnabledDecorator creates a CIMDStorageDecorator wrapping base. func newEnabledDecorator(t *testing.T, base *MemoryStorage, maxSize int, ttl time.Duration) *CIMDStorageDecorator { t.Helper() - got, err := NewCIMDStorageDecorator(base, true, maxSize, ttl) + got, err := NewCIMDStorageDecorator(base, CIMDDecoratorConfig{Enabled: true, CacheMaxSize: maxSize, FallbackTTL: ttl}) require.NoError(t, err) return got.(*CIMDStorageDecorator) } @@ -77,7 +78,7 @@ func cimdURL(srv *httptest.Server, path string) string { func TestNewCIMDStorageDecorator_DisabledReturnsBase(t *testing.T) { t.Parallel() base := newTestBase(t) - got, err := NewCIMDStorageDecorator(base, false, 10, time.Minute) + got, err := NewCIMDStorageDecorator(base, CIMDDecoratorConfig{Enabled: false, CacheMaxSize: 10, FallbackTTL: time.Minute}) require.NoError(t, err) assert.Same(t, base, got, "disabled decorator must return base unchanged") } @@ -85,21 +86,21 @@ func TestNewCIMDStorageDecorator_DisabledReturnsBase(t *testing.T) { func TestNewCIMDStorageDecorator_ZeroCacheSizeReturnsError(t *testing.T) { t.Parallel() base := newTestBase(t) - _, err := NewCIMDStorageDecorator(base, true, 0, time.Minute) + _, err := NewCIMDStorageDecorator(base, CIMDDecoratorConfig{Enabled: true, CacheMaxSize: 0, FallbackTTL: time.Minute}) require.Error(t, err) } func TestNewCIMDStorageDecorator_NegativeCacheSizeReturnsError(t *testing.T) { t.Parallel() base := newTestBase(t) - _, err := NewCIMDStorageDecorator(base, true, -1, time.Minute) + _, err := NewCIMDStorageDecorator(base, CIMDDecoratorConfig{Enabled: true, CacheMaxSize: -1, FallbackTTL: time.Minute}) require.Error(t, err) } func TestNewCIMDStorageDecorator_EnabledReturnsCIMDDecorator(t *testing.T) { t.Parallel() base := newTestBase(t) - got, err := NewCIMDStorageDecorator(base, true, 10, time.Minute) + got, err := NewCIMDStorageDecorator(base, CIMDDecoratorConfig{Enabled: true, CacheMaxSize: 10, FallbackTTL: time.Minute}) require.NoError(t, err) require.NotNil(t, got) _, isCIMD := got.(*CIMDStorageDecorator) @@ -336,14 +337,11 @@ func TestBuildFositeClient_Defaults(t *testing.T) { RedirectURIs: []string{"https://example.com/callback"}, } - got := buildFositeClient(doc) + got := buildFositeClient(doc, nil) assert.Equal(t, "https://example.com/meta.json", got.GetID()) assert.True(t, got.IsPublic()) assert.ElementsMatch(t, []string{"authorization_code", "refresh_token"}, []string(got.GetGrantTypes())) assert.ElementsMatch(t, []string{"code"}, []string(got.GetResponseTypes())) - // Documents that omit scope must still allow the default scopes so that - // CIMD clients behave consistently with DCR-registered clients. - assert.ElementsMatch(t, registration.DefaultScopes, []string(got.GetScopes())) } func TestBuildFositeClient_ExplicitGrantTypes(t *testing.T) { @@ -355,7 +353,7 @@ func TestBuildFositeClient_ExplicitGrantTypes(t *testing.T) { GrantTypes: []string{"authorization_code"}, } - got := buildFositeClient(doc) + got := buildFositeClient(doc, nil) assert.ElementsMatch(t, []string{"authorization_code"}, []string(got.GetGrantTypes())) } @@ -368,7 +366,8 @@ func TestBuildFositeClient_ScopeParsing(t *testing.T) { Scope: "openid profile email", } - got := buildFositeClient(doc) + // Scope parsing is done by fetch() before buildFositeClient. + got := buildFositeClient(doc, strings.Fields(doc.Scope)) assert.ElementsMatch(t, []string{"openid", "profile", "email"}, []string(got.GetScopes())) } @@ -380,7 +379,7 @@ func TestBuildFositeClient_LoopbackRedirectWrapsInLoopbackClient(t *testing.T) { RedirectURIs: []string{"http://localhost/callback"}, } - got := buildFositeClient(doc) + got := buildFositeClient(doc, nil) // LoopbackClient adds MatchRedirectURI — check the distinctive method is present. type loopbackMatcher interface { MatchRedirectURI(string) bool @@ -403,7 +402,7 @@ func TestBuildFositeClient_NonLoopbackRedirectReturnsOpenIDConnectClient(t *test RedirectURIs: []string{"https://example.com/callback"}, } - got := buildFositeClient(doc) + got := buildFositeClient(doc, nil) _, ok := got.(*fosite.DefaultOpenIDConnectClient) assert.True(t, ok, "non-loopback redirect URI must produce a DefaultOpenIDConnectClient") } @@ -416,7 +415,7 @@ func TestBuildFositeClient_TokenEndpointAuthMethodDefault(t *testing.T) { RedirectURIs: []string{"https://example.com/callback"}, } - got := buildFositeClient(doc) + got := buildFositeClient(doc, nil) if oidc, ok := got.(fosite.OpenIDConnectClient); ok { assert.Equal(t, "none", oidc.GetTokenEndpointAuthMethod()) } @@ -441,4 +440,198 @@ func TestFetch_RejectsUnsupportedTokenEndpointAuthMethod(t *testing.T) { dec := newEnabledDecorator(t, newTestBase(t), 10, time.Minute) _, err := dec.fetchOrCached(context.Background(), srv.URL+"/meta.json") require.Error(t, err, "fetch must fail when token_endpoint_auth_method is not \"none\"") + assert.ErrorIs(t, err, fosite.ErrInvalidClient, + "CIMD policy rejections must use ErrInvalidClient, not ErrNotFound") + assert.NotErrorIs(t, err, fosite.ErrNotFound) +} + +// serveCIMDDocWithFields starts an httptest.Server that serves a CIMD document +// customised by the provided mutator function. Pass nil for a plain valid doc. +func serveCIMDDocWithFields(t *testing.T, mutate func(*cimd.ClientMetadataDocument)) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/meta.json" { + http.NotFound(w, r) + return + } + doc := cimd.ClientMetadataDocument{ + ClientID: "http://" + r.Host + r.URL.Path, + RedirectURIs: []string{"https://example.com/callback"}, + } + if mutate != nil { + mutate(&doc) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) + })) + t.Cleanup(srv.Close) + return srv +} + +// --- grant_types validation --- + +func TestFetch_GrantTypeValidation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + grantTypes []string + wantErr bool + }{ + {"omitted grant_types accepted", nil, false}, + {"explicit [authorization_code, refresh_token] accepted", []string{"authorization_code", "refresh_token"}, false}, + {"explicit [authorization_code] accepted", []string{"authorization_code"}, false}, + {"refresh_token only missing authorization_code rejected", []string{"refresh_token"}, true}, + {"client_credentials rejected", []string{"client_credentials"}, true}, + {"implicit rejected", []string{"implicit"}, true}, + {"device_code rejected", []string{"urn:ietf:params:oauth:grant-type:device_code"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + srv := serveCIMDDocWithFields(t, func(doc *cimd.ClientMetadataDocument) { + doc.GrantTypes = tt.grantTypes + }) + dec := newEnabledDecorator(t, newTestBase(t), 10, time.Minute) + _, err := dec.fetchOrCached(context.Background(), srv.URL+"/meta.json") + if tt.wantErr { + require.Error(t, err) + assert.ErrorIs(t, err, fosite.ErrInvalidClient, + "grant_type policy rejections must use ErrInvalidClient") + assert.NotErrorIs(t, err, fosite.ErrNotFound) + } else { + require.NoError(t, err) + } + }) + } +} + +// --- response_types validation --- + +func TestFetch_ResponseTypeValidation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + responseTypes []string + wantErr bool + }{ + {"omitted response_types accepted", nil, false}, + {"code accepted", []string{"code"}, false}, + {"token rejected", []string{"token"}, true}, + {"code id_token rejected (hybrid)", []string{"code id_token"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + srv := serveCIMDDocWithFields(t, func(doc *cimd.ClientMetadataDocument) { + doc.ResponseTypes = tt.responseTypes + }) + dec := newEnabledDecorator(t, newTestBase(t), 10, time.Minute) + _, err := dec.fetchOrCached(context.Background(), srv.URL+"/meta.json") + if tt.wantErr { + require.Error(t, err) + assert.ErrorIs(t, err, fosite.ErrInvalidClient, + "response_type policy rejections must use ErrInvalidClient") + assert.NotErrorIs(t, err, fosite.ErrNotFound) + } else { + require.NoError(t, err) + } + }) + } +} + +// --- scope resolution --- + +func TestFetch_ScopeResolution(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + docScope string + scopesSupported []string + baseline []string + wantErr bool + wantScopes []string + }{ + { + name: "no constraint uses DefaultScopes", + docScope: "", + wantScopes: registration.DefaultScopes, + }, + { + name: "explicit scope accepted within ScopesSupported", + docScope: "openid", + scopesSupported: []string{"openid", "profile"}, + wantScopes: []string{"openid"}, + }, + { + name: "explicit scope outside ScopesSupported rejected", + docScope: "openid profile email", + scopesSupported: []string{"openid"}, + wantErr: true, + }, + { + name: "omitted scope with permissive ScopesSupported uses DefaultScopes", + docScope: "", + scopesSupported: []string{"openid", "profile", "email", "offline_access"}, + wantScopes: registration.DefaultScopes, + }, + { + name: "omitted scope with restrictive ScopesSupported requires explicit scope", + docScope: "", + scopesSupported: []string{"openid"}, + wantErr: true, + }, + { + name: "baseline unioned into scope set", + docScope: "openid", + scopesSupported: []string{"openid", "offline_access"}, + baseline: []string{"offline_access"}, + wantScopes: []string{"openid", "offline_access"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + scope := tt.docScope + srv := serveCIMDDocWithFields(t, func(doc *cimd.ClientMetadataDocument) { + doc.Scope = scope + }) + got, err := NewCIMDStorageDecorator(newTestBase(t), CIMDDecoratorConfig{ + Enabled: true, + CacheMaxSize: 10, + FallbackTTL: time.Minute, + ScopesSupported: tt.scopesSupported, + BaselineClientScopes: tt.baseline, + }) + require.NoError(t, err) + dec := got.(*CIMDStorageDecorator) + + client, err := dec.fetchOrCached(context.Background(), srv.URL+"/meta.json") + if tt.wantErr { + require.Error(t, err) + assert.ErrorIs(t, err, fosite.ErrInvalidClient) + assert.NotErrorIs(t, err, fosite.ErrNotFound) + return + } + require.NoError(t, err) + assert.ElementsMatch(t, tt.wantScopes, []string(client.GetScopes())) + }) + } +} + +// TestBuildFositeClient_ScopeDefaultsToDefaultScopesWhenNoScopesSupported verifies the +// fallback branch in buildFositeClient: nil resolvedScopes → DefaultScopes. +func TestBuildFositeClient_ScopeDefaultsToDefaultScopesWhenNoScopesSupported(t *testing.T) { + t.Parallel() + doc := &cimd.ClientMetadataDocument{ + ClientID: "https://example.com/meta.json", + RedirectURIs: []string{"https://example.com/callback"}, + } + got := buildFositeClient(doc, nil) + assert.ElementsMatch(t, registration.DefaultScopes, []string(got.GetScopes())) }