From 5f5c7afeeec7db3a9dbab65763c56b283001a710 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Thu, 2 Apr 2026 18:05:04 +0000 Subject: [PATCH 1/2] extauth: refactor oidc tests --- auth/extauth/oidc_login_test.go | 539 +++++++++++--------------------- oauthex/client_test.go | 88 ++++++ 2 files changed, 264 insertions(+), 363 deletions(-) create mode 100644 oauthex/client_test.go diff --git a/auth/extauth/oidc_login_test.go b/auth/extauth/oidc_login_test.go index eb5a3ebf..4c63ff6f 100644 --- a/auth/extauth/oidc_login_test.go +++ b/auth/extauth/oidc_login_test.go @@ -20,310 +20,243 @@ import ( "github.com/modelcontextprotocol/go-sdk/oauthex" ) -// TestInitiateOIDCLogin tests the OIDC authorization request generation. -func TestInitiateOIDCLogin(t *testing.T) { - // Create mock IdP server +func TestPerformOIDCLogin(t *testing.T) { idpServer := createMockOIDCServer(t) defer idpServer.Close() - config := &OIDCLoginConfig{ + + validConfig := &OIDCLoginConfig{ IssuerURL: idpServer.URL, Credentials: &oauthex.ClientCredentials{ ClientID: "test-client", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "test-secret", + }, }, RedirectURL: "http://localhost:8080/callback", Scopes: []string{"openid", "profile", "email"}, HTTPClient: idpServer.Client(), } - t.Run("successful initiation", func(t *testing.T) { - authReq, _, err := initiateOIDCLogin(context.Background(), config) - if err != nil { - t.Fatalf("initiateOIDCLogin failed: %v", err) - } - // Validate authURL - if authReq.authURL == "" { - t.Error("authURL is empty") - } - // Parse and validate URL parameters - u, err := url.Parse(authReq.authURL) + + t.Run("successful flow", func(t *testing.T) { + token, err := PerformOIDCLogin(context.Background(), validConfig, + func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + u, err := url.Parse(args.URL) + if err != nil { + return nil, fmt.Errorf("invalid authURL: %w", err) + } + q := u.Query() + + if got := q.Get("response_type"); got != "code" { + t.Errorf("response_type = %q, want %q", got, "code") + } + if got := q.Get("client_id"); got != "test-client" { + t.Errorf("client_id = %q, want %q", got, "test-client") + } + if got := q.Get("redirect_uri"); got != "http://localhost:8080/callback" { + t.Errorf("redirect_uri = %q, want %q", got, "http://localhost:8080/callback") + } + if got := q.Get("scope"); got != "openid profile email" { + t.Errorf("scope = %q, want %q", got, "openid profile email") + } + if got := q.Get("code_challenge_method"); got != "S256" { + t.Errorf("code_challenge_method = %q, want %q", got, "S256") + } + if q.Get("code_challenge") == "" { + t.Error("code_challenge is empty") + } + if q.Get("state") == "" { + t.Error("state is empty") + } + + return &auth.AuthorizationResult{ + Code: "mock-auth-code", + State: q.Get("state"), + }, nil + }) + if err != nil { - t.Fatalf("Failed to parse authURL: %v", err) - } - q := u.Query() - if q.Get("response_type") != "code" { - t.Errorf("expected response_type 'code', got '%s'", q.Get("response_type")) - } - if q.Get("client_id") != "test-client" { - t.Errorf("expected client_id 'test-client', got '%s'", q.Get("client_id")) - } - if q.Get("redirect_uri") != "http://localhost:8080/callback" { - t.Errorf("expected redirect_uri 'http://localhost:8080/callback', got '%s'", q.Get("redirect_uri")) + t.Fatalf("PerformOIDCLogin() error = %v", err) } - if q.Get("scope") != "openid profile email" { - t.Errorf("expected scope 'openid profile email', got '%s'", q.Get("scope")) - } - if q.Get("code_challenge_method") != "S256" { - t.Errorf("expected code_challenge_method 'S256', got '%s'", q.Get("code_challenge_method")) - } - // Validate state is generated - if authReq.state == "" { - t.Error("state is empty") - } - if q.Get("state") != authReq.state { - t.Errorf("state in URL doesn't match returned state") - } - // Validate PKCE parameters - if authReq.codeVerifier == "" { - t.Error("codeVerifier is empty") + + idToken, ok := token.Extra("id_token").(string) + if !ok || idToken == "" { + t.Error("id_token is missing or empty") } - if q.Get("code_challenge") == "" { - t.Error("code_challenge is empty") + if token.AccessToken == "" { + t.Error("AccessToken is empty") } }) + t.Run("with login_hint", func(t *testing.T) { - configWithHint := *config + configWithHint := *validConfig configWithHint.LoginHint = "user@example.com" - authReq, _, err := initiateOIDCLogin(context.Background(), &configWithHint) + + _, err := PerformOIDCLogin(context.Background(), &configWithHint, + func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + u, err := url.Parse(args.URL) + if err != nil { + return nil, fmt.Errorf("invalid authURL: %w", err) + } + if got := u.Query().Get("login_hint"); got != "user@example.com" { + t.Errorf("login_hint = %q, want %q", got, "user@example.com") + } + return &auth.AuthorizationResult{ + Code: "mock-auth-code", + State: u.Query().Get("state"), + }, nil + }) if err != nil { - t.Fatalf("initiateOIDCLogin failed: %v", err) + t.Fatalf("PerformOIDCLogin() error = %v", err) } - u, err := url.Parse(authReq.authURL) + }) + + t.Run("without login_hint", func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), validConfig, + func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + u, err := url.Parse(args.URL) + if err != nil { + return nil, fmt.Errorf("invalid authURL: %w", err) + } + if u.Query().Has("login_hint") { + t.Errorf("login_hint should be absent, got %q", u.Query().Get("login_hint")) + } + return &auth.AuthorizationResult{ + Code: "mock-auth-code", + State: u.Query().Get("state"), + }, nil + }) if err != nil { - t.Fatalf("Failed to parse authURL: %v", err) + t.Fatalf("PerformOIDCLogin() error = %v", err) } - q := u.Query() - if q.Get("login_hint") != "user@example.com" { - t.Errorf("expected login_hint 'user@example.com', got '%s'", q.Get("login_hint")) + }) + + t.Run("state mismatch", func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), validConfig, + func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + return &auth.AuthorizationResult{ + Code: "mock-auth-code", + State: "wrong-state", + }, nil + }) + + if err == nil { + t.Fatal("expected error for state mismatch, got nil") + } + if !strings.Contains(err.Error(), "state mismatch") { + t.Errorf("error = %v, want error containing %q", err, "state mismatch") } }) - t.Run("without login_hint", func(t *testing.T) { - authReq, _, err := initiateOIDCLogin(context.Background(), config) - if err != nil { - t.Fatalf("initiateOIDCLogin failed: %v", err) + + t.Run("fetcher error", func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), validConfig, + func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + return nil, fmt.Errorf("user cancelled") + }) + + if err == nil { + t.Fatal("expected error, got nil") } - u, err := url.Parse(authReq.authURL) - if err != nil { - t.Fatalf("Failed to parse authURL: %v", err) + if !strings.Contains(err.Error(), "user cancelled") { + t.Errorf("error = %v, want error containing %q", err, "user cancelled") + } + }) + + t.Run("nil fetcher", func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), validConfig, nil) + if err == nil { + t.Fatal("expected error for nil fetcher, got nil") } - q := u.Query() - if q.Has("login_hint") { - t.Errorf("expected no login_hint parameter, but got '%s'", q.Get("login_hint")) + if !strings.Contains(err.Error(), "authCodeFetcher is required") { + t.Errorf("error = %v, want error containing %q", err, "authCodeFetcher is required") } }) + t.Run("nil config", func(t *testing.T) { - _, _, err := initiateOIDCLogin(context.Background(), nil) + noopFetcher := func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + return nil, fmt.Errorf("should not be called") + } + _, err := PerformOIDCLogin(context.Background(), nil, noopFetcher) if err == nil { - t.Error("expected error for nil config, got nil") + t.Fatal("expected error for nil config, got nil") + } + if !strings.Contains(err.Error(), "config is required") { + t.Errorf("error = %v, want error containing %q", err, "config is required") } }) + t.Run("missing openid scope", func(t *testing.T) { - badConfig := *config - badConfig.Scopes = []string{"profile", "email"} // Missing "openid" - _, _, err := initiateOIDCLogin(context.Background(), &badConfig) + badConfig := *validConfig + badConfig.Scopes = []string{"profile", "email"} + noopFetcher := func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + return nil, fmt.Errorf("should not be called") + } + _, err := PerformOIDCLogin(context.Background(), &badConfig, noopFetcher) if err == nil { - t.Error("expected error for missing openid scope, got nil") + t.Fatal("expected error for missing openid scope, got nil") } if !strings.Contains(err.Error(), "openid") { - t.Errorf("expected error about missing 'openid', got: %v", err) + t.Errorf("error = %v, want error containing %q", err, "openid") } }) + t.Run("missing required fields", func(t *testing.T) { + noopFetcher := func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + return nil, fmt.Errorf("should not be called") + } tests := []struct { - name string - mutate func(*OIDCLoginConfig) - expectErr string + name string + mutate func(*OIDCLoginConfig) + wantErr string }{ { - name: "missing IssuerURL", - mutate: func(c *OIDCLoginConfig) { c.IssuerURL = "" }, - expectErr: "IssuerURL is required", + name: "missing IssuerURL", + mutate: func(c *OIDCLoginConfig) { c.IssuerURL = "" }, + wantErr: "IssuerURL is required", }, { - name: "missing ClientID", - mutate: func(c *OIDCLoginConfig) { c.Credentials.ClientID = "" }, - expectErr: "ClientID is required", + name: "missing ClientID", + mutate: func(c *OIDCLoginConfig) { c.Credentials = &oauthex.ClientCredentials{} }, + wantErr: "ClientID is required", }, { name: "missing RedirectURL", mutate: func(c *OIDCLoginConfig) { c.RedirectURL = "" - // Ensure ClientID is present to test RedirectURL validation - c.Credentials = &oauthex.ClientCredentials{ClientID: "test"} }, - expectErr: "RedirectURL is required", + wantErr: "RedirectURL is required", }, { name: "missing Scopes", mutate: func(c *OIDCLoginConfig) { c.Scopes = nil - // Ensure required fields are present to test Scopes validation - c.Credentials = &oauthex.ClientCredentials{ClientID: "test"} - c.RedirectURL = "http://localhost:8080/callback" }, - expectErr: "at least one scope is required", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - badConfig := *config - tt.mutate(&badConfig) - _, _, err := initiateOIDCLogin(context.Background(), &badConfig) - if err == nil { - t.Error("expected error, got nil") - } - if !strings.Contains(err.Error(), tt.expectErr) { - t.Errorf("expected error containing '%s', got: %v", tt.expectErr, err) - } - }) - } - }) -} - -// TestCompleteOIDCLogin tests the authorization code exchange. -func TestCompleteOIDCLogin(t *testing.T) { - // Create mock IdP server - idpServer := createMockOIDCServerWithToken(t) - defer idpServer.Close() - config := &OIDCLoginConfig{ - IssuerURL: idpServer.URL, - Credentials: &oauthex.ClientCredentials{ - ClientID: "test-client", - ClientSecretAuth: &oauthex.ClientSecretAuth{ - ClientSecret: "test-secret", - }, - }, - RedirectURL: "http://localhost:8080/callback", - Scopes: []string{"openid", "profile", "email"}, - HTTPClient: idpServer.Client(), - } - t.Run("successful code exchange", func(t *testing.T) { - // First initiate to get oauth2Config - _, oauth2Config, err := initiateOIDCLogin(context.Background(), config) - if err != nil { - t.Fatalf("initiateOIDCLogin failed: %v", err) - } - - token, err := completeOIDCLogin( - context.Background(), - config, - oauth2Config, - "test-auth-code", - "test-code-verifier", - ) - if err != nil { - t.Fatalf("completeOIDCLogin failed: %v", err) - } - // Validate tokens - idToken, ok := token.Extra("id_token").(string) - if !ok || idToken == "" { - t.Error("id_token is missing or empty") - } - if token.AccessToken == "" { - t.Error("AccessToken is empty") - } - if token.TokenType != "Bearer" { - t.Errorf("expected TokenType 'Bearer', got '%s'", token.TokenType) - } - if token.Expiry.IsZero() { - t.Error("Expiry is zero") - } - }) - t.Run("missing parameters", func(t *testing.T) { - _, oauth2Config, _ := initiateOIDCLogin(context.Background(), config) - - tests := []struct { - name string - authCode string - codeVerifier string - expectErr string - }{ - { - name: "missing authCode", - authCode: "", - codeVerifier: "test-verifier", - expectErr: "authCode is required", - }, - { - name: "missing codeVerifier", - authCode: "test-code", - codeVerifier: "", - expectErr: "codeVerifier is required", + wantErr: "at least one scope is required", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := completeOIDCLogin( - context.Background(), - config, - oauth2Config, - tt.authCode, - tt.codeVerifier, - ) + cfg := *validConfig + tt.mutate(&cfg) + _, err := PerformOIDCLogin(context.Background(), &cfg, noopFetcher) if err == nil { - t.Error("expected error, got nil") + t.Fatalf("expected error containing %q, got nil", tt.wantErr) } - if !strings.Contains(err.Error(), tt.expectErr) { - t.Errorf("expected error containing '%s', got: %v", tt.expectErr, err) + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error = %v, want error containing %q", err, tt.wantErr) } }) } }) } -// TestOIDCLoginE2E tests the complete OIDC login flow end-to-end. -func TestOIDCLoginE2E(t *testing.T) { - // Create mock IdP server - idpServer := createMockOIDCServerWithToken(t) - defer idpServer.Close() - config := &OIDCLoginConfig{ - IssuerURL: idpServer.URL, - Credentials: &oauthex.ClientCredentials{ - ClientID: "test-client", - ClientSecretAuth: &oauthex.ClientSecretAuth{ - ClientSecret: "test-secret", - }, - }, - RedirectURL: "http://localhost:8080/callback", - Scopes: []string{"openid", "profile", "email"}, - HTTPClient: idpServer.Client(), - } - // Step 1: Initiate login - authReq, oauth2Config, err := initiateOIDCLogin(context.Background(), config) - if err != nil { - t.Fatalf("initiateOIDCLogin failed: %v", err) - } - // Step 2: Simulate user authentication and redirect - // (In real flow, user would visit authReq.authURL and IdP would redirect back) - // Here we just use a mock authorization code - mockAuthCode := "mock-authorization-code" - // Step 3: Complete login with authorization code - token, err := completeOIDCLogin( - context.Background(), - config, - oauth2Config, - mockAuthCode, - authReq.codeVerifier, - ) - if err != nil { - t.Fatalf("completeOIDCLogin failed: %v", err) - } - // Validate we got an ID token - idToken, ok := token.Extra("id_token").(string) - if !ok || idToken == "" { - t.Error("Expected ID token, got empty or missing") - } - // Validate ID token is a JWT (has 3 parts) - parts := strings.Split(idToken, ".") - if len(parts) != 3 { - t.Errorf("Expected JWT with 3 parts, got %d parts", len(parts)) - } -} - -// createMockOIDCServer creates a mock OIDC server for testing initiateOIDCLogin. +// createMockOIDCServer creates a mock OIDC server that handles metadata +// discovery and token exchange. func createMockOIDCServer(t *testing.T) *httptest.Server { + t.Helper() var serverURL string server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Handle OIDC discovery - if r.URL.Path == "/.well-known/openid-configuration" { + switch r.URL.Path { + case "/.well-known/openid-configuration": w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]any{ "issuer": serverURL, @@ -334,44 +267,21 @@ func createMockOIDCServer(t *testing.T) *httptest.Server { "code_challenge_methods_supported": []string{"S256"}, "grant_types_supported": []string{"authorization_code"}, }) - return - } - http.NotFound(w, r) - })) - serverURL = server.URL - return server -} -// createMockOIDCServerWithToken creates a mock OIDC server that also handles token exchange. -func createMockOIDCServerWithToken(t *testing.T) *httptest.Server { - var serverURL string - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Handle OIDC discovery - if r.URL.Path == "/.well-known/openid-configuration" { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "issuer": serverURL, - "authorization_endpoint": serverURL + "/authorize", - "token_endpoint": serverURL + "/token", - "jwks_uri": serverURL + "/.well-known/jwks.json", - "response_types_supported": []string{"code"}, - "code_challenge_methods_supported": []string{"S256"}, - "grant_types_supported": []string{"authorization_code"}, - }) - return - } - // Handle token endpoint - if r.URL.Path == "/token" { + case "/token": if err := r.ParseForm(); err != nil { http.Error(w, "failed to parse form", http.StatusBadRequest) return } - // Validate grant type if r.FormValue("grant_type") != "authorization_code" { http.Error(w, "invalid grant_type", http.StatusBadRequest) return } - // Create mock ID token (JWT) + if r.FormValue("code_verifier") == "" { + http.Error(w, "missing code_verifier", http.StatusBadRequest) + return + } + now := time.Now().Unix() idToken := fmt.Sprintf("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.%s.mock-signature", base64EncodeClaims(map[string]any{ @@ -382,7 +292,6 @@ func createMockOIDCServerWithToken(t *testing.T) *httptest.Server { "iat": now, "email": "test@example.com", })) - // Return token response w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]any{ "access_token": "mock-access-token", @@ -391,9 +300,10 @@ func createMockOIDCServerWithToken(t *testing.T) *httptest.Server { "refresh_token": "mock-refresh-token", "id_token": idToken, }) - return + + default: + http.NotFound(w, r) } - http.NotFound(w, r) })) serverURL = server.URL return server @@ -404,100 +314,3 @@ func base64EncodeClaims(claims map[string]any) string { claimsJSON, _ := json.Marshal(claims) return base64.RawURLEncoding.EncodeToString(claimsJSON) } - -// TestPerformOIDCLogin tests the combined OIDC login flow with callback. -func TestPerformOIDCLogin(t *testing.T) { - // Create mock IdP server - idpServer := createMockOIDCServerWithToken(t) - defer idpServer.Close() - config := &OIDCLoginConfig{ - IssuerURL: idpServer.URL, - Credentials: &oauthex.ClientCredentials{ - ClientID: "test-client", - ClientSecretAuth: &oauthex.ClientSecretAuth{ - ClientSecret: "test-secret", - }, - }, - RedirectURL: "http://localhost:8080/callback", - Scopes: []string{"openid", "profile", "email"}, - HTTPClient: idpServer.Client(), - } - - t.Run("successful flow", func(t *testing.T) { - token, err := PerformOIDCLogin(context.Background(), config, - func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - // Validate authURL has required parameters - u, err := url.Parse(args.URL) - if err != nil { - return nil, fmt.Errorf("invalid authURL: %w", err) - } - q := u.Query() - if q.Get("response_type") != "code" { - return nil, fmt.Errorf("missing response_type") - } - if q.Get("state") == "" { - return nil, fmt.Errorf("missing state") - } - - // Simulate successful user authentication - return &auth.AuthorizationResult{ - Code: "mock-auth-code", - State: q.Get("state"), // Return the expected state from URL - }, nil - }) - - if err != nil { - t.Fatalf("PerformOIDCLogin failed: %v", err) - } - - idToken, ok := token.Extra("id_token").(string) - if !ok || idToken == "" { - t.Error("id_token is missing or empty") - } - if token.AccessToken == "" { - t.Error("AccessToken is empty") - } - }) - - t.Run("state mismatch", func(t *testing.T) { - _, err := PerformOIDCLogin(context.Background(), config, - func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - // Return wrong state to simulate CSRF attack - return &auth.AuthorizationResult{ - Code: "mock-auth-code", - State: "wrong-state", - }, nil - }) - - if err == nil { - t.Error("expected error for state mismatch, got nil") - } - if !strings.Contains(err.Error(), "state mismatch") { - t.Errorf("expected state mismatch error, got: %v", err) - } - }) - - t.Run("fetcher error", func(t *testing.T) { - _, err := PerformOIDCLogin(context.Background(), config, - func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - return nil, fmt.Errorf("user cancelled") - }) - - if err == nil { - t.Error("expected error, got nil") - } - if !strings.Contains(err.Error(), "user cancelled") { - t.Errorf("expected 'user cancelled' error, got: %v", err) - } - }) - - t.Run("nil fetcher", func(t *testing.T) { - _, err := PerformOIDCLogin(context.Background(), config, nil) - if err == nil { - t.Error("expected error for nil fetcher, got nil") - } - if !strings.Contains(err.Error(), "authCodeFetcher is required") { - t.Errorf("expected 'authCodeFetcher is required' error, got: %v", err) - } - }) -} diff --git a/oauthex/client_test.go b/oauthex/client_test.go new file mode 100644 index 00000000..b78e9c8b --- /dev/null +++ b/oauthex/client_test.go @@ -0,0 +1,88 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package oauthex + +import ( + "reflect" + "strings" + "testing" +) + +func TestClientCredentials_Validate(t *testing.T) { + tests := []struct { + name string + creds ClientCredentials + wantErr string + }{ + { + name: "valid public client", + creds: ClientCredentials{ClientID: "my-client"}, + }, + { + name: "valid confidential client", + creds: ClientCredentials{ + ClientID: "my-client", + ClientSecretAuth: &ClientSecretAuth{ + ClientSecret: "my-secret", + }, + }, + }, + { + name: "empty client ID", + creds: ClientCredentials{}, + wantErr: "ClientID is required", + }, + { + name: "empty secret in ClientSecretAuth", + creds: ClientCredentials{ + ClientID: "my-client", + ClientSecretAuth: &ClientSecretAuth{}, + }, + wantErr: "ClientSecret is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.creds.Validate() + if tt.wantErr == "" { + if err != nil { + t.Fatalf("Validate() unexpected error: %v", err) + } + return + } + if err == nil { + t.Fatalf("Validate() expected error containing %q, got nil", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Validate() error = %v, want error containing %q", err, tt.wantErr) + } + }) + } +} + +// TestClientCredentials_ValidateCoversAllAuthFields uses reflection to detect +// when new authentication method fields are added to ClientCredentials without +// updating Validate. If this test fails, update Validate() to handle the new +// field and increment knownAuthMethods. +func TestClientCredentials_ValidateCoversAllAuthFields(t *testing.T) { + const knownAuthMethods = 1 // ClientSecretAuth + + typ := reflect.TypeOf(ClientCredentials{}) + var pointerFields int + for i := range typ.NumField() { + f := typ.Field(i) + if f.Name == "ClientID" { + continue + } + if f.Type.Kind() != reflect.Ptr { + t.Errorf("field %q is %v, expected a pointer to an auth method struct", f.Name, f.Type.Kind()) + } + pointerFields++ + } + + if pointerFields != knownAuthMethods { + t.Fatalf("ClientCredentials has %d auth method fields but Validate only knows about %d -- update Validate() and this test", pointerFields, knownAuthMethods) + } +} From bfe36e218a2838bde7ae765050585ea34bca83f8 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Fri, 3 Apr 2026 09:06:32 +0000 Subject: [PATCH 2/2] Add fake IdP server and further refactor tests. --- auth/extauth/enterprise_handler_test.go | 432 +++++------------- auth/extauth/oidc_login_test.go | 369 +++++++-------- .../oauthtest/fake_authorization_server.go | 51 ++- internal/oauthtest/fake_idp_server.go | 258 +++++++++++ 4 files changed, 572 insertions(+), 538 deletions(-) create mode 100644 internal/oauthtest/fake_idp_server.go diff --git a/auth/extauth/enterprise_handler_test.go b/auth/extauth/enterprise_handler_test.go index c19b5a2f..86f04140 100644 --- a/auth/extauth/enterprise_handler_test.go +++ b/auth/extauth/enterprise_handler_test.go @@ -1,43 +1,37 @@ // Copyright 2026 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. package extauth import ( "context" - "encoding/json" "fmt" "net/http" "net/http/httptest" "strings" "testing" + "github.com/modelcontextprotocol/go-sdk/internal/oauthtest" "github.com/modelcontextprotocol/go-sdk/oauthex" "golang.org/x/oauth2" ) -// TestNewEnterpriseHandler_Validation tests validation in NewEnterpriseHandler. -func TestNewEnterpriseHandler_Validation(t *testing.T) { - validConfig := &EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ - ClientID: "idp_client_id", - }, +func validEnterpriseHandlerConfig() *EnterpriseHandlerConfig { + return &EnterpriseHandlerConfig{ + IdPIssuerURL: "https://idp.example.com", + IdPCredentials: &oauthex.ClientCredentials{ClientID: "idp_client_id"}, MCPAuthServerURL: "https://mcp-auth.example.com", MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ - ClientID: "mcp_client_id", - }, + MCPCredentials: &oauthex.ClientCredentials{ClientID: "mcp_client_id"}, IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { - token := &oauth2.Token{ - AccessToken: "mock_access_token", - TokenType: "Bearer", - } + token := &oauth2.Token{AccessToken: "mock", TokenType: "Bearer"} return token.WithExtra(map[string]any{"id_token": "mock_id_token"}), nil }, } +} +func TestNewEnterpriseHandler_Validation(t *testing.T) { tests := []struct { name string config *EnterpriseHandlerConfig @@ -50,151 +44,101 @@ func TestNewEnterpriseHandler_Validation(t *testing.T) { }, { name: "missing IdPIssuerURL", - config: &EnterpriseHandlerConfig{ - IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, - }, + config: func() *EnterpriseHandlerConfig { + c := validEnterpriseHandlerConfig() + c.IdPIssuerURL = "" + return c + }(), wantError: "IdPIssuerURL is required", }, { name: "nil IdPCredentials", - config: &EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: nil, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, - }, + config: func() *EnterpriseHandlerConfig { + c := validEnterpriseHandlerConfig() + c.IdPCredentials = nil + return c + }(), wantError: "IdPCredentials is required", }, { name: "invalid IdPCredentials - empty ClientID", - config: &EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ - ClientID: "", // Invalid - empty - }, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, - }, + config: func() *EnterpriseHandlerConfig { + c := validEnterpriseHandlerConfig() + c.IdPCredentials = &oauthex.ClientCredentials{ClientID: ""} + return c + }(), wantError: "invalid IdPCredentials", }, { name: "invalid IdPCredentials - empty ClientSecret in ClientSecretAuth", - config: &EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ - ClientID: "idp_client_id", - ClientSecretAuth: &oauthex.ClientSecretAuth{ - ClientSecret: "", // Invalid - empty secret when ClientSecretAuth is set - }, - }, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, - }, + config: func() *EnterpriseHandlerConfig { + c := validEnterpriseHandlerConfig() + c.IdPCredentials = &oauthex.ClientCredentials{ + ClientID: "idp_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ClientSecret: ""}, + } + return c + }(), wantError: "invalid IdPCredentials", }, { name: "missing MCPAuthServerURL", - config: &EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - MCPAuthServerURL: "", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, - }, + config: func() *EnterpriseHandlerConfig { + c := validEnterpriseHandlerConfig() + c.MCPAuthServerURL = "" + return c + }(), wantError: "MCPAuthServerURL is required", }, { name: "missing MCPResourceURI", - config: &EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "", - MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, - }, + config: func() *EnterpriseHandlerConfig { + c := validEnterpriseHandlerConfig() + c.MCPResourceURI = "" + return c + }(), wantError: "MCPResourceURI is required", }, { name: "nil MCPCredentials", - config: &EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: nil, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, - }, + config: func() *EnterpriseHandlerConfig { + c := validEnterpriseHandlerConfig() + c.MCPCredentials = nil + return c + }(), wantError: "MCPCredentials is required", }, { name: "invalid MCPCredentials - empty ClientID", - config: &EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ - ClientID: "", // Invalid - empty - }, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { return nil, nil }, - }, + config: func() *EnterpriseHandlerConfig { + c := validEnterpriseHandlerConfig() + c.MCPCredentials = &oauthex.ClientCredentials{ClientID: ""} + return c + }(), wantError: "invalid MCPCredentials", }, { name: "missing IDTokenFetcher", - config: &EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ClientID: "id"}, - IDTokenFetcher: nil, - }, + config: func() *EnterpriseHandlerConfig { + c := validEnterpriseHandlerConfig() + c.IDTokenFetcher = nil + return c + }(), wantError: "IDTokenFetcher is required", }, { - name: "valid config - public clients (no ClientSecretAuth)", - config: validConfig, + name: "valid public clients", + config: validEnterpriseHandlerConfig(), wantError: "", }, { - name: "valid config - confidential clients (with ClientSecretAuth)", - config: &EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ - ClientID: "idp_client_id", - ClientSecretAuth: &oauthex.ClientSecretAuth{ - ClientSecret: "idp_secret", - }, - }, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ - ClientID: "mcp_client_id", - ClientSecretAuth: &oauthex.ClientSecretAuth{ - ClientSecret: "mcp_secret", - }, - }, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { - token := &oauth2.Token{ - AccessToken: "mock_access_token", - TokenType: "Bearer", - } - return token.WithExtra(map[string]any{"id_token": "mock_id_token"}), nil - }, - }, + name: "valid confidential clients", + config: func() *EnterpriseHandlerConfig { + c := validEnterpriseHandlerConfig() + c.IdPCredentials.ClientSecretAuth = &oauthex.ClientSecretAuth{ClientSecret: "idp_secret"} + c.MCPCredentials.ClientSecretAuth = &oauthex.ClientSecretAuth{ClientSecret: "mcp_secret"} + return c + }(), wantError: "", }, } @@ -221,45 +165,47 @@ func TestNewEnterpriseHandler_Validation(t *testing.T) { } } -// TestEnterpriseHandler_Authorize_E2E tests the complete enterprise authorization flow. func TestEnterpriseHandler_Authorize_E2E(t *testing.T) { - // Set up IdP (Identity Provider) fake server with token exchange support - idpServer := setupIdPServer(t) + const idJAGToken = "id-jag-token-from-idp" - // Set up MCP authorization server with JWT bearer grant support - mcpAuthServer := setupMCPAuthServer(t) - - // Create enterprise handler - handler, err := NewEnterpriseHandler(&EnterpriseHandlerConfig{ - IdPIssuerURL: idpServer.URL, - IdPCredentials: &oauthex.ClientCredentials{ - ClientID: "idp_client_id", - ClientSecretAuth: &oauthex.ClientSecretAuth{ - ClientSecret: "idp_secret", - }, + idpServer := oauthtest.NewFakeIdPServer(oauthtest.IdPConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "idp_client_id": {Secret: "idp_secret"}, }, - MCPAuthServerURL: mcpAuthServer.URL, - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ - ClientID: "mcp_client_id", - ClientSecretAuth: &oauthex.ClientSecretAuth{ - ClientSecret: "mcp_secret", + TokenExchangeConfig: &oauthtest.TokenExchangeConfig{ + IDJAGToken: idJAGToken, + }, + }) + idpServer.Start(t) + + mcpAuthServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + RegistrationConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "mcp_client_id": {Secret: "mcp_secret"}, }, }, - MCPScopes: []string{"read", "write"}, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { - token := &oauth2.Token{ - AccessToken: "mock_access_token", - TokenType: "Bearer", - } - return token.WithExtra(map[string]any{"id_token": "mock_id_token_from_user_login"}), nil + JWTBearerConfig: &oauthtest.JWTBearerConfig{ + ValidAssertions: []string{idJAGToken}, }, }) + mcpAuthServer.Start(t) + + config := validEnterpriseHandlerConfig() + config.IdPIssuerURL = idpServer.URL() + config.IdPCredentials.ClientSecretAuth = &oauthex.ClientSecretAuth{ClientSecret: "idp_secret"} + config.MCPAuthServerURL = mcpAuthServer.URL() + config.MCPCredentials.ClientSecretAuth = &oauthex.ClientSecretAuth{ClientSecret: "mcp_secret"} + config.MCPScopes = []string{"read", "write"} + config.IDTokenFetcher = func(ctx context.Context) (*oauth2.Token, error) { + token := &oauth2.Token{AccessToken: "mock_access_token", TokenType: "Bearer"} + return token.WithExtra(map[string]any{"id_token": "mock_id_token_from_user_login"}), nil + } + + handler, err := NewEnterpriseHandler(config) if err != nil { t.Fatalf("NewEnterpriseHandler failed: %v", err) } - // Simulate a 401 response from MCP server req := httptest.NewRequest(http.MethodGet, "https://mcp.example.com/api", nil) resp := &http.Response{ StatusCode: http.StatusUnauthorized, @@ -268,12 +214,10 @@ func TestEnterpriseHandler_Authorize_E2E(t *testing.T) { Request: req, } - // Perform authorization if err := handler.Authorize(context.Background(), req, resp); err != nil { t.Fatalf("Authorize failed: %v", err) } - // Verify token source is set tokenSource, err := handler.TokenSource(context.Background()) if err != nil { t.Fatalf("TokenSource failed: %v", err) @@ -282,164 +226,22 @@ func TestEnterpriseHandler_Authorize_E2E(t *testing.T) { t.Fatal("expected token source to be set after authorization") } - // Verify we can get a token token, err := tokenSource.Token() if err != nil { t.Fatalf("Token() failed: %v", err) } - if token.AccessToken != "mcp_access_token_from_jwt_bearer" { - t.Errorf("unexpected access token: got %q, want %q", - token.AccessToken, "mcp_access_token_from_jwt_bearer") + if token.AccessToken != "test_access_token" { + t.Errorf("AccessToken = %q, want %q", token.AccessToken, "test_access_token") } } -// setupIdPServer creates a fake IdP server that supports token exchange. -func setupIdPServer(t *testing.T) *httptest.Server { - t.Helper() - mux := http.NewServeMux() - - var server *httptest.Server - - // OAuth/OIDC metadata endpoint - uses closure to get server URL - mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "issuer": server.URL, - "token_endpoint": server.URL + "/token", - "authorization_endpoint": server.URL + "/authorize", - "code_challenge_methods_supported": []string{"S256"}, - }) - }) - - // Token endpoint - supports token exchange (RFC 8693) - mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { - if err := r.ParseForm(); err != nil { - http.Error(w, "failed to parse form", http.StatusBadRequest) - return - } - - grantType := r.Form.Get("grant_type") - if grantType != oauthex.GrantTypeTokenExchange { - http.Error(w, fmt.Sprintf("unsupported grant_type: %s", grantType), http.StatusBadRequest) - return - } - - // Verify client authentication - clientID := r.Form.Get("client_id") - clientSecret := r.Form.Get("client_secret") - if clientID != "idp_client_id" || clientSecret != "idp_secret" { - http.Error(w, "invalid client credentials", http.StatusUnauthorized) - return - } - - // Verify token exchange parameters - if r.Form.Get("requested_token_type") != oauthex.TokenTypeIDJAG { - http.Error(w, "invalid requested_token_type", http.StatusBadRequest) - return - } - if r.Form.Get("subject_token_type") != oauthex.TokenTypeIDToken { - http.Error(w, "invalid subject_token_type", http.StatusBadRequest) - return - } - if r.Form.Get("subject_token") == "" { - http.Error(w, "missing subject_token", http.StatusBadRequest) - return - } - - // Return ID-JAG (Identity Assertion JWT Authorization Grant) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "access_token": "id-jag-token-from-idp", - "issued_token_type": oauthex.TokenTypeIDJAG, - "token_type": "N_A", - "expires_in": 300, - "scope": "read write", - }) - }) - - server = httptest.NewServer(mux) - t.Cleanup(server.Close) - - return server -} - -// setupMCPAuthServer creates a fake MCP authorization server that supports JWT bearer grant. -func setupMCPAuthServer(t *testing.T) *httptest.Server { - t.Helper() - mux := http.NewServeMux() - - var server *httptest.Server - - // OAuth metadata endpoint - uses closure to get server URL - mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "issuer": server.URL, - "token_endpoint": server.URL + "/token", - "code_challenge_methods_supported": []string{"S256"}, - }) - }) - - // Token endpoint - supports JWT bearer grant (RFC 7523) - mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { - if err := r.ParseForm(); err != nil { - http.Error(w, "failed to parse form", http.StatusBadRequest) - return - } - - grantType := r.Form.Get("grant_type") - if grantType != "urn:ietf:params:oauth:grant-type:jwt-bearer" { - http.Error(w, fmt.Sprintf("unsupported grant_type: %s", grantType), http.StatusBadRequest) - return - } - - // Verify client authentication - clientID := r.Form.Get("client_id") - clientSecret := r.Form.Get("client_secret") - if clientID != "mcp_client_id" || clientSecret != "mcp_secret" { - http.Error(w, "invalid client credentials", http.StatusUnauthorized) - return - } - - // Verify assertion (ID-JAG) - assertion := r.Form.Get("assertion") - if assertion != "id-jag-token-from-idp" { - http.Error(w, "invalid assertion", http.StatusBadRequest) - return - } - - // Return access token - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "access_token": "mcp_access_token_from_jwt_bearer", - "token_type": "Bearer", - "expires_in": 3600, - "scope": "read write", - }) - }) - - server = httptest.NewServer(mux) - t.Cleanup(server.Close) - - return server -} - -// TestEnterpriseHandler_Authorize_IDTokenFetcherError tests error handling when IDTokenFetcher fails. func TestEnterpriseHandler_Authorize_IDTokenFetcherError(t *testing.T) { - handler, err := NewEnterpriseHandler(&EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ - ClientID: "idp_client_id", - }, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ - ClientID: "mcp_client_id", - }, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { - return nil, fmt.Errorf("user cancelled login") - }, - }) + config := validEnterpriseHandlerConfig() + config.IDTokenFetcher = func(ctx context.Context) (*oauth2.Token, error) { + return nil, fmt.Errorf("user cancelled login") + } + + handler, err := NewEnterpriseHandler(config) if err != nil { t.Fatalf("NewEnterpriseHandler failed: %v", err) } @@ -457,30 +259,12 @@ func TestEnterpriseHandler_Authorize_IDTokenFetcherError(t *testing.T) { t.Fatal("expected error from Authorize, got nil") } if !strings.Contains(err.Error(), "failed to obtain ID token") { - t.Errorf("expected error about ID token, got: %v", err) + t.Errorf("error = %v, want error containing %q", err, "failed to obtain ID token") } } -// TestEnterpriseHandler_TokenSource_BeforeAuthorization tests TokenSource before authorization. func TestEnterpriseHandler_TokenSource_BeforeAuthorization(t *testing.T) { - handler, err := NewEnterpriseHandler(&EnterpriseHandlerConfig{ - IdPIssuerURL: "https://idp.example.com", - IdPCredentials: &oauthex.ClientCredentials{ - ClientID: "idp_client_id", - }, - MCPAuthServerURL: "https://mcp-auth.example.com", - MCPResourceURI: "https://mcp.example.com", - MCPCredentials: &oauthex.ClientCredentials{ - ClientID: "mcp_client_id", - }, - IDTokenFetcher: func(ctx context.Context) (*oauth2.Token, error) { - token := &oauth2.Token{ - AccessToken: "mock_access_token", - TokenType: "Bearer", - } - return token.WithExtra(map[string]any{"id_token": "mock_id_token"}), nil - }, - }) + handler, err := NewEnterpriseHandler(validEnterpriseHandlerConfig()) if err != nil { t.Fatalf("NewEnterpriseHandler failed: %v", err) } diff --git a/auth/extauth/oidc_login_test.go b/auth/extauth/oidc_login_test.go index 4c63ff6f..9cd9b935 100644 --- a/auth/extauth/oidc_login_test.go +++ b/auth/extauth/oidc_login_test.go @@ -1,97 +1,111 @@ // Copyright 2026 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. package extauth import ( "context" - "encoding/base64" - "encoding/json" "fmt" "net/http" - "net/http/httptest" "net/url" "strings" "testing" - "time" "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/internal/oauthtest" "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" ) -func TestPerformOIDCLogin(t *testing.T) { - idpServer := createMockOIDCServer(t) - defer idpServer.Close() +const testRedirectURL = "http://localhost:18927/callback" - validConfig := &OIDCLoginConfig{ - IssuerURL: idpServer.URL, +func validOIDCLoginConfig(issuerURL string) *OIDCLoginConfig { + return &OIDCLoginConfig{ + IssuerURL: issuerURL, Credentials: &oauthex.ClientCredentials{ ClientID: "test-client", ClientSecretAuth: &oauthex.ClientSecretAuth{ ClientSecret: "test-secret", }, }, - RedirectURL: "http://localhost:8080/callback", + RedirectURL: testRedirectURL, Scopes: []string{"openid", "profile", "email"}, - HTTPClient: idpServer.Client(), } +} - t.Run("successful flow", func(t *testing.T) { - token, err := PerformOIDCLogin(context.Background(), validConfig, - func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - u, err := url.Parse(args.URL) - if err != nil { - return nil, fmt.Errorf("invalid authURL: %w", err) - } - q := u.Query() - - if got := q.Get("response_type"); got != "code" { - t.Errorf("response_type = %q, want %q", got, "code") - } - if got := q.Get("client_id"); got != "test-client" { - t.Errorf("client_id = %q, want %q", got, "test-client") - } - if got := q.Get("redirect_uri"); got != "http://localhost:8080/callback" { - t.Errorf("redirect_uri = %q, want %q", got, "http://localhost:8080/callback") - } - if got := q.Get("scope"); got != "openid profile email" { - t.Errorf("scope = %q, want %q", got, "openid profile email") - } - if got := q.Get("code_challenge_method"); got != "S256" { - t.Errorf("code_challenge_method = %q, want %q", got, "S256") - } - if q.Get("code_challenge") == "" { - t.Error("code_challenge is empty") - } - if q.Get("state") == "" { - t.Error("state is empty") - } +func TestPerformOIDCLogin(t *testing.T) { + idpServer := oauthtest.NewFakeIdPServer(oauthtest.IdPConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "test-client": { + Secret: "test-secret", + RedirectURIs: []string{testRedirectURL}, + }, + }, + }) + idpServer.Start(t) - return &auth.AuthorizationResult{ - Code: "mock-auth-code", - State: q.Get("state"), - }, nil - }) + config := validOIDCLoginConfig(idpServer.URL()) + // fetchAuthCode visits the authorization URL on the fake IdP server, + // follows the redirect, and extracts the authorization code and state. + fetchAuthCode := func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Get(args.URL) if err != nil { - t.Fatalf("PerformOIDCLogin() error = %v", err) + return nil, fmt.Errorf("failed to visit auth URL: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusFound { + return nil, fmt.Errorf("expected redirect, got status %d", resp.StatusCode) + } + loc, err := resp.Location() + if err != nil { + return nil, fmt.Errorf("missing Location header: %w", err) } + return &auth.AuthorizationResult{ + Code: loc.Query().Get("code"), + State: loc.Query().Get("state"), + }, nil + } + verifyOIDCToken := func(token *oauth2.Token) error { idToken, ok := token.Extra("id_token").(string) if !ok || idToken == "" { - t.Error("id_token is missing or empty") + return fmt.Errorf("id_token is missing or empty") + } + parts := strings.Split(idToken, ".") + if len(parts) != 3 { + return fmt.Errorf("id_token is not a JWT (expected 3 parts, got %d)", len(parts)) } if token.AccessToken == "" { - t.Error("AccessToken is empty") + return fmt.Errorf("access token is empty") + } + if token.TokenType != "Bearer" { + return fmt.Errorf("token type = %q, want %q", token.TokenType, "Bearer") + } + return nil + } + + t.Run("successful flow", func(t *testing.T) { + token, err := PerformOIDCLogin(context.Background(), config, fetchAuthCode) + if err != nil { + t.Fatalf("PerformOIDCLogin() error = %v", err) + } + if err := verifyOIDCToken(token); err != nil { + t.Errorf("invalid token returned: %v", err) } }) t.Run("with login_hint", func(t *testing.T) { - configWithHint := *validConfig + configWithHint := *config configWithHint.LoginHint = "user@example.com" - _, err := PerformOIDCLogin(context.Background(), &configWithHint, + token, err := PerformOIDCLogin(context.Background(), &configWithHint, func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { u, err := url.Parse(args.URL) if err != nil { @@ -100,18 +114,18 @@ func TestPerformOIDCLogin(t *testing.T) { if got := u.Query().Get("login_hint"); got != "user@example.com" { t.Errorf("login_hint = %q, want %q", got, "user@example.com") } - return &auth.AuthorizationResult{ - Code: "mock-auth-code", - State: u.Query().Get("state"), - }, nil + return fetchAuthCode(ctx, args) }) if err != nil { t.Fatalf("PerformOIDCLogin() error = %v", err) } + if err := verifyOIDCToken(token); err != nil { + t.Errorf("invalid token returned: %v", err) + } }) t.Run("without login_hint", func(t *testing.T) { - _, err := PerformOIDCLogin(context.Background(), validConfig, + token, err := PerformOIDCLogin(context.Background(), config, func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { u, err := url.Parse(args.URL) if err != nil { @@ -120,18 +134,18 @@ func TestPerformOIDCLogin(t *testing.T) { if u.Query().Has("login_hint") { t.Errorf("login_hint should be absent, got %q", u.Query().Get("login_hint")) } - return &auth.AuthorizationResult{ - Code: "mock-auth-code", - State: u.Query().Get("state"), - }, nil + return fetchAuthCode(ctx, args) }) if err != nil { t.Fatalf("PerformOIDCLogin() error = %v", err) } + if err := verifyOIDCToken(token); err != nil { + t.Errorf("invalid token returned: %v", err) + } }) t.Run("state mismatch", func(t *testing.T) { - _, err := PerformOIDCLogin(context.Background(), validConfig, + _, err := PerformOIDCLogin(context.Background(), config, func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { return &auth.AuthorizationResult{ Code: "mock-auth-code", @@ -148,7 +162,7 @@ func TestPerformOIDCLogin(t *testing.T) { }) t.Run("fetcher error", func(t *testing.T) { - _, err := PerformOIDCLogin(context.Background(), validConfig, + _, err := PerformOIDCLogin(context.Background(), config, func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { return nil, fmt.Errorf("user cancelled") }) @@ -160,157 +174,90 @@ func TestPerformOIDCLogin(t *testing.T) { t.Errorf("error = %v, want error containing %q", err, "user cancelled") } }) +} - t.Run("nil fetcher", func(t *testing.T) { - _, err := PerformOIDCLogin(context.Background(), validConfig, nil) - if err == nil { - t.Fatal("expected error for nil fetcher, got nil") - } - if !strings.Contains(err.Error(), "authCodeFetcher is required") { - t.Errorf("error = %v, want error containing %q", err, "authCodeFetcher is required") - } - }) - - t.Run("nil config", func(t *testing.T) { - noopFetcher := func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - return nil, fmt.Errorf("should not be called") - } - _, err := PerformOIDCLogin(context.Background(), nil, noopFetcher) - if err == nil { - t.Fatal("expected error for nil config, got nil") - } - if !strings.Contains(err.Error(), "config is required") { - t.Errorf("error = %v, want error containing %q", err, "config is required") - } - }) - - t.Run("missing openid scope", func(t *testing.T) { - badConfig := *validConfig - badConfig.Scopes = []string{"profile", "email"} - noopFetcher := func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - return nil, fmt.Errorf("should not be called") - } - _, err := PerformOIDCLogin(context.Background(), &badConfig, noopFetcher) - if err == nil { - t.Fatal("expected error for missing openid scope, got nil") - } - if !strings.Contains(err.Error(), "openid") { - t.Errorf("error = %v, want error containing %q", err, "openid") - } - }) +func TestPerformOIDCLogin_ValidationErrors(t *testing.T) { + const nonexistentIssuer = "https://idp.example.com" - t.Run("missing required fields", func(t *testing.T) { - noopFetcher := func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - return nil, fmt.Errorf("should not be called") - } - tests := []struct { - name string - mutate func(*OIDCLoginConfig) - wantErr string - }{ - { - name: "missing IssuerURL", - mutate: func(c *OIDCLoginConfig) { c.IssuerURL = "" }, - wantErr: "IssuerURL is required", - }, - { - name: "missing ClientID", - mutate: func(c *OIDCLoginConfig) { c.Credentials = &oauthex.ClientCredentials{} }, - wantErr: "ClientID is required", - }, - { - name: "missing RedirectURL", - mutate: func(c *OIDCLoginConfig) { - c.RedirectURL = "" - }, - wantErr: "RedirectURL is required", - }, - { - name: "missing Scopes", - mutate: func(c *OIDCLoginConfig) { - c.Scopes = nil - }, - wantErr: "at least one scope is required", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cfg := *validConfig - tt.mutate(&cfg) - _, err := PerformOIDCLogin(context.Background(), &cfg, noopFetcher) - if err == nil { - t.Fatalf("expected error containing %q, got nil", tt.wantErr) - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Errorf("error = %v, want error containing %q", err, tt.wantErr) - } - }) - } - }) -} + noopFetcher := func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + return nil, fmt.Errorf("should not be called") + } -// createMockOIDCServer creates a mock OIDC server that handles metadata -// discovery and token exchange. -func createMockOIDCServer(t *testing.T) *httptest.Server { - t.Helper() - var serverURL string - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/.well-known/openid-configuration": - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "issuer": serverURL, - "authorization_endpoint": serverURL + "/authorize", - "token_endpoint": serverURL + "/token", - "jwks_uri": serverURL + "/.well-known/jwks.json", - "response_types_supported": []string{"code"}, - "code_challenge_methods_supported": []string{"S256"}, - "grant_types_supported": []string{"authorization_code"}, - }) + tests := []struct { + name string + config *OIDCLoginConfig + fetcher auth.AuthorizationCodeFetcher + wantErr string + }{ + { + name: "nil config", + config: nil, + fetcher: noopFetcher, + wantErr: "config is required", + }, + { + name: "nil fetcher", + config: validOIDCLoginConfig(nonexistentIssuer), + fetcher: nil, + wantErr: "authCodeFetcher is required", + }, + { + name: "missing IssuerURL", + config: validOIDCLoginConfig(""), + fetcher: noopFetcher, + wantErr: "IssuerURL is required", + }, + { + name: "missing ClientID", + config: func() *OIDCLoginConfig { + c := validOIDCLoginConfig(nonexistentIssuer) + c.Credentials = &oauthex.ClientCredentials{} + return c + }(), + fetcher: noopFetcher, + wantErr: "ClientID is required", + }, + { + name: "missing RedirectURL", + config: func() *OIDCLoginConfig { + c := validOIDCLoginConfig(nonexistentIssuer) + c.RedirectURL = "" + return c + }(), + fetcher: noopFetcher, + wantErr: "RedirectURL is required", + }, + { + name: "missing Scopes", + config: func() *OIDCLoginConfig { + c := validOIDCLoginConfig(nonexistentIssuer) + c.Scopes = nil + return c + }(), + fetcher: noopFetcher, + wantErr: "at least one scope is required", + }, + { + name: "missing openid scope", + config: func() *OIDCLoginConfig { + c := validOIDCLoginConfig(nonexistentIssuer) + c.Scopes = []string{"profile", "email"} + return c + }(), + fetcher: noopFetcher, + wantErr: "openid", + }, + } - case "/token": - if err := r.ParseForm(); err != nil { - http.Error(w, "failed to parse form", http.StatusBadRequest) - return + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), tt.config, tt.fetcher) + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErr) } - if r.FormValue("grant_type") != "authorization_code" { - http.Error(w, "invalid grant_type", http.StatusBadRequest) - return + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error = %v, want error containing %q", err, tt.wantErr) } - if r.FormValue("code_verifier") == "" { - http.Error(w, "missing code_verifier", http.StatusBadRequest) - return - } - - now := time.Now().Unix() - idToken := fmt.Sprintf("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.%s.mock-signature", - base64EncodeClaims(map[string]any{ - "iss": serverURL, - "sub": "test-user", - "aud": "test-client", - "exp": now + 3600, - "iat": now, - "email": "test@example.com", - })) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "access_token": "mock-access-token", - "token_type": "Bearer", - "expires_in": 3600, - "refresh_token": "mock-refresh-token", - "id_token": idToken, - }) - - default: - http.NotFound(w, r) - } - })) - serverURL = server.URL - return server -} - -// base64EncodeClaims encodes JWT claims for testing. -func base64EncodeClaims(claims map[string]any) string { - claimsJSON, _ := json.Marshal(claims) - return base64.RawURLEncoding.EncodeToString(claimsJSON) + }) + } } diff --git a/internal/oauthtest/fake_authorization_server.go b/internal/oauthtest/fake_authorization_server.go index 6fed39b4..44f937f6 100644 --- a/internal/oauthtest/fake_authorization_server.go +++ b/internal/oauthtest/fake_authorization_server.go @@ -49,6 +49,14 @@ type RegistrationConfig struct { DynamicClientRegistrationEnabled bool } +// JWTBearerConfig configures support for the JWT Bearer grant type (RFC 7523) +// on a [FakeAuthorizationServer]. +type JWTBearerConfig struct { + // ValidAssertions is the set of assertion values that are accepted. + // If empty, any non-empty assertion is accepted. + ValidAssertions []string +} + // Config holds configuration for FakeAuthorizationServer. type Config struct { // The optional path component of the issuer URL. @@ -59,6 +67,9 @@ type Config struct { MetadataEndpointConfig *MetadataEndpointConfig // Configuration for client registration. RegistrationConfig *RegistrationConfig + // JWTBearerConfig enables RFC 7523 JWT Bearer grant at the /token endpoint. + // If non-nil, the server accepts grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer. + JWTBearerConfig *JWTBearerConfig } // FakeAuthorizationServer is a fake OAuth 2.0 Authorization Server for testing. @@ -253,10 +264,19 @@ func (s *FakeAuthorizationServer) handleToken(w http.ResponseWriter, r *http.Req http.Error(w, err.Error(), http.StatusUnauthorized) return } - if r.Form.Get("grant_type") != "authorization_code" { - http.Error(w, "invalid grant_type", http.StatusBadRequest) - return + + grantType := r.Form.Get("grant_type") + switch grantType { + case "authorization_code": + s.handleAuthorizationCodeGrant(w, r) + case "urn:ietf:params:oauth:grant-type:jwt-bearer": + s.handleJWTBearerGrant(w, r) + default: + http.Error(w, fmt.Sprintf("unsupported grant_type: %s", grantType), http.StatusBadRequest) } +} + +func (s *FakeAuthorizationServer) handleAuthorizationCodeGrant(w http.ResponseWriter, r *http.Request) { code := r.Form.Get("code") if code == "" { http.Error(w, "missing code", http.StatusBadRequest) @@ -287,6 +307,31 @@ func (s *FakeAuthorizationServer) handleToken(w http.ResponseWriter, r *http.Req }) } +func (s *FakeAuthorizationServer) handleJWTBearerGrant(w http.ResponseWriter, r *http.Request) { + if s.config.JWTBearerConfig == nil { + http.Error(w, "JWT bearer grant not supported", http.StatusBadRequest) + return + } + assertion := r.Form.Get("assertion") + if assertion == "" { + http.Error(w, "missing assertion", http.StatusBadRequest) + return + } + if len(s.config.JWTBearerConfig.ValidAssertions) > 0 { + if !slices.Contains(s.config.JWTBearerConfig.ValidAssertions, assertion) { + http.Error(w, "invalid assertion", http.StatusBadRequest) + return + } + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test_access_token", + "token_type": "Bearer", + "expires_in": 3600, + }) +} + func (s *FakeAuthorizationServer) authenticateClient(r *http.Request) error { clientID, clientSecret, ok := r.BasicAuth() if !ok { diff --git a/internal/oauthtest/fake_idp_server.go b/internal/oauthtest/fake_idp_server.go new file mode 100644 index 00000000..b7057ffd --- /dev/null +++ b/internal/oauthtest/fake_idp_server.go @@ -0,0 +1,258 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package oauthtest + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "maps" + "net/http" + "net/http/httptest" + "slices" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// TokenExchangeConfig configures RFC 8693 token exchange support on a [FakeIdPServer]. +type TokenExchangeConfig struct { + // IDJAGToken is the ID-JAG value returned from token exchange. + // Defaults to "test-id-jag-token" if empty. + IDJAGToken string +} + +// IdPConfig holds configuration for [FakeIdPServer]. +type IdPConfig struct { + // PreregisteredClients maps client IDs to their info. + PreregisteredClients map[string]ClientInfo + + // TokenExchangeConfig enables RFC 8693 token exchange at the /token endpoint. + // If non-nil, the server accepts grant_type=urn:ietf:params:oauth:grant-type:token-exchange. + TokenExchangeConfig *TokenExchangeConfig +} + +// FakeIdPServer is a fake OIDC Identity Provider for testing. +// It supports: +// - OIDC discovery (/.well-known/openid-configuration) +// - Authorization Code Grant with PKCE +// - ID Token issuance (fake JWTs) +// - RFC 8693 Token Exchange (ID Token → ID-JAG), if configured +type FakeIdPServer struct { + server *httptest.Server + mux *http.ServeMux + config IdPConfig + clients map[string]ClientInfo + codes map[string]codeInfo +} + +// NewFakeIdPServer creates a new FakeIdPServer. +func NewFakeIdPServer(config IdPConfig) *FakeIdPServer { + s := &FakeIdPServer{ + mux: http.NewServeMux(), + config: config, + clients: make(map[string]ClientInfo), + codes: make(map[string]codeInfo), + } + maps.Copy(s.clients, config.PreregisteredClients) + + s.mux.HandleFunc("/.well-known/openid-configuration", s.handleMetadata) + s.mux.HandleFunc("/authorize", s.handleAuthorize) + s.mux.HandleFunc("/token", s.handleToken) + s.server = httptest.NewUnstartedServer(s.mux) + + return s +} + +// Start starts the HTTP server and registers a cleanup function on t. +func (s *FakeIdPServer) Start(t testing.TB) { + s.server.Start() + t.Cleanup(s.server.Close) +} + +// URL returns the base URL of the server (issuer). +func (s *FakeIdPServer) URL() string { + return s.server.URL +} + +func (s *FakeIdPServer) handleMetadata(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + grantTypes := []string{"authorization_code"} + if s.config.TokenExchangeConfig != nil { + grantTypes = append(grantTypes, oauthex.GrantTypeTokenExchange) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "issuer": s.URL(), + "authorization_endpoint": s.URL() + "/authorize", + "token_endpoint": s.URL() + "/token", + "jwks_uri": s.URL() + "/.well-known/jwks.json", + "response_types_supported": []string{"code"}, + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": grantTypes, + }) +} + +func (s *FakeIdPServer) handleAuthorize(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + clientID := r.URL.Query().Get("client_id") + clientInfo, ok := s.clients[clientID] + if !ok { + http.Error(w, "unknown client_id", http.StatusBadRequest) + return + } + redirectURI := r.URL.Query().Get("redirect_uri") + if redirectURI == "" { + http.Error(w, "missing redirect_uri", http.StatusBadRequest) + return + } + if !slices.Contains(clientInfo.RedirectURIs, redirectURI) { + http.Error(w, "invalid redirect_uri", http.StatusBadRequest) + return + } + codeChallenge := r.URL.Query().Get("code_challenge") + if codeChallenge == "" { + http.Error(w, "missing code_challenge", http.StatusBadRequest) + return + } + code := rand.Text() + s.codes[code] = codeInfo{CodeChallenge: codeChallenge} + state := r.URL.Query().Get("state") + redirectURL := fmt.Sprintf("%s?code=%s&state=%s", redirectURI, code, state) + http.Redirect(w, r, redirectURL, http.StatusFound) +} + +func (s *FakeIdPServer) handleToken(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + + if err := s.authenticateClient(r); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + grantType := r.Form.Get("grant_type") + switch grantType { + case "authorization_code": + s.handleAuthorizationCodeGrant(w, r) + case oauthex.GrantTypeTokenExchange: + s.handleTokenExchangeGrant(w, r) + default: + http.Error(w, fmt.Sprintf("unsupported grant_type: %s", grantType), http.StatusBadRequest) + } +} + +func (s *FakeIdPServer) handleAuthorizationCodeGrant(w http.ResponseWriter, r *http.Request) { + code := r.Form.Get("code") + if code == "" { + http.Error(w, "missing code", http.StatusBadRequest) + return + } + ci, ok := s.codes[code] + if !ok { + http.Error(w, "unknown authorization code", http.StatusBadRequest) + return + } + verifier := r.Form.Get("code_verifier") + if verifier == "" { + http.Error(w, "missing code_verifier", http.StatusBadRequest) + return + } + sha := sha256.Sum256([]byte(verifier)) + expectedChallenge := base64.RawURLEncoding.EncodeToString(sha[:]) + if expectedChallenge != ci.CodeChallenge { + http.Error(w, "PKCE verification failed", http.StatusBadRequest) + return + } + + clientID := r.Form.Get("client_id") + now := time.Now().Unix() + idToken := fakeIDToken(s.URL(), clientID, now) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test_access_token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "test_refresh_token", + "id_token": idToken, + }) +} + +func (s *FakeIdPServer) handleTokenExchangeGrant(w http.ResponseWriter, r *http.Request) { + if s.config.TokenExchangeConfig == nil { + http.Error(w, "token exchange not supported", http.StatusBadRequest) + return + } + if r.Form.Get("requested_token_type") != oauthex.TokenTypeIDJAG { + http.Error(w, "invalid requested_token_type", http.StatusBadRequest) + return + } + if r.Form.Get("subject_token_type") != oauthex.TokenTypeIDToken { + http.Error(w, "invalid subject_token_type", http.StatusBadRequest) + return + } + if r.Form.Get("subject_token") == "" { + http.Error(w, "missing subject_token", http.StatusBadRequest) + return + } + + idJAG := s.config.TokenExchangeConfig.IDJAGToken + if idJAG == "" { + idJAG = "test-id-jag-token" + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "access_token": idJAG, + "issued_token_type": oauthex.TokenTypeIDJAG, + "token_type": "N_A", + "expires_in": 300, + }) +} + +func (s *FakeIdPServer) authenticateClient(r *http.Request) error { + clientID := r.Form.Get("client_id") + clientSecret := r.Form.Get("client_secret") + clientInfo, ok := s.clients[clientID] + if !ok { + return fmt.Errorf("unknown client") + } + if clientInfo.Secret != clientSecret { + return fmt.Errorf("invalid client credentials") + } + return nil +} + +// fakeIDToken creates a fake JWT ID token for testing. +// The token has a valid structure but is not cryptographically signed. +func fakeIDToken(issuer, audience string, now int64) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) + claimsJSON, _ := json.Marshal(map[string]any{ + "iss": issuer, + "sub": "test-user", + "aud": audience, + "exp": now + 3600, + "iat": now, + "email": "test@example.com", + }) + claims := base64.RawURLEncoding.EncodeToString(claimsJSON) + return header + "." + claims + ".mock-signature" +}